Files
AIEC-new/AIEC-RAG/retriver/langsmith/extract_data.py

161 lines
5.4 KiB
Python
Raw Normal View History

2025-10-17 09:31:28 +08:00
"""
数据提取脚本
从langsmith检索结果中提取关键数据并生成JSON文件
"""
import os
import sys
import json
from typing import Dict, Any, List, Tuple
from datetime import datetime
# 添加路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
def extract_supporting_facts(result: Dict[str, Any]) -> List[List[str]]:
"""
提取支撑段落信息
返回格式: [["前30个字符", "段落ID"], ...]
"""
supporting_facts = []
# 从all_documents中提取文档
all_documents = result.get('all_documents', [])
for doc in all_documents:
if hasattr(doc, 'page_content') and hasattr(doc, 'metadata'):
# 提取前30个字符
content_preview = doc.page_content[:30] if doc.page_content else ""
# 提取文档ID尝试不同的metadata字段
doc_id = (doc.metadata.get('id') or
doc.metadata.get('document_id') or
doc.metadata.get('chunk_id') or
doc.metadata.get('source') or
f"doc_{hash(doc.page_content) % 100000}")
supporting_facts.append([content_preview, str(doc_id)])
# 如果没有从documents中获取到尝试从passages中获取
if not supporting_facts:
all_passages = result.get('all_passages', [])
for i, passage in enumerate(all_passages):
content_preview = passage[:30] if passage else ""
doc_id = f"passage_{i}"
supporting_facts.append([content_preview, doc_id])
return supporting_facts
def extract_final_answer(result: Dict[str, Any]) -> str:
"""
提取最终生成的答案
"""
return result.get('answer', result.get('final_answer', ''))
def run_extraction(query: str = "DATAOPS是如何与中国通信标准化协会的大会建立联系的",
debug_mode: str = "complex") -> Dict[str, Any]:
"""
运行检索并提取数据
"""
try:
# 生成带时间戳的项目名称
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
project_name = f"{timestamp}-extraction"
print(f"[CONFIG] 创建检索器...")
retriever = create_langsmith_retriever(
keyword="test",
top_k=3,
max_iterations=2,
max_parallel_retrievals=2,
langsmith_project=project_name
)
print(f"[TARGET] 执行检索: {query}")
result = retriever.retrieve(query, debug_mode)
# 提取数据
final_answer = extract_final_answer(result)
supporting_facts = extract_supporting_facts(result)
# 构造输出数据
output_data = {
"query": query,
"final_answer": final_answer,
"pred_supporting_facts": supporting_facts,
"gold_supporting_facts": supporting_facts, # 按您的要求复制相同内容
"extraction_timestamp": timestamp,
"langsmith_project": project_name
}
return output_data
except Exception as e:
print(f"[ERROR] 提取失败: {e}")
import traceback
traceback.print_exc()
return {}
def main():
"""主函数"""
print("[STARTING] 数据提取脚本启动")
print("="*50)
# 检查LangSmith连接
print("[INFO] 检查LangSmith连接...")
langsmith_ok = check_langsmith_connection()
if not langsmith_ok:
print("[WARNING] LangSmith连接失败但提取仍可继续")
else:
print("[OK] LangSmith连接正常")
# 运行数据提取
query = "DATAOPS是如何与中国通信标准化协会的大会建立联系的"
debug_mode = "complex"
print(f"\n[INFO] 提取查询: {query}")
print(f"[BUG] 调试模式: {debug_mode}")
extracted_data = run_extraction(query, debug_mode)
if extracted_data:
# 保存为JSON文件
output_dir = os.path.dirname(__file__)
output_file = os.path.join(output_dir, "output.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(extracted_data, f, ensure_ascii=False, indent=2)
print(f"\n[OK] 数据提取完成!")
print(f"[FILE] 输出文件: {output_file}")
print(f"[INFO] 提取到 {len(extracted_data.get('pred_supporting_facts', []))} 个支撑段落")
print(f"[NOTE] 最终答案长度: {len(extracted_data.get('final_answer', ''))} 字符")
# 显示提取结果预览
print(f"\n[INFO] 提取结果预览:")
print(f" 查询: {extracted_data.get('query', '')}")
answer = extracted_data.get('final_answer', '')
print(f" 答案: {answer[:100]}{'...' if len(answer) > 100 else ''}")
supporting_facts = extracted_data.get('pred_supporting_facts', [])
print(f" 支撑段落 ({len(supporting_facts)}个):")
for i, fact in enumerate(supporting_facts[:3]): # 只显示前3个
if len(fact) >= 2:
print(f" {i+1}. {fact[0]}... (ID: {fact[1]})")
if len(supporting_facts) > 3:
print(f" ... 还有 {len(supporting_facts) - 3} 个段落")
else:
print("[ERROR] 数据提取失败")
if __name__ == "__main__":
main()