Files
AIEC-RAG/retriver/langgraph/langchain_hipporag_retriever.py

1045 lines
45 KiB
Python
Raw Normal View History

2025-09-24 09:29:12 +08:00
"""
基于LangChain Elasticsearch集成的HippoRAG2检索器
严格按照原始HippoRAG2逻辑实现
1. query2edge: 用户查询 ES边索引向量检索 LLM过滤 节点分数
2. query2passage: 用户查询 ES段落索引向量检索 段落分数
3. 合并个性化字典 PageRank传播 返回排序的文本节点
"""
import os
import sys
import json
import numpy as np
import networkx as nx
from typing import Dict, List, Tuple, Any, Optional
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_elasticsearch import ElasticsearchStore
from elasticsearch import Elasticsearch
from pydantic import Field
import json_repair
# LangSmith追踪支持
from langsmith import traceable
# 添加项目路径
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.append(project_root)
from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel
from retriver.langgraph.dashscope_llm import DashScopeLLM
from elasticsearch_vectorization.config import ElasticsearchConfig
def min_max_normalize(x: np.ndarray) -> np.ndarray:
"""最小-最大标准化"""
if np.max(x) == np.min(x):
return np.ones_like(x)
return (x - np.min(x)) / (np.max(x) - np.min(x))
class InferenceConfig:
"""推理配置类"""
def __init__(self,
topk_edges: int = 10,
weight_adjust: float = 0.05,
ppr_alpha: float = 0.95,
ppr_max_iter: int = 100,
ppr_tol: float = 1e-6,
top_k_events: int = 10, # 新增:事件节点数量
top_k_passages: int = 3, # 新增:段落节点数量
global_topk: int = 20): # 新增:全局段落检索数量
self.topk_edges = topk_edges
self.weight_adjust = weight_adjust
self.ppr_alpha = ppr_alpha
self.ppr_max_iter = ppr_max_iter
self.ppr_tol = ppr_tol
self.top_k_events = top_k_events # 新增
self.top_k_passages = top_k_passages # 新增
self.global_topk = global_topk # 新增
class LangChainHippoRAGRetriever(BaseRetriever):
"""
基于LangChain ES集成的HippoRAG检索器
使用原生ElasticsearchStore替代自定义HTTP请求
"""
# Pydantic字段
top_k: int = Field(default=13, description="返回的文档数量10个事件+3个段落")
keyword: str = Field(description="关键词,用于索引名称")
llm_generator: Any = Field(description="LLM生成器", exclude=True)
embedding_model: Any = Field(description="嵌入模型", exclude=True)
inference_config: Any = Field(description="推理配置", exclude=True)
graph_data: Any = Field(default=None, description="图数据", exclude=True)
es_client: Any = Field(default=None, description="ES客户端", exclude=True)
edges_index: str = Field(default="", description="边索引名称")
passages_index: str = Field(default="", description="段落索引名称")
def __init__(
self,
llm_generator: DashScopeLLM,
embedding_model: DashScopeEmbeddingModel,
keyword: str,
top_k: int = 13,
inference_config: Optional[InferenceConfig] = None,
**kwargs
):
"""初始化LangChain HippoRAG检索器"""
# 初始化配置
if inference_config is None:
inference_config = InferenceConfig()
# 初始化Elasticsearch客户端
es_client = Elasticsearch(
ElasticsearchConfig.ES_HOST,
basic_auth=(ElasticsearchConfig.ES_USERNAME, ElasticsearchConfig.ES_PASSWORD),
timeout=30,
max_retries=3,
retry_on_timeout=True
)
# 生成索引名称
edges_index = ElasticsearchConfig.get_edges_index_name(keyword)
passages_index = ElasticsearchConfig.get_passages_index_name(keyword)
# 调用父类初始化
super().__init__(
top_k=top_k,
keyword=keyword,
llm_generator=llm_generator,
embedding_model=embedding_model,
inference_config=inference_config,
edges_index=edges_index,
passages_index=passages_index,
graph_data=None, # 将在_load_graph_data中设置
es_client=es_client,
**kwargs
)
# 加载图数据
self._load_graph_data()
print(f"LangChain HippoRAG检索器初始化完成")
print(f"边索引: {edges_index}")
print(f"段落索引: {passages_index}")
def _load_graph_data(self):
"""加载包含概念节点的图数据"""
import pickle
graph_path = os.path.join(
os.path.dirname(__file__), '..', '..',
f'{self.keyword}_with_concept.pkl'
)
if os.path.exists(graph_path):
with open(graph_path, 'rb') as f:
self.graph_data = pickle.load(f)
print(f"[OK] 加载包含概念的图数据: {len(self.graph_data.nodes())} 个节点")
# 统计节点类型
node_types = {}
for node_id, attrs in self.graph_data.nodes(data=True):
node_type = attrs.get('type', 'unknown')
node_types[node_type] = node_types.get(node_type, 0) + 1
print(f" 节点类型统计: {node_types}")
else:
print(f"[ERROR] 包含概念的图数据文件不存在: {graph_path}")
self.graph_data = None
@traceable(name="Query2Edge_Retrieval")
def _query2edge(self, query: str, topN: int = 10) -> Dict[str, float]:
"""查询到边的检索使用原生ES客户端按原始逻辑"""
# 记录输入参数到LangSmith
print(f"[SEARCH] Query2Edge开始: {query}, topN={topN}")
# 1. 查询向量化
query_emb = self.embedding_model.encode([query], query_type="edge")[0]
print(f"[SEARCH] DEBUG: 查询向量维度: {query_emb.shape}")
# 2. 使用原生ES查询相似边
try:
# 检查索引是否存在
if not self.es_client.indices.exists(index=self.edges_index):
print(f"[WARNING] 边索引不存在: {self.edges_index}")
return {}
print(f"[OK] DEBUG: 边索引存在: {self.edges_index}")
# ES script_score查询
query_body = {
"size": topN * 2,
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_emb.tolist()
}
}
}
}
}
result = self.es_client.search(index=self.edges_index, body=query_body)
hits = result.get("hits", {}).get("hits", [])
print(f"[SEARCH] DEBUG: ES返回边数量: {len(hits)}")
# 记录ES检索详情到LangSmith
es_search_info = {
"query": query,
"edges_index": self.edges_index,
"topN_requested": topN,
"es_hits_returned": len(hits),
"query_vector_dimension": query_emb.shape[0],
"top_edges_preview": []
}
# 3. 构建边列表用于LLM过滤
before_filter_edge_json = {"fact": []}
valid_edges = []
scores = []
print(f"[SEARCH] DEBUG: 图中节点数量: {len(self.graph_data.nodes()) if self.graph_data else 0}")
for i, hit in enumerate(hits):
source = hit["_source"]
try:
head_node_id = source['head_node_id']
tail_node_id = source['tail_node_id']
score = hit.get('_score', 0.5)
if i < 3: # 只打印前3个边的详细信息
print(f"[SEARCH] DEBUG: 边{i+1} (分数:{score:.3f}): {head_node_id} -> {tail_node_id}")
head_exists = head_node_id in self.graph_data.nodes if self.graph_data else False
tail_exists = tail_node_id in self.graph_data.nodes if self.graph_data else False
print(f" 头节点存在: {head_exists}, 尾节点存在: {tail_exists}")
# 添加到LangSmith监控信息
es_search_info["top_edges_preview"].append({
"rank": i + 1,
"score": float(score),
"head_entity": source.get('head_entity', ''),
"relation": source.get('relation', ''),
"tail_entity": source.get('tail_entity', ''),
"head_node_id": head_node_id,
"tail_node_id": tail_node_id,
"head_exists_in_graph": head_exists,
"tail_exists_in_graph": tail_exists
})
# 检查边的节点是否存在于图中
if self.graph_data and head_node_id in self.graph_data.nodes and tail_node_id in self.graph_data.nodes:
edge_str = [source['head_entity'], source['relation'], source['tail_entity']]
before_filter_edge_json['fact'].append(edge_str)
valid_edges.append(((head_node_id, tail_node_id), source.get('edge_index', 0)))
scores.append(hit.get('_score', 0.5))
except KeyError as e:
if i < 3:
print(f"[SEARCH] DEBUG: 边{i+1} KeyError: {e}")
continue
print(f"[SEARCH] DEBUG: 有效边数量: {len(valid_edges)}")
# 4. 如果没有有效边,返回空字典
if len(before_filter_edge_json['fact']) == 0:
print("[WARNING] DEBUG: 没有有效边,返回空字典")
return {}
# 5. 使用LLM过滤边提取为独立的可追踪方法
filtered_facts = self._llm_filter_edges(query, before_filter_edge_json)
if len(filtered_facts) == 0:
return {}
# 6. 构建节点分数字典
node_score_dict = {}
for edge_fact in filtered_facts:
for i, (edge, idx) in enumerate(valid_edges):
try:
head_id = self.graph_data.nodes[edge[0]].get('id', edge[0]) if self.graph_data else edge[0]
tail_id = self.graph_data.nodes[edge[1]].get('id', edge[1]) if self.graph_data else edge[1]
relation = before_filter_edge_json['fact'][i][1] if i < len(before_filter_edge_json['fact']) else 'related_to'
if head_id == edge_fact[0] and relation == edge_fact[1] and tail_id == edge_fact[2]:
head, tail = edge[0], edge[1]
sim_score = scores[i] if i < len(scores) else 0.5
# 更新节点分数
if head not in node_score_dict:
node_score_dict[head] = sim_score
else:
node_score_dict[head] = max(node_score_dict[head], sim_score)
if tail not in node_score_dict:
node_score_dict[tail] = sim_score
else:
node_score_dict[tail] = max(node_score_dict[tail], sim_score)
break
except (KeyError, IndexError):
continue
# 记录最终结果到LangSmith
final_result = {
**es_search_info,
"llm_filter_applied": True,
"edges_before_llm_filter": len(before_filter_edge_json['fact']),
"edges_after_llm_filter": len(filtered_facts),
"llm_filter_ratio": len(filtered_facts) / len(before_filter_edge_json['fact']) if before_filter_edge_json['fact'] else 0,
"final_node_scores": len(node_score_dict),
"node_score_preview": dict(list(node_score_dict.items())[:5]) if node_score_dict else {},
"edge_to_node_mapping_success": len(node_score_dict) > 0,
"valid_edges_found": len(valid_edges),
"filtered_facts_sample": filtered_facts[:3] if filtered_facts else []
}
print(f"[SEARCH] Query2Edge结果记录到LangSmith: {len(node_score_dict)}个节点获得分数")
return node_score_dict
except Exception as e:
print(f"[ERROR] 边检索失败: {e}")
import traceback
traceback.print_exc()
return {}
@traceable(name="LLM_Edge_Filter")
def _llm_filter_edges(self, query: str, before_filter_edge_json: Dict) -> List:
"""
使用LLM过滤边的三元组并提供详细监控
Args:
query: 用户查询
before_filter_edge_json: 过滤前的边JSON数据
Returns:
过滤后的事实列表
"""
print(f"[?] LLM边过滤开始: {len(before_filter_edge_json['fact'])}个边待过滤")
# 检查是否跳过边过滤
from prompt_loader import get_prompt_loader
prompt_loader = get_prompt_loader()
if prompt_loader.should_skip_llm('edge_filter'):
print(f"[NEXT] 跳过边过滤LLM调用使用所有边")
# 不过滤,返回所有边
return before_filter_edge_json['fact']
# 记录过滤前的详细信息
filter_input_info = {
"query": query,
"edges_before_filter": len(before_filter_edge_json['fact']),
"sample_edges_before_filter": before_filter_edge_json['fact'][:5], # 前5个边的预览
"all_edges_before_filter": before_filter_edge_json['fact'] # 完整的边列表
}
try:
# 调用LLM进行过滤
llm_response = self.llm_generator.filter_triples_with_entity_event(
query,
json.dumps(before_filter_edge_json, ensure_ascii=False)
)
# 记录LLM的原始响应
print(f"[?] DEBUG: LLM原始响应: {llm_response[:200]}...")
# 解析LLM响应
try:
filtered_facts = json_repair.loads(llm_response)['fact']
parse_success = True
parse_error = None
except Exception as parse_e:
print(f"[?] DEBUG: JSON解析失败: {parse_e}")
filtered_facts = []
parse_success = False
parse_error = str(parse_e)
# 记录过滤结果的详细信息
filter_result_info = {
**filter_input_info,
"llm_raw_response": llm_response,
"llm_response_length": len(llm_response),
"json_parse_success": parse_success,
"json_parse_error": parse_error,
"edges_after_filter": len(filtered_facts),
"filtered_facts": filtered_facts,
"filter_ratio": len(filtered_facts) / len(before_filter_edge_json['fact']) if before_filter_edge_json['fact'] else 0,
"filtered_out_count": len(before_filter_edge_json['fact']) - len(filtered_facts),
"sample_filtered_facts": filtered_facts[:5] if filtered_facts else []
}
print(f"[?] LLM过滤完成: {len(before_filter_edge_json['fact'])} -> {len(filtered_facts)} 个边")
print(f"[?] 过滤比例: {filter_result_info['filter_ratio']:.2%}")
return filtered_facts
except Exception as e:
print(f"[ERROR] LLM边过滤失败: {e}")
# 记录错误信息到LangSmith
error_info = {
**filter_input_info,
"llm_filter_error": str(e),
"edges_after_filter": 0,
"filtered_facts": [],
"filter_success": False
}
import traceback
traceback.print_exc()
return []
@traceable(name="Query2Passage_Retrieval")
def _query2passage(self, query: str, weight_adjust: float = 0.05) -> Dict[str, float]:
"""查询到段落的检索使用原生ES客户端按原始逻辑"""
# 记录输入参数到LangSmith
print(f"[SEARCH] Query2Passage开始: {query}, weight_adjust={weight_adjust}")
# 1. 查询向量化
query_emb = self.embedding_model.encode([query], query_type="passage")[0]
print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}")
# 2. 使用原生ES查询相似段落
try:
# 检查索引是否存在
if not self.es_client.indices.exists(index=self.passages_index):
print(f"[WARNING] 段落索引不存在: {self.passages_index}")
return {}
# ES script_score查询
query_body = {
"size": 1000, # 获取足够多的段落
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_emb.tolist()
}
}
}
}
}
result = self.es_client.search(index=self.passages_index, body=query_body)
hits = result.get("hits", {}).get("hits", [])
print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)}")
# 记录ES检索详情到LangSmith
es_search_info = {
"query": query,
"passages_index": self.passages_index,
"weight_adjust": weight_adjust,
"es_hits_returned": len(hits),
"query_vector_dimension": query_emb.shape[0],
"top_passages_preview": []
}
# 3. 构建段落分数字典
passage_scores = {}
for i, hit in enumerate(hits):
source = hit["_source"]
passage_id = source.get('passage_id') or source.get('node_id')
if passage_id:
original_score = hit.get('_score', 0.5)
adjusted_score = original_score * weight_adjust # 应用权重调整
passage_scores[passage_id] = adjusted_score
# 收集前几个段落的预览信息
if i < 5: # 收集前5个段落的详细信息
es_search_info["top_passages_preview"].append({
"rank": i + 1,
"passage_id": passage_id,
"original_score": float(original_score),
"adjusted_score": float(adjusted_score),
"content_preview": source.get("content", source.get("text", ""))[:100] + "..." if source.get("content") or source.get("text") else "内容为空"
})
# 记录最终结果到LangSmith
final_result = {
**es_search_info,
"final_passage_scores": len(passage_scores),
"passage_score_preview": dict(list(passage_scores.items())[:5]) if passage_scores else {},
"total_score_sum": sum(passage_scores.values()) if passage_scores else 0
}
print(f"[SEARCH] Query2Passage结果记录到LangSmith: {len(passage_scores)}个段落获得分数")
return passage_scores
except Exception as e:
print(f"[ERROR] 段落检索失败: {e}")
import traceback
traceback.print_exc()
return {}
def _find_connected_text_nodes(self, entity_node_ids: List[str]) -> List[str]:
"""
根据实体/事件/概念节点ID查找连接的Text节点
Args:
entity_node_ids: 实体节点ID列表
Returns:
连接的Text节点ID列表
"""
text_node_ids = set()
if not self.graph_data or not entity_node_ids:
return []
print(f"[SEARCH] 查找 {len(entity_node_ids)} 个节点的连接Text节点")
for entity_id in entity_node_ids:
if entity_id in self.graph_data.nodes:
neighbors = list(self.graph_data.neighbors(entity_id))
for neighbor_id in neighbors:
neighbor_data = self.graph_data.nodes[neighbor_id]
if neighbor_data.get('type') == 'text':
text_node_ids.add(neighbor_id)
result = list(text_node_ids)
print(f"[SEARCH] 找到 {len(result)} 个相关的Text节点")
return result
def _query2passage_filtered(self, query: str, entity_node_ids: List[str], weight_adjust: float = 0.05) -> Dict[str, float]:
"""
基于LLM过滤后的实体节点进行段落检索
只对与这些实体节点相关的Text节点进行向量相似度计算
Args:
query: 查询文本
entity_node_ids: LLM过滤后的实体节点ID列表
weight_adjust: 权重调整因子
Returns:
段落分数字典
"""
print(f"[SEARCH] Query2Passage(Filtered)开始: {query}, 基于{len(entity_node_ids)}个节点")
print(f"[PARAM] weight_adjust = {weight_adjust}")
# 1. 查找相关的Text节点
text_node_ids = self._find_connected_text_nodes(entity_node_ids)
if not text_node_ids:
print("[WARNING] 没有找到相关的Text节点")
return {}
# 2. 查询向量化
query_emb = self.embedding_model.encode([query], query_type="passage")[0]
print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}")
# 3. 使用原生ES查询相关Text节点的段落
try:
# 检查索引是否存在
if not self.es_client.indices.exists(index=self.passages_index):
print(f"[WARNING] 段落索引不存在: {self.passages_index}")
return {}
# 构建过滤查询只搜索指定的text_id
query_body = {
"size": min(1000, len(text_node_ids)), # 不超过text节点数量
"query": {
"script_score": {
"query": {
"terms": {
"passage_id": text_node_ids # 只查询指定的text节点
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_emb.tolist()
}
}
}
}
}
result = self.es_client.search(index=self.passages_index, body=query_body)
hits = result.get("hits", {}).get("hits", [])
print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)} (从{len(text_node_ids)}个候选Text节点中筛选)")
# 4. 构建段落分数字典
passage_scores = {}
for i, hit in enumerate(hits):
source = hit["_source"]
passage_id = source.get('passage_id') or source.get('node_id')
if passage_id:
original_score = hit.get('_score', 0.5)
adjusted_score = original_score * weight_adjust
passage_scores[passage_id] = adjusted_score
if i < 3: # 打印前3个段落的分数调整细节
print(f" [SCORE-F] 段落{i+1}: 原始分数={original_score:.4f} * weight_adjust={weight_adjust:.4f} = 调整后分数={adjusted_score:.4f}")
print(f"[SEARCH] Query2Passage(Filtered)完成: {len(passage_scores)}个段落获得分数")
return passage_scores
except Exception as e:
print(f"[ERROR] 过滤段落检索失败: {e}")
import traceback
traceback.print_exc()
return {}
def _query2passage_global_topk(self, query: str, topk: int = 20, weight_adjust: float = 0.05) -> Dict[str, float]:
"""
全局Top-K段落检索
从所有段落中返回最相似的Top-K个段落
Args:
query: 查询文本
topk: 返回的段落数量
weight_adjust: 权重调整因子
Returns:
段落分数字典
"""
print(f"[SEARCH] Query2Passage(Global-Top{topk})开始: {query}")
print(f"[PARAM] weight_adjust = {weight_adjust}")
# 查询向量化
query_emb = self.embedding_model.encode([query], query_type="passage")[0]
print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}")
try:
# 检查索引是否存在
if not self.es_client.indices.exists(index=self.passages_index):
print(f"[WARNING] 段落索引不存在: {self.passages_index}")
return {}
# ES script_score查询限制返回数量
query_body = {
"size": topk, # 只返回Top-K个结果
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {
"query_vector": query_emb.tolist()
}
}
}
}
}
result = self.es_client.search(index=self.passages_index, body=query_body)
hits = result.get("hits", {}).get("hits", [])
print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)} (Global Top-{topk})")
# 构建段落分数字典
passage_scores = {}
for i, hit in enumerate(hits):
source = hit["_source"]
passage_id = source.get('passage_id') or source.get('node_id')
if passage_id:
original_score = hit.get('_score', 0.5)
adjusted_score = original_score * weight_adjust
passage_scores[passage_id] = adjusted_score
if i < 3: # 打印前3个段落的分数调整细节
print(f" [SCORE-G] 段落{i+1}: 原始分数={original_score:.4f} * weight_adjust={weight_adjust:.4f} = 调整后分数={adjusted_score:.4f}")
print(f"[SEARCH] Query2Passage(Global-Top{topk})完成: {len(passage_scores)}个段落获得分数")
return passage_scores
except Exception as e:
print(f"[ERROR] 全局Top-K段落检索失败: {e}")
import traceback
traceback.print_exc()
return {}
def _query2passage_hybrid(self, query: str, entity_node_ids: List[str], weight_adjust: float = 0.05, global_topk: int = 20) -> Dict[str, float]:
"""
混合检索策略过滤版 + 全局Top-K
确保既有高效性又不遗失高质量段落
Args:
query: 查询文本
entity_node_ids: LLM过滤后的实体节点ID列表
weight_adjust: 权重调整因子
global_topk: 全局检索的Top-K数量
Returns:
合并后的段落分数字典
"""
print(f"[SEARCH] Query2Passage(Hybrid)开始: 过滤版+全局Top-{global_topk}")
print(f"[PARAM] weight_adjust = {weight_adjust} (用于调整段落相似度分数)")
# 1. 过滤版检索
filtered_scores = self._query2passage_filtered(query, entity_node_ids, weight_adjust)
print(f" [TARGET] 过滤版获得: {len(filtered_scores)}个段落")
# 2. 全局Top-K检索
global_scores = self._query2passage_global_topk(query, topk=global_topk, weight_adjust=weight_adjust)
print(f" [?] 全局检索获得: {len(global_scores)}个段落")
# 3. 合并结果,过滤版优先(如果有重复,保持过滤版的分数)
combined_scores = dict(filtered_scores) # 先添加过滤版结果
# 添加全局检索中不重复的结果
added_count = 0
for passage_id, score in global_scores.items():
if passage_id not in combined_scores:
combined_scores[passage_id] = score
added_count += 1
print(f" [?] 合并后总计: {len(combined_scores)}个段落 (新增{added_count}个)")
print(f"[SEARCH] Query2Passage(Hybrid)完成")
return combined_scores
@traceable(name="Personalized_PageRank")
def _personalized_pagerank(self, personalization_dict: Dict[str, float], alpha: float = 0.85) -> Dict[str, float]:
"""个性化PageRank计算按原始逻辑"""
# 记录输入参数到LangSmith
non_zero_nodes = sum(1 for v in personalization_dict.values() if v > 0)
print(f"[SEARCH] PageRank开始: {non_zero_nodes}/{len(personalization_dict)}个节点有非零分数, alpha={alpha}")
if not personalization_dict or not self.graph_data.nodes():
print("[WARNING] PageRank输入为空返回空结果")
return {}
# 记录PageRank输入详情到LangSmith
pagerank_input_info = {
"total_nodes_in_graph": len(self.graph_data.nodes()),
"total_edges_in_graph": len(self.graph_data.edges()),
"personalization_nodes": len(personalization_dict),
"non_zero_personalization_nodes": non_zero_nodes,
"alpha": alpha,
"max_iter": self.inference_config.ppr_max_iter,
"tolerance": self.inference_config.ppr_tol,
"top_personalization_scores": dict(sorted(personalization_dict.items(), key=lambda x: x[1], reverse=True)[:10])
}
try:
# 使用原始的个性化字典进行PageRank计算
ppr_scores = nx.pagerank(
self.graph_data,
alpha=alpha,
personalization=personalization_dict,
max_iter=self.inference_config.ppr_max_iter,
tol=self.inference_config.ppr_tol
)
# 记录PageRank结果到LangSmith
pagerank_result_info = {
**pagerank_input_info,
"pagerank_successful": True,
"total_nodes_with_ppr_scores": len(ppr_scores),
"top_ppr_scores": dict(sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)[:10]),
"ppr_score_statistics": {
"max_score": max(ppr_scores.values()) if ppr_scores else 0,
"min_score": min(ppr_scores.values()) if ppr_scores else 0,
"mean_score": sum(ppr_scores.values()) / len(ppr_scores) if ppr_scores else 0,
"total_score_sum": sum(ppr_scores.values()) if ppr_scores else 0
}
}
print(f"[SEARCH] PageRank结果记录到LangSmith: {len(ppr_scores)}个节点获得PPR分数")
return ppr_scores
except Exception as e:
print(f"[ERROR] PageRank计算失败: {e}")
return {}
@traceable(name="HippoRAG_Complete_Retrieval")
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""检索相关文档严格按原始HippoRAG2逻辑"""
if not self.graph_data:
print("[WARNING] 图数据未加载")
return []
print(f"[SEARCH] HippoRAG2检索开始: {query}")
# 1. query2edge - 获取节点分数字典
node_dict = self._query2edge(query, topN=self.inference_config.topk_edges)
print(f"[INFO] 边检索得到 {len(node_dict)} 个节点")
# 2. query2passage - 混合检索策略(过滤版 + 全局Top-K
entity_node_ids = list(node_dict.keys()) # 获取LLM过滤后的节点ID
print(f"[CONFIG] 使用配置参数: weight_adjust={self.inference_config.weight_adjust}, global_topk={self.inference_config.global_topk}")
text_dict = self._query2passage_hybrid(query, entity_node_ids, weight_adjust=self.inference_config.weight_adjust, global_topk=self.inference_config.global_topk)
print(f"[?] 段落检索得到 {len(text_dict)} 个段落")
# 3. 合并个性化字典
personalization_dict = {}
# 将所有节点初始化为0
for node in self.graph_data.nodes():
personalization_dict[node] = 0.0
# 添加边检索的节点分数
for node, score in node_dict.items():
if node in personalization_dict:
personalization_dict[node] = score
# 添加段落检索的文本节点分数
for text_id, score in text_dict.items():
if text_id in personalization_dict:
personalization_dict[text_id] = score
print(f"[TARGET] 个性化字典包含 {sum(1 for v in personalization_dict.values() if v > 0)} 个非零节点")
# 4. 个性化PageRank传播
ppr_scores = self._personalized_pagerank(personalization_dict, self.inference_config.ppr_alpha)
if not ppr_scores:
print("[WARNING] PageRank计算失败")
return []
# 5. 节点排序并分别筛选事件节点和段落节点
event_node_scores = []
text_node_scores = []
for node_id, ppr_score in ppr_scores.items():
if node_id in self.graph_data.nodes:
node_type = self.graph_data.nodes[node_id].get('type', '')
if node_type == 'event':
event_node_scores.append((node_id, ppr_score))
elif node_id in text_dict: # 确保是文本节点且在检索结果中
text_node_scores.append((node_id, ppr_score))
# 按PageRank分数排序
event_node_scores.sort(key=lambda x: x[1], reverse=True)
text_node_scores.sort(key=lambda x: x[1], reverse=True)
# 6. 构建Document对象选择TOP-K事件和TOP-K段落
documents = []
# 6.1 处理事件节点 (TOP-K)
for i, (node_id, ppr_score) in enumerate(event_node_scores[:self.inference_config.top_k_events]):
# 事件节点的内容通常存储在图数据中
try:
if node_id in self.graph_data.nodes:
node_data = self.graph_data.nodes[node_id]
# 获取事件内容可能的字段包括content, text, description, label等
content = (node_data.get('content') or
node_data.get('text') or
node_data.get('description') or
node_data.get('label') or
node_data.get('id') or
f"事件节点 {node_id}")
else:
content = f"事件节点 {node_id} 数据未找到"
except Exception as e:
content = f"事件节点 {node_id} 内容获取失败: {str(e)}"
doc = Document(
page_content=content,
metadata={
"node_id": node_id,
"node_type": "event",
"ppr_score": ppr_score,
"edge_score": node_dict.get(node_id, 0.0),
"passage_score": text_dict.get(node_id, 0.0),
"rank": i + 1,
"source": "hipporag2_langchain_event"
}
)
documents.append(doc)
# 6.2 处理段落节点 (TOP-K)
for i, (node_id, ppr_score) in enumerate(text_node_scores[:self.inference_config.top_k_passages]):
# 使用ES查询获取段落内容
try:
query_body = {
"query": {
"term": {
"passage_id": node_id # 或者node_id取决于存储格式
}
},
"size": 1
}
result = self.es_client.search(index=self.passages_index, body=query_body)
if result.get("hits", {}).get("hits"):
source = result["hits"]["hits"][0]["_source"]
content = source.get("content") or source.get("text") or f"段落 {node_id} 内容未找到"
else:
# 尝试用node_id查询
query_body["query"]["term"] = {"node_id": node_id}
result = self.es_client.search(index=self.passages_index, body=query_body)
if result.get("hits", {}).get("hits"):
source = result["hits"]["hits"][0]["_source"]
content = source.get("content") or source.get("text") or f"段落 {node_id} 内容未找到"
else:
content = f"段落 {node_id} 内容未找到"
except:
content = f"段落 {node_id} 内容未找到"
doc = Document(
page_content=content,
metadata={
"node_id": node_id,
"node_type": "text",
"ppr_score": ppr_score,
"edge_score": node_dict.get(node_id, 0.0),
"passage_score": text_dict.get(node_id, 0.0),
"rank": len(documents) + 1, # 继续排序
"source": "hipporag2_langchain_text"
}
)
documents.append(doc)
event_count = len(event_node_scores[:self.inference_config.top_k_events])
text_count = len(text_node_scores[:self.inference_config.top_k_passages])
print(f"[OK] HippoRAG2检索完成返回 {event_count} 个事件节点 + {text_count} 个段落节点,共 {len(documents)} 个文档")
# 不在metadata中存储PageRank分数以避免LangSmith追踪大量数据
# doc.metadata["complete_ppr_scores"] = ppr_scores # 注释掉避免LangSmith传输
for doc in documents:
doc.metadata["query"] = query
# 只存储统计信息而不是完整的PageRank数据
doc.metadata["pagerank_available"] = len(ppr_scores) > 0
return documents
def get_complete_pagerank_scores(self, query: str) -> Dict[str, Any]:
"""
获取完整的PageRank分数信息所有非零节点
Args:
query: 查询字符串
Returns:
包含完整PageRank信息的字典
"""
if not self.graph_data:
print("[WARNING] 图数据未加载")
return {}
print(f"[SEARCH] 获取完整PageRank分数: {query}")
# 1. query2edge - 获取节点分数字典
node_dict = self._query2edge(query, topN=self.inference_config.topk_edges)
# 2. query2passage - 混合检索策略(过滤版 + 全局Top-K
entity_node_ids = list(node_dict.keys()) # 获取LLM过滤后的节点ID
text_dict = self._query2passage_hybrid(query, entity_node_ids, weight_adjust=self.inference_config.weight_adjust, global_topk=self.inference_config.global_topk)
# 3. 合并个性化字典
personalization_dict = {}
# 将所有节点初始化为0
for node in self.graph_data.nodes():
personalization_dict[node] = 0.0
# 添加边检索的节点分数
for node, score in node_dict.items():
if node in personalization_dict:
personalization_dict[node] = score
# 添加段落检索的文本节点分数
for text_id, score in text_dict.items():
if text_id in personalization_dict:
personalization_dict[text_id] = score
# 4. 个性化PageRank传播
ppr_scores = self._personalized_pagerank(personalization_dict, self.inference_config.ppr_alpha)
if not ppr_scores:
print("[WARNING] PageRank计算失败")
return {}
# 5. 筛选非零分数的节点,按分数降序排列
non_zero_scores = [(node_id, score) for node_id, score in ppr_scores.items() if score > 0]
non_zero_scores.sort(key=lambda x: x[1], reverse=True)
# 6. 构建完整的分数信息
return {
"query": query,
"total_nodes_in_graph": len(self.graph_data.nodes()),
"total_edges_in_graph": len(self.graph_data.edges()),
"personalization_input": {
"node_dict_count": len(node_dict),
"text_dict_count": len(text_dict),
"non_zero_personalization_count": sum(1 for v in personalization_dict.values() if v > 0)
},
"pagerank_results": {
"total_nodes_with_scores": len(ppr_scores),
"non_zero_nodes_count": len(non_zero_scores),
"max_score": max(ppr_scores.values()) if ppr_scores else 0,
"min_score": min(ppr_scores.values()) if ppr_scores else 0,
"mean_score": sum(ppr_scores.values()) / len(ppr_scores) if ppr_scores else 0,
"all_scores": dict(ppr_scores), # 完整的所有节点分数
"non_zero_scores_sorted": non_zero_scores # 非零分数,降序排列
},
"input_scores": {
"node_dict": node_dict,
"text_dict": text_dict
}
}
def create_langchain_hipporag_retriever(
keyword: str,
top_k: int = 13,
inference_config: Optional[InferenceConfig] = None,
oneapi_key: Optional[str] = None,
oneapi_base_url: Optional[str] = None,
oneapi_model_gen: Optional[str] = None,
oneapi_model_embed: Optional[str] = None
) -> LangChainHippoRAGRetriever:
"""
创建LangChain HippoRAG检索器的便捷函数
Args:
keyword: 关键词
top_k: 返回文档数量
inference_config: 推理配置
oneapi_key: OneAPI密钥
oneapi_base_url: OneAPI基础URL
oneapi_model_gen: 生成模型名称
oneapi_model_embed: 嵌入模型名称
Returns:
LangChainHippoRAGRetriever实例
"""
# 初始化模型
llm_generator = DashScopeLLM(
api_key=oneapi_key,
model_name=oneapi_model_gen
)
embedding_model = DashScopeEmbeddingModel(
api_key=oneapi_key,
model_name=oneapi_model_embed
)
return LangChainHippoRAGRetriever(
llm_generator=llm_generator,
embedding_model=embedding_model,
keyword=keyword,
top_k=top_k,
inference_config=inference_config
)
if __name__ == "__main__":
# 测试代码
from dotenv import load_dotenv
load_dotenv()
retriever = create_langchain_hipporag_retriever(
keyword="test",
top_k=13
)
# 测试检索
results = retriever.get_relevant_documents("什么是供应链风险管理?")
print(f"\n[TARGET] 检索结果:")
for i, doc in enumerate(results, 1):
print(f"{i}. {doc.metadata['node_id']} (PPR: {doc.metadata['ppr_score']:.4f})")
print(f" {doc.page_content[:100]}...")