1082 lines
47 KiB
Python
1082 lines
47 KiB
Python
"""
|
||
基于LangChain Elasticsearch集成的HippoRAG2检索器
|
||
严格按照原始HippoRAG2逻辑实现:
|
||
1. query2edge: 用户查询 → ES边索引向量检索 → LLM过滤 → 节点分数
|
||
2. query2passage: 用户查询 → ES段落索引向量检索 → 段落分数
|
||
3. 合并个性化字典 → PageRank传播 → 返回排序的文本节点
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import numpy as np
|
||
import networkx as nx
|
||
from typing import Dict, List, Tuple, Any, Optional
|
||
from langchain_core.retrievers import BaseRetriever
|
||
from langchain_core.documents import Document
|
||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||
from langchain_elasticsearch import ElasticsearchStore
|
||
from elasticsearch import Elasticsearch
|
||
from pydantic import Field
|
||
import json_repair
|
||
|
||
# LangSmith追踪支持
|
||
from langsmith import traceable
|
||
|
||
# 添加项目路径
|
||
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
|
||
sys.path.append(project_root)
|
||
|
||
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
|
||
from retriver.langgraph.dashscope_llm import DashScopeLLM
|
||
from elasticsearch_vectorization.config import ElasticsearchConfig
|
||
|
||
|
||
def min_max_normalize(x: np.ndarray) -> np.ndarray:
|
||
"""最小-最大标准化"""
|
||
if np.max(x) == np.min(x):
|
||
return np.ones_like(x)
|
||
return (x - np.min(x)) / (np.max(x) - np.min(x))
|
||
|
||
|
||
class InferenceConfig:
|
||
"""推理配置类"""
|
||
def __init__(self,
|
||
topk_edges: int = 10,
|
||
weight_adjust: float = 0.05,
|
||
ppr_alpha: float = 0.95,
|
||
ppr_max_iter: int = 100,
|
||
ppr_tol: float = 1e-6,
|
||
top_k_events: int = 10, # 新增:事件节点数量
|
||
top_k_passages: int = 3, # 新增:段落节点数量
|
||
global_topk: int = 20): # 新增:全局段落检索数量
|
||
self.topk_edges = topk_edges
|
||
self.weight_adjust = weight_adjust
|
||
self.ppr_alpha = ppr_alpha
|
||
self.ppr_max_iter = ppr_max_iter
|
||
self.ppr_tol = ppr_tol
|
||
self.top_k_events = top_k_events # 新增
|
||
self.top_k_passages = top_k_passages # 新增
|
||
self.global_topk = global_topk # 新增
|
||
|
||
|
||
class LangChainHippoRAGRetriever(BaseRetriever):
|
||
"""
|
||
基于LangChain ES集成的HippoRAG检索器
|
||
使用原生ElasticsearchStore替代自定义HTTP请求
|
||
"""
|
||
|
||
# Pydantic字段
|
||
top_k: int = Field(default=13, description="返回的文档数量(10个事件+3个段落)")
|
||
keyword: str = Field(description="关键词,用于索引名称")
|
||
llm_generator: Any = Field(description="LLM生成器", exclude=True)
|
||
embedding_model: Any = Field(description="嵌入模型", exclude=True)
|
||
inference_config: Any = Field(description="推理配置", exclude=True)
|
||
graph_data: Any = Field(default=None, description="图数据", exclude=True)
|
||
es_client: Any = Field(default=None, description="ES客户端", exclude=True)
|
||
edges_index: str = Field(default="", description="边索引名称")
|
||
passages_index: str = Field(default="", description="段落索引名称")
|
||
|
||
def __init__(
|
||
self,
|
||
llm_generator: DashScopeLLM,
|
||
embedding_model: DashScopeEmbeddingModel,
|
||
keyword: str,
|
||
top_k: int = 13,
|
||
inference_config: Optional[InferenceConfig] = None,
|
||
**kwargs
|
||
):
|
||
"""初始化LangChain HippoRAG检索器"""
|
||
|
||
# 初始化配置
|
||
if inference_config is None:
|
||
inference_config = InferenceConfig()
|
||
|
||
# 初始化Elasticsearch客户端
|
||
es_client = Elasticsearch(
|
||
ElasticsearchConfig.ES_HOST,
|
||
basic_auth=(ElasticsearchConfig.ES_USERNAME, ElasticsearchConfig.ES_PASSWORD),
|
||
timeout=30,
|
||
max_retries=3,
|
||
retry_on_timeout=True
|
||
)
|
||
|
||
# 生成索引名称
|
||
edges_index = ElasticsearchConfig.get_edges_index_name(keyword)
|
||
passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
|
||
|
||
# 调用父类初始化
|
||
super().__init__(
|
||
top_k=top_k,
|
||
keyword=keyword,
|
||
llm_generator=llm_generator,
|
||
embedding_model=embedding_model,
|
||
inference_config=inference_config,
|
||
edges_index=edges_index,
|
||
passages_index=passages_index,
|
||
graph_data=None, # 将在_load_graph_data中设置
|
||
es_client=es_client,
|
||
**kwargs
|
||
)
|
||
|
||
# 加载图数据
|
||
self._load_graph_data()
|
||
|
||
print(f"LangChain HippoRAG检索器初始化完成")
|
||
print(f"边索引: {edges_index}")
|
||
print(f"段落索引: {passages_index}")
|
||
|
||
def _load_graph_data(self):
|
||
"""加载包含概念节点的图数据"""
|
||
import pickle
|
||
graph_path = os.path.join(
|
||
os.path.dirname(__file__), '..', '..',
|
||
f'{self.keyword}_with_concept.pkl'
|
||
)
|
||
|
||
if os.path.exists(graph_path):
|
||
with open(graph_path, 'rb') as f:
|
||
self.graph_data = pickle.load(f)
|
||
print(f"[OK] 加载包含概念的图数据: {len(self.graph_data.nodes())} 个节点")
|
||
|
||
# 统计节点类型
|
||
node_types = {}
|
||
for node_id, attrs in self.graph_data.nodes(data=True):
|
||
node_type = attrs.get('type', 'unknown')
|
||
node_types[node_type] = node_types.get(node_type, 0) + 1
|
||
print(f" 节点类型统计: {node_types}")
|
||
|
||
else:
|
||
print(f"[ERROR] 包含概念的图数据文件不存在: {graph_path}")
|
||
self.graph_data = None
|
||
|
||
|
||
@traceable(name="Query2Edge_Retrieval")
|
||
def _query2edge(self, query: str, topN: int = 10) -> Dict[str, float]:
|
||
"""查询到边的检索(使用原生ES客户端,按原始逻辑)"""
|
||
# 记录输入参数到LangSmith
|
||
print(f"[SEARCH] Query2Edge开始: {query}, topN={topN}")
|
||
|
||
# 1. 查询向量化
|
||
query_emb = self.embedding_model.encode([query], query_type="edge")[0]
|
||
print(f"[SEARCH] DEBUG: 查询向量维度: {query_emb.shape}")
|
||
|
||
# 2. 使用原生ES查询相似边
|
||
try:
|
||
# 检查索引是否存在
|
||
if not self.es_client.indices.exists(index=self.edges_index):
|
||
print(f"[WARNING] 边索引不存在: {self.edges_index}")
|
||
return {}
|
||
print(f"[OK] DEBUG: 边索引存在: {self.edges_index}")
|
||
|
||
# ES script_score查询
|
||
query_body = {
|
||
"size": topN * 2,
|
||
"query": {
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||
"params": {
|
||
"query_vector": query_emb.tolist()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = self.es_client.search(index=self.edges_index, body=query_body)
|
||
hits = result.get("hits", {}).get("hits", [])
|
||
print(f"[SEARCH] DEBUG: ES返回边数量: {len(hits)}")
|
||
|
||
# 记录ES检索详情到LangSmith
|
||
es_search_info = {
|
||
"query": query,
|
||
"edges_index": self.edges_index,
|
||
"topN_requested": topN,
|
||
"es_hits_returned": len(hits),
|
||
"query_vector_dimension": query_emb.shape[0],
|
||
"top_edges_preview": []
|
||
}
|
||
|
||
# 3. 构建边列表用于LLM过滤
|
||
before_filter_edge_json = {"fact": []}
|
||
valid_edges = []
|
||
scores = []
|
||
|
||
print(f"[SEARCH] DEBUG: 图中节点数量: {len(self.graph_data.nodes()) if self.graph_data else 0}")
|
||
|
||
for i, hit in enumerate(hits):
|
||
source = hit["_source"]
|
||
try:
|
||
head_node_id = source['head_node_id']
|
||
tail_node_id = source['tail_node_id']
|
||
score = hit.get('_score', 0.5)
|
||
|
||
if i < 3: # 只打印前3个边的详细信息
|
||
print(f"[SEARCH] DEBUG: 边{i+1} (分数:{score:.3f}): {head_node_id} -> {tail_node_id}")
|
||
head_exists = head_node_id in self.graph_data.nodes if self.graph_data else False
|
||
tail_exists = tail_node_id in self.graph_data.nodes if self.graph_data else False
|
||
print(f" 头节点存在: {head_exists}, 尾节点存在: {tail_exists}")
|
||
|
||
# 添加到LangSmith监控信息
|
||
es_search_info["top_edges_preview"].append({
|
||
"rank": i + 1,
|
||
"score": float(score),
|
||
"head_entity": source.get('head_entity', ''),
|
||
"relation": source.get('relation', ''),
|
||
"tail_entity": source.get('tail_entity', ''),
|
||
"head_node_id": head_node_id,
|
||
"tail_node_id": tail_node_id,
|
||
"head_exists_in_graph": head_exists,
|
||
"tail_exists_in_graph": tail_exists
|
||
})
|
||
|
||
# 检查边的节点是否存在于图中
|
||
if self.graph_data and head_node_id in self.graph_data.nodes and tail_node_id in self.graph_data.nodes:
|
||
edge_str = [source['head_entity'], source['relation'], source['tail_entity']]
|
||
before_filter_edge_json['fact'].append(edge_str)
|
||
valid_edges.append(((head_node_id, tail_node_id), source.get('edge_index', 0)))
|
||
scores.append(hit.get('_score', 0.5))
|
||
except KeyError as e:
|
||
if i < 3:
|
||
print(f"[SEARCH] DEBUG: 边{i+1} KeyError: {e}")
|
||
continue
|
||
|
||
print(f"[SEARCH] DEBUG: 有效边数量: {len(valid_edges)}")
|
||
|
||
# 4. 如果没有有效边,返回空字典
|
||
if len(before_filter_edge_json['fact']) == 0:
|
||
print("[WARNING] DEBUG: 没有有效边,返回空字典")
|
||
return {}
|
||
|
||
# 5. 使用LLM过滤边(提取为独立的可追踪方法)
|
||
filtered_facts = self._llm_filter_edges(query, before_filter_edge_json)
|
||
|
||
if len(filtered_facts) == 0:
|
||
return {}
|
||
|
||
# 6. 构建节点分数字典
|
||
node_score_dict = {}
|
||
for edge_fact in filtered_facts:
|
||
for i, (edge, idx) in enumerate(valid_edges):
|
||
try:
|
||
head_id = self.graph_data.nodes[edge[0]].get('id', edge[0]) if self.graph_data else edge[0]
|
||
tail_id = self.graph_data.nodes[edge[1]].get('id', edge[1]) if self.graph_data else edge[1]
|
||
relation = before_filter_edge_json['fact'][i][1] if i < len(before_filter_edge_json['fact']) else 'related_to'
|
||
|
||
if head_id == edge_fact[0] and relation == edge_fact[1] and tail_id == edge_fact[2]:
|
||
head, tail = edge[0], edge[1]
|
||
sim_score = scores[i] if i < len(scores) else 0.5
|
||
|
||
# 更新节点分数
|
||
if head not in node_score_dict:
|
||
node_score_dict[head] = sim_score
|
||
else:
|
||
node_score_dict[head] = max(node_score_dict[head], sim_score)
|
||
|
||
if tail not in node_score_dict:
|
||
node_score_dict[tail] = sim_score
|
||
else:
|
||
node_score_dict[tail] = max(node_score_dict[tail], sim_score)
|
||
break
|
||
except (KeyError, IndexError):
|
||
continue
|
||
|
||
# 记录最终结果到LangSmith
|
||
final_result = {
|
||
**es_search_info,
|
||
"llm_filter_applied": True,
|
||
"edges_before_llm_filter": len(before_filter_edge_json['fact']),
|
||
"edges_after_llm_filter": len(filtered_facts),
|
||
"llm_filter_ratio": len(filtered_facts) / len(before_filter_edge_json['fact']) if before_filter_edge_json['fact'] else 0,
|
||
"final_node_scores": len(node_score_dict),
|
||
"node_score_preview": dict(list(node_score_dict.items())[:5]) if node_score_dict else {},
|
||
"edge_to_node_mapping_success": len(node_score_dict) > 0,
|
||
"valid_edges_found": len(valid_edges),
|
||
"filtered_facts_sample": filtered_facts[:3] if filtered_facts else []
|
||
}
|
||
print(f"[SEARCH] Query2Edge结果记录到LangSmith: {len(node_score_dict)}个节点获得分数")
|
||
|
||
return node_score_dict
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 边检索失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {}
|
||
|
||
@traceable(name="LLM_Edge_Filter")
|
||
def _llm_filter_edges(self, query: str, before_filter_edge_json: Dict) -> List:
|
||
"""
|
||
使用LLM过滤边的三元组,并提供详细监控
|
||
|
||
Args:
|
||
query: 用户查询
|
||
before_filter_edge_json: 过滤前的边JSON数据
|
||
|
||
Returns:
|
||
过滤后的事实列表
|
||
"""
|
||
print(f"[?] LLM边过滤开始: {len(before_filter_edge_json['fact'])}个边待过滤")
|
||
|
||
# 检查是否跳过边过滤
|
||
from prompt_loader import get_prompt_loader
|
||
prompt_loader = get_prompt_loader()
|
||
if prompt_loader.should_skip_llm('edge_filter'):
|
||
print(f"[NEXT] 跳过边过滤LLM调用,使用所有边")
|
||
# 不过滤,返回所有边
|
||
return before_filter_edge_json['fact']
|
||
|
||
# 记录过滤前的详细信息
|
||
filter_input_info = {
|
||
"query": query,
|
||
"edges_before_filter": len(before_filter_edge_json['fact']),
|
||
"sample_edges_before_filter": before_filter_edge_json['fact'][:5], # 前5个边的预览
|
||
"all_edges_before_filter": before_filter_edge_json['fact'] # 完整的边列表
|
||
}
|
||
|
||
try:
|
||
# 调用LLM进行过滤
|
||
llm_response = self.llm_generator.filter_triples_with_entity_event(
|
||
query,
|
||
json.dumps(before_filter_edge_json, ensure_ascii=False)
|
||
)
|
||
|
||
# 记录LLM的原始响应
|
||
print(f"[?] DEBUG: LLM原始响应: {llm_response[:200]}...")
|
||
|
||
# 解析LLM响应
|
||
try:
|
||
filtered_facts = json_repair.loads(llm_response)['fact']
|
||
parse_success = True
|
||
parse_error = None
|
||
except Exception as parse_e:
|
||
print(f"[?] DEBUG: JSON解析失败: {parse_e}")
|
||
filtered_facts = []
|
||
parse_success = False
|
||
parse_error = str(parse_e)
|
||
|
||
# 记录过滤结果的详细信息
|
||
filter_result_info = {
|
||
**filter_input_info,
|
||
"llm_raw_response": llm_response,
|
||
"llm_response_length": len(llm_response),
|
||
"json_parse_success": parse_success,
|
||
"json_parse_error": parse_error,
|
||
"edges_after_filter": len(filtered_facts),
|
||
"filtered_facts": filtered_facts,
|
||
"filter_ratio": len(filtered_facts) / len(before_filter_edge_json['fact']) if before_filter_edge_json['fact'] else 0,
|
||
"filtered_out_count": len(before_filter_edge_json['fact']) - len(filtered_facts),
|
||
"sample_filtered_facts": filtered_facts[:5] if filtered_facts else []
|
||
}
|
||
|
||
print(f"[?] LLM过滤完成: {len(before_filter_edge_json['fact'])} -> {len(filtered_facts)} 个边")
|
||
print(f"[?] 过滤比例: {filter_result_info['filter_ratio']:.2%}")
|
||
|
||
return filtered_facts
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] LLM边过滤失败: {e}")
|
||
# 记录错误信息到LangSmith
|
||
error_info = {
|
||
**filter_input_info,
|
||
"llm_filter_error": str(e),
|
||
"edges_after_filter": 0,
|
||
"filtered_facts": [],
|
||
"filter_success": False
|
||
}
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
@traceable(name="Query2Passage_Retrieval")
|
||
def _query2passage(self, query: str, weight_adjust: float = 0.05) -> Dict[str, float]:
|
||
"""查询到段落的检索(使用原生ES客户端,按原始逻辑)"""
|
||
# 记录输入参数到LangSmith
|
||
print(f"[SEARCH] Query2Passage开始: {query}, weight_adjust={weight_adjust}")
|
||
|
||
# 1. 查询向量化
|
||
query_emb = self.embedding_model.encode([query], query_type="passage")[0]
|
||
print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}")
|
||
|
||
# 2. 使用原生ES查询相似段落
|
||
try:
|
||
# 检查索引是否存在
|
||
if not self.es_client.indices.exists(index=self.passages_index):
|
||
print(f"[WARNING] 段落索引不存在: {self.passages_index}")
|
||
return {}
|
||
|
||
# ES script_score查询
|
||
query_body = {
|
||
"size": 1000, # 获取足够多的段落
|
||
"query": {
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||
"params": {
|
||
"query_vector": query_emb.tolist()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = self.es_client.search(index=self.passages_index, body=query_body)
|
||
hits = result.get("hits", {}).get("hits", [])
|
||
print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)}")
|
||
|
||
# 记录ES检索详情到LangSmith
|
||
es_search_info = {
|
||
"query": query,
|
||
"passages_index": self.passages_index,
|
||
"weight_adjust": weight_adjust,
|
||
"es_hits_returned": len(hits),
|
||
"query_vector_dimension": query_emb.shape[0],
|
||
"top_passages_preview": []
|
||
}
|
||
|
||
# 3. 构建段落分数字典
|
||
passage_scores = {}
|
||
for i, hit in enumerate(hits):
|
||
source = hit["_source"]
|
||
passage_id = source.get('passage_id') or source.get('node_id')
|
||
if passage_id:
|
||
original_score = hit.get('_score', 0.5)
|
||
adjusted_score = original_score * weight_adjust # 应用权重调整
|
||
passage_scores[passage_id] = adjusted_score
|
||
|
||
# 收集前几个段落的预览信息
|
||
if i < 5: # 收集前5个段落的详细信息
|
||
es_search_info["top_passages_preview"].append({
|
||
"rank": i + 1,
|
||
"passage_id": passage_id,
|
||
"original_score": float(original_score),
|
||
"adjusted_score": float(adjusted_score),
|
||
"content_preview": source.get("content", source.get("text", ""))[:100] + "..." if source.get("content") or source.get("text") else "内容为空"
|
||
})
|
||
|
||
# 记录最终结果到LangSmith
|
||
final_result = {
|
||
**es_search_info,
|
||
"final_passage_scores": len(passage_scores),
|
||
"passage_score_preview": dict(list(passage_scores.items())[:5]) if passage_scores else {},
|
||
"total_score_sum": sum(passage_scores.values()) if passage_scores else 0
|
||
}
|
||
print(f"[SEARCH] Query2Passage结果记录到LangSmith: {len(passage_scores)}个段落获得分数")
|
||
|
||
return passage_scores
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 段落检索失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {}
|
||
|
||
def _find_connected_text_nodes(self, entity_node_ids: List[str]) -> List[str]:
|
||
"""
|
||
根据实体/事件/概念节点ID查找连接的Text节点
|
||
|
||
Args:
|
||
entity_node_ids: 实体节点ID列表
|
||
|
||
Returns:
|
||
连接的Text节点ID列表
|
||
"""
|
||
text_node_ids = set()
|
||
|
||
if not self.graph_data or not entity_node_ids:
|
||
return []
|
||
|
||
print(f"[SEARCH] 查找 {len(entity_node_ids)} 个节点的连接Text节点")
|
||
|
||
for entity_id in entity_node_ids:
|
||
if entity_id in self.graph_data.nodes:
|
||
neighbors = list(self.graph_data.neighbors(entity_id))
|
||
for neighbor_id in neighbors:
|
||
neighbor_data = self.graph_data.nodes[neighbor_id]
|
||
if neighbor_data.get('type') == 'text':
|
||
text_node_ids.add(neighbor_id)
|
||
|
||
result = list(text_node_ids)
|
||
print(f"[SEARCH] 找到 {len(result)} 个相关的Text节点")
|
||
return result
|
||
|
||
def _query2passage_filtered(self, query: str, entity_node_ids: List[str], weight_adjust: float = 0.05) -> Dict[str, float]:
|
||
"""
|
||
基于LLM过滤后的实体节点进行段落检索
|
||
只对与这些实体节点相关的Text节点进行向量相似度计算
|
||
|
||
Args:
|
||
query: 查询文本
|
||
entity_node_ids: LLM过滤后的实体节点ID列表
|
||
weight_adjust: 权重调整因子
|
||
|
||
Returns:
|
||
段落分数字典
|
||
"""
|
||
print(f"[SEARCH] Query2Passage(Filtered)开始: {query}, 基于{len(entity_node_ids)}个节点")
|
||
print(f"[PARAM] weight_adjust = {weight_adjust}")
|
||
|
||
# 1. 查找相关的Text节点
|
||
text_node_ids = self._find_connected_text_nodes(entity_node_ids)
|
||
|
||
if not text_node_ids:
|
||
print("[WARNING] 没有找到相关的Text节点")
|
||
return {}
|
||
|
||
# 2. 查询向量化
|
||
query_emb = self.embedding_model.encode([query], query_type="passage")[0]
|
||
print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}")
|
||
|
||
# 3. 使用原生ES查询相关Text节点的段落
|
||
try:
|
||
# 检查索引是否存在
|
||
if not self.es_client.indices.exists(index=self.passages_index):
|
||
print(f"[WARNING] 段落索引不存在: {self.passages_index}")
|
||
return {}
|
||
|
||
# 构建过滤查询:只搜索指定的text_id
|
||
query_body = {
|
||
"size": min(1000, len(text_node_ids)), # 不超过text节点数量
|
||
"query": {
|
||
"script_score": {
|
||
"query": {
|
||
"terms": {
|
||
"passage_id": text_node_ids # 只查询指定的text节点
|
||
}
|
||
},
|
||
"script": {
|
||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||
"params": {
|
||
"query_vector": query_emb.tolist()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = self.es_client.search(index=self.passages_index, body=query_body)
|
||
hits = result.get("hits", {}).get("hits", [])
|
||
print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)} (从{len(text_node_ids)}个候选Text节点中筛选)")
|
||
|
||
# 4. 构建段落分数字典
|
||
passage_scores = {}
|
||
for i, hit in enumerate(hits):
|
||
source = hit["_source"]
|
||
passage_id = source.get('passage_id') or source.get('node_id')
|
||
if passage_id:
|
||
original_score = hit.get('_score', 0.5)
|
||
adjusted_score = original_score * weight_adjust
|
||
passage_scores[passage_id] = adjusted_score
|
||
if i < 3: # 打印前3个段落的分数调整细节
|
||
print(f" [SCORE-F] 段落{i+1}: 原始分数={original_score:.4f} * weight_adjust={weight_adjust:.4f} = 调整后分数={adjusted_score:.4f}")
|
||
|
||
print(f"[SEARCH] Query2Passage(Filtered)完成: {len(passage_scores)}个段落获得分数")
|
||
return passage_scores
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 过滤段落检索失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {}
|
||
|
||
def _query2passage_global_topk(self, query: str, topk: int = 20, weight_adjust: float = 0.05) -> Dict[str, float]:
|
||
"""
|
||
全局Top-K段落检索
|
||
从所有段落中返回最相似的Top-K个段落
|
||
|
||
Args:
|
||
query: 查询文本
|
||
topk: 返回的段落数量
|
||
weight_adjust: 权重调整因子
|
||
|
||
Returns:
|
||
段落分数字典
|
||
"""
|
||
print(f"[SEARCH] Query2Passage(Global-Top{topk})开始: {query}")
|
||
print(f"[PARAM] weight_adjust = {weight_adjust}")
|
||
|
||
# 查询向量化
|
||
query_emb = self.embedding_model.encode([query], query_type="passage")[0]
|
||
print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}")
|
||
|
||
try:
|
||
# 检查索引是否存在
|
||
if not self.es_client.indices.exists(index=self.passages_index):
|
||
print(f"[WARNING] 段落索引不存在: {self.passages_index}")
|
||
return {}
|
||
|
||
# ES script_score查询,限制返回数量
|
||
query_body = {
|
||
"size": topk, # 只返回Top-K个结果
|
||
"query": {
|
||
"script_score": {
|
||
"query": {"match_all": {}},
|
||
"script": {
|
||
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
||
"params": {
|
||
"query_vector": query_emb.tolist()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
result = self.es_client.search(index=self.passages_index, body=query_body)
|
||
hits = result.get("hits", {}).get("hits", [])
|
||
print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)} (Global Top-{topk})")
|
||
|
||
# 构建段落分数字典
|
||
passage_scores = {}
|
||
for i, hit in enumerate(hits):
|
||
source = hit["_source"]
|
||
passage_id = source.get('passage_id') or source.get('node_id')
|
||
if passage_id:
|
||
original_score = hit.get('_score', 0.5)
|
||
adjusted_score = original_score * weight_adjust
|
||
passage_scores[passage_id] = adjusted_score
|
||
if i < 3: # 打印前3个段落的分数调整细节
|
||
print(f" [SCORE-G] 段落{i+1}: 原始分数={original_score:.4f} * weight_adjust={weight_adjust:.4f} = 调整后分数={adjusted_score:.4f}")
|
||
|
||
print(f"[SEARCH] Query2Passage(Global-Top{topk})完成: {len(passage_scores)}个段落获得分数")
|
||
return passage_scores
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 全局Top-K段落检索失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {}
|
||
|
||
def _query2passage_hybrid(self, query: str, entity_node_ids: List[str], weight_adjust: float = 0.05, global_topk: int = 20) -> Dict[str, float]:
|
||
"""
|
||
混合检索策略:过滤版 + 全局Top-K
|
||
确保既有高效性又不遗失高质量段落
|
||
|
||
Args:
|
||
query: 查询文本
|
||
entity_node_ids: LLM过滤后的实体节点ID列表
|
||
weight_adjust: 权重调整因子
|
||
global_topk: 全局检索的Top-K数量
|
||
|
||
Returns:
|
||
合并后的段落分数字典
|
||
"""
|
||
print(f"[SEARCH] Query2Passage(Hybrid)开始: 过滤版+全局Top-{global_topk}")
|
||
print(f"[PARAM] weight_adjust = {weight_adjust} (用于调整段落相似度分数)")
|
||
|
||
# 1. 过滤版检索
|
||
filtered_scores = self._query2passage_filtered(query, entity_node_ids, weight_adjust)
|
||
print(f" [TARGET] 过滤版获得: {len(filtered_scores)}个段落")
|
||
|
||
# 2. 全局Top-K检索
|
||
global_scores = self._query2passage_global_topk(query, topk=global_topk, weight_adjust=weight_adjust)
|
||
print(f" [?] 全局检索获得: {len(global_scores)}个段落")
|
||
|
||
# 3. 合并结果,过滤版优先(如果有重复,保持过滤版的分数)
|
||
combined_scores = dict(filtered_scores) # 先添加过滤版结果
|
||
|
||
# 添加全局检索中不重复的结果
|
||
added_count = 0
|
||
for passage_id, score in global_scores.items():
|
||
if passage_id not in combined_scores:
|
||
combined_scores[passage_id] = score
|
||
added_count += 1
|
||
|
||
print(f" [?] 合并后总计: {len(combined_scores)}个段落 (新增{added_count}个)")
|
||
print(f"[SEARCH] Query2Passage(Hybrid)完成")
|
||
|
||
return combined_scores
|
||
|
||
@traceable(name="Personalized_PageRank")
|
||
def _personalized_pagerank(self, personalization_dict: Dict[str, float], alpha: float = 0.85) -> Dict[str, float]:
|
||
"""个性化PageRank计算(按原始逻辑)"""
|
||
# 记录输入参数到LangSmith
|
||
non_zero_nodes = sum(1 for v in personalization_dict.values() if v > 0)
|
||
print(f"[SEARCH] PageRank开始: {non_zero_nodes}/{len(personalization_dict)}个节点有非零分数, alpha={alpha}")
|
||
|
||
if not personalization_dict or not self.graph_data.nodes():
|
||
print("[WARNING] PageRank输入为空,返回空结果")
|
||
return {}
|
||
|
||
# 记录PageRank输入详情到LangSmith
|
||
pagerank_input_info = {
|
||
"total_nodes_in_graph": len(self.graph_data.nodes()),
|
||
"total_edges_in_graph": len(self.graph_data.edges()),
|
||
"personalization_nodes": len(personalization_dict),
|
||
"non_zero_personalization_nodes": non_zero_nodes,
|
||
"alpha": alpha,
|
||
"max_iter": self.inference_config.ppr_max_iter,
|
||
"tolerance": self.inference_config.ppr_tol,
|
||
"top_personalization_scores": dict(sorted(personalization_dict.items(), key=lambda x: x[1], reverse=True)[:10])
|
||
}
|
||
|
||
try:
|
||
# 使用原始的个性化字典进行PageRank计算
|
||
ppr_scores = nx.pagerank(
|
||
self.graph_data,
|
||
alpha=alpha,
|
||
personalization=personalization_dict,
|
||
max_iter=self.inference_config.ppr_max_iter,
|
||
tol=self.inference_config.ppr_tol
|
||
)
|
||
|
||
# 记录PageRank结果到LangSmith
|
||
pagerank_result_info = {
|
||
**pagerank_input_info,
|
||
"pagerank_successful": True,
|
||
"total_nodes_with_ppr_scores": len(ppr_scores),
|
||
"top_ppr_scores": dict(sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||
"ppr_score_statistics": {
|
||
"max_score": max(ppr_scores.values()) if ppr_scores else 0,
|
||
"min_score": min(ppr_scores.values()) if ppr_scores else 0,
|
||
"mean_score": sum(ppr_scores.values()) / len(ppr_scores) if ppr_scores else 0,
|
||
"total_score_sum": sum(ppr_scores.values()) if ppr_scores else 0
|
||
}
|
||
}
|
||
print(f"[SEARCH] PageRank结果记录到LangSmith: {len(ppr_scores)}个节点获得PPR分数")
|
||
|
||
return ppr_scores
|
||
except Exception as e:
|
||
print(f"[ERROR] PageRank计算失败: {e}")
|
||
return {}
|
||
|
||
@traceable(name="HippoRAG_Complete_Retrieval")
|
||
def _get_relevant_documents(
|
||
self,
|
||
query: str,
|
||
*,
|
||
run_manager: CallbackManagerForRetrieverRun
|
||
) -> List[Document]:
|
||
"""检索相关文档(严格按原始HippoRAG2逻辑)"""
|
||
|
||
if not self.graph_data:
|
||
print("[WARNING] 图数据未加载")
|
||
return []
|
||
|
||
print(f"[SEARCH] HippoRAG2检索开始: {query}")
|
||
|
||
# 1. query2edge - 获取节点分数字典
|
||
node_dict = self._query2edge(query, topN=self.inference_config.topk_edges)
|
||
print(f"[INFO] 边检索得到 {len(node_dict)} 个节点")
|
||
|
||
# 2. query2passage - 混合检索策略(过滤版 + 全局Top-K)
|
||
entity_node_ids = list(node_dict.keys()) # 获取LLM过滤后的节点ID
|
||
print(f"[CONFIG] 使用配置参数: weight_adjust={self.inference_config.weight_adjust}, global_topk={self.inference_config.global_topk}")
|
||
text_dict = self._query2passage_hybrid(query, entity_node_ids, weight_adjust=self.inference_config.weight_adjust, global_topk=self.inference_config.global_topk)
|
||
print(f"[?] 段落检索得到 {len(text_dict)} 个段落")
|
||
|
||
# 3. 合并个性化字典
|
||
personalization_dict = {}
|
||
|
||
# 将所有节点初始化为0
|
||
for node in self.graph_data.nodes():
|
||
personalization_dict[node] = 0.0
|
||
|
||
# 添加边检索的节点分数
|
||
for node, score in node_dict.items():
|
||
if node in personalization_dict:
|
||
personalization_dict[node] = score
|
||
|
||
# 添加段落检索的文本节点分数
|
||
for text_id, score in text_dict.items():
|
||
if text_id in personalization_dict:
|
||
personalization_dict[text_id] = score
|
||
|
||
print(f"[TARGET] 个性化字典包含 {sum(1 for v in personalization_dict.values() if v > 0)} 个非零节点")
|
||
|
||
# 4. 个性化PageRank传播
|
||
ppr_scores = self._personalized_pagerank(personalization_dict, self.inference_config.ppr_alpha)
|
||
|
||
if not ppr_scores:
|
||
print("[WARNING] PageRank计算失败")
|
||
return []
|
||
|
||
# 5. 节点排序并分别筛选事件节点和段落节点
|
||
event_node_scores = []
|
||
text_node_scores = []
|
||
|
||
for node_id, ppr_score in ppr_scores.items():
|
||
if node_id in self.graph_data.nodes:
|
||
node_type = self.graph_data.nodes[node_id].get('type', '')
|
||
if node_type == 'event':
|
||
event_node_scores.append((node_id, ppr_score))
|
||
elif node_id in text_dict: # 确保是文本节点且在检索结果中
|
||
text_node_scores.append((node_id, ppr_score))
|
||
|
||
# 按PageRank分数排序
|
||
event_node_scores.sort(key=lambda x: x[1], reverse=True)
|
||
text_node_scores.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
# 6. 构建Document对象:选择TOP-K事件和TOP-K段落
|
||
documents = []
|
||
|
||
print(f"[DEBUG] 准备构建文档: {len(event_node_scores)}个事件节点候选, {len(text_node_scores)}个段落节点候选")
|
||
print(f"[DEBUG] 将选择TOP-{self.inference_config.top_k_events}事件 + TOP-{self.inference_config.top_k_passages}段落")
|
||
|
||
# 6.1 处理事件节点 (TOP-K)
|
||
for i, (node_id, ppr_score) in enumerate(event_node_scores[:self.inference_config.top_k_events]):
|
||
# 事件节点的内容通常存储在图数据中
|
||
try:
|
||
if node_id in self.graph_data.nodes:
|
||
node_data = self.graph_data.nodes[node_id]
|
||
# 获取事件内容,可能的字段包括:content, text, description, label等
|
||
content = (node_data.get('content') or
|
||
node_data.get('text') or
|
||
node_data.get('description') or
|
||
node_data.get('label') or
|
||
node_data.get('id') or
|
||
f"事件节点 {node_id}")
|
||
# 获取事件节点的source_text_id属性
|
||
source_text_id = node_data.get('source_text_id', '')
|
||
# 获取事件节点的node_id属性(图中存储的节点ID)
|
||
graph_node_id = node_data.get('node_id', node_id)
|
||
# 获取事件节点的evidence属性
|
||
evidence = node_data.get('evidence', '')
|
||
else:
|
||
content = f"事件节点 {node_id} 数据未找到"
|
||
source_text_id = ''
|
||
graph_node_id = node_id
|
||
evidence = ''
|
||
|
||
except Exception as e:
|
||
content = f"事件节点 {node_id} 内容获取失败: {str(e)}"
|
||
source_text_id = ''
|
||
graph_node_id = node_id
|
||
evidence = ''
|
||
|
||
doc = Document(
|
||
page_content=content,
|
||
metadata={
|
||
"node_id": node_id, # 检索时使用的节点ID
|
||
"graph_node_id": graph_node_id, # 图中存储的节点ID属性
|
||
"node_type": "event",
|
||
"ppr_score": ppr_score,
|
||
"edge_score": node_dict.get(node_id, 0.0),
|
||
"passage_score": text_dict.get(node_id, 0.0),
|
||
"rank": i + 1,
|
||
"source": "hipporag2_langchain_event",
|
||
"source_text_id": source_text_id, # 新增:事件节点的source_text_id属性
|
||
"evidence": evidence # 新增:事件节点的evidence属性
|
||
}
|
||
)
|
||
documents.append(doc)
|
||
|
||
# 6.2 处理段落节点 (TOP-K)
|
||
for i, (node_id, ppr_score) in enumerate(text_node_scores[:self.inference_config.top_k_passages]):
|
||
# 直接从图数据中获取段落内容
|
||
try:
|
||
if node_id in self.graph_data.nodes:
|
||
node_data = self.graph_data.nodes[node_id]
|
||
|
||
# 强化调试:打印节点的所有字段,帮助找到正确的文本字段
|
||
print(f"[DEBUG] ===== 段落节点调试信息 =====")
|
||
print(f"[DEBUG] 节点ID: {node_id}")
|
||
print(f"[DEBUG] 节点类型: {node_data.get('type', 'unknown')}")
|
||
print(f"[DEBUG] 所有字段名: {sorted(list(node_data.keys()))}")
|
||
|
||
# 打印可能包含文本的字段值(截取前100字符)
|
||
text_fields = ['original_text', 'text', 'content', 'passage_text', 'full_text', 'raw_text', 'description', 'label']
|
||
for field in text_fields:
|
||
if field in node_data:
|
||
field_value = str(node_data[field])
|
||
preview = field_value[:100] + "..." if len(field_value) > 100 else field_value
|
||
print(f"[DEBUG] {field}: {preview}")
|
||
|
||
print(f"[DEBUG] ===== 结束调试信息 =====")
|
||
|
||
# 尝试多个可能的文本字段
|
||
content = (node_data.get('original_text') or
|
||
node_data.get('text') or
|
||
node_data.get('content') or
|
||
node_data.get('passage_text') or
|
||
node_data.get('full_text') or
|
||
node_data.get('raw_text') or
|
||
node_data.get('description') or
|
||
node_data.get('label') or
|
||
f"段落 {node_id} 内容未找到")
|
||
|
||
# 如果仍然没找到内容,打印完整的节点数据
|
||
if content == f"段落 {node_id} 内容未找到":
|
||
print(f"[DEBUG] !!! 段落节点 {node_id} 完整数据: {dict(node_data)}")
|
||
print(f"[DEBUG] !!! 节点数据类型: {type(node_data)}")
|
||
|
||
# 获取段落节点的evidence属性
|
||
evidence = node_data.get('evidence', '')
|
||
else:
|
||
content = f"段落 {node_id} 数据未找到"
|
||
evidence = ''
|
||
|
||
except Exception as e:
|
||
content = f"段落 {node_id} 内容获取失败: {str(e)}"
|
||
evidence = ''
|
||
print(f"[DEBUG] 段落节点 {node_id} 获取异常: {str(e)}")
|
||
|
||
doc = Document(
|
||
page_content=content,
|
||
metadata={
|
||
"node_id": node_id,
|
||
"node_type": "text",
|
||
"ppr_score": ppr_score,
|
||
"edge_score": node_dict.get(node_id, 0.0),
|
||
"passage_score": text_dict.get(node_id, 0.0),
|
||
"rank": len(documents) + 1, # 继续排序
|
||
"source": "hipporag2_langchain_text",
|
||
"evidence": evidence # 新增:段落节点的evidence属性
|
||
}
|
||
)
|
||
documents.append(doc)
|
||
|
||
event_count = len(event_node_scores[:self.inference_config.top_k_events])
|
||
text_count = len(text_node_scores[:self.inference_config.top_k_passages])
|
||
print(f"[OK] HippoRAG2检索完成,返回 {event_count} 个事件节点 + {text_count} 个段落节点,共 {len(documents)} 个文档")
|
||
|
||
# 不在metadata中存储PageRank分数以避免LangSmith追踪大量数据
|
||
# doc.metadata["complete_ppr_scores"] = ppr_scores # 注释掉避免LangSmith传输
|
||
for doc in documents:
|
||
doc.metadata["query"] = query
|
||
# 只存储统计信息而不是完整的PageRank数据
|
||
doc.metadata["pagerank_available"] = len(ppr_scores) > 0
|
||
|
||
return documents
|
||
|
||
def get_complete_pagerank_scores(self, query: str) -> Dict[str, Any]:
|
||
"""
|
||
获取完整的PageRank分数信息(所有非零节点)
|
||
|
||
Args:
|
||
query: 查询字符串
|
||
|
||
Returns:
|
||
包含完整PageRank信息的字典
|
||
"""
|
||
if not self.graph_data:
|
||
print("[WARNING] 图数据未加载")
|
||
return {}
|
||
|
||
print(f"[SEARCH] 获取完整PageRank分数: {query}")
|
||
|
||
# 1. query2edge - 获取节点分数字典
|
||
node_dict = self._query2edge(query, topN=self.inference_config.topk_edges)
|
||
|
||
# 2. query2passage - 混合检索策略(过滤版 + 全局Top-K)
|
||
entity_node_ids = list(node_dict.keys()) # 获取LLM过滤后的节点ID
|
||
text_dict = self._query2passage_hybrid(query, entity_node_ids, weight_adjust=self.inference_config.weight_adjust, global_topk=self.inference_config.global_topk)
|
||
|
||
# 3. 合并个性化字典
|
||
personalization_dict = {}
|
||
|
||
# 将所有节点初始化为0
|
||
for node in self.graph_data.nodes():
|
||
personalization_dict[node] = 0.0
|
||
|
||
# 添加边检索的节点分数
|
||
for node, score in node_dict.items():
|
||
if node in personalization_dict:
|
||
personalization_dict[node] = score
|
||
|
||
# 添加段落检索的文本节点分数
|
||
for text_id, score in text_dict.items():
|
||
if text_id in personalization_dict:
|
||
personalization_dict[text_id] = score
|
||
|
||
# 4. 个性化PageRank传播
|
||
ppr_scores = self._personalized_pagerank(personalization_dict, self.inference_config.ppr_alpha)
|
||
|
||
if not ppr_scores:
|
||
print("[WARNING] PageRank计算失败")
|
||
return {}
|
||
|
||
# 5. 筛选非零分数的节点,按分数降序排列
|
||
non_zero_scores = [(node_id, score) for node_id, score in ppr_scores.items() if score > 0]
|
||
non_zero_scores.sort(key=lambda x: x[1], reverse=True)
|
||
|
||
# 6. 构建完整的分数信息
|
||
return {
|
||
"query": query,
|
||
"total_nodes_in_graph": len(self.graph_data.nodes()),
|
||
"total_edges_in_graph": len(self.graph_data.edges()),
|
||
"personalization_input": {
|
||
"node_dict_count": len(node_dict),
|
||
"text_dict_count": len(text_dict),
|
||
"non_zero_personalization_count": sum(1 for v in personalization_dict.values() if v > 0)
|
||
},
|
||
"pagerank_results": {
|
||
"total_nodes_with_scores": len(ppr_scores),
|
||
"non_zero_nodes_count": len(non_zero_scores),
|
||
"max_score": max(ppr_scores.values()) if ppr_scores else 0,
|
||
"min_score": min(ppr_scores.values()) if ppr_scores else 0,
|
||
"mean_score": sum(ppr_scores.values()) / len(ppr_scores) if ppr_scores else 0,
|
||
"all_scores": dict(ppr_scores), # 完整的所有节点分数
|
||
"non_zero_scores_sorted": non_zero_scores # 非零分数,降序排列
|
||
},
|
||
"input_scores": {
|
||
"node_dict": node_dict,
|
||
"text_dict": text_dict
|
||
}
|
||
}
|
||
|
||
|
||
def create_langchain_hipporag_retriever(
|
||
keyword: str,
|
||
top_k: int = 13,
|
||
inference_config: Optional[InferenceConfig] = None,
|
||
oneapi_key: Optional[str] = None,
|
||
oneapi_base_url: Optional[str] = None,
|
||
oneapi_model_gen: Optional[str] = None,
|
||
oneapi_model_embed: Optional[str] = None
|
||
) -> LangChainHippoRAGRetriever:
|
||
"""
|
||
创建LangChain HippoRAG检索器的便捷函数
|
||
|
||
Args:
|
||
keyword: 关键词
|
||
top_k: 返回文档数量
|
||
inference_config: 推理配置
|
||
oneapi_key: OneAPI密钥
|
||
oneapi_base_url: OneAPI基础URL
|
||
oneapi_model_gen: 生成模型名称
|
||
oneapi_model_embed: 嵌入模型名称
|
||
|
||
Returns:
|
||
LangChainHippoRAGRetriever实例
|
||
"""
|
||
|
||
# 初始化模型
|
||
llm_generator = DashScopeLLM(
|
||
api_key=oneapi_key,
|
||
model_name=oneapi_model_gen
|
||
)
|
||
|
||
embedding_model = DashScopeEmbeddingModel(
|
||
api_key=oneapi_key,
|
||
model_name=oneapi_model_embed
|
||
)
|
||
|
||
return LangChainHippoRAGRetriever(
|
||
llm_generator=llm_generator,
|
||
embedding_model=embedding_model,
|
||
keyword=keyword,
|
||
top_k=top_k,
|
||
inference_config=inference_config
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
retriever = create_langchain_hipporag_retriever(
|
||
keyword="test",
|
||
top_k=13
|
||
)
|
||
|
||
# 测试检索
|
||
results = retriever.get_relevant_documents("什么是供应链风险管理?")
|
||
|
||
print(f"\n[TARGET] 检索结果:")
|
||
for i, doc in enumerate(results, 1):
|
||
print(f"{i}. {doc.metadata['node_id']} (PPR: {doc.metadata['ppr_score']:.4f})")
|
||
print(f" {doc.page_content[:100]}...") |