first commit
This commit is contained in:
237
atlas_rag/retriever/hipporag2.py
Normal file
237
atlas_rag/retriever/hipporag2.py
Normal file
@ -0,0 +1,237 @@
|
||||
import networkx as nx
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Tuple
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import json_repair
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BasePassageRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
def min_max_normalize(x):
|
||||
min_val = np.min(x)
|
||||
max_val = np.max(x)
|
||||
range_val = max_val - min_val
|
||||
|
||||
# Handle the case where all values are the same (range is zero)
|
||||
if range_val == 0:
|
||||
return np.ones_like(x) # Return an array of ones with the same shape as x
|
||||
|
||||
return (x - min_val) / range_val
|
||||
|
||||
|
||||
class HippoRAG2Retriever(BasePassageRetriever):
|
||||
def __init__(self, llm_generator:LLMGenerator,
|
||||
sentence_encoder:BaseEmbeddingModel,
|
||||
data : dict,
|
||||
inference_config: Optional[InferenceConfig] = None,
|
||||
logger = None,
|
||||
**kwargs):
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.node_list = data["node_list"]
|
||||
self.edge_list = data["edge_list"]
|
||||
self.edge_embeddings = data["edge_embeddings"]
|
||||
self.text_embeddings = data["text_embeddings"]
|
||||
self.edge_faiss_index = data["edge_faiss_index"]
|
||||
self.passage_dict = data["text_dict"]
|
||||
self.text_id_list = list(self.passage_dict.keys())
|
||||
self.KG = data["KG"]
|
||||
self.KG = self.KG.subgraph(self.node_list + self.text_id_list)
|
||||
|
||||
|
||||
self.logger = logger
|
||||
if self.logger is None:
|
||||
self.logging = False
|
||||
else:
|
||||
self.logging = True
|
||||
|
||||
hipporag2mode = "query2edge"
|
||||
if hipporag2mode == "query2edge":
|
||||
self.retrieve_node_fn = self.query2edge
|
||||
elif hipporag2mode == "query2node":
|
||||
self.retrieve_node_fn = self.query2node
|
||||
elif hipporag2mode == "ner2node":
|
||||
self.retrieve_node_fn = self.ner2node
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {hipporag2mode}. Choose from 'query2edge', 'query2node', or 'query2passage'.")
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
node_id_to_file_id = {}
|
||||
for node_id in tqdm(list(self.KG.nodes)):
|
||||
if self.inference_config.keyword == "musique" and self.KG.nodes[node_id]['type']=="passage":
|
||||
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["id"]
|
||||
else:
|
||||
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["file_id"]
|
||||
self.node_id_to_file_id = node_id_to_file_id
|
||||
|
||||
def ner(self, text):
|
||||
return self.llm_generator.ner(text)
|
||||
|
||||
def ner2node(self, query, topN = 10):
|
||||
entities = self.ner(query)
|
||||
entities = entities.split(", ")
|
||||
|
||||
if len(entities) == 0:
|
||||
entities = [query]
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
node_score_dict = {}
|
||||
for entity_index, entity in enumerate(entities):
|
||||
topk_for_this_entity = 1
|
||||
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
|
||||
scores = min_max_normalize(self.node_embeddings@entity_embedding[0].T)
|
||||
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
node = self.node_list[index]
|
||||
if node not in topk_nodes:
|
||||
topk_nodes.append(node)
|
||||
node_score_dict[node] = sim_score
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
result_node_score_dict = {}
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
for node in topk_nodes:
|
||||
if node in node_score_dict:
|
||||
result_node_score_dict[node] = node_score_dict[node]
|
||||
return result_node_score_dict
|
||||
|
||||
def query2node(self, query, topN = 10):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="entity")
|
||||
scores = min_max_normalize(self.node_embeddings@query_emb[0].T)
|
||||
index_matrix = np.argsort(scores)[-topN:][::-1]
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
result_node_score_dict = {}
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
node = self.node_list[index]
|
||||
result_node_score_dict[node] = sim_score
|
||||
|
||||
return result_node_score_dict
|
||||
|
||||
def query2edge(self, query, topN = 10):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="edge")
|
||||
scores = min_max_normalize(self.edge_embeddings@query_emb[0].T)
|
||||
index_matrix = np.argsort(scores)[-topN:][::-1]
|
||||
log_edge_list = []
|
||||
for index in index_matrix:
|
||||
edge = self.edge_list[index]
|
||||
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
|
||||
log_edge_list.append(edge_str)
|
||||
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
# construct the edge list
|
||||
before_filter_edge_json = {}
|
||||
before_filter_edge_json['fact'] = []
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
edge = self.edge_list[index]
|
||||
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
|
||||
before_filter_edge_json['fact'].append(edge_str)
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG2 Before Filter Edge: {before_filter_edge_json['fact']}")
|
||||
filtered_facts = self.llm_generator.filter_triples_with_entity_event(query, json.dumps(before_filter_edge_json, ensure_ascii=False))
|
||||
filtered_facts = json_repair.loads(filtered_facts)['fact']
|
||||
if len(filtered_facts) == 0:
|
||||
return {}
|
||||
# use filtered facts to get the edge id and check if it exists in the original candidate list.
|
||||
node_score_dict = {}
|
||||
log_edge_list = []
|
||||
for edge in filtered_facts:
|
||||
edge_str = f'{edge[0]} {edge[1]} {edge[2]}'
|
||||
search_emb = self.sentence_encoder.encode([edge_str], query_type="search")
|
||||
D, I = self.edge_faiss_index.search(search_emb, 1)
|
||||
filtered_index = I[0][0]
|
||||
# get the edge and the original score
|
||||
edge = self.edge_list[filtered_index]
|
||||
log_edge_list.append([self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']])
|
||||
head, tail = edge[0], edge[1]
|
||||
sim_score = scores[filtered_index]
|
||||
|
||||
if head not in node_score_dict:
|
||||
node_score_dict[head] = [sim_score]
|
||||
else:
|
||||
node_score_dict[head].append(sim_score)
|
||||
if tail not in node_score_dict:
|
||||
node_score_dict[tail] = [sim_score]
|
||||
else:
|
||||
node_score_dict[tail].append(sim_score)
|
||||
# average the scores
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG2: Filtered edges: {log_edge_list}")
|
||||
|
||||
# take average of the scores
|
||||
for node in node_score_dict:
|
||||
node_score_dict[node] = sum(node_score_dict[node]) / len(node_score_dict[node])
|
||||
|
||||
return node_score_dict
|
||||
|
||||
def query2passage(self, query, weight_adjust = 0.05):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="passage")
|
||||
sim_scores = self.text_embeddings @ query_emb[0].T
|
||||
sim_scores = min_max_normalize(sim_scores)*weight_adjust # converted to probability
|
||||
# create dict of passage id and score
|
||||
return dict(zip(self.text_id_list, sim_scores))
|
||||
|
||||
def retrieve_personalization_dict(self, query, topN=30, weight_adjust=0.05):
|
||||
node_dict = self.retrieve_node_fn(query, topN=topN)
|
||||
text_dict = self.query2passage(query, weight_adjust=weight_adjust)
|
||||
|
||||
return node_dict, text_dict
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
topN_edges = self.inference_config.topk_edges
|
||||
weight_adjust = self.inference_config.weight_adjust
|
||||
|
||||
node_dict, text_dict = self.retrieve_personalization_dict(query, topN=topN_edges, weight_adjust=weight_adjust)
|
||||
|
||||
personalization_dict = {}
|
||||
if len(node_dict) == 0:
|
||||
# return topN text passages
|
||||
sorted_passages = sorted(text_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages = sorted_passages[:topN]
|
||||
sorted_passages_contents = []
|
||||
sorted_scores = []
|
||||
sorted_passage_ids = []
|
||||
for passage_id, score in sorted_passages:
|
||||
sorted_passages_contents.append(self.passage_dict[passage_id])
|
||||
sorted_scores.append(float(score))
|
||||
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
|
||||
return sorted_passages_contents, sorted_passage_ids
|
||||
|
||||
personalization_dict.update(node_dict)
|
||||
personalization_dict.update(text_dict)
|
||||
# retrieve the top N passages
|
||||
pr = nx.pagerank(self.KG, personalization=personalization_dict,
|
||||
alpha = self.inference_config.ppr_alpha,
|
||||
max_iter=self.inference_config.ppr_max_iter,
|
||||
tol=self.inference_config.ppr_tol)
|
||||
|
||||
# get the top N passages based on the text_id list and pagerank score
|
||||
text_dict_score = {}
|
||||
for node in self.text_id_list:
|
||||
# filter out nodes that have 0 score
|
||||
if pr[node] > 0.0:
|
||||
text_dict_score[node] = pr[node]
|
||||
|
||||
# return topN passages
|
||||
sorted_passages_ids = sorted(text_dict_score.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages_ids = sorted_passages_ids[:topN]
|
||||
|
||||
sorted_passages_contents = []
|
||||
sorted_scores = []
|
||||
sorted_passage_ids = []
|
||||
for passage_id, score in sorted_passages_ids:
|
||||
sorted_passages_contents.append(self.passage_dict[passage_id])
|
||||
sorted_scores.append(score)
|
||||
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
|
||||
return sorted_passages_contents, sorted_passage_ids
|
||||
Reference in New Issue
Block a user