Files
AIEC-RAG---/AIEC-RAG/atlas_rag/retriever/tog.py

195 lines
8.4 KiB
Python
Raw Normal View History

2025-09-25 10:33:37 +08:00
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from typing import Optional
from atlas_rag.retriever.base import BaseEdgeRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
class TogRetriever(BaseEdgeRetriever):
def __init__(self, llm_generator, sentence_encoder, data, inference_config: Optional[InferenceConfig] = None):
self.KG = data["KG"]
self.node_list = list(self.KG.nodes)
self.edge_list = list(self.KG.edges)
self.edge_list_with_relation = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in self.edge_list]
self.edge_list_string = [f"{edge[0]} {self.KG.edges[edge]['relation']} {edge[1]}" for edge in self.edge_list]
self.llm_generator:LLMGenerator = llm_generator
self.sentence_encoder:BaseEmbeddingModel = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.edge_embeddings = data["edge_embeddings"]
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
def ner(self, text):
messages = [
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
{"role": "user", "content": f"Extract the named entities from: Are Portland International Airport and Gerald R. Ford International Airport both located in Oregon?"},
{"role": "system", "content": "Portland International Airport, Gerald R. Ford International Airport, Oregon"},
{"role": "user", "content": f"Extract the named entities from: {text}"},
]
response = self.llm_generator.generate_response(messages)
generated_text = response
# print(generated_text)
return generated_text
def retrieve_topk_nodes(self, query, topN=5, **kwargs):
# extract entities from the query
entities = self.ner(query)
entities = entities.split(", ")
if len(entities) == 0:
# If the NER cannot extract any entities, we
# use the query as the entity to do approximate search
entities = [query]
# evenly distribute the topk for each entity
topk_for_each_entity = topN//len(entities)
# retrieve the top k nodes
topk_nodes = []
for entity_index, entity in enumerate(entities):
if entity in self.node_list:
topk_nodes.append(entity)
for entity_index, entity in enumerate(entities):
topk_for_this_entity = topk_for_each_entity + 1
entity_embedding = self.sentence_encoder.encode([entity])
# Calculate similarity scores using dot product
scores = self.node_embeddings @ entity_embedding[0].T
# Get top-k indices
top_indices = np.argsort(scores)[-topk_for_this_entity:][::-1]
topk_nodes += [self.node_list[i] for i in top_indices]
topk_nodes = list(set(topk_nodes))
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
return topk_nodes
def retrieve(self, query, topN=5, **kwargs):
"""
Retrieve the top N paths that connect the entities in the query.
Dmax is the maximum depth of the search.
"""
Dmax = self.inference_config.Dmax
# in the first step, we retrieve the top k nodes
initial_nodes = self.retrieve_topk_nodes(query, topN=topN)
E = initial_nodes
P = [ [e] for e in E]
D = 0
while D <= Dmax:
P = self.search(query, P)
P = self.prune(query, P, topN)
if self.reasoning(query, P):
generated_text = self.generate(query, P)
break
D += 1
if D > Dmax:
generated_text = self.generate(query, P)
# print(generated_text)
return generated_text
def search(self, query, P):
new_paths = []
for path in P:
tail_entity = path[-1]
sucessors = list(self.KG.successors(tail_entity))
predecessors = list(self.KG.predecessors(tail_entity))
# print(f"tail_entity: {tail_entity}")
# print(f"sucessors: {sucessors}")
# print(f"predecessors: {predecessors}")
# # print the attributes of the tail_entity
# print(f"attributes of the tail_entity: {self.KG.nodes[tail_entity]}")
# remove the entity that is already in the path
sucessors = [neighbour for neighbour in sucessors if neighbour not in path]
predecessors = [neighbour for neighbour in predecessors if neighbour not in path]
if len(sucessors) == 0 and len(predecessors) == 0:
new_paths.append(path)
continue
for neighbour in sucessors:
relation = self.KG.edges[(tail_entity, neighbour)]["relation"]
new_path = path + [relation, neighbour]
new_paths.append(new_path)
for neighbour in predecessors:
relation = self.KG.edges[(neighbour, tail_entity)]["relation"]
new_path = path + [relation, neighbour]
new_paths.append(new_path)
return new_paths
def prune(self, query, P, topN=3):
ratings = []
for path in P:
path_string = ""
for index, node_or_relation in enumerate(path):
if index % 2 == 0:
id_path = self.KG.nodes[node_or_relation]["id"]
else:
id_path = node_or_relation
path_string += f"{id_path} --->"
path_string = path_string[:-5]
prompt = f"Please rating the following path based on the relevance to the question. The ratings should be in the range of 1 to 5. 1 for least relevant and 5 for most relevant. Only provide the rating, do not provide any other information. The output should be a single integer number. If you think the path is not relevant, please provide 0. If you think the path is relevant, please provide a rating between 1 and 5. \n Query: {query} \n path: {path_string}"
messages = [{"role": "system", "content": "Answer the question following the prompt."},
{"role": "user", "content": f"{prompt}"}]
response = self.llm_generator.generate_response(messages)
# print(response)
rating = int(response)
ratings.append(rating)
# sort the paths based on the ratings
sorted_paths = [path for _, path in sorted(zip(ratings, P), reverse=True)]
return sorted_paths[:topN]
def reasoning(self, query, P):
triples = []
for path in P:
for i in range(0, len(path)-2, 2):
# triples.append((path[i], path[i+1], path[i+2]))
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
triples_string = ". ".join(triples_string)
prompt = f"Given a question and the associated retrieved knowledge graph triples (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triples and your knowledge (Yes or No). Query: {query} \n Knowledge triples: {triples_string}"
messages = [{"role": "system", "content": "Answer the question following the prompt."},
{"role": "user", "content": f"{prompt}"}]
response = self.llm_generator.generate_response(messages)
return "yes" in response.lower()
def generate(self, query, P):
triples = []
for path in P:
for i in range(0, len(path)-2, 2):
# triples.append((path[i], path[i+1], path[i+2]))
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
# response = self.llm_generator.generate_with_context_kg(query, triples_string)
return triples_string, ["N/A" for _ in range(len(triples_string))]