195 lines
8.4 KiB
Python
195 lines
8.4 KiB
Python
|
|
import numpy as np
|
||
|
|
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||
|
|
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||
|
|
from typing import Optional
|
||
|
|
from atlas_rag.retriever.base import BaseEdgeRetriever
|
||
|
|
from atlas_rag.retriever.inference_config import InferenceConfig
|
||
|
|
|
||
|
|
class TogRetriever(BaseEdgeRetriever):
|
||
|
|
def __init__(self, llm_generator, sentence_encoder, data, inference_config: Optional[InferenceConfig] = None):
|
||
|
|
self.KG = data["KG"]
|
||
|
|
|
||
|
|
self.node_list = list(self.KG.nodes)
|
||
|
|
self.edge_list = list(self.KG.edges)
|
||
|
|
self.edge_list_with_relation = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in self.edge_list]
|
||
|
|
self.edge_list_string = [f"{edge[0]} {self.KG.edges[edge]['relation']} {edge[1]}" for edge in self.edge_list]
|
||
|
|
|
||
|
|
self.llm_generator:LLMGenerator = llm_generator
|
||
|
|
self.sentence_encoder:BaseEmbeddingModel = sentence_encoder
|
||
|
|
|
||
|
|
self.node_embeddings = data["node_embeddings"]
|
||
|
|
self.edge_embeddings = data["edge_embeddings"]
|
||
|
|
|
||
|
|
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||
|
|
|
||
|
|
def ner(self, text):
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
|
||
|
|
{"role": "user", "content": f"Extract the named entities from: Are Portland International Airport and Gerald R. Ford International Airport both located in Oregon?"},
|
||
|
|
{"role": "system", "content": "Portland International Airport, Gerald R. Ford International Airport, Oregon"},
|
||
|
|
{"role": "user", "content": f"Extract the named entities from: {text}"},
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
response = self.llm_generator.generate_response(messages)
|
||
|
|
generated_text = response
|
||
|
|
# print(generated_text)
|
||
|
|
return generated_text
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def retrieve_topk_nodes(self, query, topN=5, **kwargs):
|
||
|
|
# extract entities from the query
|
||
|
|
entities = self.ner(query)
|
||
|
|
entities = entities.split(", ")
|
||
|
|
|
||
|
|
if len(entities) == 0:
|
||
|
|
# If the NER cannot extract any entities, we
|
||
|
|
# use the query as the entity to do approximate search
|
||
|
|
entities = [query]
|
||
|
|
|
||
|
|
# evenly distribute the topk for each entity
|
||
|
|
topk_for_each_entity = topN//len(entities)
|
||
|
|
|
||
|
|
# retrieve the top k nodes
|
||
|
|
topk_nodes = []
|
||
|
|
|
||
|
|
for entity_index, entity in enumerate(entities):
|
||
|
|
if entity in self.node_list:
|
||
|
|
topk_nodes.append(entity)
|
||
|
|
|
||
|
|
for entity_index, entity in enumerate(entities):
|
||
|
|
topk_for_this_entity = topk_for_each_entity + 1
|
||
|
|
|
||
|
|
entity_embedding = self.sentence_encoder.encode([entity])
|
||
|
|
# Calculate similarity scores using dot product
|
||
|
|
scores = self.node_embeddings @ entity_embedding[0].T
|
||
|
|
# Get top-k indices
|
||
|
|
top_indices = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||
|
|
topk_nodes += [self.node_list[i] for i in top_indices]
|
||
|
|
|
||
|
|
topk_nodes = list(set(topk_nodes))
|
||
|
|
|
||
|
|
if len(topk_nodes) > 2*topN:
|
||
|
|
topk_nodes = topk_nodes[:2*topN]
|
||
|
|
return topk_nodes
|
||
|
|
|
||
|
|
def retrieve(self, query, topN=5, **kwargs):
|
||
|
|
"""
|
||
|
|
Retrieve the top N paths that connect the entities in the query.
|
||
|
|
Dmax is the maximum depth of the search.
|
||
|
|
"""
|
||
|
|
Dmax = self.inference_config.Dmax
|
||
|
|
# in the first step, we retrieve the top k nodes
|
||
|
|
initial_nodes = self.retrieve_topk_nodes(query, topN=topN)
|
||
|
|
E = initial_nodes
|
||
|
|
P = [ [e] for e in E]
|
||
|
|
D = 0
|
||
|
|
|
||
|
|
while D <= Dmax:
|
||
|
|
P = self.search(query, P)
|
||
|
|
P = self.prune(query, P, topN)
|
||
|
|
|
||
|
|
if self.reasoning(query, P):
|
||
|
|
generated_text = self.generate(query, P)
|
||
|
|
break
|
||
|
|
|
||
|
|
D += 1
|
||
|
|
|
||
|
|
if D > Dmax:
|
||
|
|
generated_text = self.generate(query, P)
|
||
|
|
|
||
|
|
# print(generated_text)
|
||
|
|
return generated_text
|
||
|
|
|
||
|
|
def search(self, query, P):
|
||
|
|
new_paths = []
|
||
|
|
for path in P:
|
||
|
|
tail_entity = path[-1]
|
||
|
|
sucessors = list(self.KG.successors(tail_entity))
|
||
|
|
predecessors = list(self.KG.predecessors(tail_entity))
|
||
|
|
|
||
|
|
# print(f"tail_entity: {tail_entity}")
|
||
|
|
# print(f"sucessors: {sucessors}")
|
||
|
|
# print(f"predecessors: {predecessors}")
|
||
|
|
|
||
|
|
# # print the attributes of the tail_entity
|
||
|
|
# print(f"attributes of the tail_entity: {self.KG.nodes[tail_entity]}")
|
||
|
|
|
||
|
|
# remove the entity that is already in the path
|
||
|
|
sucessors = [neighbour for neighbour in sucessors if neighbour not in path]
|
||
|
|
predecessors = [neighbour for neighbour in predecessors if neighbour not in path]
|
||
|
|
|
||
|
|
if len(sucessors) == 0 and len(predecessors) == 0:
|
||
|
|
new_paths.append(path)
|
||
|
|
continue
|
||
|
|
for neighbour in sucessors:
|
||
|
|
relation = self.KG.edges[(tail_entity, neighbour)]["relation"]
|
||
|
|
new_path = path + [relation, neighbour]
|
||
|
|
new_paths.append(new_path)
|
||
|
|
|
||
|
|
for neighbour in predecessors:
|
||
|
|
relation = self.KG.edges[(neighbour, tail_entity)]["relation"]
|
||
|
|
new_path = path + [relation, neighbour]
|
||
|
|
new_paths.append(new_path)
|
||
|
|
|
||
|
|
return new_paths
|
||
|
|
|
||
|
|
def prune(self, query, P, topN=3):
|
||
|
|
ratings = []
|
||
|
|
|
||
|
|
for path in P:
|
||
|
|
path_string = ""
|
||
|
|
for index, node_or_relation in enumerate(path):
|
||
|
|
if index % 2 == 0:
|
||
|
|
id_path = self.KG.nodes[node_or_relation]["id"]
|
||
|
|
else:
|
||
|
|
id_path = node_or_relation
|
||
|
|
path_string += f"{id_path} --->"
|
||
|
|
path_string = path_string[:-5]
|
||
|
|
|
||
|
|
prompt = f"Please rating the following path based on the relevance to the question. The ratings should be in the range of 1 to 5. 1 for least relevant and 5 for most relevant. Only provide the rating, do not provide any other information. The output should be a single integer number. If you think the path is not relevant, please provide 0. If you think the path is relevant, please provide a rating between 1 and 5. \n Query: {query} \n path: {path_string}"
|
||
|
|
|
||
|
|
messages = [{"role": "system", "content": "Answer the question following the prompt."},
|
||
|
|
{"role": "user", "content": f"{prompt}"}]
|
||
|
|
|
||
|
|
response = self.llm_generator.generate_response(messages)
|
||
|
|
# print(response)
|
||
|
|
rating = int(response)
|
||
|
|
ratings.append(rating)
|
||
|
|
|
||
|
|
# sort the paths based on the ratings
|
||
|
|
sorted_paths = [path for _, path in sorted(zip(ratings, P), reverse=True)]
|
||
|
|
|
||
|
|
return sorted_paths[:topN]
|
||
|
|
|
||
|
|
def reasoning(self, query, P):
|
||
|
|
triples = []
|
||
|
|
for path in P:
|
||
|
|
for i in range(0, len(path)-2, 2):
|
||
|
|
|
||
|
|
# triples.append((path[i], path[i+1], path[i+2]))
|
||
|
|
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
|
||
|
|
|
||
|
|
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
|
||
|
|
triples_string = ". ".join(triples_string)
|
||
|
|
|
||
|
|
prompt = f"Given a question and the associated retrieved knowledge graph triples (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triples and your knowledge (Yes or No). Query: {query} \n Knowledge triples: {triples_string}"
|
||
|
|
|
||
|
|
messages = [{"role": "system", "content": "Answer the question following the prompt."},
|
||
|
|
{"role": "user", "content": f"{prompt}"}]
|
||
|
|
|
||
|
|
response = self.llm_generator.generate_response(messages)
|
||
|
|
return "yes" in response.lower()
|
||
|
|
|
||
|
|
def generate(self, query, P):
|
||
|
|
triples = []
|
||
|
|
for path in P:
|
||
|
|
for i in range(0, len(path)-2, 2):
|
||
|
|
# triples.append((path[i], path[i+1], path[i+2]))
|
||
|
|
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
|
||
|
|
|
||
|
|
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
|
||
|
|
|
||
|
|
# response = self.llm_generator.generate_with_context_kg(query, triples_string)
|
||
|
|
return triples_string, ["N/A" for _ in range(len(triples_string))]
|