first commit

This commit is contained in:
闫旭隆
2025-10-17 09:31:28 +08:00
commit 4698145045
589 changed files with 196795 additions and 0 deletions

View File

@ -0,0 +1,140 @@
from tqdm import tqdm
import networkx as nx
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from logging import Logger
from typing import Optional
from atlas_rag.retriever.base import BasePassageRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
class HippoRAGRetriever(BasePassageRetriever):
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
data:dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs):
self.passage_dict = data["text_dict"]
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.node_list = data["node_list"]
file_id_to_node_id = {}
self.KG = data["KG"]
for node_id in tqdm(list(self.KG.nodes)):
if self.KG.nodes[node_id]['type'] == "passage":
if self.KG.nodes[node_id]['file_id'] not in file_id_to_node_id:
file_id_to_node_id[self.KG.nodes[node_id]['file_id']] = []
file_id_to_node_id[self.KG.nodes[node_id]['file_id']].append(node_id)
self.file_id_to_node_id = file_id_to_node_id
self.KG:nx.DiGraph = self.KG.subgraph(self.node_list)
self.node_name_list = [self.KG.nodes[node]["id"] for node in self.node_list]
self.logger :Logger = logger
if self.logger is None:
self.logging = False
else:
self.logging = True
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
def retrieve_personalization_dict(self, query, topN=10):
# extract entities from the query
entities = self.llm_generator.ner(query)
entities = entities.split(", ")
if self.logging:
self.logger.info(f"HippoRAG NER Entities: {entities}")
# print("Entities:", entities)
if len(entities) == 0:
# If the NER cannot extract any entities, we
# use the query as the entity to do approximate search
entities = [query]
# evenly distribute the topk for each entity
topk_for_each_entity = topN//len(entities)
# retrieve the top k nodes
topk_nodes = []
for entity_index, entity in enumerate(entities):
if entity in self.node_name_list:
# get the index of the entity in the node list
index = self.node_name_list.index(entity)
topk_nodes.append(self.node_list[index])
else:
topk_for_this_entity = 1
# print("Topk for this entity:", topk_for_this_entity)
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
scores = self.node_embeddings@entity_embedding[0].T
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
topk_nodes += [self.node_list[i] for i in index_matrix]
if self.logging:
self.logger.info(f"HippoRAG Topk Nodes: {[self.KG.nodes[node]['id'] for node in topk_nodes]}")
topk_nodes = list(set(topk_nodes))
# assert len(topk_nodes) <= topN
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
# print("Topk nodes:", topk_nodes)
# find the number of docs that one work appears in
freq_dict_for_nodes = {}
for node in topk_nodes:
node_data = self.KG.nodes[node]
# print(node_data)
file_ids = node_data["file_id"]
file_ids_list = file_ids.split(",")
#uniq this list
file_ids_list = list(set(file_ids_list))
freq_dict_for_nodes[node] = len(file_ids_list)
personalization_dict = {node: 1 / freq_dict_for_nodes[node] for node in topk_nodes}
# print("personalization dict: ")
return personalization_dict
def retrieve(self, query, topN=5, **kwargs):
topN_nodes = self.inference_config.topk_nodes
personaliation_dict = self.retrieve_personalization_dict(query, topN=topN_nodes)
# retrieve the top N passages
pr = nx.pagerank(self.KG, personalization=personaliation_dict)
for node in pr:
pr[node] = round(pr[node], 4)
if pr[node] < 0.001:
pr[node] = 0
passage_probabilities_sum = {}
for node in pr:
node_data = self.KG.nodes[node]
file_ids = node_data["file_id"]
# for each file id check through each text_id
file_ids_list = file_ids.split(",")
#uniq this list
file_ids_list = list(set(file_ids_list))
# file id to node id
for file_id in file_ids_list:
if file_id == 'concept_file':
continue
for node_id in self.file_id_to_node_id[file_id]:
if node_id not in passage_probabilities_sum:
passage_probabilities_sum[node_id] = 0
passage_probabilities_sum[node_id] += pr[node]
sorted_passages = sorted(passage_probabilities_sum.items(), key=lambda x: x[1], reverse=True)
top_passages = sorted_passages[:topN]
top_passages, scores = zip(*top_passages)
passag_contents = [self.passage_dict[passage_id] for passage_id in top_passages]
return passag_contents, top_passages