import networkx as nx import json from tqdm import tqdm import json from tqdm import tqdm from typing import Dict, List, Tuple import networkx as nx import numpy as np import json_repair from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel from atlas_rag.llm_generator.llm_generator import LLMGenerator from logging import Logger from dataclasses import dataclass from typing import Optional from atlas_rag.retriever.base import BasePassageRetriever from atlas_rag.retriever.inference_config import InferenceConfig def min_max_normalize(x): min_val = np.min(x) max_val = np.max(x) range_val = max_val - min_val # Handle the case where all values are the same (range is zero) if range_val == 0: return np.ones_like(x) # Return an array of ones with the same shape as x return (x - min_val) / range_val class HippoRAG2Retriever(BasePassageRetriever): def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel, data : dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs): self.llm_generator = llm_generator self.sentence_encoder = sentence_encoder self.node_embeddings = data["node_embeddings"] self.node_list = data["node_list"] self.edge_list = data["edge_list"] self.edge_embeddings = data["edge_embeddings"] self.text_embeddings = data["text_embeddings"] self.edge_faiss_index = data["edge_faiss_index"] self.passage_dict = data["text_dict"] self.text_id_list = list(self.passage_dict.keys()) self.KG = data["KG"] self.KG = self.KG.subgraph(self.node_list + self.text_id_list) self.logger = logger if self.logger is None: self.logging = False else: self.logging = True hipporag2mode = "query2edge" if hipporag2mode == "query2edge": self.retrieve_node_fn = self.query2edge elif hipporag2mode == "query2node": self.retrieve_node_fn = self.query2node elif hipporag2mode == "ner2node": self.retrieve_node_fn = self.ner2node else: raise ValueError(f"Invalid mode: {hipporag2mode}. Choose from 'query2edge', 'query2node', or 'query2passage'.") self.inference_config = inference_config if inference_config is not None else InferenceConfig() node_id_to_file_id = {} for node_id in tqdm(list(self.KG.nodes)): if self.inference_config.keyword == "musique" and self.KG.nodes[node_id]['type']=="passage": node_id_to_file_id[node_id] = self.KG.nodes[node_id]["id"] else: node_id_to_file_id[node_id] = self.KG.nodes[node_id]["file_id"] self.node_id_to_file_id = node_id_to_file_id def ner(self, text): return self.llm_generator.ner(text) def ner2node(self, query, topN = 10): entities = self.ner(query) entities = entities.split(", ") if len(entities) == 0: entities = [query] # retrieve the top k nodes topk_nodes = [] node_score_dict = {} for entity_index, entity in enumerate(entities): topk_for_this_entity = 1 entity_embedding = self.sentence_encoder.encode([entity], query_type="search") scores = min_max_normalize(self.node_embeddings@entity_embedding[0].T) index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1] similarity_matrix = [scores[i] for i in index_matrix] for index, sim_score in zip(index_matrix, similarity_matrix): node = self.node_list[index] if node not in topk_nodes: topk_nodes.append(node) node_score_dict[node] = sim_score topk_nodes = list(set(topk_nodes)) result_node_score_dict = {} if len(topk_nodes) > 2*topN: topk_nodes = topk_nodes[:2*topN] for node in topk_nodes: if node in node_score_dict: result_node_score_dict[node] = node_score_dict[node] return result_node_score_dict def query2node(self, query, topN = 10): query_emb = self.sentence_encoder.encode([query], query_type="entity") scores = min_max_normalize(self.node_embeddings@query_emb[0].T) index_matrix = np.argsort(scores)[-topN:][::-1] similarity_matrix = [scores[i] for i in index_matrix] result_node_score_dict = {} for index, sim_score in zip(index_matrix, similarity_matrix): node = self.node_list[index] result_node_score_dict[node] = sim_score return result_node_score_dict def query2edge(self, query, topN = 10): query_emb = self.sentence_encoder.encode([query], query_type="edge") scores = min_max_normalize(self.edge_embeddings@query_emb[0].T) index_matrix = np.argsort(scores)[-topN:][::-1] log_edge_list = [] for index in index_matrix: edge = self.edge_list[index] edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']] log_edge_list.append(edge_str) similarity_matrix = [scores[i] for i in index_matrix] # construct the edge list before_filter_edge_json = {} before_filter_edge_json['fact'] = [] for index, sim_score in zip(index_matrix, similarity_matrix): edge = self.edge_list[index] edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']] before_filter_edge_json['fact'].append(edge_str) if self.logging: self.logger.info(f"HippoRAG2 Before Filter Edge: {before_filter_edge_json['fact']}") filtered_facts = self.llm_generator.filter_triples_with_entity_event(query, json.dumps(before_filter_edge_json, ensure_ascii=False)) filtered_facts = json_repair.loads(filtered_facts)['fact'] if len(filtered_facts) == 0: return {} # use filtered facts to get the edge id and check if it exists in the original candidate list. node_score_dict = {} log_edge_list = [] for edge in filtered_facts: edge_str = f'{edge[0]} {edge[1]} {edge[2]}' search_emb = self.sentence_encoder.encode([edge_str], query_type="search") D, I = self.edge_faiss_index.search(search_emb, 1) filtered_index = I[0][0] # get the edge and the original score edge = self.edge_list[filtered_index] log_edge_list.append([self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]) head, tail = edge[0], edge[1] sim_score = scores[filtered_index] if head not in node_score_dict: node_score_dict[head] = [sim_score] else: node_score_dict[head].append(sim_score) if tail not in node_score_dict: node_score_dict[tail] = [sim_score] else: node_score_dict[tail].append(sim_score) # average the scores if self.logging: self.logger.info(f"HippoRAG2: Filtered edges: {log_edge_list}") # take average of the scores for node in node_score_dict: node_score_dict[node] = sum(node_score_dict[node]) / len(node_score_dict[node]) return node_score_dict def query2passage(self, query, weight_adjust = 0.05): query_emb = self.sentence_encoder.encode([query], query_type="passage") sim_scores = self.text_embeddings @ query_emb[0].T sim_scores = min_max_normalize(sim_scores)*weight_adjust # converted to probability # create dict of passage id and score return dict(zip(self.text_id_list, sim_scores)) def retrieve_personalization_dict(self, query, topN=30, weight_adjust=0.05): node_dict = self.retrieve_node_fn(query, topN=topN) text_dict = self.query2passage(query, weight_adjust=weight_adjust) return node_dict, text_dict def retrieve(self, query, topN=5, **kwargs): topN_edges = self.inference_config.topk_edges weight_adjust = self.inference_config.weight_adjust node_dict, text_dict = self.retrieve_personalization_dict(query, topN=topN_edges, weight_adjust=weight_adjust) personalization_dict = {} if len(node_dict) == 0: # return topN text passages sorted_passages = sorted(text_dict.items(), key=lambda x: x[1], reverse=True) sorted_passages = sorted_passages[:topN] sorted_passages_contents = [] sorted_scores = [] sorted_passage_ids = [] for passage_id, score in sorted_passages: sorted_passages_contents.append(self.passage_dict[passage_id]) sorted_scores.append(float(score)) sorted_passage_ids.append(self.node_id_to_file_id[passage_id]) return sorted_passages_contents, sorted_passage_ids personalization_dict.update(node_dict) personalization_dict.update(text_dict) # retrieve the top N passages pr = nx.pagerank(self.KG, personalization=personalization_dict, alpha = self.inference_config.ppr_alpha, max_iter=self.inference_config.ppr_max_iter, tol=self.inference_config.ppr_tol) # get the top N passages based on the text_id list and pagerank score text_dict_score = {} for node in self.text_id_list: # filter out nodes that have 0 score if pr[node] > 0.0: text_dict_score[node] = pr[node] # return topN passages sorted_passages_ids = sorted(text_dict_score.items(), key=lambda x: x[1], reverse=True) sorted_passages_ids = sorted_passages_ids[:topN] sorted_passages_contents = [] sorted_scores = [] sorted_passage_ids = [] for passage_id, score in sorted_passages_ids: sorted_passages_contents.append(self.passage_dict[passage_id]) sorted_scores.append(score) sorted_passage_ids.append(self.node_id_to_file_id[passage_id]) return sorted_passages_contents, sorted_passage_ids