Files
AIEC-RAG---/AIEC-RAG/retriver/langsmith/langsmith_retriever.py

580 lines
26 KiB
Python
Raw Permalink Normal View History

2025-09-25 10:33:37 +08:00
"""
集成LangSmith监控的迭代检索器
Token监控通过OneAPILLM自动处理
"""
import os
import time
import json
from datetime import datetime
from typing import Dict, Any, Optional
# LangSmith追踪支持
from langsmith import traceable
from retriver.langgraph.iterative_retriever import IterativeRetriever
from retriver.langgraph.graph_state import create_initial_state
class LangSmithIterativeRetriever(IterativeRetriever):
"""
集成LangSmith监控的迭代检索器
Token消耗通过OneAPILLM的_generate方法自动记录
"""
def __init__(
self,
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
model_name: Optional[str] = None,
embed_model_name: Optional[str] = None,
complexity_model_name: Optional[str] = None,
sufficiency_model_name: Optional[str] = None,
langsmith_project: Optional[str] = None,
skip_llm_generation: bool = False
):
# 配置LangSmith
self._setup_langsmith(langsmith_project)
# 调用父类初始化
super().__init__(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=model_name,
embed_model_name=embed_model_name,
complexity_model_name=complexity_model_name,
sufficiency_model_name=sufficiency_model_name,
skip_llm_generation=skip_llm_generation
)
print("[SEARCH] LangSmith监控已启用")
def _setup_langsmith(self, project_name: Optional[str] = None):
"""配置LangSmith环境"""
# 设置环境变量
if not os.getenv("LANGCHAIN_TRACING_V2"):
os.environ["LANGCHAIN_TRACING_V2"] = "true"
if not os.getenv("LANGCHAIN_ENDPOINT"):
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
# 设置项目名称
if project_name:
os.environ["LANGCHAIN_PROJECT"] = project_name
elif not os.getenv("LANGCHAIN_PROJECT"):
os.environ["LANGCHAIN_PROJECT"] = "hipporag-retriever"
# 检查API密钥
if not os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGCHAIN_API_KEY") == "your_langsmith_api_key_here":
print("[WARNING] 请设置正确的LANGCHAIN_API_KEY环境变量")
print("访问 https://smith.langchain.com 获取API密钥")
@traceable(name="Complete_Retrieval_Process")
def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]:
"""
带LangSmith追踪的完整检索过程
"""
print(f"[STARTING] 开始检索过程 (LangSmith追踪)")
print(f"[POINT] 项目: {os.getenv('LANGCHAIN_PROJECT', 'hipporag-retriever')}")
print(f"[SEARCH] 查询: {query}")
print(f"[BUG] 调试模式: {mode}")
start_time = time.time()
try:
# 创建初始状态
initial_state = create_initial_state(
original_query=query,
max_iterations=self.max_iterations,
debug_mode=mode
)
# 执行工作流 - LangChain会自动追踪所有LLM调用
# 配置LangSmith避免追踪大型数据结构
config = {
"recursion_limit": 50,
"metadata": {
"note": "大型PageRank数据仅保存本地不上传LangSmith"
},
"callbacks": [], # 禁用额外的回调避免数据泄漏
"run_name": f"Retrieval-{query[:20]}..."
}
result_state = self.workflow.invoke(initial_state, config=config)
# 构建结果
end_time = time.time()
total_time = end_time - start_time
# 构建完整结果(本地使用)
full_result = self._build_final_result(result_state, total_time)
# 清理结果状态中的大型数据结构以避免LangSmith传输
cleaned_state = self._clean_state_for_langsmith(result_state)
# 为LangSmith创建轻量级版本移除大型数据结构
langsmith_result = self._build_langsmith_safe_result(cleaned_state, total_time)
print(f"[OK] 检索完成 (耗时: {total_time:.2f}秒)")
print(f"[LINK] 在LangSmith中查看详细追踪信息")
# 返回轻量级版本给LangSmith但将完整结果保存到文件
self._save_full_result_to_file(full_result)
# 为了兼容langsmith_example.py返回包含真实文档数据的结果
# 但移除PageRank等大型数据以避免LangSmith传输问题
result_with_docs = {
"query": langsmith_result.get("query", ""),
"answer": langsmith_result.get("answer", ""),
"total_passages": langsmith_result.get("total_passages", 0),
"retrieval_path": langsmith_result.get("retrieval_path", "unknown"),
"iterations": langsmith_result.get("iterations", 0),
"is_sufficient": langsmith_result.get("is_sufficient", False),
"sub_queries": langsmith_result.get("sub_queries", [])[:3],
# 包含真实的文档数据为langsmith_example.py使用
"all_documents": result_state.get('all_documents', []),
"all_passages": result_state.get('all_passages', []),
"passage_sources": result_state.get('passage_sources', []),
# 其他完整信息
"query_complexity": result_state.get('query_complexity', {}),
"decomposed_sub_queries": result_state.get('decomposed_sub_queries', []),
"sufficiency_check": result_state.get('sufficiency_check', {}),
"debug_info": {
"total_time": total_time,
"langsmith_project": langsmith_result.get("debug_info", {}).get("langsmith_project", ""),
"note": "包含真实文档数据的完整版本"
}
}
return result_with_docs
except KeyboardInterrupt:
print(f"\n[WARNING] 检索被用户中断")
end_time = time.time()
return {
"query": query,
"answer": "检索被用户中断",
"error": "KeyboardInterrupt",
"total_time": end_time - start_time,
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": "KeyboardInterrupt", "total_time": end_time - start_time}
}
except Exception as e:
print(f"[ERROR] 检索过程出错: {e}")
end_time = time.time()
return {
"query": query,
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
"error": str(e),
"total_time": end_time - start_time,
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": str(e), "total_time": end_time - start_time}
}
def _build_final_result(self, final_state, total_time: float) -> Dict[str, Any]:
"""构建最终结果"""
# 获取最终的Token使用统计
final_token_info = self._get_token_usage_info()
return {
"query": final_state.get('original_query', ''),
"answer": final_state.get('final_answer', '') or "未能生成答案",
# 查询复杂度信息
"query_complexity": final_state.get('query_complexity', {}),
"is_complex_query": final_state.get('is_complex_query', False),
"retrieval_path": "complex_hipporag" if final_state.get('is_complex_query', False) else "simple_vector",
"iterations": final_state.get('current_iteration', 0),
"total_passages": len(final_state.get('all_passages', [])),
"sub_queries": final_state.get('sub_queries', []),
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
"initial_retrieval_details": final_state.get('initial_retrieval_details', {}),
"sufficiency_check": final_state.get('sufficiency_check', {}),
"current_sub_queries": final_state.get('current_sub_queries', []),
"is_sufficient": final_state.get('is_sufficient', False),
# 完整的文档和段落数据(本地文件专用)
"all_documents": final_state.get('all_documents', []),
"all_passages": final_state.get('all_passages', []),
"passage_sources": final_state.get('passage_sources', []),
# PageRank数据已移除状态存储改为本地文件存储避免LangSmith传输
"pagerank_data_available": final_state.get('pagerank_data_available', False),
"pagerank_summary": final_state.get('pagerank_summary', {}),
"concept_exploration_results": final_state.get('concept_exploration_results', {}),
"exploration_round": final_state.get('exploration_round', 0),
"debug_info": {
"total_time": total_time,
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
"langsmith_project": os.getenv('LANGCHAIN_PROJECT', 'hipporag-retriever'),
# Token使用统计
"token_usage_summary": final_token_info,
# 路径统计
"complexity_analysis": {
"is_complex": final_state.get('is_complex_query', False),
"complexity_level": final_state.get('query_complexity', {}).get('complexity_level', 'unknown'),
"confidence": final_state.get('query_complexity', {}).get('confidence', 0),
"reason": final_state.get('query_complexity', {}).get('reason', '')
},
# 调试模式信息
"debug_mode_analysis": {
"debug_mode": final_state.get('debug_mode', '0'),
"debug_override": final_state.get('query_complexity', {}).get('debug_override', {}),
"path_override_applied": bool(final_state.get('query_complexity', {}).get('debug_override', {}))
},
# 充分性检查历史
"sufficiency_analysis": {
"final_sufficiency": final_state.get('is_sufficient', False),
"sufficiency_check_details": final_state.get('sufficiency_check', {}),
"iteration_sufficiency_history": [
{
"iteration": item.get('iteration', 0),
"is_sufficient": item.get('is_sufficient', False),
"confidence": item.get('sufficiency_check', {}).get('confidence', 0),
"reason": item.get('sufficiency_check', {}).get('reason', '')
}
for item in final_state.get('iteration_history', [])
if 'sufficiency_check' in item
],
"sufficiency_progression": self._analyze_sufficiency_progression(final_state)
},
# 路由决策历史
"routing_analysis": {
"total_routing_decisions": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') in ['sufficiency_check', 'sub_query_generation', 'parallel_retrieval', 'collect_pagerank_scores']
]),
"sub_query_generation_count": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') == 'sub_query_generation'
]),
"parallel_retrieval_count": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') == 'parallel_retrieval'
]),
"pagerank_collection_count": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') == 'collect_pagerank_scores'
])
},
# 概念探索分析(新增)
"concept_exploration_analysis": {
"exploration_enabled": final_state.get('exploration_round', 0) > 0,
"exploration_rounds": final_state.get('exploration_round', 0),
"pagerank_nodes_analyzed": len(final_state.get('pagerank_summary', {}).get('all_nodes_sorted', [])),
"successful_branches_total": sum([
round_data.get('successful_branches', 0)
for round_key, round_data in final_state.get('concept_exploration_results', {}).items()
if round_key.startswith('round_')
]),
"total_branches_attempted": sum([
round_data.get('total_branches', 0)
for round_key, round_data in final_state.get('concept_exploration_results', {}).items()
if round_key.startswith('round_')
])
}
},
"all_passages": final_state.get('all_passages', []),
"all_documents": final_state.get('all_documents', []), # 添加文档列表
"iteration_history": final_state.get('iteration_history', [])
}
def _clean_state_for_langsmith(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""清理状态中的大型数据结构避免发送给LangSmith"""
cleaned_state = state.copy()
# PageRank数据已从状态中移除不需要清理
if cleaned_state.get('pagerank_data_available', False):
print(f"[?] PageRank数据已存储在本地未包含在状态中")
# 清理概念探索结果中的大型数据
if 'concept_exploration_results' in cleaned_state:
cleaned_exploration = {}
for key, value in cleaned_state['concept_exploration_results'].items():
if isinstance(value, dict):
# 只保留统计信息,移除具体的节点数据
cleaned_exploration[key] = {
'total_branches': value.get('total_branches', 0),
'successful_branches': value.get('successful_branches', 0),
'exploration_type': value.get('exploration_type', 'unknown')
}
else:
cleaned_exploration[key] = value
cleaned_state['concept_exploration_results'] = cleaned_exploration
print(f"[?] 已清理概念探索详细结果,只保留统计信息")
# 清理PageRank汇总中的详细节点数据
if 'pagerank_summary' in cleaned_state:
summary = cleaned_state['pagerank_summary']
if isinstance(summary, dict) and 'all_nodes_sorted' in summary:
nodes_count = len(summary.get('all_nodes_sorted', []))
cleaned_summary = {k: v for k, v in summary.items() if k != 'all_nodes_sorted'}
cleaned_summary['nodes_count'] = nodes_count
cleaned_state['pagerank_summary'] = cleaned_summary
print(f"[?] 已清理 {nodes_count} 个PageRank节点详情只保留统计")
return cleaned_state
def _analyze_sufficiency_progression(self, final_state) -> Dict[str, Any]:
"""分析充分性检查的进展"""
iteration_history = final_state.get('iteration_history', [])
sufficiency_checks = [
item for item in iteration_history
if 'sufficiency_check' in item
]
if not sufficiency_checks:
return {"status": "no_sufficiency_checks"}
# 分析进展模式
confidences = [sc.get('sufficiency_check', {}).get('confidence', 0) for sc in sufficiency_checks]
sufficiencies = [sc.get('is_sufficient', False) for sc in sufficiency_checks]
progression_pattern = "unknown"
if len(sufficiencies) >= 2:
if not sufficiencies[0] and sufficiencies[-1]:
progression_pattern = "improved_to_sufficient"
elif all(sufficiencies):
progression_pattern = "consistently_sufficient"
elif not any(sufficiencies):
progression_pattern = "consistently_insufficient"
else:
progression_pattern = "mixed"
return {
"total_checks": len(sufficiency_checks),
"confidence_progression": confidences,
"sufficiency_progression": sufficiencies,
"pattern": progression_pattern,
"final_confidence": confidences[-1] if confidences else 0,
"confidence_improvement": confidences[-1] - confidences[0] if len(confidences) >= 2 else 0
}
def _get_token_usage_info(self) -> Dict[str, Any]:
"""
获取当前的Token使用信息
"""
try:
# 尝试从不同的属性路径获取Token信息
debug_info = {}
# 检查self.llm是否存在
if hasattr(self, 'llm'):
debug_info["has_llm"] = True
# 检查oneapi_generator是否存在
if hasattr(self.llm, 'oneapi_generator'):
debug_info["has_generator"] = True
generator = self.llm.oneapi_generator
# 获取Token统计
last_usage = getattr(generator, 'last_token_usage', {})
total_usage = getattr(generator, 'total_token_usage', {})
model_name = getattr(generator, 'model_name', 'unknown')
debug_info.update({
"last_call": last_usage,
"total_usage": total_usage,
"model_name": model_name,
"has_last_usage": bool(last_usage),
"has_total_usage": bool(total_usage)
})
return debug_info
else:
debug_info["has_generator"] = False
else:
debug_info["has_llm"] = False
debug_info["error"] = "无法找到Token信息"
return debug_info
except Exception as e:
return {
"error": f"获取Token信息失败: {str(e)}",
"exception_type": type(e).__name__
}
@traceable(name="Simple_Retrieve")
def retrieve_simple(self, query: str, mode: str = "0") -> str:
"""简单检索接口"""
result = self.retrieve(query, mode)
return result.get('answer', '')
def _build_langsmith_safe_result(self, final_state, total_time: float) -> Dict[str, Any]:
"""构建LangSmith安全的结果移除大型数据结构以避免传输限制"""
final_token_info = self._get_token_usage_info()
return {
"query": final_state.get('original_query', ''),
"answer": final_state.get('final_answer', '') or "未能生成答案",
# 查询复杂度信息
"query_complexity": final_state.get('query_complexity', {}),
"is_complex_query": final_state.get('is_complex_query', False),
"retrieval_path": "complex_hipporag" if final_state.get('is_complex_query', False) else "simple_vector",
"iterations": final_state.get('current_iteration', 0),
"total_passages": len(final_state.get('all_passages', [])),
"sub_queries": final_state.get('sub_queries', []),
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
"sufficiency_check": final_state.get('sufficiency_check', {}),
"is_sufficient": final_state.get('is_sufficient', False),
# 添加文档和段落数据langsmith_example.py需要这些数据
# 注意这些数据仅供本地脚本使用不会发送到LangSmith web端
"all_documents": final_state.get('all_documents', []),
"all_passages": final_state.get('all_passages', []),
# 简化的统计信息,不包含大型数据结构
"pagerank_summary_stats": {
"data_available": final_state.get('pagerank_data_available', False),
"exploration_rounds": final_state.get('exploration_round', 0),
"has_concept_exploration": bool(final_state.get('concept_exploration_results', {}))
},
"debug_info": {
"total_time": total_time,
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
"langsmith_project": os.getenv('LANGCHAIN_PROJECT', 'hipporag-retriever'),
"token_usage_summary": final_token_info,
"complexity_analysis": {
"is_complex": final_state.get('is_complex_query', False),
"complexity_level": final_state.get('query_complexity', {}).get('complexity_level', 'unknown'),
"confidence": final_state.get('query_complexity', {}).get('confidence', 0),
"reason": final_state.get('query_complexity', {}).get('reason', '')
},
"final_sufficiency": final_state.get('is_sufficient', False),
"note": "完整结果已保存到本地文件此为LangSmith优化版本"
}
}
def _save_full_result_to_file(self, full_result: Dict[str, Any]):
"""将完整结果保存到本地文件"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"langsmith_full_{timestamp}.json"
# 创建json_langsmith目录如果不存在
json_dir = os.path.join(os.path.dirname(__file__), "json_langsmith")
os.makedirs(json_dir, exist_ok=True)
filepath = os.path.join(json_dir, filename)
# 序列化时需要特殊处理numpy数组和Document对象等
def json_serializer(obj):
# 处理langchain Document对象
if hasattr(obj, 'page_content') and hasattr(obj, 'metadata'):
return {
'page_content': obj.page_content,
'metadata': obj.metadata
}
# 处理numpy数组
elif hasattr(obj, 'tolist'):
return obj.tolist()
# 处理其他自定义对象
elif hasattr(obj, '__dict__'):
try:
return obj.__dict__
except:
return str(obj)
else:
return str(obj)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(full_result, f, ensure_ascii=False, indent=2, default=json_serializer)
print(f"[FOLDER] 完整结果已保存: {filename}")
except Exception as e:
print(f"[WARNING] 保存完整结果失败: {e}")
def create_langsmith_retriever(
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
langsmith_project: Optional[str] = None,
**kwargs
) -> LangSmithIterativeRetriever:
"""创建带LangSmith监控的迭代检索器"""
return LangSmithIterativeRetriever(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
langsmith_project=langsmith_project,
**kwargs
)
# 用于检查LangSmith连接状态的工具函数
def check_langsmith_connection() -> bool:
"""检查LangSmith连接状态带重试机制"""
import time
max_retries = 3
retry_delay = 2 # 秒
for attempt in range(max_retries):
try:
from langsmith import Client
api_key = os.getenv("LANGCHAIN_API_KEY")
if not api_key or api_key == "your_langsmith_api_key_here":
print("[ERROR] LangSmith API密钥未设置或无效")
return False
# 设置较短的超时时间
client = Client()
# 尝试获取项目列表来验证连接
projects = list(client.list_projects(limit=1))
print("[OK] LangSmith连接正常")
return True
except Exception as e:
error_msg = str(e)
# 检查是否为已知的服务不可用错误
if any(keyword in error_msg.lower() for keyword in ['503', 'service unavailable', 'server error', 's3', 'timeout', 'deadline exceeded']):
print(f"[WARNING] LangSmith服务暂时不可用 (尝试 {attempt + 1}/{max_retries}): {error_msg[:100]}...")
if attempt < max_retries - 1: # 还有重试次数
print(f"{retry_delay}秒后重试...")
time.sleep(retry_delay)
retry_delay *= 1.5 # 递增重试间隔
continue
else:
print("[ERROR] LangSmith服务持续不可用将在本地模式下运行")
return False
else:
print(f"[ERROR] LangSmith连接失败: {e}")
return False
return False