first commit
This commit is contained in:
161
AIEC-RAG/retriver/langsmith/extract_data.py
Normal file
161
AIEC-RAG/retriver/langsmith/extract_data.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""
|
||||
数据提取脚本
|
||||
从langsmith检索结果中提取关键数据并生成JSON文件
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
# 添加路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
||||
sys.path.append(project_root)
|
||||
|
||||
from retriver.langsmith.langsmith_retriever import create_langsmith_retriever, check_langsmith_connection
|
||||
|
||||
|
||||
def extract_supporting_facts(result: Dict[str, Any]) -> List[List[str]]:
|
||||
"""
|
||||
提取支撑段落信息
|
||||
返回格式: [["前30个字符", "段落ID"], ...]
|
||||
"""
|
||||
supporting_facts = []
|
||||
|
||||
# 从all_documents中提取文档
|
||||
all_documents = result.get('all_documents', [])
|
||||
|
||||
for doc in all_documents:
|
||||
if hasattr(doc, 'page_content') and hasattr(doc, 'metadata'):
|
||||
# 提取前30个字符
|
||||
content_preview = doc.page_content[:30] if doc.page_content else ""
|
||||
|
||||
# 提取文档ID(尝试不同的metadata字段)
|
||||
doc_id = (doc.metadata.get('id') or
|
||||
doc.metadata.get('document_id') or
|
||||
doc.metadata.get('chunk_id') or
|
||||
doc.metadata.get('source') or
|
||||
f"doc_{hash(doc.page_content) % 100000}")
|
||||
|
||||
supporting_facts.append([content_preview, str(doc_id)])
|
||||
|
||||
# 如果没有从documents中获取到,尝试从passages中获取
|
||||
if not supporting_facts:
|
||||
all_passages = result.get('all_passages', [])
|
||||
for i, passage in enumerate(all_passages):
|
||||
content_preview = passage[:30] if passage else ""
|
||||
doc_id = f"passage_{i}"
|
||||
supporting_facts.append([content_preview, doc_id])
|
||||
|
||||
return supporting_facts
|
||||
|
||||
|
||||
def extract_final_answer(result: Dict[str, Any]) -> str:
|
||||
"""
|
||||
提取最终生成的答案
|
||||
"""
|
||||
return result.get('answer', result.get('final_answer', ''))
|
||||
|
||||
|
||||
def run_extraction(query: str = "DATAOPS是如何与中国通信标准化协会的大会建立联系的?",
|
||||
debug_mode: str = "complex") -> Dict[str, Any]:
|
||||
"""
|
||||
运行检索并提取数据
|
||||
"""
|
||||
try:
|
||||
# 生成带时间戳的项目名称
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
project_name = f"{timestamp}-extraction"
|
||||
|
||||
print(f"[CONFIG] 创建检索器...")
|
||||
retriever = create_langsmith_retriever(
|
||||
keyword="test",
|
||||
top_k=3,
|
||||
max_iterations=2,
|
||||
max_parallel_retrievals=2,
|
||||
langsmith_project=project_name
|
||||
)
|
||||
|
||||
print(f"[TARGET] 执行检索: {query}")
|
||||
result = retriever.retrieve(query, debug_mode)
|
||||
|
||||
# 提取数据
|
||||
final_answer = extract_final_answer(result)
|
||||
supporting_facts = extract_supporting_facts(result)
|
||||
|
||||
# 构造输出数据
|
||||
output_data = {
|
||||
"query": query,
|
||||
"final_answer": final_answer,
|
||||
"pred_supporting_facts": supporting_facts,
|
||||
"gold_supporting_facts": supporting_facts, # 按您的要求复制相同内容
|
||||
"extraction_timestamp": timestamp,
|
||||
"langsmith_project": project_name
|
||||
}
|
||||
|
||||
return output_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 提取失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("[STARTING] 数据提取脚本启动")
|
||||
print("="*50)
|
||||
|
||||
# 检查LangSmith连接
|
||||
print("[INFO] 检查LangSmith连接...")
|
||||
langsmith_ok = check_langsmith_connection()
|
||||
if not langsmith_ok:
|
||||
print("[WARNING] LangSmith连接失败,但提取仍可继续")
|
||||
else:
|
||||
print("[OK] LangSmith连接正常")
|
||||
|
||||
# 运行数据提取
|
||||
query = "DATAOPS是如何与中国通信标准化协会的大会建立联系的?"
|
||||
debug_mode = "complex"
|
||||
|
||||
print(f"\n[INFO] 提取查询: {query}")
|
||||
print(f"[BUG] 调试模式: {debug_mode}")
|
||||
|
||||
extracted_data = run_extraction(query, debug_mode)
|
||||
|
||||
if extracted_data:
|
||||
# 保存为JSON文件
|
||||
output_dir = os.path.dirname(__file__)
|
||||
output_file = os.path.join(output_dir, "output.json")
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(extracted_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n[OK] 数据提取完成!")
|
||||
print(f"[FILE] 输出文件: {output_file}")
|
||||
print(f"[INFO] 提取到 {len(extracted_data.get('pred_supporting_facts', []))} 个支撑段落")
|
||||
print(f"[NOTE] 最终答案长度: {len(extracted_data.get('final_answer', ''))} 字符")
|
||||
|
||||
# 显示提取结果预览
|
||||
print(f"\n[INFO] 提取结果预览:")
|
||||
print(f" 查询: {extracted_data.get('query', '')}")
|
||||
answer = extracted_data.get('final_answer', '')
|
||||
print(f" 答案: {answer[:100]}{'...' if len(answer) > 100 else ''}")
|
||||
|
||||
supporting_facts = extracted_data.get('pred_supporting_facts', [])
|
||||
print(f" 支撑段落 ({len(supporting_facts)}个):")
|
||||
for i, fact in enumerate(supporting_facts[:3]): # 只显示前3个
|
||||
if len(fact) >= 2:
|
||||
print(f" {i+1}. {fact[0]}... (ID: {fact[1]})")
|
||||
|
||||
if len(supporting_facts) > 3:
|
||||
print(f" ... 还有 {len(supporting_facts) - 3} 个段落")
|
||||
|
||||
else:
|
||||
print("[ERROR] 数据提取失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user