""" 基于LangChain Elasticsearch集成的HippoRAG2检索器 严格按照原始HippoRAG2逻辑实现: 1. query2edge: 用户查询 → ES边索引向量检索 → LLM过滤 → 节点分数 2. query2passage: 用户查询 → ES段落索引向量检索 → 段落分数 3. 合并个性化字典 → PageRank传播 → 返回排序的文本节点 """ import os import sys import json import numpy as np import networkx as nx from typing import Dict, List, Tuple, Any, Optional from langchain_core.retrievers import BaseRetriever from langchain_core.documents import Document from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_elasticsearch import ElasticsearchStore from elasticsearch import Elasticsearch from pydantic import Field import json_repair # LangSmith追踪支持 from langsmith import traceable # 添加项目路径 project_root = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.append(project_root) from retriver.langgraph.dashscope_embedding import DashScopeEmbeddingModel from retriver.langgraph.dashscope_llm import DashScopeLLM from elasticsearch_vectorization.config import ElasticsearchConfig def min_max_normalize(x: np.ndarray) -> np.ndarray: """最小-最大标准化""" if np.max(x) == np.min(x): return np.ones_like(x) return (x - np.min(x)) / (np.max(x) - np.min(x)) class InferenceConfig: """推理配置类""" def __init__(self, topk_edges: int = 10, weight_adjust: float = 0.05, ppr_alpha: float = 0.95, ppr_max_iter: int = 100, ppr_tol: float = 1e-6, top_k_events: int = 10, # 新增:事件节点数量 top_k_passages: int = 3, # 新增:段落节点数量 global_topk: int = 20): # 新增:全局段落检索数量 self.topk_edges = topk_edges self.weight_adjust = weight_adjust self.ppr_alpha = ppr_alpha self.ppr_max_iter = ppr_max_iter self.ppr_tol = ppr_tol self.top_k_events = top_k_events # 新增 self.top_k_passages = top_k_passages # 新增 self.global_topk = global_topk # 新增 class LangChainHippoRAGRetriever(BaseRetriever): """ 基于LangChain ES集成的HippoRAG检索器 使用原生ElasticsearchStore替代自定义HTTP请求 """ # Pydantic字段 top_k: int = Field(default=13, description="返回的文档数量(10个事件+3个段落)") keyword: str = Field(description="关键词,用于索引名称") llm_generator: Any = Field(description="LLM生成器", exclude=True) embedding_model: Any = Field(description="嵌入模型", exclude=True) inference_config: Any = Field(description="推理配置", exclude=True) graph_data: Any = Field(default=None, description="图数据", exclude=True) es_client: Any = Field(default=None, description="ES客户端", exclude=True) edges_index: str = Field(default="", description="边索引名称") passages_index: str = Field(default="", description="段落索引名称") def __init__( self, llm_generator: DashScopeLLM, embedding_model: DashScopeEmbeddingModel, keyword: str, top_k: int = 13, inference_config: Optional[InferenceConfig] = None, **kwargs ): """初始化LangChain HippoRAG检索器""" # 初始化配置 if inference_config is None: inference_config = InferenceConfig() # 初始化Elasticsearch客户端 es_client = Elasticsearch( ElasticsearchConfig.ES_HOST, basic_auth=(ElasticsearchConfig.ES_USERNAME, ElasticsearchConfig.ES_PASSWORD), timeout=30, max_retries=3, retry_on_timeout=True ) # 生成索引名称 edges_index = ElasticsearchConfig.get_edges_index_name(keyword) passages_index = ElasticsearchConfig.get_passages_index_name(keyword) # 调用父类初始化 super().__init__( top_k=top_k, keyword=keyword, llm_generator=llm_generator, embedding_model=embedding_model, inference_config=inference_config, edges_index=edges_index, passages_index=passages_index, graph_data=None, # 将在_load_graph_data中设置 es_client=es_client, **kwargs ) # 加载图数据 self._load_graph_data() print(f"LangChain HippoRAG检索器初始化完成") print(f"边索引: {edges_index}") print(f"段落索引: {passages_index}") def _load_graph_data(self): """加载包含概念节点的图数据""" import pickle graph_path = os.path.join( os.path.dirname(__file__), '..', '..', f'{self.keyword}_with_concept.pkl' ) if os.path.exists(graph_path): with open(graph_path, 'rb') as f: self.graph_data = pickle.load(f) print(f"[OK] 加载包含概念的图数据: {len(self.graph_data.nodes())} 个节点") # 统计节点类型 node_types = {} for node_id, attrs in self.graph_data.nodes(data=True): node_type = attrs.get('type', 'unknown') node_types[node_type] = node_types.get(node_type, 0) + 1 print(f" 节点类型统计: {node_types}") else: print(f"[ERROR] 包含概念的图数据文件不存在: {graph_path}") self.graph_data = None @traceable(name="Query2Edge_Retrieval") def _query2edge(self, query: str, topN: int = 10) -> Dict[str, float]: """查询到边的检索(使用原生ES客户端,按原始逻辑)""" # 记录输入参数到LangSmith print(f"[SEARCH] Query2Edge开始: {query}, topN={topN}") # 1. 查询向量化 query_emb = self.embedding_model.encode([query], query_type="edge")[0] print(f"[SEARCH] DEBUG: 查询向量维度: {query_emb.shape}") # 2. 使用原生ES查询相似边 try: # 检查索引是否存在 if not self.es_client.indices.exists(index=self.edges_index): print(f"[WARNING] 边索引不存在: {self.edges_index}") return {} print(f"[OK] DEBUG: 边索引存在: {self.edges_index}") # ES script_score查询 query_body = { "size": topN * 2, "query": { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": { "query_vector": query_emb.tolist() } } } } } result = self.es_client.search(index=self.edges_index, body=query_body) hits = result.get("hits", {}).get("hits", []) print(f"[SEARCH] DEBUG: ES返回边数量: {len(hits)}") # 记录ES检索详情到LangSmith es_search_info = { "query": query, "edges_index": self.edges_index, "topN_requested": topN, "es_hits_returned": len(hits), "query_vector_dimension": query_emb.shape[0], "top_edges_preview": [] } # 3. 构建边列表用于LLM过滤 before_filter_edge_json = {"fact": []} valid_edges = [] scores = [] print(f"[SEARCH] DEBUG: 图中节点数量: {len(self.graph_data.nodes()) if self.graph_data else 0}") for i, hit in enumerate(hits): source = hit["_source"] try: head_node_id = source['head_node_id'] tail_node_id = source['tail_node_id'] score = hit.get('_score', 0.5) if i < 3: # 只打印前3个边的详细信息 print(f"[SEARCH] DEBUG: 边{i+1} (分数:{score:.3f}): {head_node_id} -> {tail_node_id}") head_exists = head_node_id in self.graph_data.nodes if self.graph_data else False tail_exists = tail_node_id in self.graph_data.nodes if self.graph_data else False print(f" 头节点存在: {head_exists}, 尾节点存在: {tail_exists}") # 添加到LangSmith监控信息 es_search_info["top_edges_preview"].append({ "rank": i + 1, "score": float(score), "head_entity": source.get('head_entity', ''), "relation": source.get('relation', ''), "tail_entity": source.get('tail_entity', ''), "head_node_id": head_node_id, "tail_node_id": tail_node_id, "head_exists_in_graph": head_exists, "tail_exists_in_graph": tail_exists }) # 检查边的节点是否存在于图中 if self.graph_data and head_node_id in self.graph_data.nodes and tail_node_id in self.graph_data.nodes: edge_str = [source['head_entity'], source['relation'], source['tail_entity']] before_filter_edge_json['fact'].append(edge_str) valid_edges.append(((head_node_id, tail_node_id), source.get('edge_index', 0))) scores.append(hit.get('_score', 0.5)) except KeyError as e: if i < 3: print(f"[SEARCH] DEBUG: 边{i+1} KeyError: {e}") continue print(f"[SEARCH] DEBUG: 有效边数量: {len(valid_edges)}") # 4. 如果没有有效边,返回空字典 if len(before_filter_edge_json['fact']) == 0: print("[WARNING] DEBUG: 没有有效边,返回空字典") return {} # 5. 使用LLM过滤边(提取为独立的可追踪方法) filtered_facts = self._llm_filter_edges(query, before_filter_edge_json) if len(filtered_facts) == 0: return {} # 6. 构建节点分数字典 node_score_dict = {} for edge_fact in filtered_facts: for i, (edge, idx) in enumerate(valid_edges): try: head_id = self.graph_data.nodes[edge[0]].get('id', edge[0]) if self.graph_data else edge[0] tail_id = self.graph_data.nodes[edge[1]].get('id', edge[1]) if self.graph_data else edge[1] relation = before_filter_edge_json['fact'][i][1] if i < len(before_filter_edge_json['fact']) else 'related_to' if head_id == edge_fact[0] and relation == edge_fact[1] and tail_id == edge_fact[2]: head, tail = edge[0], edge[1] sim_score = scores[i] if i < len(scores) else 0.5 # 更新节点分数 if head not in node_score_dict: node_score_dict[head] = sim_score else: node_score_dict[head] = max(node_score_dict[head], sim_score) if tail not in node_score_dict: node_score_dict[tail] = sim_score else: node_score_dict[tail] = max(node_score_dict[tail], sim_score) break except (KeyError, IndexError): continue # 记录最终结果到LangSmith final_result = { **es_search_info, "llm_filter_applied": True, "edges_before_llm_filter": len(before_filter_edge_json['fact']), "edges_after_llm_filter": len(filtered_facts), "llm_filter_ratio": len(filtered_facts) / len(before_filter_edge_json['fact']) if before_filter_edge_json['fact'] else 0, "final_node_scores": len(node_score_dict), "node_score_preview": dict(list(node_score_dict.items())[:5]) if node_score_dict else {}, "edge_to_node_mapping_success": len(node_score_dict) > 0, "valid_edges_found": len(valid_edges), "filtered_facts_sample": filtered_facts[:3] if filtered_facts else [] } print(f"[SEARCH] Query2Edge结果记录到LangSmith: {len(node_score_dict)}个节点获得分数") return node_score_dict except Exception as e: print(f"[ERROR] 边检索失败: {e}") import traceback traceback.print_exc() return {} @traceable(name="LLM_Edge_Filter") def _llm_filter_edges(self, query: str, before_filter_edge_json: Dict) -> List: """ 使用LLM过滤边的三元组,并提供详细监控 Args: query: 用户查询 before_filter_edge_json: 过滤前的边JSON数据 Returns: 过滤后的事实列表 """ print(f"[?] LLM边过滤开始: {len(before_filter_edge_json['fact'])}个边待过滤") # 检查是否跳过边过滤 from prompt_loader import get_prompt_loader prompt_loader = get_prompt_loader() if prompt_loader.should_skip_llm('edge_filter'): print(f"[NEXT] 跳过边过滤LLM调用,使用所有边") # 不过滤,返回所有边 return before_filter_edge_json['fact'] # 记录过滤前的详细信息 filter_input_info = { "query": query, "edges_before_filter": len(before_filter_edge_json['fact']), "sample_edges_before_filter": before_filter_edge_json['fact'][:5], # 前5个边的预览 "all_edges_before_filter": before_filter_edge_json['fact'] # 完整的边列表 } try: # 调用LLM进行过滤 llm_response = self.llm_generator.filter_triples_with_entity_event( query, json.dumps(before_filter_edge_json, ensure_ascii=False) ) # 记录LLM的原始响应 print(f"[?] DEBUG: LLM原始响应: {llm_response[:200]}...") # 解析LLM响应 try: filtered_facts = json_repair.loads(llm_response)['fact'] parse_success = True parse_error = None except Exception as parse_e: print(f"[?] DEBUG: JSON解析失败: {parse_e}") filtered_facts = [] parse_success = False parse_error = str(parse_e) # 记录过滤结果的详细信息 filter_result_info = { **filter_input_info, "llm_raw_response": llm_response, "llm_response_length": len(llm_response), "json_parse_success": parse_success, "json_parse_error": parse_error, "edges_after_filter": len(filtered_facts), "filtered_facts": filtered_facts, "filter_ratio": len(filtered_facts) / len(before_filter_edge_json['fact']) if before_filter_edge_json['fact'] else 0, "filtered_out_count": len(before_filter_edge_json['fact']) - len(filtered_facts), "sample_filtered_facts": filtered_facts[:5] if filtered_facts else [] } print(f"[?] LLM过滤完成: {len(before_filter_edge_json['fact'])} -> {len(filtered_facts)} 个边") print(f"[?] 过滤比例: {filter_result_info['filter_ratio']:.2%}") return filtered_facts except Exception as e: print(f"[ERROR] LLM边过滤失败: {e}") # 记录错误信息到LangSmith error_info = { **filter_input_info, "llm_filter_error": str(e), "edges_after_filter": 0, "filtered_facts": [], "filter_success": False } import traceback traceback.print_exc() return [] @traceable(name="Query2Passage_Retrieval") def _query2passage(self, query: str, weight_adjust: float = 0.05) -> Dict[str, float]: """查询到段落的检索(使用原生ES客户端,按原始逻辑)""" # 记录输入参数到LangSmith print(f"[SEARCH] Query2Passage开始: {query}, weight_adjust={weight_adjust}") # 1. 查询向量化 query_emb = self.embedding_model.encode([query], query_type="passage")[0] print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}") # 2. 使用原生ES查询相似段落 try: # 检查索引是否存在 if not self.es_client.indices.exists(index=self.passages_index): print(f"[WARNING] 段落索引不存在: {self.passages_index}") return {} # ES script_score查询 query_body = { "size": 1000, # 获取足够多的段落 "query": { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": { "query_vector": query_emb.tolist() } } } } } result = self.es_client.search(index=self.passages_index, body=query_body) hits = result.get("hits", {}).get("hits", []) print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)}") # 记录ES检索详情到LangSmith es_search_info = { "query": query, "passages_index": self.passages_index, "weight_adjust": weight_adjust, "es_hits_returned": len(hits), "query_vector_dimension": query_emb.shape[0], "top_passages_preview": [] } # 3. 构建段落分数字典 passage_scores = {} for i, hit in enumerate(hits): source = hit["_source"] passage_id = source.get('passage_id') or source.get('node_id') if passage_id: original_score = hit.get('_score', 0.5) adjusted_score = original_score * weight_adjust # 应用权重调整 passage_scores[passage_id] = adjusted_score # 收集前几个段落的预览信息 if i < 5: # 收集前5个段落的详细信息 es_search_info["top_passages_preview"].append({ "rank": i + 1, "passage_id": passage_id, "original_score": float(original_score), "adjusted_score": float(adjusted_score), "content_preview": source.get("content", source.get("text", ""))[:100] + "..." if source.get("content") or source.get("text") else "内容为空" }) # 记录最终结果到LangSmith final_result = { **es_search_info, "final_passage_scores": len(passage_scores), "passage_score_preview": dict(list(passage_scores.items())[:5]) if passage_scores else {}, "total_score_sum": sum(passage_scores.values()) if passage_scores else 0 } print(f"[SEARCH] Query2Passage结果记录到LangSmith: {len(passage_scores)}个段落获得分数") return passage_scores except Exception as e: print(f"[ERROR] 段落检索失败: {e}") import traceback traceback.print_exc() return {} def _find_connected_text_nodes(self, entity_node_ids: List[str]) -> List[str]: """ 根据实体/事件/概念节点ID查找连接的Text节点 Args: entity_node_ids: 实体节点ID列表 Returns: 连接的Text节点ID列表 """ text_node_ids = set() if not self.graph_data or not entity_node_ids: return [] print(f"[SEARCH] 查找 {len(entity_node_ids)} 个节点的连接Text节点") for entity_id in entity_node_ids: if entity_id in self.graph_data.nodes: neighbors = list(self.graph_data.neighbors(entity_id)) for neighbor_id in neighbors: neighbor_data = self.graph_data.nodes[neighbor_id] if neighbor_data.get('type') == 'text': text_node_ids.add(neighbor_id) result = list(text_node_ids) print(f"[SEARCH] 找到 {len(result)} 个相关的Text节点") return result def _query2passage_filtered(self, query: str, entity_node_ids: List[str], weight_adjust: float = 0.05) -> Dict[str, float]: """ 基于LLM过滤后的实体节点进行段落检索 只对与这些实体节点相关的Text节点进行向量相似度计算 Args: query: 查询文本 entity_node_ids: LLM过滤后的实体节点ID列表 weight_adjust: 权重调整因子 Returns: 段落分数字典 """ print(f"[SEARCH] Query2Passage(Filtered)开始: {query}, 基于{len(entity_node_ids)}个节点") print(f"[PARAM] weight_adjust = {weight_adjust}") # 1. 查找相关的Text节点 text_node_ids = self._find_connected_text_nodes(entity_node_ids) if not text_node_ids: print("[WARNING] 没有找到相关的Text节点") return {} # 2. 查询向量化 query_emb = self.embedding_model.encode([query], query_type="passage")[0] print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}") # 3. 使用原生ES查询相关Text节点的段落 try: # 检查索引是否存在 if not self.es_client.indices.exists(index=self.passages_index): print(f"[WARNING] 段落索引不存在: {self.passages_index}") return {} # 构建过滤查询:只搜索指定的text_id query_body = { "size": min(1000, len(text_node_ids)), # 不超过text节点数量 "query": { "script_score": { "query": { "terms": { "passage_id": text_node_ids # 只查询指定的text节点 } }, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": { "query_vector": query_emb.tolist() } } } } } result = self.es_client.search(index=self.passages_index, body=query_body) hits = result.get("hits", {}).get("hits", []) print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)} (从{len(text_node_ids)}个候选Text节点中筛选)") # 4. 构建段落分数字典 passage_scores = {} for i, hit in enumerate(hits): source = hit["_source"] passage_id = source.get('passage_id') or source.get('node_id') if passage_id: original_score = hit.get('_score', 0.5) adjusted_score = original_score * weight_adjust passage_scores[passage_id] = adjusted_score if i < 3: # 打印前3个段落的分数调整细节 print(f" [SCORE-F] 段落{i+1}: 原始分数={original_score:.4f} * weight_adjust={weight_adjust:.4f} = 调整后分数={adjusted_score:.4f}") print(f"[SEARCH] Query2Passage(Filtered)完成: {len(passage_scores)}个段落获得分数") return passage_scores except Exception as e: print(f"[ERROR] 过滤段落检索失败: {e}") import traceback traceback.print_exc() return {} def _query2passage_global_topk(self, query: str, topk: int = 20, weight_adjust: float = 0.05) -> Dict[str, float]: """ 全局Top-K段落检索 从所有段落中返回最相似的Top-K个段落 Args: query: 查询文本 topk: 返回的段落数量 weight_adjust: 权重调整因子 Returns: 段落分数字典 """ print(f"[SEARCH] Query2Passage(Global-Top{topk})开始: {query}") print(f"[PARAM] weight_adjust = {weight_adjust}") # 查询向量化 query_emb = self.embedding_model.encode([query], query_type="passage")[0] print(f"[SEARCH] DEBUG: 段落查询向量维度: {query_emb.shape}") try: # 检查索引是否存在 if not self.es_client.indices.exists(index=self.passages_index): print(f"[WARNING] 段落索引不存在: {self.passages_index}") return {} # ES script_score查询,限制返回数量 query_body = { "size": topk, # 只返回Top-K个结果 "query": { "script_score": { "query": {"match_all": {}}, "script": { "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0", "params": { "query_vector": query_emb.tolist() } } } } } result = self.es_client.search(index=self.passages_index, body=query_body) hits = result.get("hits", {}).get("hits", []) print(f"[SEARCH] DEBUG: ES返回段落数量: {len(hits)} (Global Top-{topk})") # 构建段落分数字典 passage_scores = {} for i, hit in enumerate(hits): source = hit["_source"] passage_id = source.get('passage_id') or source.get('node_id') if passage_id: original_score = hit.get('_score', 0.5) adjusted_score = original_score * weight_adjust passage_scores[passage_id] = adjusted_score if i < 3: # 打印前3个段落的分数调整细节 print(f" [SCORE-G] 段落{i+1}: 原始分数={original_score:.4f} * weight_adjust={weight_adjust:.4f} = 调整后分数={adjusted_score:.4f}") print(f"[SEARCH] Query2Passage(Global-Top{topk})完成: {len(passage_scores)}个段落获得分数") return passage_scores except Exception as e: print(f"[ERROR] 全局Top-K段落检索失败: {e}") import traceback traceback.print_exc() return {} def _query2passage_hybrid(self, query: str, entity_node_ids: List[str], weight_adjust: float = 0.05, global_topk: int = 20) -> Dict[str, float]: """ 混合检索策略:过滤版 + 全局Top-K 确保既有高效性又不遗失高质量段落 Args: query: 查询文本 entity_node_ids: LLM过滤后的实体节点ID列表 weight_adjust: 权重调整因子 global_topk: 全局检索的Top-K数量 Returns: 合并后的段落分数字典 """ print(f"[SEARCH] Query2Passage(Hybrid)开始: 过滤版+全局Top-{global_topk}") print(f"[PARAM] weight_adjust = {weight_adjust} (用于调整段落相似度分数)") # 1. 过滤版检索 filtered_scores = self._query2passage_filtered(query, entity_node_ids, weight_adjust) print(f" [TARGET] 过滤版获得: {len(filtered_scores)}个段落") # 2. 全局Top-K检索 global_scores = self._query2passage_global_topk(query, topk=global_topk, weight_adjust=weight_adjust) print(f" [?] 全局检索获得: {len(global_scores)}个段落") # 3. 合并结果,过滤版优先(如果有重复,保持过滤版的分数) combined_scores = dict(filtered_scores) # 先添加过滤版结果 # 添加全局检索中不重复的结果 added_count = 0 for passage_id, score in global_scores.items(): if passage_id not in combined_scores: combined_scores[passage_id] = score added_count += 1 print(f" [?] 合并后总计: {len(combined_scores)}个段落 (新增{added_count}个)") print(f"[SEARCH] Query2Passage(Hybrid)完成") return combined_scores @traceable(name="Personalized_PageRank") def _personalized_pagerank(self, personalization_dict: Dict[str, float], alpha: float = 0.85) -> Dict[str, float]: """个性化PageRank计算(按原始逻辑)""" # 记录输入参数到LangSmith non_zero_nodes = sum(1 for v in personalization_dict.values() if v > 0) print(f"[SEARCH] PageRank开始: {non_zero_nodes}/{len(personalization_dict)}个节点有非零分数, alpha={alpha}") if not personalization_dict or not self.graph_data.nodes(): print("[WARNING] PageRank输入为空,返回空结果") return {} # 记录PageRank输入详情到LangSmith pagerank_input_info = { "total_nodes_in_graph": len(self.graph_data.nodes()), "total_edges_in_graph": len(self.graph_data.edges()), "personalization_nodes": len(personalization_dict), "non_zero_personalization_nodes": non_zero_nodes, "alpha": alpha, "max_iter": self.inference_config.ppr_max_iter, "tolerance": self.inference_config.ppr_tol, "top_personalization_scores": dict(sorted(personalization_dict.items(), key=lambda x: x[1], reverse=True)[:10]) } try: # 使用原始的个性化字典进行PageRank计算 ppr_scores = nx.pagerank( self.graph_data, alpha=alpha, personalization=personalization_dict, max_iter=self.inference_config.ppr_max_iter, tol=self.inference_config.ppr_tol ) # 记录PageRank结果到LangSmith pagerank_result_info = { **pagerank_input_info, "pagerank_successful": True, "total_nodes_with_ppr_scores": len(ppr_scores), "top_ppr_scores": dict(sorted(ppr_scores.items(), key=lambda x: x[1], reverse=True)[:10]), "ppr_score_statistics": { "max_score": max(ppr_scores.values()) if ppr_scores else 0, "min_score": min(ppr_scores.values()) if ppr_scores else 0, "mean_score": sum(ppr_scores.values()) / len(ppr_scores) if ppr_scores else 0, "total_score_sum": sum(ppr_scores.values()) if ppr_scores else 0 } } print(f"[SEARCH] PageRank结果记录到LangSmith: {len(ppr_scores)}个节点获得PPR分数") return ppr_scores except Exception as e: print(f"[ERROR] PageRank计算失败: {e}") return {} @traceable(name="HippoRAG_Complete_Retrieval") def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """检索相关文档(严格按原始HippoRAG2逻辑)""" if not self.graph_data: print("[WARNING] 图数据未加载") return [] print(f"[SEARCH] HippoRAG2检索开始: {query}") # 1. query2edge - 获取节点分数字典 node_dict = self._query2edge(query, topN=self.inference_config.topk_edges) print(f"[INFO] 边检索得到 {len(node_dict)} 个节点") # 2. query2passage - 混合检索策略(过滤版 + 全局Top-K) entity_node_ids = list(node_dict.keys()) # 获取LLM过滤后的节点ID print(f"[CONFIG] 使用配置参数: weight_adjust={self.inference_config.weight_adjust}, global_topk={self.inference_config.global_topk}") text_dict = self._query2passage_hybrid(query, entity_node_ids, weight_adjust=self.inference_config.weight_adjust, global_topk=self.inference_config.global_topk) print(f"[?] 段落检索得到 {len(text_dict)} 个段落") # 3. 合并个性化字典 personalization_dict = {} # 将所有节点初始化为0 for node in self.graph_data.nodes(): personalization_dict[node] = 0.0 # 添加边检索的节点分数 for node, score in node_dict.items(): if node in personalization_dict: personalization_dict[node] = score # 添加段落检索的文本节点分数 for text_id, score in text_dict.items(): if text_id in personalization_dict: personalization_dict[text_id] = score print(f"[TARGET] 个性化字典包含 {sum(1 for v in personalization_dict.values() if v > 0)} 个非零节点") # 4. 个性化PageRank传播 ppr_scores = self._personalized_pagerank(personalization_dict, self.inference_config.ppr_alpha) if not ppr_scores: print("[WARNING] PageRank计算失败") return [] # 5. 节点排序并分别筛选事件节点和段落节点 event_node_scores = [] text_node_scores = [] for node_id, ppr_score in ppr_scores.items(): if node_id in self.graph_data.nodes: node_type = self.graph_data.nodes[node_id].get('type', '') if node_type == 'event': event_node_scores.append((node_id, ppr_score)) elif node_id in text_dict: # 确保是文本节点且在检索结果中 text_node_scores.append((node_id, ppr_score)) # 按PageRank分数排序 event_node_scores.sort(key=lambda x: x[1], reverse=True) text_node_scores.sort(key=lambda x: x[1], reverse=True) # 6. 构建Document对象:选择TOP-K事件和TOP-K段落 documents = [] print(f"[DEBUG] 准备构建文档: {len(event_node_scores)}个事件节点候选, {len(text_node_scores)}个段落节点候选") print(f"[DEBUG] 将选择TOP-{self.inference_config.top_k_events}事件 + TOP-{self.inference_config.top_k_passages}段落") # 6.1 处理事件节点 (TOP-K) for i, (node_id, ppr_score) in enumerate(event_node_scores[:self.inference_config.top_k_events]): # 事件节点的内容通常存储在图数据中 try: if node_id in self.graph_data.nodes: node_data = self.graph_data.nodes[node_id] # 获取事件内容,可能的字段包括:content, text, description, label等 content = (node_data.get('content') or node_data.get('text') or node_data.get('description') or node_data.get('label') or node_data.get('id') or f"事件节点 {node_id}") # 获取事件节点的source_text_id属性 source_text_id = node_data.get('source_text_id', '') # 获取事件节点的node_id属性(图中存储的节点ID) graph_node_id = node_data.get('node_id', node_id) # 获取事件节点的evidence属性 evidence = node_data.get('evidence', '') else: content = f"事件节点 {node_id} 数据未找到" source_text_id = '' graph_node_id = node_id evidence = '' except Exception as e: content = f"事件节点 {node_id} 内容获取失败: {str(e)}" source_text_id = '' graph_node_id = node_id evidence = '' doc = Document( page_content=content, metadata={ "node_id": node_id, # 检索时使用的节点ID "graph_node_id": graph_node_id, # 图中存储的节点ID属性 "node_type": "event", "ppr_score": ppr_score, "edge_score": node_dict.get(node_id, 0.0), "passage_score": text_dict.get(node_id, 0.0), "rank": i + 1, "source": "hipporag2_langchain_event", "source_text_id": source_text_id, # 新增:事件节点的source_text_id属性 "evidence": evidence # 新增:事件节点的evidence属性 } ) documents.append(doc) # 6.2 处理段落节点 (TOP-K) for i, (node_id, ppr_score) in enumerate(text_node_scores[:self.inference_config.top_k_passages]): # 直接从图数据中获取段落内容 try: if node_id in self.graph_data.nodes: node_data = self.graph_data.nodes[node_id] # 强化调试:打印节点的所有字段,帮助找到正确的文本字段 print(f"[DEBUG] ===== 段落节点调试信息 =====") print(f"[DEBUG] 节点ID: {node_id}") print(f"[DEBUG] 节点类型: {node_data.get('type', 'unknown')}") print(f"[DEBUG] 所有字段名: {sorted(list(node_data.keys()))}") # 打印可能包含文本的字段值(截取前100字符) text_fields = ['original_text', 'text', 'content', 'passage_text', 'full_text', 'raw_text', 'description', 'label'] for field in text_fields: if field in node_data: field_value = str(node_data[field]) preview = field_value[:100] + "..." if len(field_value) > 100 else field_value print(f"[DEBUG] {field}: {preview}") print(f"[DEBUG] ===== 结束调试信息 =====") # 尝试多个可能的文本字段 content = (node_data.get('original_text') or node_data.get('text') or node_data.get('content') or node_data.get('passage_text') or node_data.get('full_text') or node_data.get('raw_text') or node_data.get('description') or node_data.get('label') or f"段落 {node_id} 内容未找到") # 如果仍然没找到内容,打印完整的节点数据 if content == f"段落 {node_id} 内容未找到": print(f"[DEBUG] !!! 段落节点 {node_id} 完整数据: {dict(node_data)}") print(f"[DEBUG] !!! 节点数据类型: {type(node_data)}") # 获取段落节点的evidence属性 evidence = node_data.get('evidence', '') else: content = f"段落 {node_id} 数据未找到" evidence = '' except Exception as e: content = f"段落 {node_id} 内容获取失败: {str(e)}" evidence = '' print(f"[DEBUG] 段落节点 {node_id} 获取异常: {str(e)}") doc = Document( page_content=content, metadata={ "node_id": node_id, "node_type": "text", "ppr_score": ppr_score, "edge_score": node_dict.get(node_id, 0.0), "passage_score": text_dict.get(node_id, 0.0), "rank": len(documents) + 1, # 继续排序 "source": "hipporag2_langchain_text", "evidence": evidence # 新增:段落节点的evidence属性 } ) documents.append(doc) event_count = len(event_node_scores[:self.inference_config.top_k_events]) text_count = len(text_node_scores[:self.inference_config.top_k_passages]) print(f"[OK] HippoRAG2检索完成,返回 {event_count} 个事件节点 + {text_count} 个段落节点,共 {len(documents)} 个文档") # 不在metadata中存储PageRank分数以避免LangSmith追踪大量数据 # doc.metadata["complete_ppr_scores"] = ppr_scores # 注释掉避免LangSmith传输 for doc in documents: doc.metadata["query"] = query # 只存储统计信息而不是完整的PageRank数据 doc.metadata["pagerank_available"] = len(ppr_scores) > 0 return documents def get_complete_pagerank_scores(self, query: str) -> Dict[str, Any]: """ 获取完整的PageRank分数信息(所有非零节点) Args: query: 查询字符串 Returns: 包含完整PageRank信息的字典 """ if not self.graph_data: print("[WARNING] 图数据未加载") return {} print(f"[SEARCH] 获取完整PageRank分数: {query}") # 1. query2edge - 获取节点分数字典 node_dict = self._query2edge(query, topN=self.inference_config.topk_edges) # 2. query2passage - 混合检索策略(过滤版 + 全局Top-K) entity_node_ids = list(node_dict.keys()) # 获取LLM过滤后的节点ID text_dict = self._query2passage_hybrid(query, entity_node_ids, weight_adjust=self.inference_config.weight_adjust, global_topk=self.inference_config.global_topk) # 3. 合并个性化字典 personalization_dict = {} # 将所有节点初始化为0 for node in self.graph_data.nodes(): personalization_dict[node] = 0.0 # 添加边检索的节点分数 for node, score in node_dict.items(): if node in personalization_dict: personalization_dict[node] = score # 添加段落检索的文本节点分数 for text_id, score in text_dict.items(): if text_id in personalization_dict: personalization_dict[text_id] = score # 4. 个性化PageRank传播 ppr_scores = self._personalized_pagerank(personalization_dict, self.inference_config.ppr_alpha) if not ppr_scores: print("[WARNING] PageRank计算失败") return {} # 5. 筛选非零分数的节点,按分数降序排列 non_zero_scores = [(node_id, score) for node_id, score in ppr_scores.items() if score > 0] non_zero_scores.sort(key=lambda x: x[1], reverse=True) # 6. 构建完整的分数信息 return { "query": query, "total_nodes_in_graph": len(self.graph_data.nodes()), "total_edges_in_graph": len(self.graph_data.edges()), "personalization_input": { "node_dict_count": len(node_dict), "text_dict_count": len(text_dict), "non_zero_personalization_count": sum(1 for v in personalization_dict.values() if v > 0) }, "pagerank_results": { "total_nodes_with_scores": len(ppr_scores), "non_zero_nodes_count": len(non_zero_scores), "max_score": max(ppr_scores.values()) if ppr_scores else 0, "min_score": min(ppr_scores.values()) if ppr_scores else 0, "mean_score": sum(ppr_scores.values()) / len(ppr_scores) if ppr_scores else 0, "all_scores": dict(ppr_scores), # 完整的所有节点分数 "non_zero_scores_sorted": non_zero_scores # 非零分数,降序排列 }, "input_scores": { "node_dict": node_dict, "text_dict": text_dict } } def create_langchain_hipporag_retriever( keyword: str, top_k: int = 13, inference_config: Optional[InferenceConfig] = None, oneapi_key: Optional[str] = None, oneapi_base_url: Optional[str] = None, oneapi_model_gen: Optional[str] = None, oneapi_model_embed: Optional[str] = None ) -> LangChainHippoRAGRetriever: """ 创建LangChain HippoRAG检索器的便捷函数 Args: keyword: 关键词 top_k: 返回文档数量 inference_config: 推理配置 oneapi_key: OneAPI密钥 oneapi_base_url: OneAPI基础URL oneapi_model_gen: 生成模型名称 oneapi_model_embed: 嵌入模型名称 Returns: LangChainHippoRAGRetriever实例 """ # 初始化模型 llm_generator = DashScopeLLM( api_key=oneapi_key, model_name=oneapi_model_gen ) embedding_model = DashScopeEmbeddingModel( api_key=oneapi_key, model_name=oneapi_model_embed ) return LangChainHippoRAGRetriever( llm_generator=llm_generator, embedding_model=embedding_model, keyword=keyword, top_k=top_k, inference_config=inference_config ) if __name__ == "__main__": # 测试代码 from dotenv import load_dotenv load_dotenv() retriever = create_langchain_hipporag_retriever( keyword="test", top_k=13 ) # 测试检索 results = retriever.get_relevant_documents("什么是供应链风险管理?") print(f"\n[TARGET] 检索结果:") for i, doc in enumerate(results, 1): print(f"{i}. {doc.metadata['node_id']} (PPR: {doc.metadata['ppr_score']:.4f})") print(f" {doc.page_content[:100]}...")