first commit
This commit is contained in:
140
AIEC-RAG/atlas_rag/retriever/hipporag.py
Normal file
140
AIEC-RAG/atlas_rag/retriever/hipporag.py
Normal file
@ -0,0 +1,140 @@
|
||||
from tqdm import tqdm
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BasePassageRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
|
||||
class HippoRAGRetriever(BasePassageRetriever):
|
||||
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
data:dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs):
|
||||
self.passage_dict = data["text_dict"]
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.node_list = data["node_list"]
|
||||
file_id_to_node_id = {}
|
||||
self.KG = data["KG"]
|
||||
for node_id in tqdm(list(self.KG.nodes)):
|
||||
if self.KG.nodes[node_id]['type'] == "passage":
|
||||
if self.KG.nodes[node_id]['file_id'] not in file_id_to_node_id:
|
||||
file_id_to_node_id[self.KG.nodes[node_id]['file_id']] = []
|
||||
file_id_to_node_id[self.KG.nodes[node_id]['file_id']].append(node_id)
|
||||
self.file_id_to_node_id = file_id_to_node_id
|
||||
|
||||
self.KG:nx.DiGraph = self.KG.subgraph(self.node_list)
|
||||
self.node_name_list = [self.KG.nodes[node]["id"] for node in self.node_list]
|
||||
|
||||
|
||||
self.logger :Logger = logger
|
||||
if self.logger is None:
|
||||
self.logging = False
|
||||
else:
|
||||
self.logging = True
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
|
||||
def retrieve_personalization_dict(self, query, topN=10):
|
||||
|
||||
# extract entities from the query
|
||||
entities = self.llm_generator.ner(query)
|
||||
entities = entities.split(", ")
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG NER Entities: {entities}")
|
||||
# print("Entities:", entities)
|
||||
|
||||
if len(entities) == 0:
|
||||
# If the NER cannot extract any entities, we
|
||||
# use the query as the entity to do approximate search
|
||||
entities = [query]
|
||||
|
||||
# evenly distribute the topk for each entity
|
||||
topk_for_each_entity = topN//len(entities)
|
||||
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
if entity in self.node_name_list:
|
||||
# get the index of the entity in the node list
|
||||
index = self.node_name_list.index(entity)
|
||||
topk_nodes.append(self.node_list[index])
|
||||
else:
|
||||
topk_for_this_entity = 1
|
||||
|
||||
# print("Topk for this entity:", topk_for_this_entity)
|
||||
|
||||
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
|
||||
scores = self.node_embeddings@entity_embedding[0].T
|
||||
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
|
||||
topk_nodes += [self.node_list[i] for i in index_matrix]
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG Topk Nodes: {[self.KG.nodes[node]['id'] for node in topk_nodes]}")
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
|
||||
# assert len(topk_nodes) <= topN
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
|
||||
|
||||
# print("Topk nodes:", topk_nodes)
|
||||
# find the number of docs that one work appears in
|
||||
freq_dict_for_nodes = {}
|
||||
for node in topk_nodes:
|
||||
node_data = self.KG.nodes[node]
|
||||
# print(node_data)
|
||||
file_ids = node_data["file_id"]
|
||||
file_ids_list = file_ids.split(",")
|
||||
#uniq this list
|
||||
file_ids_list = list(set(file_ids_list))
|
||||
freq_dict_for_nodes[node] = len(file_ids_list)
|
||||
|
||||
personalization_dict = {node: 1 / freq_dict_for_nodes[node] for node in topk_nodes}
|
||||
|
||||
# print("personalization dict: ")
|
||||
return personalization_dict
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
topN_nodes = self.inference_config.topk_nodes
|
||||
personaliation_dict = self.retrieve_personalization_dict(query, topN=topN_nodes)
|
||||
|
||||
# retrieve the top N passages
|
||||
pr = nx.pagerank(self.KG, personalization=personaliation_dict)
|
||||
|
||||
for node in pr:
|
||||
pr[node] = round(pr[node], 4)
|
||||
if pr[node] < 0.001:
|
||||
pr[node] = 0
|
||||
|
||||
passage_probabilities_sum = {}
|
||||
for node in pr:
|
||||
node_data = self.KG.nodes[node]
|
||||
file_ids = node_data["file_id"]
|
||||
# for each file id check through each text_id
|
||||
file_ids_list = file_ids.split(",")
|
||||
#uniq this list
|
||||
file_ids_list = list(set(file_ids_list))
|
||||
# file id to node id
|
||||
|
||||
for file_id in file_ids_list:
|
||||
if file_id == 'concept_file':
|
||||
continue
|
||||
for node_id in self.file_id_to_node_id[file_id]:
|
||||
if node_id not in passage_probabilities_sum:
|
||||
passage_probabilities_sum[node_id] = 0
|
||||
passage_probabilities_sum[node_id] += pr[node]
|
||||
|
||||
sorted_passages = sorted(passage_probabilities_sum.items(), key=lambda x: x[1], reverse=True)
|
||||
top_passages = sorted_passages[:topN]
|
||||
top_passages, scores = zip(*top_passages)
|
||||
|
||||
passag_contents = [self.passage_dict[passage_id] for passage_id in top_passages]
|
||||
|
||||
return passag_contents, top_passages
|
||||
Reference in New Issue
Block a user