161 lines
5.4 KiB
Python
161 lines
5.4 KiB
Python
"""
|
||
数据提取脚本
|
||
从langsmith检索结果中提取关键数据并生成JSON文件
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
from typing import Dict, Any, List, Tuple
|
||
from datetime import datetime
|
||
|
||
# 添加路径
|
||
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
||
sys.path.append(project_root)
|
||
|
||
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
||
|
||
|
||
def extract_supporting_facts(result: Dict[str, Any]) -> List[List[str]]:
|
||
"""
|
||
提取支撑段落信息
|
||
返回格式: [["前30个字符", "段落ID"], ...]
|
||
"""
|
||
supporting_facts = []
|
||
|
||
# 从all_documents中提取文档
|
||
all_documents = result.get('all_documents', [])
|
||
|
||
for doc in all_documents:
|
||
if hasattr(doc, 'page_content') and hasattr(doc, 'metadata'):
|
||
# 提取前30个字符
|
||
content_preview = doc.page_content[:30] if doc.page_content else ""
|
||
|
||
# 提取文档ID(尝试不同的metadata字段)
|
||
doc_id = (doc.metadata.get('id') or
|
||
doc.metadata.get('document_id') or
|
||
doc.metadata.get('chunk_id') or
|
||
doc.metadata.get('source') or
|
||
f"doc_{hash(doc.page_content) % 100000}")
|
||
|
||
supporting_facts.append([content_preview, str(doc_id)])
|
||
|
||
# 如果没有从documents中获取到,尝试从passages中获取
|
||
if not supporting_facts:
|
||
all_passages = result.get('all_passages', [])
|
||
for i, passage in enumerate(all_passages):
|
||
content_preview = passage[:30] if passage else ""
|
||
doc_id = f"passage_{i}"
|
||
supporting_facts.append([content_preview, doc_id])
|
||
|
||
return supporting_facts
|
||
|
||
|
||
def extract_final_answer(result: Dict[str, Any]) -> str:
|
||
"""
|
||
提取最终生成的答案
|
||
"""
|
||
return result.get('answer', result.get('final_answer', ''))
|
||
|
||
|
||
def run_extraction(query: str = "DATAOPS是如何与中国通信标准化协会的大会建立联系的?",
|
||
debug_mode: str = "complex") -> Dict[str, Any]:
|
||
"""
|
||
运行检索并提取数据
|
||
"""
|
||
try:
|
||
# 生成带时间戳的项目名称
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
project_name = f"{timestamp}-extraction"
|
||
|
||
print(f"[CONFIG] 创建检索器...")
|
||
retriever = create_langsmith_retriever(
|
||
keyword="test",
|
||
top_k=3,
|
||
max_iterations=2,
|
||
max_parallel_retrievals=2,
|
||
langsmith_project=project_name
|
||
)
|
||
|
||
print(f"[TARGET] 执行检索: {query}")
|
||
result = retriever.retrieve(query, debug_mode)
|
||
|
||
# 提取数据
|
||
final_answer = extract_final_answer(result)
|
||
supporting_facts = extract_supporting_facts(result)
|
||
|
||
# 构造输出数据
|
||
output_data = {
|
||
"query": query,
|
||
"final_answer": final_answer,
|
||
"pred_supporting_facts": supporting_facts,
|
||
"gold_supporting_facts": supporting_facts, # 按您的要求复制相同内容
|
||
"extraction_timestamp": timestamp,
|
||
"langsmith_project": project_name
|
||
}
|
||
|
||
return output_data
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 提取失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {}
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("[STARTING] 数据提取脚本启动")
|
||
print("="*50)
|
||
|
||
# 检查LangSmith连接
|
||
print("[INFO] 检查LangSmith连接...")
|
||
langsmith_ok = check_langsmith_connection()
|
||
if not langsmith_ok:
|
||
print("[WARNING] LangSmith连接失败,但提取仍可继续")
|
||
else:
|
||
print("[OK] LangSmith连接正常")
|
||
|
||
# 运行数据提取
|
||
query = "DATAOPS是如何与中国通信标准化协会的大会建立联系的?"
|
||
debug_mode = "complex"
|
||
|
||
print(f"\n[INFO] 提取查询: {query}")
|
||
print(f"[BUG] 调试模式: {debug_mode}")
|
||
|
||
extracted_data = run_extraction(query, debug_mode)
|
||
|
||
if extracted_data:
|
||
# 保存为JSON文件
|
||
output_dir = os.path.dirname(__file__)
|
||
output_file = os.path.join(output_dir, "output.json")
|
||
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
json.dump(extracted_data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"\n[OK] 数据提取完成!")
|
||
print(f"[FILE] 输出文件: {output_file}")
|
||
print(f"[INFO] 提取到 {len(extracted_data.get('pred_supporting_facts', []))} 个支撑段落")
|
||
print(f"[NOTE] 最终答案长度: {len(extracted_data.get('final_answer', ''))} 字符")
|
||
|
||
# 显示提取结果预览
|
||
print(f"\n[INFO] 提取结果预览:")
|
||
print(f" 查询: {extracted_data.get('query', '')}")
|
||
answer = extracted_data.get('final_answer', '')
|
||
print(f" 答案: {answer[:100]}{'...' if len(answer) > 100 else ''}")
|
||
|
||
supporting_facts = extracted_data.get('pred_supporting_facts', [])
|
||
print(f" 支撑段落 ({len(supporting_facts)}个):")
|
||
for i, fact in enumerate(supporting_facts[:3]): # 只显示前3个
|
||
if len(fact) >= 2:
|
||
print(f" {i+1}. {fact[0]}... (ID: {fact[1]})")
|
||
|
||
if len(supporting_facts) > 3:
|
||
print(f" ... 还有 {len(supporting_facts) - 3} 个段落")
|
||
|
||
else:
|
||
print("[ERROR] 数据提取失败")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |