Files
AIEC-RAG---/AIEC-RAG/retriver/langsmith/langsmith_retriever.py
2025-09-25 10:33:37 +08:00

580 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
集成LangSmith监控的迭代检索器
Token监控通过OneAPILLM自动处理
"""
import os
import time
import json
from datetime import datetime
from typing import Dict, Any, Optional
# LangSmith追踪支持
from langsmith import traceable
from retriver.langgraph.iterative_retriever import IterativeRetriever
from retriver.langgraph.graph_state import create_initial_state
class LangSmithIterativeRetriever(IterativeRetriever):
"""
集成LangSmith监控的迭代检索器
Token消耗通过OneAPILLM的_generate方法自动记录
"""
def __init__(
self,
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
model_name: Optional[str] = None,
embed_model_name: Optional[str] = None,
complexity_model_name: Optional[str] = None,
sufficiency_model_name: Optional[str] = None,
langsmith_project: Optional[str] = None,
skip_llm_generation: bool = False
):
# 配置LangSmith
self._setup_langsmith(langsmith_project)
# 调用父类初始化
super().__init__(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
oneapi_key=oneapi_key,
oneapi_base_url=oneapi_base_url,
model_name=model_name,
embed_model_name=embed_model_name,
complexity_model_name=complexity_model_name,
sufficiency_model_name=sufficiency_model_name,
skip_llm_generation=skip_llm_generation
)
print("[SEARCH] LangSmith监控已启用")
def _setup_langsmith(self, project_name: Optional[str] = None):
"""配置LangSmith环境"""
# 设置环境变量
if not os.getenv("LANGCHAIN_TRACING_V2"):
os.environ["LANGCHAIN_TRACING_V2"] = "true"
if not os.getenv("LANGCHAIN_ENDPOINT"):
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
# 设置项目名称
if project_name:
os.environ["LANGCHAIN_PROJECT"] = project_name
elif not os.getenv("LANGCHAIN_PROJECT"):
os.environ["LANGCHAIN_PROJECT"] = "hipporag-retriever"
# 检查API密钥
if not os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGCHAIN_API_KEY") == "your_langsmith_api_key_here":
print("[WARNING] 请设置正确的LANGCHAIN_API_KEY环境变量")
print("访问 https://smith.langchain.com 获取API密钥")
@traceable(name="Complete_Retrieval_Process")
def retrieve(self, query: str, mode: str = "0") -> Dict[str, Any]:
"""
带LangSmith追踪的完整检索过程
"""
print(f"[STARTING] 开始检索过程 (LangSmith追踪)")
print(f"[POINT] 项目: {os.getenv('LANGCHAIN_PROJECT', 'hipporag-retriever')}")
print(f"[SEARCH] 查询: {query}")
print(f"[BUG] 调试模式: {mode}")
start_time = time.time()
try:
# 创建初始状态
initial_state = create_initial_state(
original_query=query,
max_iterations=self.max_iterations,
debug_mode=mode
)
# 执行工作流 - LangChain会自动追踪所有LLM调用
# 配置LangSmith避免追踪大型数据结构
config = {
"recursion_limit": 50,
"metadata": {
"note": "大型PageRank数据仅保存本地不上传LangSmith"
},
"callbacks": [], # 禁用额外的回调避免数据泄漏
"run_name": f"Retrieval-{query[:20]}..."
}
result_state = self.workflow.invoke(initial_state, config=config)
# 构建结果
end_time = time.time()
total_time = end_time - start_time
# 构建完整结果(本地使用)
full_result = self._build_final_result(result_state, total_time)
# 清理结果状态中的大型数据结构以避免LangSmith传输
cleaned_state = self._clean_state_for_langsmith(result_state)
# 为LangSmith创建轻量级版本移除大型数据结构
langsmith_result = self._build_langsmith_safe_result(cleaned_state, total_time)
print(f"[OK] 检索完成 (耗时: {total_time:.2f}秒)")
print(f"[LINK] 在LangSmith中查看详细追踪信息")
# 返回轻量级版本给LangSmith但将完整结果保存到文件
self._save_full_result_to_file(full_result)
# 为了兼容langsmith_example.py返回包含真实文档数据的结果
# 但移除PageRank等大型数据以避免LangSmith传输问题
result_with_docs = {
"query": langsmith_result.get("query", ""),
"answer": langsmith_result.get("answer", ""),
"total_passages": langsmith_result.get("total_passages", 0),
"retrieval_path": langsmith_result.get("retrieval_path", "unknown"),
"iterations": langsmith_result.get("iterations", 0),
"is_sufficient": langsmith_result.get("is_sufficient", False),
"sub_queries": langsmith_result.get("sub_queries", [])[:3],
# 包含真实的文档数据为langsmith_example.py使用
"all_documents": result_state.get('all_documents', []),
"all_passages": result_state.get('all_passages', []),
"passage_sources": result_state.get('passage_sources', []),
# 其他完整信息
"query_complexity": result_state.get('query_complexity', {}),
"decomposed_sub_queries": result_state.get('decomposed_sub_queries', []),
"sufficiency_check": result_state.get('sufficiency_check', {}),
"debug_info": {
"total_time": total_time,
"langsmith_project": langsmith_result.get("debug_info", {}).get("langsmith_project", ""),
"note": "包含真实文档数据的完整版本"
}
}
return result_with_docs
except KeyboardInterrupt:
print(f"\n[WARNING] 检索被用户中断")
end_time = time.time()
return {
"query": query,
"answer": "检索被用户中断",
"error": "KeyboardInterrupt",
"total_time": end_time - start_time,
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": "KeyboardInterrupt", "total_time": end_time - start_time}
}
except Exception as e:
print(f"[ERROR] 检索过程出错: {e}")
end_time = time.time()
return {
"query": query,
"answer": f"抱歉,检索过程中遇到错误: {str(e)}",
"error": str(e),
"total_time": end_time - start_time,
"iterations": 0,
"total_passages": 0,
"sub_queries": [],
"debug_info": {"error": str(e), "total_time": end_time - start_time}
}
def _build_final_result(self, final_state, total_time: float) -> Dict[str, Any]:
"""构建最终结果"""
# 获取最终的Token使用统计
final_token_info = self._get_token_usage_info()
return {
"query": final_state.get('original_query', ''),
"answer": final_state.get('final_answer', '') or "未能生成答案",
# 查询复杂度信息
"query_complexity": final_state.get('query_complexity', {}),
"is_complex_query": final_state.get('is_complex_query', False),
"retrieval_path": "complex_hipporag" if final_state.get('is_complex_query', False) else "simple_vector",
"iterations": final_state.get('current_iteration', 0),
"total_passages": len(final_state.get('all_passages', [])),
"sub_queries": final_state.get('sub_queries', []),
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
"initial_retrieval_details": final_state.get('initial_retrieval_details', {}),
"sufficiency_check": final_state.get('sufficiency_check', {}),
"current_sub_queries": final_state.get('current_sub_queries', []),
"is_sufficient": final_state.get('is_sufficient', False),
# 完整的文档和段落数据(本地文件专用)
"all_documents": final_state.get('all_documents', []),
"all_passages": final_state.get('all_passages', []),
"passage_sources": final_state.get('passage_sources', []),
# PageRank数据已移除状态存储改为本地文件存储避免LangSmith传输
"pagerank_data_available": final_state.get('pagerank_data_available', False),
"pagerank_summary": final_state.get('pagerank_summary', {}),
"concept_exploration_results": final_state.get('concept_exploration_results', {}),
"exploration_round": final_state.get('exploration_round', 0),
"debug_info": {
"total_time": total_time,
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
"langsmith_project": os.getenv('LANGCHAIN_PROJECT', 'hipporag-retriever'),
# Token使用统计
"token_usage_summary": final_token_info,
# 路径统计
"complexity_analysis": {
"is_complex": final_state.get('is_complex_query', False),
"complexity_level": final_state.get('query_complexity', {}).get('complexity_level', 'unknown'),
"confidence": final_state.get('query_complexity', {}).get('confidence', 0),
"reason": final_state.get('query_complexity', {}).get('reason', '')
},
# 调试模式信息
"debug_mode_analysis": {
"debug_mode": final_state.get('debug_mode', '0'),
"debug_override": final_state.get('query_complexity', {}).get('debug_override', {}),
"path_override_applied": bool(final_state.get('query_complexity', {}).get('debug_override', {}))
},
# 充分性检查历史
"sufficiency_analysis": {
"final_sufficiency": final_state.get('is_sufficient', False),
"sufficiency_check_details": final_state.get('sufficiency_check', {}),
"iteration_sufficiency_history": [
{
"iteration": item.get('iteration', 0),
"is_sufficient": item.get('is_sufficient', False),
"confidence": item.get('sufficiency_check', {}).get('confidence', 0),
"reason": item.get('sufficiency_check', {}).get('reason', '')
}
for item in final_state.get('iteration_history', [])
if 'sufficiency_check' in item
],
"sufficiency_progression": self._analyze_sufficiency_progression(final_state)
},
# 路由决策历史
"routing_analysis": {
"total_routing_decisions": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') in ['sufficiency_check', 'sub_query_generation', 'parallel_retrieval', 'collect_pagerank_scores']
]),
"sub_query_generation_count": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') == 'sub_query_generation'
]),
"parallel_retrieval_count": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') == 'parallel_retrieval'
]),
"pagerank_collection_count": len([
item for item in final_state.get('iteration_history', [])
if item.get('action') == 'collect_pagerank_scores'
])
},
# 概念探索分析(新增)
"concept_exploration_analysis": {
"exploration_enabled": final_state.get('exploration_round', 0) > 0,
"exploration_rounds": final_state.get('exploration_round', 0),
"pagerank_nodes_analyzed": len(final_state.get('pagerank_summary', {}).get('all_nodes_sorted', [])),
"successful_branches_total": sum([
round_data.get('successful_branches', 0)
for round_key, round_data in final_state.get('concept_exploration_results', {}).items()
if round_key.startswith('round_')
]),
"total_branches_attempted": sum([
round_data.get('total_branches', 0)
for round_key, round_data in final_state.get('concept_exploration_results', {}).items()
if round_key.startswith('round_')
])
}
},
"all_passages": final_state.get('all_passages', []),
"all_documents": final_state.get('all_documents', []), # 添加文档列表
"iteration_history": final_state.get('iteration_history', [])
}
def _clean_state_for_langsmith(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""清理状态中的大型数据结构避免发送给LangSmith"""
cleaned_state = state.copy()
# PageRank数据已从状态中移除不需要清理
if cleaned_state.get('pagerank_data_available', False):
print(f"[?] PageRank数据已存储在本地未包含在状态中")
# 清理概念探索结果中的大型数据
if 'concept_exploration_results' in cleaned_state:
cleaned_exploration = {}
for key, value in cleaned_state['concept_exploration_results'].items():
if isinstance(value, dict):
# 只保留统计信息,移除具体的节点数据
cleaned_exploration[key] = {
'total_branches': value.get('total_branches', 0),
'successful_branches': value.get('successful_branches', 0),
'exploration_type': value.get('exploration_type', 'unknown')
}
else:
cleaned_exploration[key] = value
cleaned_state['concept_exploration_results'] = cleaned_exploration
print(f"[?] 已清理概念探索详细结果,只保留统计信息")
# 清理PageRank汇总中的详细节点数据
if 'pagerank_summary' in cleaned_state:
summary = cleaned_state['pagerank_summary']
if isinstance(summary, dict) and 'all_nodes_sorted' in summary:
nodes_count = len(summary.get('all_nodes_sorted', []))
cleaned_summary = {k: v for k, v in summary.items() if k != 'all_nodes_sorted'}
cleaned_summary['nodes_count'] = nodes_count
cleaned_state['pagerank_summary'] = cleaned_summary
print(f"[?] 已清理 {nodes_count} 个PageRank节点详情只保留统计")
return cleaned_state
def _analyze_sufficiency_progression(self, final_state) -> Dict[str, Any]:
"""分析充分性检查的进展"""
iteration_history = final_state.get('iteration_history', [])
sufficiency_checks = [
item for item in iteration_history
if 'sufficiency_check' in item
]
if not sufficiency_checks:
return {"status": "no_sufficiency_checks"}
# 分析进展模式
confidences = [sc.get('sufficiency_check', {}).get('confidence', 0) for sc in sufficiency_checks]
sufficiencies = [sc.get('is_sufficient', False) for sc in sufficiency_checks]
progression_pattern = "unknown"
if len(sufficiencies) >= 2:
if not sufficiencies[0] and sufficiencies[-1]:
progression_pattern = "improved_to_sufficient"
elif all(sufficiencies):
progression_pattern = "consistently_sufficient"
elif not any(sufficiencies):
progression_pattern = "consistently_insufficient"
else:
progression_pattern = "mixed"
return {
"total_checks": len(sufficiency_checks),
"confidence_progression": confidences,
"sufficiency_progression": sufficiencies,
"pattern": progression_pattern,
"final_confidence": confidences[-1] if confidences else 0,
"confidence_improvement": confidences[-1] - confidences[0] if len(confidences) >= 2 else 0
}
def _get_token_usage_info(self) -> Dict[str, Any]:
"""
获取当前的Token使用信息
"""
try:
# 尝试从不同的属性路径获取Token信息
debug_info = {}
# 检查self.llm是否存在
if hasattr(self, 'llm'):
debug_info["has_llm"] = True
# 检查oneapi_generator是否存在
if hasattr(self.llm, 'oneapi_generator'):
debug_info["has_generator"] = True
generator = self.llm.oneapi_generator
# 获取Token统计
last_usage = getattr(generator, 'last_token_usage', {})
total_usage = getattr(generator, 'total_token_usage', {})
model_name = getattr(generator, 'model_name', 'unknown')
debug_info.update({
"last_call": last_usage,
"total_usage": total_usage,
"model_name": model_name,
"has_last_usage": bool(last_usage),
"has_total_usage": bool(total_usage)
})
return debug_info
else:
debug_info["has_generator"] = False
else:
debug_info["has_llm"] = False
debug_info["error"] = "无法找到Token信息"
return debug_info
except Exception as e:
return {
"error": f"获取Token信息失败: {str(e)}",
"exception_type": type(e).__name__
}
@traceable(name="Simple_Retrieve")
def retrieve_simple(self, query: str, mode: str = "0") -> str:
"""简单检索接口"""
result = self.retrieve(query, mode)
return result.get('answer', '')
def _build_langsmith_safe_result(self, final_state, total_time: float) -> Dict[str, Any]:
"""构建LangSmith安全的结果移除大型数据结构以避免传输限制"""
final_token_info = self._get_token_usage_info()
return {
"query": final_state.get('original_query', ''),
"answer": final_state.get('final_answer', '') or "未能生成答案",
# 查询复杂度信息
"query_complexity": final_state.get('query_complexity', {}),
"is_complex_query": final_state.get('is_complex_query', False),
"retrieval_path": "complex_hipporag" if final_state.get('is_complex_query', False) else "simple_vector",
"iterations": final_state.get('current_iteration', 0),
"total_passages": len(final_state.get('all_passages', [])),
"sub_queries": final_state.get('sub_queries', []),
"decomposed_sub_queries": final_state.get('decomposed_sub_queries', []),
"sufficiency_check": final_state.get('sufficiency_check', {}),
"is_sufficient": final_state.get('is_sufficient', False),
# 添加文档和段落数据langsmith_example.py需要这些数据
# 注意这些数据仅供本地脚本使用不会发送到LangSmith web端
"all_documents": final_state.get('all_documents', []),
"all_passages": final_state.get('all_passages', []),
# 简化的统计信息,不包含大型数据结构
"pagerank_summary_stats": {
"data_available": final_state.get('pagerank_data_available', False),
"exploration_rounds": final_state.get('exploration_round', 0),
"has_concept_exploration": bool(final_state.get('concept_exploration_results', {}))
},
"debug_info": {
"total_time": total_time,
"retrieval_calls": final_state.get('debug_info', {}).get('retrieval_calls', 0),
"llm_calls": final_state.get('debug_info', {}).get('llm_calls', 0),
"langsmith_project": os.getenv('LANGCHAIN_PROJECT', 'hipporag-retriever'),
"token_usage_summary": final_token_info,
"complexity_analysis": {
"is_complex": final_state.get('is_complex_query', False),
"complexity_level": final_state.get('query_complexity', {}).get('complexity_level', 'unknown'),
"confidence": final_state.get('query_complexity', {}).get('confidence', 0),
"reason": final_state.get('query_complexity', {}).get('reason', '')
},
"final_sufficiency": final_state.get('is_sufficient', False),
"note": "完整结果已保存到本地文件此为LangSmith优化版本"
}
}
def _save_full_result_to_file(self, full_result: Dict[str, Any]):
"""将完整结果保存到本地文件"""
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"langsmith_full_{timestamp}.json"
# 创建json_langsmith目录如果不存在
json_dir = os.path.join(os.path.dirname(__file__), "json_langsmith")
os.makedirs(json_dir, exist_ok=True)
filepath = os.path.join(json_dir, filename)
# 序列化时需要特殊处理numpy数组和Document对象等
def json_serializer(obj):
# 处理langchain Document对象
if hasattr(obj, 'page_content') and hasattr(obj, 'metadata'):
return {
'page_content': obj.page_content,
'metadata': obj.metadata
}
# 处理numpy数组
elif hasattr(obj, 'tolist'):
return obj.tolist()
# 处理其他自定义对象
elif hasattr(obj, '__dict__'):
try:
return obj.__dict__
except:
return str(obj)
else:
return str(obj)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(full_result, f, ensure_ascii=False, indent=2, default=json_serializer)
print(f"[FOLDER] 完整结果已保存: {filename}")
except Exception as e:
print(f"[WARNING] 保存完整结果失败: {e}")
def create_langsmith_retriever(
keyword: str,
top_k: int = 2,
max_iterations: int = 2,
max_parallel_retrievals: int = 2,
langsmith_project: Optional[str] = None,
**kwargs
) -> LangSmithIterativeRetriever:
"""创建带LangSmith监控的迭代检索器"""
return LangSmithIterativeRetriever(
keyword=keyword,
top_k=top_k,
max_iterations=max_iterations,
max_parallel_retrievals=max_parallel_retrievals,
langsmith_project=langsmith_project,
**kwargs
)
# 用于检查LangSmith连接状态的工具函数
def check_langsmith_connection() -> bool:
"""检查LangSmith连接状态带重试机制"""
import time
max_retries = 3
retry_delay = 2 # 秒
for attempt in range(max_retries):
try:
from langsmith import Client
api_key = os.getenv("LANGCHAIN_API_KEY")
if not api_key or api_key == "your_langsmith_api_key_here":
print("[ERROR] LangSmith API密钥未设置或无效")
return False
# 设置较短的超时时间
client = Client()
# 尝试获取项目列表来验证连接
projects = list(client.list_projects(limit=1))
print("[OK] LangSmith连接正常")
return True
except Exception as e:
error_msg = str(e)
# 检查是否为已知的服务不可用错误
if any(keyword in error_msg.lower() for keyword in ['503', 'service unavailable', 'server error', 's3', 'timeout', 'deadline exceeded']):
print(f"[WARNING] LangSmith服务暂时不可用 (尝试 {attempt + 1}/{max_retries}): {error_msg[:100]}...")
if attempt < max_retries - 1: # 还有重试次数
print(f"{retry_delay}秒后重试...")
time.sleep(retry_delay)
retry_delay *= 1.5 # 递增重试间隔
continue
else:
print("[ERROR] LangSmith服务持续不可用将在本地模式下运行")
return False
else:
print(f"[ERROR] LangSmith连接失败: {e}")
return False
return False