""" 数据提取脚本 从langsmith检索结果中提取关键数据并生成JSON文件 """ import os import sys import json from typing import Dict, Any, List, Tuple from datetime import datetime # 添加路径 project_root = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.append(project_root) from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection def extract_supporting_facts(result: Dict[str, Any]) -> List[List[str]]: """ 提取支撑段落信息 返回格式: [["前30个字符", "段落ID"], ...] """ supporting_facts = [] # 从all_documents中提取文档 all_documents = result.get('all_documents', []) for doc in all_documents: if hasattr(doc, 'page_content') and hasattr(doc, 'metadata'): # 提取前30个字符 content_preview = doc.page_content[:30] if doc.page_content else "" # 提取文档ID(尝试不同的metadata字段) doc_id = (doc.metadata.get('id') or doc.metadata.get('document_id') or doc.metadata.get('chunk_id') or doc.metadata.get('source') or f"doc_{hash(doc.page_content) % 100000}") supporting_facts.append([content_preview, str(doc_id)]) # 如果没有从documents中获取到,尝试从passages中获取 if not supporting_facts: all_passages = result.get('all_passages', []) for i, passage in enumerate(all_passages): content_preview = passage[:30] if passage else "" doc_id = f"passage_{i}" supporting_facts.append([content_preview, doc_id]) return supporting_facts def extract_final_answer(result: Dict[str, Any]) -> str: """ 提取最终生成的答案 """ return result.get('answer', result.get('final_answer', '')) def run_extraction(query: str = "DATAOPS是如何与中国通信标准化协会的大会建立联系的?", debug_mode: str = "complex") -> Dict[str, Any]: """ 运行检索并提取数据 """ try: # 生成带时间戳的项目名称 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") project_name = f"{timestamp}-extraction" print(f"[CONFIG] 创建检索器...") retriever = create_langsmith_retriever( keyword="test", top_k=3, max_iterations=2, max_parallel_retrievals=2, langsmith_project=project_name ) print(f"[TARGET] 执行检索: {query}") result = retriever.retrieve(query, debug_mode) # 提取数据 final_answer = extract_final_answer(result) supporting_facts = extract_supporting_facts(result) # 构造输出数据 output_data = { "query": query, "final_answer": final_answer, "pred_supporting_facts": supporting_facts, "gold_supporting_facts": supporting_facts, # 按您的要求复制相同内容 "extraction_timestamp": timestamp, "langsmith_project": project_name } return output_data except Exception as e: print(f"[ERROR] 提取失败: {e}") import traceback traceback.print_exc() return {} def main(): """主函数""" print("[STARTING] 数据提取脚本启动") print("="*50) # 检查LangSmith连接 print("[INFO] 检查LangSmith连接...") langsmith_ok = check_langsmith_connection() if not langsmith_ok: print("[WARNING] LangSmith连接失败,但提取仍可继续") else: print("[OK] LangSmith连接正常") # 运行数据提取 query = "DATAOPS是如何与中国通信标准化协会的大会建立联系的?" debug_mode = "complex" print(f"\n[INFO] 提取查询: {query}") print(f"[BUG] 调试模式: {debug_mode}") extracted_data = run_extraction(query, debug_mode) if extracted_data: # 保存为JSON文件 output_dir = os.path.dirname(__file__) output_file = os.path.join(output_dir, "output.json") with open(output_file, 'w', encoding='utf-8') as f: json.dump(extracted_data, f, ensure_ascii=False, indent=2) print(f"\n[OK] 数据提取完成!") print(f"[FILE] 输出文件: {output_file}") print(f"[INFO] 提取到 {len(extracted_data.get('pred_supporting_facts', []))} 个支撑段落") print(f"[NOTE] 最终答案长度: {len(extracted_data.get('final_answer', ''))} 字符") # 显示提取结果预览 print(f"\n[INFO] 提取结果预览:") print(f" 查询: {extracted_data.get('query', '')}") answer = extracted_data.get('final_answer', '') print(f" 答案: {answer[:100]}{'...' if len(answer) > 100 else ''}") supporting_facts = extracted_data.get('pred_supporting_facts', []) print(f" 支撑段落 ({len(supporting_facts)}个):") for i, fact in enumerate(supporting_facts[:3]): # 只显示前3个 if len(fact) >= 2: print(f" {i+1}. {fact[0]}... (ID: {fact[1]})") if len(supporting_facts) > 3: print(f" ... 还有 {len(supporting_facts) - 3} 个段落") else: print("[ERROR] 数据提取失败") if __name__ == "__main__": main()