Files
AIEC-RAG---/AIEC-RAG/retriver/langsmith/langsmith_example.py
2025-09-25 10:33:37 +08:00

547 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LangSmith监控使用示例 - 交互式版本
手动输入查询问题,执行完整检索流程并监控
"""
import os
import sys
import json
import pickle
from typing import Dict, Any, List
from datetime import datetime
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
def load_graph(pkl_path: str):
"""加载图数据"""
try:
with open(pkl_path, 'rb') as f:
graph = pickle.load(f)
return graph
except Exception as e:
print(f"[WARNING] 加载图数据失败: {e}")
return None
def extract_evidences_path(passage_ids: List[str], pkl_path: str = "test_with_concept.pkl") -> List[List[List[str]]]:
"""
根据段落ID提取相关的三元组
提取段落连接的实体和事件之间的关系三元组,而不是段落与实体的直接关系
Args:
passage_ids: 段落ID列表
pkl_path: 图数据文件路径
Returns:
每个段落对应的三元组列表,格式: [[[头实体, 关系, 尾实体], ...], ...]
"""
graph = load_graph(pkl_path)
if graph is None:
return []
all_evidences = []
for passage_id in passage_ids:
passage_evidences = []
# 在图中查找对应的文本节点
text_node_id = None
# 方法1: 直接使用passage_id作为节点ID
if passage_id in graph.nodes:
text_node_id = passage_id
else:
# 方法2: 在所有文本节点中查找匹配的ID
for node_id, node_data in graph.nodes(data=True):
if isinstance(node_data, dict):
node_type = node_data.get('type', '').lower()
if 'text' in node_type:
# 检查各种可能的ID字段
if (node_data.get('id') == passage_id or
node_data.get('name') == passage_id or
node_id == passage_id):
text_node_id = node_id
break
if text_node_id is None:
print(f"[WARNING] 未找到段落ID对应的文本节点: {passage_id}")
passage_evidences = []
else:
# 步骤1: 获取该文本节点连接的所有entity和event节点
connected_entities_events = []
neighbors = list(graph.neighbors(text_node_id))
for neighbor_id in neighbors:
neighbor_data = graph.nodes.get(neighbor_id, {})
if isinstance(neighbor_data, dict):
neighbor_type = neighbor_data.get('type', '').lower()
if 'entity' in neighbor_type or 'event' in neighbor_type:
connected_entities_events.append(neighbor_id)
print(f"[OK] 段落 {passage_id[:20]}... 连接了 {len(connected_entities_events)} 个实体/事件节点")
# 步骤2: 在这些实体和事件节点之间寻找连接关系
seen_triplets = set() # 用于去重
for i, entity1_id in enumerate(connected_entities_events):
entity1_data = graph.nodes.get(entity1_id, {})
entity1_name = entity1_data.get('name', entity1_data.get('id', str(entity1_id)))
entity1_type = entity1_data.get('type', '').lower()
# 检查entity1与其他实体/事件的连接
for j, entity2_id in enumerate(connected_entities_events):
if i >= j: # 避免重复检查和自连接
continue
entity2_data = graph.nodes.get(entity2_id, {})
entity2_name = entity2_data.get('name', entity2_data.get('id', str(entity2_id)))
entity2_type = entity2_data.get('type', '').lower()
# 检查两个节点之间是否有边连接
edge_data = graph.get_edge_data(entity1_id, entity2_id)
reverse_edge_data = graph.get_edge_data(entity2_id, entity1_id)
if edge_data:
# entity1 -> entity2
relation = "连接" # 默认关系
if isinstance(edge_data, dict):
relation = edge_data.get('relation', edge_data.get('label', '连接'))
triplet = [entity1_name, relation, entity2_name]
triplet_key = (entity1_name, relation, entity2_name)
if triplet_key not in seen_triplets:
seen_triplets.add(triplet_key)
passage_evidences.append(triplet)
if reverse_edge_data and edge_data != reverse_edge_data:
# entity2 -> entity1 (如果与正向边不同)
relation = "连接" # 默认关系
if isinstance(reverse_edge_data, dict):
relation = reverse_edge_data.get('relation', reverse_edge_data.get('label', '连接'))
triplet = [entity2_name, relation, entity1_name]
triplet_key = (entity2_name, relation, entity1_name)
if triplet_key not in seen_triplets:
seen_triplets.add(triplet_key)
passage_evidences.append(triplet)
print(f" 提取到 {len(passage_evidences)} 个实体间三元组")
all_evidences.append(passage_evidences)
return all_evidences
def main():
"""主函数 - 硬编码查询测试和监控"""
# ================================
# 在这里修改您要测试的问题和调试模式
# ================================
query = "混沌工程的定义是什么DataOps是什么" # 简化查询便于快速测试
debug_mode = "complex" # 可选值: "0"(自动判断), "simple"(强制简单路径), "complex"(强制复杂路径)
print("[STARTING] LangSmith监控检索系统 (基于阿里云DashScope)")
print("="*50)
print("[TIP] 本系统将自动:")
print(" • 分析查询复杂度 (阿里云通义千问)")
print(" • 选择最优检索路径")
print(" • 执行混合检索 (事件节点+段落节点)")
print(" • 迭代推理检索和充分性检查")
print(" • 多轮子查询生成和检索")
print(" • 提供详细的执行监控")
print(" • 在LangSmith中记录完整过程")
# 检查LangSmith连接
print("\n[INFO] 检查LangSmith连接...")
langsmith_ok = check_langsmith_connection()
if not langsmith_ok:
print("[WARNING] LangSmith连接失败但检索器仍可正常工作")
else:
print("[OK] LangSmith连接正常")
print(f"\n[INFO] 测试查询: {query}")
print(f"[BUG] 调试模式: {debug_mode}")
print(f" {'自动复杂度判断' if debug_mode == '0' else '强制简单路径' if debug_mode == 'simple' else '强制复杂路径' if debug_mode == 'complex' else '未知模式'}")
try:
# 生成带时间戳的项目名称
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
project_name = f"{timestamp}-{query}"
print(f"\n[PACKAGE] 创建新的LangSmith项目: {project_name}")
# 创建带LangSmith监控的检索器
print("[CONFIG] 初始化检索器...")
retriever = create_langsmith_retriever(
keyword="test",
top_k=13, # 匹配新的检索器设置10事件+3段落
max_iterations=1, # 减少迭代次数
max_parallel_retrievals=1, # 减少并行数
langsmith_project=project_name
)
print("[OK] 检索器创建成功")
# 执行检索(完整流程追踪)
print(f"\n[TARGET] 开始检索...")
print("[TIP] 完整过程将在LangSmith中追踪包括:")
print(" • 查询复杂度判断")
print(" • 路径选择 (简单向量检索 vs 复杂推理检索)")
print(" • 混合检索 (TOP-10事件节点 + TOP-3段落节点)")
print(" • 迭代检索和充分性检查")
print(" • 子查询生成和并行检索")
print(" • 最终答案生成过程")
result = retriever.retrieve(query, debug_mode)
# 输出详细结果
print("\n" + "="*60)
print("[INFO] 检索完成 - 详细结果")
print("="*60)
# 1. 查询路径分析
complexity = result.get('query_complexity', {})
debug_override = complexity.get('debug_override', {})
print(f"[?] 查询路径分析:")
print(f" 查询复杂度: {complexity.get('complexity_level', 'unknown').upper()}")
print(f" 置信度: {complexity.get('confidence', 0):.2f}")
print(f" 执行路径: {result.get('retrieval_path', 'unknown')}")
print(f" 判断依据: {complexity.get('reason', '')[:120]}{'...' if len(complexity.get('reason', '')) > 120 else ''}")
if debug_override:
print(f" [BUG] 调试覆盖: {debug_override.get('original_complexity')}{result.get('is_complex_query')}")
# 2. 执行流程概览
debug_info = result.get('debug_info', {})
iterations = result.get('iterations', 0)
is_sufficient = result.get('is_sufficient', False)
print(f"\n[CHART] 执行流程概览:")
print(f" 总耗时: {debug_info.get('total_time', 0):.2f}")
print(f" 迭代次数: {iterations}")
print(f" LLM调用: {debug_info.get('llm_calls', 0)}")
print(f" 检索内容: {result.get('total_passages', 0)}个 (事件+段落)")
print(f" 最终状态: {'[OK] 信息充分' if is_sufficient else '[WARNING] 信息不充分但已达上限'}")
# 3. 查询分解与子查询生成
decomposed_queries = result.get('decomposed_sub_queries', [])
all_sub_queries = result.get('sub_queries', [])
if decomposed_queries or all_sub_queries:
print(f"\n[TARGET] 查询分解与子查询:")
if decomposed_queries:
print(f" [INFO] 初始分解 ({len(decomposed_queries)}个):")
for i, sub_q in enumerate(decomposed_queries, 1):
print(f" {i}. {sub_q}")
additional_queries = [q for q in all_sub_queries if q not in decomposed_queries]
if additional_queries:
print(f" [RELOAD] 迭代生成 ({len(additional_queries)}个):")
for i, sub_q in enumerate(additional_queries, 1):
print(f" {i}. {sub_q}")
# 4. 充分性检查演进
sufficiency_analysis = debug_info.get('sufficiency_analysis', {})
sufficiency_history = sufficiency_analysis.get('iteration_sufficiency_history', [])
final_sufficiency = result.get('sufficiency_check', {})
if sufficiency_history or final_sufficiency:
print(f"\n[THINK] 充分性检查演进:")
if sufficiency_history:
for hist in sufficiency_history:
iter_num = hist.get('iteration', 0)
is_suff = hist.get('is_sufficient', False)
conf = hist.get('confidence', 0)
status = '[OK] 充分' if is_suff else '[ERROR] 不充分'
print(f" 迭代{iter_num}: {status} (置信度: {conf:.2f})")
if final_sufficiency:
final_reason = final_sufficiency.get('reason', '')
print(f" 最终原因: {final_reason[:150]}{'...' if len(final_reason) > 150 else ''}")
# 5. Token使用统计
token_info = debug_info.get('token_usage_summary', {})
if token_info and not token_info.get('error'):
total_usage = token_info.get('total_usage', {})
print(f"\n[NUMBER] 资源使用统计:")
print(f" 模型: {token_info.get('model_name', 'unknown')}")
print(f" 输入Token: {total_usage.get('prompt_tokens', 0):,}")
print(f" 输出Token: {total_usage.get('completion_tokens', 0):,}")
print(f" 总计Token: {total_usage.get('total_tokens', 0):,}")
print(f" 调用次数: {total_usage.get('call_count', 0)}")
# 6. 检索状态摘要
state_summary = result.get('state_summary', {})
if state_summary:
print(f"\n[FILE] 检索状态摘要:")
retrieval_path = state_summary.get('retrieval_path', '')
final_state = state_summary.get('final_state', '')
total_queries_processed = state_summary.get('total_queries_processed', 0)
if retrieval_path:
print(f" 执行路径: {retrieval_path}")
if final_state:
print(f" 最终状态: {final_state}")
if total_queries_processed > 0:
print(f" 处理查询数: {total_queries_processed}")
# 显示迭代历史
iteration_history = result.get('iteration_history', [])
if iteration_history:
print(f" 迭代历史:")
for hist in iteration_history[-3:]: # 只显示最后3次
iter_num = hist.get('iteration', 0)
action = hist.get('action', 'unknown')
print(f" 迭代{iter_num}: {action}")
# 7. 最终答案
print(f"\n[NOTE] 最终答案:")
print("" * 60)
answer = result.get('answer', '未能生成答案')
print(answer)
print("" * 60)
# 8. 检索结果统计(在路由分析前显示)
print(f"\n[INFO] 检索结果统计:")
# 从轻量级结果获取检索内容统计
total_passages = result.get('total_passages', 0)
print(f" 实际检索到的内容数: {total_passages} (事件节点+段落节点)")
if total_passages > 0:
print(f" 检索方式: LangSmith轻量级统计")
print(f" 内容组成: TOP-10事件节点 + TOP-3段落节点")
print(f" 详细内容已保存到本地完整文件")
print(f" [FOLDER] 完整数据位置: json_langsmith/ 目录")
else:
print(f" [WARNING] 未检索到任何内容")
# 9. 路由决策分析(高级信息)
routing_analysis = debug_info.get('routing_analysis', {})
sufficiency_progression = debug_info.get('sufficiency_analysis', {}).get('sufficiency_progression', {})
if routing_analysis or sufficiency_progression:
print(f"\n[?] 路由决策分析:")
if routing_analysis:
print(f" 总决策次数: {routing_analysis.get('total_routing_decisions', 0)}")
print(f" 子查询生成: {routing_analysis.get('sub_query_generation_count', 0)}")
print(f" 并行检索: {routing_analysis.get('parallel_retrieval_count', 0)}")
if sufficiency_progression.get('pattern'):
pattern = sufficiency_progression['pattern']
pattern_desc = {
'improved_to_sufficient': '逐步改善至充分',
'consistently_sufficient': '始终充分',
'consistently_insufficient': '始终不充分',
'mixed': '结果不一致'
}.get(pattern, pattern)
print(f" 充分性模式: {pattern_desc}")
conf_improvement = sufficiency_progression.get('confidence_improvement', 0)
if conf_improvement > 0:
print(f" 置信度提升: +{conf_improvement:.2f}")
elif conf_improvement < 0:
print(f" 置信度下降: {conf_improvement:.2f}")
# 10. LangSmith追踪信息
print(f"\n[LINK] LangSmith详细追踪:")
print(f" 项目名称: {project_name}")
print(" 访问地址: https://smith.langchain.com")
print(" 详细监控内容:")
print(" • 节点级别的执行时间和数据流")
print(" • 每一步的Token使用详情")
print(" • 复杂度判断的详细推理过程")
print(" • 混合检索的节点过滤过程 (事件+段落)")
print(" • 充分性检查的迭代过程和演进")
print(" • 子查询生成的上下文和反馈")
print(" • 路由决策的完整链路追踪")
print(" • 不同检索路径的性能对比")
print(" • 并行检索的执行情况")
print(" • 错误和异常的完整追踪")
print(f"\n[OK] 查询处理完成!")
# ================================
# 数据提取和JSON生成
# ================================
print(f"\n[FILE] 开始提取关键数据...")
try:
# 提取最终答案
final_answer = result.get('answer', '')
# 提取支撑段落信息和事件信息
supporting_facts = [] # 保持原有段落提取
supporting_events = [] # 新增事件信息提取
# 直接从result中获取真实的检索数据
print(f"[INFO] 从检索结果中提取文档数据...")
all_documents = result.get('all_documents', [])
print(f" 检索结果中的文档数: {len(all_documents)}")
seen_docs = set()
if all_documents:
for doc in all_documents:
if hasattr(doc, 'page_content') and hasattr(doc, 'metadata'):
# 提取前30个字符
content_preview = doc.page_content[:30] if doc.page_content else ""
# 提取文档ID优先使用node_id作为text_id
doc_id = (doc.metadata.get('node_id') or # HippoRAG中的节点ID对应图中的text_id
doc.metadata.get('passage_id') or # ES检索器中的段落ID
doc.metadata.get('id') or
doc.metadata.get('document_id') or
doc.metadata.get('chunk_id') or
doc.metadata.get('source') or
f"doc_{hash(doc.page_content) % 100000}")
# 获取节点类型
node_type = doc.metadata.get('node_type', 'unknown')
# 去重基于doc_id和内容的组合
doc_key = (str(doc_id), content_preview)
if doc_key not in seen_docs:
seen_docs.add(doc_key)
# 根据节点类型分别添加到不同列表
if node_type == 'event':
supporting_events.append([content_preview, str(doc_id)])
print(f" [OK] 添加事件: {content_preview}... (ID: {doc_id})")
else:
# 段落节点或其他类型都归入supporting_facts
supporting_facts.append([content_preview, str(doc_id)])
print(f" [OK] 添加段落: {content_preview}... (ID: {doc_id})")
else:
print(f" [RELOAD] 跳过重复文档: {content_preview}...")
# 如果没有真实数据,回退到轻量级模拟
if not all_documents:
total_passages = result.get('total_passages', 0)
print(f"[INFO] 回退到轻量级模拟: {total_passages} 个内容信息")
# 生成模拟事件数据TOP-10
expected_events = min(10, total_passages)
for i in range(expected_events):
content_preview = f"[事件{i+1}内容-轻量级]"
event_id = f"langsmith_event_{i}"
supporting_events.append([content_preview, event_id])
print(f" [OK] 模拟事件: {content_preview} (ID: {event_id})")
# 生成模拟段落数据TOP-3
expected_passages = min(3, max(0, total_passages - expected_events))
for i in range(expected_passages):
content_preview = f"[段落{i+1}内容-轻量级]"
passage_id = f"langsmith_passage_{i}"
supporting_facts.append([content_preview, passage_id])
print(f" [OK] 模拟段落: {content_preview} (ID: {passage_id})")
if total_passages > 13:
print(f" [INFO] 还有 {total_passages - 13} 个额外内容")
# 暂时跳过三元组提取,只记录段落和事件信息
print(f"\n[INFO] 跳过三元组提取,保存段落和事件信息")
pred_evidences_path = [] # 置空,不进行三元组提取
# 构造JSON数据
json_data = {
"query": query,
"pred_answer": final_answer,
"pred_supporting_facts": supporting_facts, # 保持原有段落字段
"pred_supporting_events": supporting_events, # 新增事件信息字段
"pred_evidences_path": pred_evidences_path,
"extraction_timestamp": timestamp,
"langsmith_project": project_name,
"total_passages": len(supporting_facts), # 段落数量
"total_events": len(supporting_events), # 事件数量
"total_content": len(supporting_facts) + len(supporting_events), # 总内容数量
"total_triplets": sum(len(evidence_list) for evidence_list in pred_evidences_path),
"answer_length": len(final_answer)
}
# 保存JSON文件
# 创建json_output目录如果不存在
json_output_dir = os.path.join(os.path.dirname(__file__), "json_output")
os.makedirs(json_output_dir, exist_ok=True)
output_file = os.path.join(json_output_dir, f"output_{timestamp}.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"[OK] 数据提取完成!")
print(f"[FILE] 输出文件: {output_file}")
print(f"[INFO] 最终统计:")
print(f" 支撑段落: {len(supporting_facts)}")
print(f" 支撑事件: {len(supporting_events)}")
print(f" 总内容数: {len(supporting_facts) + len(supporting_events)}")
print(f" 答案长度: {len(final_answer)} 字符")
print(f" 三元组提取: 已跳过")
# 检查数据一致性
if len(supporting_facts) == 0 and len(supporting_events) == 0:
print(f"[WARNING] 警告: 未提取到任何支撑信息(段落或事件),请检查检索结果格式")
print(f" 检索结果中的字段:")
print(f" - all_documents: {len(result.get('all_documents', []))}")
print(f" - all_passages: {len(result.get('all_passages', []))}")
if result.get('all_documents'):
first_doc = result.get('all_documents')[0]
print(f" - 第一个文档类型: {type(first_doc)}")
if hasattr(first_doc, 'metadata'):
print(f" - 第一个文档metadata: {first_doc.metadata}")
else:
print(f"[OK] 成功提取到信息,无需提取三元组")
if len(supporting_facts) > 0:
print(f" - 段落信息: {len(supporting_facts)}")
if len(supporting_events) > 0:
print(f" - 事件信息: {len(supporting_events)}")
# 显示提取结果预览
print(f"\n[INFO] 提取数据预览:")
print(f" 答案: {final_answer[:100]}{'...' if len(final_answer) > 100 else ''}")
# 显示支撑段落
if supporting_facts:
print(f" 支撑段落 ({len(supporting_facts)}个):")
for i, fact in enumerate(supporting_facts[:3]): # 只显示前3个
if len(fact) >= 2:
print(f" {i+1}. '{fact[0]}...' (ID: {fact[1]})")
if len(supporting_facts) > 3:
print(f" ... 还有 {len(supporting_facts) - 3} 个段落")
# 显示支撑事件
if supporting_events:
print(f" 支撑事件 ({len(supporting_events)}个):")
for i, event in enumerate(supporting_events[:3]): # 只显示前3个
if len(event) >= 2:
print(f" {i+1}. '{event[0]}...' (ID: {event[1]})")
if len(supporting_events) > 3:
print(f" ... 还有 {len(supporting_events) - 3} 个事件")
print(f"\n 三元组提取: 已跳过(保存段落和事件信息)")
except Exception as extract_error:
print(f"[WARNING] 数据提取失败: {extract_error}")
import traceback
traceback.print_exc()
except KeyboardInterrupt:
print(f"\n[WARNING] 用户中断检索过程")
print(f"[TIP] 检索已安全停止")
except Exception as e:
print(f"[ERROR] 检索失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()