Files
AIEC-RAG/retriver/langsmith/extract_data.py
2025-09-24 09:29:12 +08:00

161 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据提取脚本
从langsmith检索结果中提取关键数据并生成JSON文件
"""
import os
import sys
import json
from typing import Dict, Any, List, Tuple
from datetime import datetime
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
def extract_supporting_facts(result: Dict[str, Any]) -> List[List[str]]:
"""
提取支撑段落信息
返回格式: [["前30个字符", "段落ID"], ...]
"""
supporting_facts = []
# 从all_documents中提取文档
all_documents = result.get('all_documents', [])
for doc in all_documents:
if hasattr(doc, 'page_content') and hasattr(doc, 'metadata'):
# 提取前30个字符
content_preview = doc.page_content[:30] if doc.page_content else ""
# 提取文档ID尝试不同的metadata字段
doc_id = (doc.metadata.get('id') or
doc.metadata.get('document_id') or
doc.metadata.get('chunk_id') or
doc.metadata.get('source') or
f"doc_{hash(doc.page_content) % 100000}")
supporting_facts.append([content_preview, str(doc_id)])
# 如果没有从documents中获取到尝试从passages中获取
if not supporting_facts:
all_passages = result.get('all_passages', [])
for i, passage in enumerate(all_passages):
content_preview = passage[:30] if passage else ""
doc_id = f"passage_{i}"
supporting_facts.append([content_preview, doc_id])
return supporting_facts
def extract_final_answer(result: Dict[str, Any]) -> str:
"""
提取最终生成的答案
"""
return result.get('answer', result.get('final_answer', ''))
def run_extraction(query: str = "DATAOPS是如何与中国通信标准化协会的大会建立联系的",
debug_mode: str = "complex") -> Dict[str, Any]:
"""
运行检索并提取数据
"""
try:
# 生成带时间戳的项目名称
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
project_name = f"{timestamp}-extraction"
print(f"[CONFIG] 创建检索器...")
retriever = create_langsmith_retriever(
keyword="test",
top_k=3,
max_iterations=2,
max_parallel_retrievals=2,
langsmith_project=project_name
)
print(f"[TARGET] 执行检索: {query}")
result = retriever.retrieve(query, debug_mode)
# 提取数据
final_answer = extract_final_answer(result)
supporting_facts = extract_supporting_facts(result)
# 构造输出数据
output_data = {
"query": query,
"final_answer": final_answer,
"pred_supporting_facts": supporting_facts,
"gold_supporting_facts": supporting_facts, # 按您的要求复制相同内容
"extraction_timestamp": timestamp,
"langsmith_project": project_name
}
return output_data
except Exception as e:
print(f"[ERROR] 提取失败: {e}")
import traceback
traceback.print_exc()
return {}
def main():
"""主函数"""
print("[STARTING] 数据提取脚本启动")
print("="*50)
# 检查LangSmith连接
print("[INFO] 检查LangSmith连接...")
langsmith_ok = check_langsmith_connection()
if not langsmith_ok:
print("[WARNING] LangSmith连接失败但提取仍可继续")
else:
print("[OK] LangSmith连接正常")
# 运行数据提取
query = "DATAOPS是如何与中国通信标准化协会的大会建立联系的"
debug_mode = "complex"
print(f"\n[INFO] 提取查询: {query}")
print(f"[BUG] 调试模式: {debug_mode}")
extracted_data = run_extraction(query, debug_mode)
if extracted_data:
# 保存为JSON文件
output_dir = os.path.dirname(__file__)
output_file = os.path.join(output_dir, "output.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(extracted_data, f, ensure_ascii=False, indent=2)
print(f"\n[OK] 数据提取完成!")
print(f"[FILE] 输出文件: {output_file}")
print(f"[INFO] 提取到 {len(extracted_data.get('pred_supporting_facts', []))} 个支撑段落")
print(f"[NOTE] 最终答案长度: {len(extracted_data.get('final_answer', ''))} 字符")
# 显示提取结果预览
print(f"\n[INFO] 提取结果预览:")
print(f" 查询: {extracted_data.get('query', '')}")
answer = extracted_data.get('final_answer', '')
print(f" 答案: {answer[:100]}{'...' if len(answer) > 100 else ''}")
supporting_facts = extracted_data.get('pred_supporting_facts', [])
print(f" 支撑段落 ({len(supporting_facts)}个):")
for i, fact in enumerate(supporting_facts[:3]): # 只显示前3个
if len(fact) >= 2:
print(f" {i+1}. {fact[0]}... (ID: {fact[1]})")
if len(supporting_facts) > 3:
print(f" ... 还有 {len(supporting_facts) - 3} 个段落")
else:
print("[ERROR] 数据提取失败")
if __name__ == "__main__":
main()