Files
AIEC-new/AIEC-RAG/prompt_loader.py
2025-10-17 09:31:28 +08:00

576 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
提示词加载器
支持从配置文件动态加载和管理提示词模板
"""
import yaml
import os
from pathlib import Path
from typing import Dict, Any, Optional
from langchain.prompts import PromptTemplate
class PromptLoader:
"""提示词加载器"""
# 默认提示词模板
DEFAULT_PROMPTS = {
"query_complexity_check": """
作为一个智能查询分析助手,请分析用户查询的复杂度,判断该查询是否需要生成多方面的多个子查询来回答。
用户查询:{query}
请根据以下标准判断查询复杂度:
【复杂查询Complex特征 - 优先判断】:
1. **多问句查询**:包含多个问号(?),涉及多个不同的问题或主题
2. **跨领域查询**:涉及多个不同领域或行业的知识(如金融+技术+风险管理等)
3. **复合问题**:一个查询中包含多个子查询,即使每个子查询本身较简单
4. **关系型查询**:询问实体间的关系、比较、关联等
5. **因果推理**:询问原因、结果、影响等
6. **综合分析**:需要综合多个信息源进行分析
7. **推理链查询**:需要通过知识图谱的路径推理才能回答
8. **列表型查询**:要求列举多个项目或要素的问题
【简单查询Simple特征】
1. **单一问句**:只包含一个问号,聚焦于单一主题
2. **单领域查询**:仅涉及一个明确的领域或概念
3. **直接定义查询**:询问单一概念的定义、特点、属性等
4. **单一实体信息查询**:询问某个具体事物的基本信息
5. **可以通过文档中的连续文本段落直接回答的单一问题**
请以JSON格式返回判断结果
{{
"is_complex": false, // true表示复杂查询false表示简单查询
"complexity_level": "simple", // "simple""complex"
"confidence": 0.9, // 置信度0-1之间
"reason": "这是一个复杂查询,需要生成多方面的个子查询来回答..."
}}
请确保返回有效的JSON格式。
""",
"sufficiency_check": """
作为一个智能问答助手,请判断仅从给定的信息是否已经足够回答用户的查询。
用户查询:{query}
已生成的子查询:{decomposed_sub_queries}
检索到的信息(包括原始查询和子查询的结果):
{passages}
事件之间的关系:
{event_triples}
请分析这些信息是否包含足够的内容来完整回答用户的查询。注意检索结果包含三部分:
1. 【事件信息】- 来自知识图谱的事件节点,包含事件描述和上下文
2. 【段落信息】- 来自文档的段落内容,包含详细的文本描述
3. 【事件关系】- 检索到的事件之间的直接关联关系,有助于理解事件间的逻辑联系
如果信息充分请返回JSON格式
{{
"is_sufficient": true,
"confidence": 0.9,
"reason": "事件信息、段落信息和事件关系包含了回答查询所需的关键内容..."
}}
如果信息不充分请返回JSON格式并详细说明缺失的信息
{{
"is_sufficient": false,
"confidence": 0.5,
"reason": "检索信息缺少某些关键内容具体包括1) 缺少XXX的详细描述2) 缺少XXX的具体实例3) 缺少XXX的应用场景等..."
}}
请确保返回有效的JSON格式。
""",
"query_decomposition": """
你是一个专业的查询分解助手。你的任务是将用户的复合查询分解为{num_sub_queries}个独立的、适合向量检索的子查询。
用户原始查询:{original_query}
【重要】向量检索特点:
- 向量检索通过语义相似度匹配,找到与查询最相关的文档段落
- 每个子查询应该聚焦一个明确的主题或概念
- 子查询应该是完整的、可独立检索的问题
【分解策略】:
1. **识别查询结构**
- 仔细查看查询中是否有多个问号()分隔的不同问题
- 查找连接词:"""以及""还有""另外"
- 查找标点符号:句号(。)、分号()等分隔符
2. **按主题分解**
- 如果查询包含多个独立主题,将其分解为独立的问题
- 每个子查询保持完整性,包含必要的上下文信息
【要求】:
1. [NOTE] 必须使用自然语言,类似人类会问的问题
2. [ERROR] 禁止使用SQL语句、代码或技术查询语法
3. [SEARCH] 严格按照查询的自然分割点进行分解
4. [?] 每个子查询必须是完整的自然语言问题
5. [BLOCKED] 绝不要添加"详细信息""相关内容""补充信息"等后缀
6. [INFO] 保持原始查询中的所有关键信息(时间、地点、对象等)
7. [TARGET] 确保子查询可以独立进行向量检索
请严格按照JSON格式返回自然语言查询
{{
"sub_queries": [自然语言子查询列表]
}}
""",
"sub_query_generation": """
基于用户的原始查询、已有的检索结果和充分性检查反馈,生成{num_sub_queries}个相关的子查询来获取缺失信息。
原始查询:{original_query}
之前生成的子查询:{previous_sub_queries}
已有检索结果(包含事件信息和段落信息):
{existing_passages}
已发现的事件关系:
{event_triples}
充分性检查反馈(信息不充分的原因):
{insufficiency_reason}
请根据充分性检查反馈中指出的缺失信息,生成{num_sub_queries}个具体的自然语言子查询来补充这些缺失内容。注意已有检索结果包含三部分:
1. 【事件信息】- 来自知识图谱的事件节点
2. 【段落信息】- 来自文档的段落内容
3. 【事件关系】- 已发现的事件之间的直接关联关系
【重要要求】:
1. [NOTE] 必须使用自然语言表达,类似人类会问的问题
2. [ERROR] 禁止使用SQL语句、代码或技术查询语法
3. [OK] 使用疑问句形式,如"什么是...""如何...""有哪些..."
4. [TARGET] 直接针对充分性检查中指出的缺失信息
5. [LINK] 与原始查询高度相关
6. [INFO] 查询要具体明确,能够获取到具体的信息
7. [BLOCKED] 避免与之前已生成的子查询重复
8. [TARGET] 确保每个子查询都能独立检索到有价值的信息
请以JSON格式返回自然语言查询
{{
"sub_queries": [自然语言子查询列表]
}}
""",
"simple_answer": """
基于检索到的信息,请为用户的查询提供一个{answer_style}的答案。
用户查询:{query}
检索到的相关信息:
{passages}
请基于这些信息回答用户的查询。注意检索结果包含:
1. 【事件信息】- 来自知识图谱的事件节点,提供事件相关的上下文
2. 【段落信息】- 来自文档的段落内容,提供详细的文本描述
要求:
1. 直接回答用户的查询
2. 严格基于提供的信息,不要编造内容
3. 综合利用事件信息和段落信息
4. 如果信息不足以完整回答,请明确说明
5. 答案风格:{answer_style_description}
答案:
""",
"edge_filter": """
你是一个知识图谱事实过滤专家,擅长根据问题相关性筛选事实。
关键要求:
1. 你只能从提供的输入列表中选择事实 - 绝对不能创建或生成新的事实
2. 你的输出必须是输入事实的严格子集
3. 只包含与回答问题直接相关的事实
4. 如果不确定,宁可选择更少的事实,也不要选择更多
5. 保持每个事实的准确格式:[主语, 关系, 宾语]
过滤规则:
- 只选择包含与问题直接相关的实体或关系的事实
- 不能修改、改写或创建输入事实的变体
- 不能添加看起来相关但不在输入中的事实
- 输出事实数量必须 ≤ 输入事实数量
返回格式为包含"fact"键的JSON对象值为选中的事实数组。
示例:
输入事实:[["A", "关系1", "B"], ["B", "关系2", "C"], ["D", "关系3", "E"]]
问题A和B是什么关系
正确输出:{{"fact": [["A", "关系1", "B"]]}}
错误做法:添加新事实或修改现有事实
问题:{question}
待筛选的输入事实:
{triples}
从上述输入中仅选择最相关的事实来回答问题。记住:只能是严格的子集!""",
"final_answer": """
基于所有检索到的信息,请为用户的查询提供一个{answer_style}的答案。
原始查询:{original_query}
生成的子查询:{sub_queries}
所有检索到的信息(包括原始查询和所有子查询的结果):
{all_passages}
发现的事件关系:
{event_triples}
请综合所有信息,提供一个完整的答案。注意检索结果包含三部分:
1. 【事件信息】- 来自知识图谱的事件节点,提供结构化的知识
2. 【段落信息】- 来自文档的段落内容,提供详细的描述
3. 【事件关系】- 发现的事件之间的直接关联关系,有助于理解逻辑联系
要求:
1. 全面回答用户的所有问题
2. 严格基于检索到的信息,不要编造内容
3. 综合利用所有事件信息、段落信息和事件关系
4. 逻辑清晰,条理分明
5. 答案风格:{answer_style_description}
{source_requirement}
答案:
"""
}
def __init__(self, config_path: str = "rag_config.yaml"):
"""
初始化提示词加载器
Args:
config_path: 配置文件路径
"""
# 优先使用环境变量指定的配置文件
env_config_path = os.environ.get('RAG_CONFIG_PATH')
if env_config_path:
config_path = env_config_path
# 总是从项目根目录prompt_loader.py所在目录查找配置文件
if not os.path.isabs(config_path):
# 获取本文件所在的目录(项目根目录)
project_root = Path(__file__).parent
self.config_path = project_root / config_path
else:
self.config_path = Path(config_path)
self.config = self._load_config()
self.prompts_cache = {}
def _load_config(self) -> Dict[str, Any]:
"""加载配置文件"""
if not self.config_path.exists():
print(f"[WARNING] 配置文件不存在: {self.config_path},使用默认提示词")
return {}
try:
with open(self.config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
except Exception as e:
print(f"[ERROR] 加载配置文件失败: {e},使用默认提示词")
return {}
def get_prompt_template(self, prompt_name: str, **kwargs) -> PromptTemplate:
"""
获取提示词模板
Args:
prompt_name: 提示词名称
**kwargs: 额外的格式化参数
Returns:
PromptTemplate对象
"""
# 检查缓存
cache_key = f"{prompt_name}_{hash(frozenset(kwargs.items()))}"
if cache_key in self.prompts_cache:
return self.prompts_cache[cache_key]
# 获取提示词文本
prompt_text = self._get_prompt_text(prompt_name, **kwargs)
# 提取输入变量
input_variables = self._extract_input_variables(prompt_text)
# 创建PromptTemplate
prompt_template = PromptTemplate(
input_variables=input_variables,
template=prompt_text
)
# 缓存
self.prompts_cache[cache_key] = prompt_template
return prompt_template
def should_skip_llm(self, prompt_name: str) -> bool:
"""
检查是否应该跳过LLM调用
Args:
prompt_name: 提示词名称
Returns:
True: 跳过LLM调用使用默认响应
False: 正常调用LLM
"""
if 'prompt_templates' in self.config:
prompt_config = self.config['prompt_templates'].get(prompt_name, {})
# 检查skip_llm配置
if prompt_config.get('skip_llm', False):
print(f"[WARNING] {prompt_name} 配置skip_llm=true跳过LLM调用")
return True
return False
def get_default_response(self, prompt_name: str) -> dict:
"""
获取跳过LLM时的默认响应
Args:
prompt_name: 提示词名称
Returns:
默认响应字典
"""
default_responses = {
'query_complexity_check': {
"is_complex": False, # 默认简单查询
"complexity_level": "simple",
"confidence": 1.0,
"reason": "LLM调用已跳过默认为简单查询"
},
'sufficiency_check': {
"is_sufficient": True, # 默认充分
"confidence": 1.0,
"reason": "LLM调用已跳过默认为信息充分"
},
'query_decomposition': {
"sub_queries": [] # 不分解,使用原查询
},
'sub_query_generation': {
"sub_queries": [] # 不生成新查询
},
'simple_answer': {
"answer": "基于检索到的信息直接返回"
},
'final_answer': {
"answer": "基于检索到的信息直接返回"
},
'edge_filter': {
"fact": [] # 跳过边过滤时,返回空列表(不过滤)
}
}
return default_responses.get(prompt_name, {})
def _get_prompt_text(self, prompt_name: str, **kwargs) -> str:
"""
获取提示词文本
Args:
prompt_name: 提示词名称
**kwargs: 格式化参数
Returns:
格式化后的提示词文本
"""
# 尝试从配置文件获取自定义提示词
if 'prompt_templates' in self.config:
prompt_config = self.config['prompt_templates'].get(prompt_name, {})
# 如果配置中有完整的模板文本,使用它
if 'template' in prompt_config:
prompt_text = prompt_config['template']
# 否则使用默认模板
else:
prompt_text = self.DEFAULT_PROMPTS.get(prompt_name, "")
# 应用配置参数
prompt_text = self._apply_config_params(prompt_name, prompt_text, prompt_config, **kwargs)
else:
# 使用默认提示词
prompt_text = self.DEFAULT_PROMPTS.get(prompt_name, "")
return prompt_text
def _apply_config_params(self, prompt_name: str, prompt_text: str,
prompt_config: Dict[str, Any], **kwargs) -> str:
"""
应用配置参数到提示词
Args:
prompt_name: 提示词名称
prompt_text: 原始提示词文本
prompt_config: 提示词配置
**kwargs: 额外参数
Returns:
应用参数后的提示词文本
"""
# 根据不同的提示词类型,应用不同的配置
if prompt_name == "query_decomposition":
# 应用子查询数量配置
num_sub_queries = prompt_config.get('default_sub_queries', 2)
prompt_text = prompt_text.replace("{num_sub_queries}", str(num_sub_queries))
elif prompt_name == "sub_query_generation":
# 应用子查询数量配置
queries_per_iteration = prompt_config.get('queries_per_iteration', 2)
prompt_text = prompt_text.replace("{num_sub_queries}", str(queries_per_iteration))
elif prompt_name == "simple_answer":
# 应用答案风格配置
answer_style = prompt_config.get('answer_style', 'concise')
answer_style_desc = self._get_answer_style_description(answer_style)
prompt_text = prompt_text.replace("{answer_style}", answer_style)
prompt_text = prompt_text.replace("{answer_style_description}", answer_style_desc)
elif prompt_name == "final_answer":
# 应用答案风格和来源引用配置
answer_style = prompt_config.get('answer_style', 'comprehensive')
answer_style_desc = self._get_answer_style_description(answer_style)
include_sources = prompt_config.get('include_sources', False)
prompt_text = prompt_text.replace("{answer_style}", answer_style)
prompt_text = prompt_text.replace("{answer_style_description}", answer_style_desc)
if include_sources:
source_req = "6. 在答案末尾列出信息来源段落ID或事件ID"
else:
source_req = ""
prompt_text = prompt_text.replace("{source_requirement}", source_req)
# 应用额外的kwargs参数
for key, value in kwargs.items():
placeholder = "{" + key + "}"
if placeholder in prompt_text:
prompt_text = prompt_text.replace(placeholder, str(value))
return prompt_text
def _get_answer_style_description(self, style: str) -> str:
"""
获取答案风格描述
Args:
style: 答案风格
Returns:
风格描述文本
"""
style_descriptions = {
"concise": "简洁明了,重点突出,避免冗余",
"detailed": "详细完整,包含所有相关信息",
"balanced": "平衡简洁与详细,保持适中的信息量",
"comprehensive": "全面综合,涵盖所有方面,结构清晰",
"structured": "结构化呈现,使用编号、分点等形式",
"summary": "摘要形式,提炼关键信息"
}
return style_descriptions.get(style, "清晰准确")
def _extract_input_variables(self, prompt_text: str) -> list:
"""
从提示词文本中提取输入变量
Args:
prompt_text: 提示词文本
Returns:
输入变量列表
"""
import re
# 查找所有 {variable_name} 格式的变量
pattern = r'\{([^}]+)\}'
matches = re.findall(pattern, prompt_text)
# 过滤掉JSON格式的大括号和已知的配置参数
config_params = ['num_sub_queries', 'answer_style', 'answer_style_description', 'source_requirement']
input_vars = []
for match in matches:
# 跳过包含空格、引号或JSON格式的内容
if ' ' in match or '"' in match or ':' in match or match in config_params:
continue
if match not in input_vars:
input_vars.append(match)
return input_vars
def update_prompt(self, prompt_name: str, prompt_text: str):
"""
动态更新提示词(运行时修改)
Args:
prompt_name: 提示词名称
prompt_text: 新的提示词文本
"""
if 'prompt_templates' not in self.config:
self.config['prompt_templates'] = {}
if prompt_name not in self.config['prompt_templates']:
self.config['prompt_templates'][prompt_name] = {}
self.config['prompt_templates'][prompt_name]['template'] = prompt_text
# 清除缓存
self.prompts_cache = {}
print(f"[OK] 已更新提示词: {prompt_name}")
def reload_config(self):
"""重新加载配置文件"""
self.config = self._load_config()
self.prompts_cache = {}
print("[OK] 已重新加载提示词配置")
def get_all_prompt_names(self) -> list:
"""获取所有可用的提示词名称"""
return list(self.DEFAULT_PROMPTS.keys())
def print_prompt_summary(self):
"""打印提示词配置摘要"""
print("\n" + "="*50)
print("[NOTE] 提示词配置摘要")
print("="*50)
if 'prompt_templates' in self.config:
for prompt_name, prompt_config in self.config['prompt_templates'].items():
if prompt_config.get('enabled', True):
print(f"\n[OK] {prompt_name}:")
for key, value in prompt_config.items():
if key not in ['enabled', 'template']:
print(f" - {key}: {value}")
else:
print(f"\n[ERROR] {prompt_name}: 已禁用")
else:
print("[WARNING] 使用默认提示词配置")
print("="*50)
# 全局提示词加载器实例
_global_prompt_loader = None
def get_prompt_loader() -> PromptLoader:
"""获取全局提示词加载器实例"""
global _global_prompt_loader
if _global_prompt_loader is None:
_global_prompt_loader = PromptLoader()
return _global_prompt_loader
def reload_prompts(config_path: str = "rag_config.yaml"):
"""重新加载提示词配置"""
global _global_prompt_loader
_global_prompt_loader = PromptLoader(config_path)
return _global_prompt_loader