first commit
This commit is contained in:
4
atlas_rag/retriever/__init__.py
Normal file
4
atlas_rag/retriever/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .hipporag import HippoRAGRetriever
|
||||
from .hipporag2 import HippoRAG2Retriever
|
||||
from .simple_retriever import SimpleGraphRetriever, SimpleTextRetriever
|
||||
from .tog import TogRetriever
|
||||
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/base.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/base.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-39.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/hipporag.cpython-39.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/hipporag2.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/hipporag2.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/inference_config.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/inference_config.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/simple_retriever.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/simple_retriever.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/retriever/__pycache__/tog.cpython-311.pyc
Normal file
BIN
atlas_rag/retriever/__pycache__/tog.cpython-311.pyc
Normal file
Binary file not shown.
27
atlas_rag/retriever/base.py
Normal file
27
atlas_rag/retriever/base.py
Normal file
@ -0,0 +1,27 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@abstractmethod
|
||||
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
class BaseEdgeRetriever(BaseRetriever):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@abstractmethod
|
||||
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
class BasePassageRetriever(BaseRetriever):
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@abstractmethod
|
||||
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
|
||||
raise NotImplementedError("This method should be overridden by subclasses.")
|
||||
|
||||
140
atlas_rag/retriever/hipporag.py
Normal file
140
atlas_rag/retriever/hipporag.py
Normal file
@ -0,0 +1,140 @@
|
||||
from tqdm import tqdm
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BasePassageRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
|
||||
class HippoRAGRetriever(BasePassageRetriever):
|
||||
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
data:dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs):
|
||||
self.passage_dict = data["text_dict"]
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.node_list = data["node_list"]
|
||||
file_id_to_node_id = {}
|
||||
self.KG = data["KG"]
|
||||
for node_id in tqdm(list(self.KG.nodes)):
|
||||
if self.KG.nodes[node_id]['type'] == "passage":
|
||||
if self.KG.nodes[node_id]['file_id'] not in file_id_to_node_id:
|
||||
file_id_to_node_id[self.KG.nodes[node_id]['file_id']] = []
|
||||
file_id_to_node_id[self.KG.nodes[node_id]['file_id']].append(node_id)
|
||||
self.file_id_to_node_id = file_id_to_node_id
|
||||
|
||||
self.KG:nx.DiGraph = self.KG.subgraph(self.node_list)
|
||||
self.node_name_list = [self.KG.nodes[node]["id"] for node in self.node_list]
|
||||
|
||||
|
||||
self.logger :Logger = logger
|
||||
if self.logger is None:
|
||||
self.logging = False
|
||||
else:
|
||||
self.logging = True
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
|
||||
def retrieve_personalization_dict(self, query, topN=10):
|
||||
|
||||
# extract entities from the query
|
||||
entities = self.llm_generator.ner(query)
|
||||
entities = entities.split(", ")
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG NER Entities: {entities}")
|
||||
# print("Entities:", entities)
|
||||
|
||||
if len(entities) == 0:
|
||||
# If the NER cannot extract any entities, we
|
||||
# use the query as the entity to do approximate search
|
||||
entities = [query]
|
||||
|
||||
# evenly distribute the topk for each entity
|
||||
topk_for_each_entity = topN//len(entities)
|
||||
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
if entity in self.node_name_list:
|
||||
# get the index of the entity in the node list
|
||||
index = self.node_name_list.index(entity)
|
||||
topk_nodes.append(self.node_list[index])
|
||||
else:
|
||||
topk_for_this_entity = 1
|
||||
|
||||
# print("Topk for this entity:", topk_for_this_entity)
|
||||
|
||||
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
|
||||
scores = self.node_embeddings@entity_embedding[0].T
|
||||
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
|
||||
topk_nodes += [self.node_list[i] for i in index_matrix]
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG Topk Nodes: {[self.KG.nodes[node]['id'] for node in topk_nodes]}")
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
|
||||
# assert len(topk_nodes) <= topN
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
|
||||
|
||||
# print("Topk nodes:", topk_nodes)
|
||||
# find the number of docs that one work appears in
|
||||
freq_dict_for_nodes = {}
|
||||
for node in topk_nodes:
|
||||
node_data = self.KG.nodes[node]
|
||||
# print(node_data)
|
||||
file_ids = node_data["file_id"]
|
||||
file_ids_list = file_ids.split(",")
|
||||
#uniq this list
|
||||
file_ids_list = list(set(file_ids_list))
|
||||
freq_dict_for_nodes[node] = len(file_ids_list)
|
||||
|
||||
personalization_dict = {node: 1 / freq_dict_for_nodes[node] for node in topk_nodes}
|
||||
|
||||
# print("personalization dict: ")
|
||||
return personalization_dict
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
topN_nodes = self.inference_config.topk_nodes
|
||||
personaliation_dict = self.retrieve_personalization_dict(query, topN=topN_nodes)
|
||||
|
||||
# retrieve the top N passages
|
||||
pr = nx.pagerank(self.KG, personalization=personaliation_dict)
|
||||
|
||||
for node in pr:
|
||||
pr[node] = round(pr[node], 4)
|
||||
if pr[node] < 0.001:
|
||||
pr[node] = 0
|
||||
|
||||
passage_probabilities_sum = {}
|
||||
for node in pr:
|
||||
node_data = self.KG.nodes[node]
|
||||
file_ids = node_data["file_id"]
|
||||
# for each file id check through each text_id
|
||||
file_ids_list = file_ids.split(",")
|
||||
#uniq this list
|
||||
file_ids_list = list(set(file_ids_list))
|
||||
# file id to node id
|
||||
|
||||
for file_id in file_ids_list:
|
||||
if file_id == 'concept_file':
|
||||
continue
|
||||
for node_id in self.file_id_to_node_id[file_id]:
|
||||
if node_id not in passage_probabilities_sum:
|
||||
passage_probabilities_sum[node_id] = 0
|
||||
passage_probabilities_sum[node_id] += pr[node]
|
||||
|
||||
sorted_passages = sorted(passage_probabilities_sum.items(), key=lambda x: x[1], reverse=True)
|
||||
top_passages = sorted_passages[:topN]
|
||||
top_passages, scores = zip(*top_passages)
|
||||
|
||||
passag_contents = [self.passage_dict[passage_id] for passage_id in top_passages]
|
||||
|
||||
return passag_contents, top_passages
|
||||
237
atlas_rag/retriever/hipporag2.py
Normal file
237
atlas_rag/retriever/hipporag2.py
Normal file
@ -0,0 +1,237 @@
|
||||
import networkx as nx
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Tuple
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import json_repair
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from logging import Logger
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BasePassageRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
def min_max_normalize(x):
|
||||
min_val = np.min(x)
|
||||
max_val = np.max(x)
|
||||
range_val = max_val - min_val
|
||||
|
||||
# Handle the case where all values are the same (range is zero)
|
||||
if range_val == 0:
|
||||
return np.ones_like(x) # Return an array of ones with the same shape as x
|
||||
|
||||
return (x - min_val) / range_val
|
||||
|
||||
|
||||
class HippoRAG2Retriever(BasePassageRetriever):
|
||||
def __init__(self, llm_generator:LLMGenerator,
|
||||
sentence_encoder:BaseEmbeddingModel,
|
||||
data : dict,
|
||||
inference_config: Optional[InferenceConfig] = None,
|
||||
logger = None,
|
||||
**kwargs):
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.node_list = data["node_list"]
|
||||
self.edge_list = data["edge_list"]
|
||||
self.edge_embeddings = data["edge_embeddings"]
|
||||
self.text_embeddings = data["text_embeddings"]
|
||||
self.edge_faiss_index = data["edge_faiss_index"]
|
||||
self.passage_dict = data["text_dict"]
|
||||
self.text_id_list = list(self.passage_dict.keys())
|
||||
self.KG = data["KG"]
|
||||
self.KG = self.KG.subgraph(self.node_list + self.text_id_list)
|
||||
|
||||
|
||||
self.logger = logger
|
||||
if self.logger is None:
|
||||
self.logging = False
|
||||
else:
|
||||
self.logging = True
|
||||
|
||||
hipporag2mode = "query2edge"
|
||||
if hipporag2mode == "query2edge":
|
||||
self.retrieve_node_fn = self.query2edge
|
||||
elif hipporag2mode == "query2node":
|
||||
self.retrieve_node_fn = self.query2node
|
||||
elif hipporag2mode == "ner2node":
|
||||
self.retrieve_node_fn = self.ner2node
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {hipporag2mode}. Choose from 'query2edge', 'query2node', or 'query2passage'.")
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
node_id_to_file_id = {}
|
||||
for node_id in tqdm(list(self.KG.nodes)):
|
||||
if self.inference_config.keyword == "musique" and self.KG.nodes[node_id]['type']=="passage":
|
||||
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["id"]
|
||||
else:
|
||||
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["file_id"]
|
||||
self.node_id_to_file_id = node_id_to_file_id
|
||||
|
||||
def ner(self, text):
|
||||
return self.llm_generator.ner(text)
|
||||
|
||||
def ner2node(self, query, topN = 10):
|
||||
entities = self.ner(query)
|
||||
entities = entities.split(", ")
|
||||
|
||||
if len(entities) == 0:
|
||||
entities = [query]
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
node_score_dict = {}
|
||||
for entity_index, entity in enumerate(entities):
|
||||
topk_for_this_entity = 1
|
||||
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
|
||||
scores = min_max_normalize(self.node_embeddings@entity_embedding[0].T)
|
||||
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
node = self.node_list[index]
|
||||
if node not in topk_nodes:
|
||||
topk_nodes.append(node)
|
||||
node_score_dict[node] = sim_score
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
result_node_score_dict = {}
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
for node in topk_nodes:
|
||||
if node in node_score_dict:
|
||||
result_node_score_dict[node] = node_score_dict[node]
|
||||
return result_node_score_dict
|
||||
|
||||
def query2node(self, query, topN = 10):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="entity")
|
||||
scores = min_max_normalize(self.node_embeddings@query_emb[0].T)
|
||||
index_matrix = np.argsort(scores)[-topN:][::-1]
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
result_node_score_dict = {}
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
node = self.node_list[index]
|
||||
result_node_score_dict[node] = sim_score
|
||||
|
||||
return result_node_score_dict
|
||||
|
||||
def query2edge(self, query, topN = 10):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="edge")
|
||||
scores = min_max_normalize(self.edge_embeddings@query_emb[0].T)
|
||||
index_matrix = np.argsort(scores)[-topN:][::-1]
|
||||
log_edge_list = []
|
||||
for index in index_matrix:
|
||||
edge = self.edge_list[index]
|
||||
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
|
||||
log_edge_list.append(edge_str)
|
||||
|
||||
similarity_matrix = [scores[i] for i in index_matrix]
|
||||
# construct the edge list
|
||||
before_filter_edge_json = {}
|
||||
before_filter_edge_json['fact'] = []
|
||||
for index, sim_score in zip(index_matrix, similarity_matrix):
|
||||
edge = self.edge_list[index]
|
||||
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
|
||||
before_filter_edge_json['fact'].append(edge_str)
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG2 Before Filter Edge: {before_filter_edge_json['fact']}")
|
||||
filtered_facts = self.llm_generator.filter_triples_with_entity_event(query, json.dumps(before_filter_edge_json, ensure_ascii=False))
|
||||
filtered_facts = json_repair.loads(filtered_facts)['fact']
|
||||
if len(filtered_facts) == 0:
|
||||
return {}
|
||||
# use filtered facts to get the edge id and check if it exists in the original candidate list.
|
||||
node_score_dict = {}
|
||||
log_edge_list = []
|
||||
for edge in filtered_facts:
|
||||
edge_str = f'{edge[0]} {edge[1]} {edge[2]}'
|
||||
search_emb = self.sentence_encoder.encode([edge_str], query_type="search")
|
||||
D, I = self.edge_faiss_index.search(search_emb, 1)
|
||||
filtered_index = I[0][0]
|
||||
# get the edge and the original score
|
||||
edge = self.edge_list[filtered_index]
|
||||
log_edge_list.append([self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']])
|
||||
head, tail = edge[0], edge[1]
|
||||
sim_score = scores[filtered_index]
|
||||
|
||||
if head not in node_score_dict:
|
||||
node_score_dict[head] = [sim_score]
|
||||
else:
|
||||
node_score_dict[head].append(sim_score)
|
||||
if tail not in node_score_dict:
|
||||
node_score_dict[tail] = [sim_score]
|
||||
else:
|
||||
node_score_dict[tail].append(sim_score)
|
||||
# average the scores
|
||||
if self.logging:
|
||||
self.logger.info(f"HippoRAG2: Filtered edges: {log_edge_list}")
|
||||
|
||||
# take average of the scores
|
||||
for node in node_score_dict:
|
||||
node_score_dict[node] = sum(node_score_dict[node]) / len(node_score_dict[node])
|
||||
|
||||
return node_score_dict
|
||||
|
||||
def query2passage(self, query, weight_adjust = 0.05):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="passage")
|
||||
sim_scores = self.text_embeddings @ query_emb[0].T
|
||||
sim_scores = min_max_normalize(sim_scores)*weight_adjust # converted to probability
|
||||
# create dict of passage id and score
|
||||
return dict(zip(self.text_id_list, sim_scores))
|
||||
|
||||
def retrieve_personalization_dict(self, query, topN=30, weight_adjust=0.05):
|
||||
node_dict = self.retrieve_node_fn(query, topN=topN)
|
||||
text_dict = self.query2passage(query, weight_adjust=weight_adjust)
|
||||
|
||||
return node_dict, text_dict
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
topN_edges = self.inference_config.topk_edges
|
||||
weight_adjust = self.inference_config.weight_adjust
|
||||
|
||||
node_dict, text_dict = self.retrieve_personalization_dict(query, topN=topN_edges, weight_adjust=weight_adjust)
|
||||
|
||||
personalization_dict = {}
|
||||
if len(node_dict) == 0:
|
||||
# return topN text passages
|
||||
sorted_passages = sorted(text_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages = sorted_passages[:topN]
|
||||
sorted_passages_contents = []
|
||||
sorted_scores = []
|
||||
sorted_passage_ids = []
|
||||
for passage_id, score in sorted_passages:
|
||||
sorted_passages_contents.append(self.passage_dict[passage_id])
|
||||
sorted_scores.append(float(score))
|
||||
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
|
||||
return sorted_passages_contents, sorted_passage_ids
|
||||
|
||||
personalization_dict.update(node_dict)
|
||||
personalization_dict.update(text_dict)
|
||||
# retrieve the top N passages
|
||||
pr = nx.pagerank(self.KG, personalization=personalization_dict,
|
||||
alpha = self.inference_config.ppr_alpha,
|
||||
max_iter=self.inference_config.ppr_max_iter,
|
||||
tol=self.inference_config.ppr_tol)
|
||||
|
||||
# get the top N passages based on the text_id list and pagerank score
|
||||
text_dict_score = {}
|
||||
for node in self.text_id_list:
|
||||
# filter out nodes that have 0 score
|
||||
if pr[node] > 0.0:
|
||||
text_dict_score[node] = pr[node]
|
||||
|
||||
# return topN passages
|
||||
sorted_passages_ids = sorted(text_dict_score.items(), key=lambda x: x[1], reverse=True)
|
||||
sorted_passages_ids = sorted_passages_ids[:topN]
|
||||
|
||||
sorted_passages_contents = []
|
||||
sorted_scores = []
|
||||
sorted_passage_ids = []
|
||||
for passage_id, score in sorted_passages_ids:
|
||||
sorted_passages_contents.append(self.passage_dict[passage_id])
|
||||
sorted_scores.append(score)
|
||||
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
|
||||
return sorted_passages_contents, sorted_passage_ids
|
||||
23
atlas_rag/retriever/inference_config.py
Normal file
23
atlas_rag/retriever/inference_config.py
Normal file
@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
"""
|
||||
Configuration class for inference settings.
|
||||
|
||||
Attributes:
|
||||
topk (int): Number of top results to retrieve. Default is 5.
|
||||
Dmax (int): Maximum depth for search. Default is 4.
|
||||
weight_adjust (float): Weight adjustment factor for passage retrieval. Default is 0.05.
|
||||
topk_edges (int): Number of top edges to retrieve. Default is 50.
|
||||
topk_nodes (int): Number of top nodes to retrieve. Default is 10.
|
||||
"""
|
||||
keyword: str = "musique"
|
||||
topk: int = 5
|
||||
Dmax: int = 4
|
||||
weight_adjust: float = 1.0
|
||||
topk_edges: int = 50
|
||||
topk_nodes: int = 10
|
||||
ppr_alpha: float = 0.99
|
||||
ppr_max_iter: int = 2000
|
||||
ppr_tol: float = 1e-7
|
||||
0
atlas_rag/retriever/lkg_retriever/__init__.py
Normal file
0
atlas_rag/retriever/lkg_retriever/__init__.py
Normal file
40
atlas_rag/retriever/lkg_retriever/base.py
Normal file
40
atlas_rag/retriever/lkg_retriever/base.py
Normal file
@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLargeKGRetriever(ABC):
|
||||
def __init__():
|
||||
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
|
||||
@abstractmethod
|
||||
def retrieve_passages(self, query, retriever_config:dict):
|
||||
"""
|
||||
Retrieve passages based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
topN (int): Number of top passages to retrieve.
|
||||
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
|
||||
sampling_area (int): Area for sampling in the graph.
|
||||
|
||||
Returns:
|
||||
List of retrieved passages and their scores.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
|
||||
class BaseLargeKGEdgeRetriever(ABC):
|
||||
def __init__():
|
||||
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
|
||||
@abstractmethod
|
||||
def retrieve_passages(self, query, retriever_config:dict):
|
||||
"""
|
||||
Retrieve Edges / Paths based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
topN (int): Number of top passages to retrieve.
|
||||
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
|
||||
sampling_area (int): Area for sampling in the graph.
|
||||
|
||||
Returns:
|
||||
List of retrieved passages and their scores.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
313
atlas_rag/retriever/lkg_retriever/lkgr.py
Normal file
313
atlas_rag/retriever/lkg_retriever/lkgr.py
Normal file
@ -0,0 +1,313 @@
|
||||
from difflib import get_close_matches
|
||||
from logging import Logger
|
||||
import faiss
|
||||
from neo4j import GraphDatabase
|
||||
import time
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
nltk.download('stopwords')
|
||||
from graphdatascience import GraphDataScience
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
import string
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
||||
|
||||
|
||||
|
||||
class LargeKGRetriever(BaseLargeKGRetriever):
|
||||
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
|
||||
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
node_index:faiss.Index, passage_index:faiss.Index,
|
||||
topN: int = 5,
|
||||
number_of_source_nodes_per_ner: int = 10,
|
||||
sampling_area : int = 250,logger:Logger = None):
|
||||
# istantiate one kg resources
|
||||
self.keyword = keyword
|
||||
self.neo4j_driver = neo4j_driver
|
||||
self.gds_driver = GraphDataScience(self.neo4j_driver)
|
||||
self.topN = topN
|
||||
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
|
||||
self.sampling_area = sampling_area
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.node_faiss_index = node_index
|
||||
# self.edge_faiss_index = self.edge_indexes[keyword]
|
||||
self.passage_faiss_index = passage_index
|
||||
|
||||
self.verbose = False if logger is None else True
|
||||
self.logger = logger
|
||||
|
||||
self.ppr_weight_threshold = 0.00005
|
||||
|
||||
def set_model(self, model):
|
||||
if self.llm_generator.inference_type == 'openai':
|
||||
self.llm_generator.model_name = model
|
||||
else:
|
||||
raise ValueError("Model can only be set for OpenAI inference type.")
|
||||
|
||||
def ner(self, text):
|
||||
return self.llm_generator.large_kg_ner(text)
|
||||
|
||||
def convert_numeric_id_to_name(self, numeric_id):
|
||||
if numeric_id.isdigit():
|
||||
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
|
||||
else:
|
||||
return numeric_id
|
||||
|
||||
def has_intersection(self, word_set, input_string):
|
||||
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
|
||||
if self.keyword == 'cc_en':
|
||||
# Check if any phrase in word_set is a substring of cleaned_string
|
||||
for phrase in word_set:
|
||||
if phrase in cleaned_string:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
# Check if any word in word_set is present in the cleaned_string's words
|
||||
words_in_string = set(cleaned_string.split())
|
||||
return not word_set.isdisjoint(words_in_string)
|
||||
|
||||
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
|
||||
graph = self.gds_driver.graph.get('largekgrag_graph')
|
||||
node_count = graph.node_count()
|
||||
sampling_ratio = sampling_area / node_count
|
||||
aggregation_node_dict = []
|
||||
ppr_weight_threshold = self.ppr_weight_threshold
|
||||
start_time = time.time()
|
||||
|
||||
# Pre-filter nodes based on ppr_weight threshold
|
||||
# Precompute word sets based on keyword
|
||||
if self.keyword == 'cc_en':
|
||||
filtered_personalization = {
|
||||
node_id: ppr_weight
|
||||
for node_id, ppr_weight in personalization_dict.items()
|
||||
if ppr_weight >= ppr_weight_threshold
|
||||
}
|
||||
stop_words = set(stopwords.words('english'))
|
||||
word_set_phrases = set()
|
||||
word_set_words = set()
|
||||
for node_id, ppr_weight in filtered_personalization.items():
|
||||
name = self.gds_driver.util.asNode(node_id)['name']
|
||||
if name:
|
||||
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
|
||||
if cleaned_phrase:
|
||||
# Process for 'cc_en': remove stop words and add cleaned phrase
|
||||
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
|
||||
if filtered_words:
|
||||
cleaned_phrase_filtered = ' '.join(filtered_words)
|
||||
word_set_phrases.add(cleaned_phrase_filtered)
|
||||
|
||||
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
|
||||
if self.verbose:
|
||||
self.logger.info(f"Optimized word set: {word_set}")
|
||||
else:
|
||||
filtered_personalization = personalization_dict
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
|
||||
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
|
||||
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
|
||||
# Process each node in the filtered personalization dict
|
||||
for node_id, ppr_weight in filtered_personalization.items():
|
||||
try:
|
||||
self.gds_driver.graph.drop('rwr_sample')
|
||||
start_time = time.time()
|
||||
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
|
||||
restartProbability = 0.4, logProgress = False)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
result = self.gds_driver.pageRank.stream(
|
||||
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
|
||||
).sort_values("score", ascending=False)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"pagerank type: {type(result)}")
|
||||
self.logger.info(f"pagerank result: {result}")
|
||||
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
# if self.keyword == 'cc_en':
|
||||
if self.keyword != 'cc_en':
|
||||
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
|
||||
else:
|
||||
result = result.to_dict('records')
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG :result: {result}")
|
||||
for entry in result:
|
||||
if self.keyword == 'cc_en':
|
||||
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
|
||||
if not self.has_intersection(word_set, node_name):
|
||||
continue
|
||||
|
||||
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
|
||||
aggregation_node_dict.append({
|
||||
'nodeId': numeric_id,
|
||||
'score': entry['score'] * ppr_weight
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self.logger.error(f"Error processing node {node_id}: {e}")
|
||||
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
|
||||
if self.verbose:
|
||||
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
with self.neo4j_driver.session() as session:
|
||||
intermediate_time = time.time()
|
||||
# Step 1: Distribute entity scores to connected text nodes and find the top 5
|
||||
query_scores = """
|
||||
UNWIND $entries AS entry
|
||||
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
|
||||
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
|
||||
ORDER BY total_score DESC
|
||||
LIMIT $topN
|
||||
RETURN textId, total_score
|
||||
"""
|
||||
# Execute query to aggregate scores
|
||||
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
|
||||
top_numeric_ids = []
|
||||
top_scores = []
|
||||
|
||||
# Extract the top text node IDs and scores
|
||||
for record in result_scores:
|
||||
top_numeric_ids.append(record["textId"])
|
||||
top_scores.append(record["total_score"])
|
||||
|
||||
# Step 2: Use top numeric IDs to retrieve the original text
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
|
||||
intermediate_time = time.time()
|
||||
query_text = """
|
||||
UNWIND $textIds AS textId
|
||||
MATCH (t:Text {numeric_id: textId})
|
||||
RETURN t.original_text AS text, t.numeric_id AS textId
|
||||
"""
|
||||
result_texts = session.run(query_text, textIds=top_numeric_ids)
|
||||
topN_passages = []
|
||||
score_dict = dict(zip(top_numeric_ids, top_scores))
|
||||
|
||||
# Combine original text with scores
|
||||
for record in result_texts:
|
||||
original_text = record["text"]
|
||||
text_id = record["textId"]
|
||||
score = score_dict.get(text_id, 0)
|
||||
topN_passages.append((original_text, score))
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
|
||||
# Sort passages by score
|
||||
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
|
||||
top_texts = [item[0] for item in topN_passages][:topN]
|
||||
top_scores = [item[1] for item in topN_passages][:topN]
|
||||
if self.verbose:
|
||||
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
|
||||
self.logger.info(f"Top passages: {top_texts}")
|
||||
self.logger.info(f"Top scores: {top_scores}")
|
||||
if self.verbose:
|
||||
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
|
||||
return top_texts, top_scores
|
||||
|
||||
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
|
||||
# extract entities from the query
|
||||
entities = self.ner(query)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
|
||||
if len(entities) == 0:
|
||||
entities = [query]
|
||||
num_entities = len(entities)
|
||||
initial_nodes = []
|
||||
for entity in entities:
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
|
||||
initial_nodes += [str(i)for i in I[0]]
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
|
||||
name_id_map = {}
|
||||
for node_id in initial_nodes:
|
||||
name = self.convert_numeric_id_to_name(node_id)
|
||||
name_id_map[name] = node_id
|
||||
|
||||
topk_nodes = list(set(initial_nodes))
|
||||
# convert the numeric id to string and filter again then return numeric id
|
||||
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
|
||||
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
|
||||
|
||||
# Second pass: Add filtered keywords
|
||||
filtered_top_k_nodes = []
|
||||
filter_log_dict = {}
|
||||
match_threshold = 0.8
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
|
||||
for keyword in filtered_keywords:
|
||||
# Check for an exact match first
|
||||
if keyword in name_id_map:
|
||||
filtered_top_k_nodes.append(name_id_map[keyword])
|
||||
filter_log_dict[keyword] = name_id_map[keyword]
|
||||
else:
|
||||
# Look for close matches using difflib's get_close_matches
|
||||
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
|
||||
|
||||
if close_matches:
|
||||
# If a close match is found, add the corresponding node
|
||||
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
|
||||
|
||||
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
|
||||
|
||||
topk_nodes = list(set(filtered_top_k_nodes))
|
||||
if len(topk_nodes) > 2 * num_entities:
|
||||
topk_nodes = topk_nodes[:2 * num_entities]
|
||||
return topk_nodes
|
||||
|
||||
def _process_text(self, text):
|
||||
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
|
||||
text = text.lower()
|
||||
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
|
||||
return set(text.split())
|
||||
|
||||
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
|
||||
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
|
||||
if topk_nodes == []:
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
|
||||
return {}
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
|
||||
freq_dict_for_nodes = {}
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
|
||||
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
|
||||
"""
|
||||
with self.neo4j_driver.session() as session:
|
||||
result = session.run(query, nodes=topk_nodes)
|
||||
for record in result:
|
||||
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
|
||||
# Create the personalization dictionary
|
||||
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
|
||||
return personalization_dict
|
||||
|
||||
def retrieve_passages(self, query):
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
|
||||
topN = self.topN
|
||||
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
|
||||
sampling_area = self.sampling_area
|
||||
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
|
||||
if personalization_dict == {}:
|
||||
return [], [0]
|
||||
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
|
||||
return topN_passages, topN_scores
|
||||
|
||||
|
||||
|
||||
469
atlas_rag/retriever/lkg_retriever/tog.py
Normal file
469
atlas_rag/retriever/lkg_retriever/tog.py
Normal file
@ -0,0 +1,469 @@
|
||||
from neo4j import GraphDatabase
|
||||
import faiss
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
import time
|
||||
import logging
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
nltk.download('stopwords')
|
||||
|
||||
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
|
||||
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
|
||||
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
|
||||
node_index: faiss.Index,
|
||||
topN : int = 5,
|
||||
Dmax : int = 3,
|
||||
Wmax : int = 3,
|
||||
prune_size: int = 10,
|
||||
logger: logging.Logger = None):
|
||||
"""
|
||||
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
|
||||
|
||||
Args:
|
||||
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
|
||||
neo4j_driver (GraphDatabase): Neo4j driver for database access.
|
||||
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
|
||||
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
|
||||
node_index (faiss.Index): FAISS index for node embeddings.
|
||||
logger (Logger, optional): Logger for verbose output.
|
||||
"""
|
||||
self.keyword = keyword
|
||||
self.neo4j_driver = neo4j_driver
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.filter_encoder = filter_encoder
|
||||
self.node_faiss_index = node_index
|
||||
self.verbose = logger is not None
|
||||
self.logger = logger
|
||||
self.topN = topN
|
||||
self.Dmax = Dmax
|
||||
self.Wmax = Wmax
|
||||
self.prune_size = prune_size
|
||||
|
||||
def ner(self, text: str) -> List[str]:
|
||||
"""
|
||||
Extract named entities from the query text using the LLM.
|
||||
|
||||
Args:
|
||||
text (str): The query text.
|
||||
|
||||
Returns:
|
||||
List[str]: List of extracted entities.
|
||||
"""
|
||||
entities = self.llm_generator.large_kg_ner(text)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Extracted entities: {entities}")
|
||||
return entities
|
||||
|
||||
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
|
||||
"""
|
||||
Retrieve top-k nodes similar to entities in the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
top_k_nodes (int): Number of nodes to retrieve per entity.
|
||||
|
||||
Returns:
|
||||
List[str]: List of node numeric_ids.
|
||||
"""
|
||||
start_time = time.time()
|
||||
entities = self.ner(query)
|
||||
if self.verbose:
|
||||
ner_time = time.time() - start_time
|
||||
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
|
||||
if not entities:
|
||||
entities = [query]
|
||||
|
||||
initial_nodes = []
|
||||
for entity in entities:
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
D, I = self.node_faiss_index.search(entity_embedding, 3)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
|
||||
if len(I[0]) > 0: # Check if results exist
|
||||
initial_nodes.extend([str(i) for i in I[0]])
|
||||
# no need filtering as ToG pruning will handle it.
|
||||
topk_nodes_ids = list(set(initial_nodes))
|
||||
|
||||
with self.neo4j_driver.session() as session:
|
||||
start_time = time.time()
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.numeric_id IN $topk_nodes_ids
|
||||
RETURN n.numeric_id AS id, n.name AS name
|
||||
"""
|
||||
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
|
||||
topk_nodes_dict = {}
|
||||
for record in result:
|
||||
topk_nodes_dict[record["id"]] = record["name"]
|
||||
if self.verbose:
|
||||
neo4j_time = time.time() - start_time
|
||||
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
|
||||
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
|
||||
|
||||
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
|
||||
"""
|
||||
Expand each path by adding neighbors of the last node.
|
||||
|
||||
Args:
|
||||
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: List of expanded paths.
|
||||
"""
|
||||
last_nodes = []
|
||||
last_node_ids = []
|
||||
last_node_types = []
|
||||
|
||||
paths_end_with_text = []
|
||||
paths_end_with_text_id = []
|
||||
paths_end_with_text_type = []
|
||||
|
||||
paths_end_with_node = []
|
||||
paths_end_with_node_id = []
|
||||
paths_end_with_node_type = []
|
||||
if self.verbose:
|
||||
self.logger.info(f"Expanding paths, current paths: {P}")
|
||||
for p, pid, ptype in zip(P, PIDS, PTYPES):
|
||||
if not p or not pid or not ptype: # Skip empty paths
|
||||
continue
|
||||
t = ptype[-1]
|
||||
if t == "Text":
|
||||
paths_end_with_text.append(p)
|
||||
paths_end_with_text_id.append(pid)
|
||||
paths_end_with_text_type.append(ptype)
|
||||
continue
|
||||
last_node = p[-1] # Last node in the path
|
||||
last_node_id = pid[-1] # Last node numeric_id
|
||||
|
||||
last_nodes.append(last_node)
|
||||
last_node_ids.append(last_node_id)
|
||||
last_node_types.append(t)
|
||||
|
||||
paths_end_with_node.append(p)
|
||||
paths_end_with_node_id.append(pid)
|
||||
paths_end_with_node_type.append(ptype)
|
||||
|
||||
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
|
||||
|
||||
if not last_node_ids:
|
||||
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
|
||||
|
||||
with self.neo4j_driver.session() as session:
|
||||
# Query Node relationships
|
||||
start_time = time.time()
|
||||
outgoing_query = """
|
||||
CALL apoc.cypher.runTimeboxed(
|
||||
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
|
||||
WITH n, r, m ORDER BY rand() LIMIT 60000
|
||||
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
|
||||
{last_node_ids: $last_node_ids},
|
||||
60000
|
||||
)
|
||||
YIELD value
|
||||
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
|
||||
"""
|
||||
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
|
||||
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
|
||||
for record in outgoing_result]
|
||||
if self.verbose:
|
||||
outgoing_time = time.time() - start_time
|
||||
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
|
||||
|
||||
# # Query outgoing Node -> Text relationships
|
||||
# start_time = time.time()
|
||||
# outgoing_text_query = """
|
||||
# MATCH (n:Node)-[r:Source]->(t:Text)
|
||||
# WHERE n.numeric_id IN $last_node_ids
|
||||
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
|
||||
# """
|
||||
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
|
||||
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
|
||||
# for record in outgoing_text_result]
|
||||
# if self.verbose:
|
||||
# outgoing_text_time = time.time() - start_time
|
||||
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
|
||||
|
||||
last_node_to_new_paths = defaultdict(list)
|
||||
last_node_to_new_paths_ids = defaultdict(list)
|
||||
last_node_to_new_paths_types = defaultdict(list)
|
||||
|
||||
|
||||
for p, pid, ptype in zip(P, PIDS, PTYPES):
|
||||
last_node = p[-1]
|
||||
last_node_id = pid[-1]
|
||||
# Outgoing Node -> Node
|
||||
for source, source_name, rel_type, target, target_name, target_type in outgoing:
|
||||
if source == last_node_id and target_name not in p:
|
||||
new_path = p + [rel_type, target_name]
|
||||
if target_name.lower() in stopwords.words('english'):
|
||||
continue
|
||||
last_node_to_new_paths[last_node].append(new_path)
|
||||
last_node_to_new_paths_ids[last_node].append(pid + [target])
|
||||
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
|
||||
|
||||
# # Outgoing Node -> Text
|
||||
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
|
||||
# if source == last_node_id and target_name not in p:
|
||||
# new_path = p + [rel_type, target_name]
|
||||
|
||||
# last_node_to_new_paths_text[last_node].append(new_path)
|
||||
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
|
||||
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
|
||||
|
||||
# # Incoming Node -> Node
|
||||
# for source, rel_type, target, source_name, source_type in incoming:
|
||||
# if target == last_node_id and source not in p:
|
||||
# new_path = p + [rel_type, source_name]
|
||||
|
||||
|
||||
num_paths = 0
|
||||
for last_node, new_paths in last_node_to_new_paths.items():
|
||||
num_paths += len(new_paths)
|
||||
# for last_node, new_paths in last_node_to_new_paths_text.items():
|
||||
# num_paths += len(new_paths)
|
||||
new_paths = []
|
||||
new_pids = []
|
||||
new_ptypes = []
|
||||
if self.verbose:
|
||||
self.logger.info(f"Number of new paths before filtering: {num_paths}")
|
||||
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
|
||||
if num_paths > len(last_node_ids) * width:
|
||||
# Apply filtering when total paths exceed threshold
|
||||
for last_node, new_ps in last_node_to_new_paths.items():
|
||||
if len(new_ps) > width:
|
||||
path_embeddings = self.filter_encoder.encode(new_ps)
|
||||
query_embeddings = self.filter_encoder.encode([query])
|
||||
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
|
||||
top_indices = np.argsort(scores)[-width:]
|
||||
new_paths.extend([new_ps[i] for i in top_indices])
|
||||
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
|
||||
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
|
||||
else:
|
||||
new_paths.extend(new_ps)
|
||||
new_pids.extend(last_node_to_new_paths_ids[last_node])
|
||||
new_ptypes.extend(last_node_to_new_paths_types[last_node])
|
||||
else:
|
||||
# Collect all paths without filtering when total is at or below threshold
|
||||
for last_node, new_ps in last_node_to_new_paths.items():
|
||||
new_paths.extend(new_ps)
|
||||
new_pids.extend(last_node_to_new_paths_ids[last_node])
|
||||
new_ptypes.extend(last_node_to_new_paths_types[last_node])
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Expanded paths count: {len(new_paths)}")
|
||||
self.logger.info(f"Expanded paths: {new_paths}")
|
||||
return new_paths, new_pids, new_ptypes
|
||||
|
||||
def path_to_string(self, path: List[str]) -> str:
|
||||
"""
|
||||
Convert a path to a human-readable string for LLM rating.
|
||||
|
||||
Args:
|
||||
path (List[str]): Path as a list of node_ids and relation types.
|
||||
|
||||
Returns:
|
||||
str: String representation of the path.
|
||||
"""
|
||||
if len(path) < 1:
|
||||
return ""
|
||||
path_str = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for i in range(0, len(path), 2):
|
||||
node_id = path[i]
|
||||
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
|
||||
node_name = result.single()["n.name"] if result.single() else node_id
|
||||
if i + 1 < len(path):
|
||||
rel_type = path[i + 1]
|
||||
path_str.append(f"{node_name} ---> {rel_type} --->")
|
||||
else:
|
||||
path_str.append(node_name)
|
||||
return " ".join(path_str).strip()
|
||||
|
||||
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
|
||||
"""
|
||||
Prune paths to keep the top N based on LLM relevance ratings.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
P (List[List[str]]): List of paths to prune.
|
||||
topN (int): Number of paths to retain.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: Top N paths.
|
||||
"""
|
||||
ratings = []
|
||||
path_strings = P
|
||||
|
||||
# Process paths in chunks of 10
|
||||
for i in range(0, len(path_strings), self.prune_size):
|
||||
chunk = path_strings[i:i + self.prune_size]
|
||||
|
||||
# Construct user prompt with the current chunk of paths listed
|
||||
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
|
||||
for j, path_str in enumerate(chunk, 1):
|
||||
user_prompt += f"{j + i}. {path_str}\n"
|
||||
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
|
||||
|
||||
# Define system prompt to expect a list of integers
|
||||
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
|
||||
|
||||
# Send the prompt to the language model
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
|
||||
if self.verbose:
|
||||
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
|
||||
|
||||
# Parse the response into a list of ratings
|
||||
rating_str = response.strip()
|
||||
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
|
||||
if len(chunk_ratings) > len(chunk):
|
||||
chunk_ratings = chunk_ratings[:len(chunk)]
|
||||
if self.verbose:
|
||||
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
|
||||
ratings.extend(chunk_ratings) # Concatenate ratings
|
||||
|
||||
# Ensure ratings length matches number of paths, padding with 0s if necessary
|
||||
if len(ratings) < len(path_strings):
|
||||
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
|
||||
# ratings += [0] * (len(path_strings) - len(ratings))
|
||||
# fall back to use filter encoder to get topN
|
||||
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
|
||||
path_embeddings = self.filter_encoder.encode(path_strings)
|
||||
query_embedding = self.filter_encoder.encode([query])
|
||||
scores = np.dot(path_embeddings, query_embedding.T).flatten()
|
||||
top_indices = np.argsort(scores)[-topN:]
|
||||
top_paths = [path_strings[i] for i in top_indices]
|
||||
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
|
||||
elif len(ratings) > len(path_strings):
|
||||
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
|
||||
ratings = ratings[:len(path_strings)]
|
||||
|
||||
# Sort indices based on ratings in descending order
|
||||
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
|
||||
|
||||
# Filter out indices where the rating is 0
|
||||
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
|
||||
|
||||
# Take the top N indices from the filtered list
|
||||
top_indices = filtered_indices[:topN]
|
||||
|
||||
# Use the filtered indices to get the top paths, PIDS, and PTYPES
|
||||
if self.verbose:
|
||||
self.logger.info(f"Top indices after pruning: {top_indices}")
|
||||
self.logger.info(f"length of path_strings: {len(path_strings)}")
|
||||
top_paths = [path_strings[i] for i in top_indices]
|
||||
top_pids = [PIDS[i] for i in top_indices]
|
||||
top_ptypes = [PTYPES[i] for i in top_indices]
|
||||
|
||||
|
||||
# Log top paths if verbose mode is enabled
|
||||
if self.verbose:
|
||||
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
|
||||
|
||||
return top_paths, top_pids, top_ptypes
|
||||
|
||||
def reasoning(self, query: str, P: List[List[str]]) -> bool:
|
||||
"""
|
||||
Check if the current paths are sufficient to answer the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
P (List[List[str]]): Current list of paths.
|
||||
|
||||
Returns:
|
||||
bool: True if sufficient, False otherwise.
|
||||
"""
|
||||
triples = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for path in P:
|
||||
if len(path) < 3:
|
||||
continue
|
||||
for i in range(0, len(path) - 2, 2):
|
||||
node1_name = path[i]
|
||||
rel = path[i + 1]
|
||||
node2_name = path[i + 2]
|
||||
triples.append(f"({node1_name}, {rel}, {node2_name})")
|
||||
triples_str = ". ".join(triples)
|
||||
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
|
||||
messages = [
|
||||
{"role": "system", "content": "Answer Yes or No only."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Reasoning result: {response}")
|
||||
return "yes" in response.lower()
|
||||
|
||||
def retrieve_passages(self, query: str) -> List[str]:
|
||||
"""
|
||||
Retrieve the top N paths to answer the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
topN (int): Number of paths to return.
|
||||
Dmax (int): Maximum depth of path expansion.
|
||||
Wmax (int): Maximum width of path expansion.
|
||||
|
||||
Returns:
|
||||
List[str]: List of triples as strings.
|
||||
"""
|
||||
topN = self.topN
|
||||
Dmax = self.Dmax
|
||||
Wmax = self.Wmax
|
||||
if self.verbose:
|
||||
self.logger.info(f"Retrieving paths for query: {query}")
|
||||
|
||||
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
|
||||
if not initial_nodes:
|
||||
if self.verbose:
|
||||
self.logger.info("No initial nodes found.")
|
||||
return []
|
||||
|
||||
P = [[node] for node in initial_nodes]
|
||||
PIDS = [[node_id] for node_id in initial_nodes_ids]
|
||||
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
|
||||
for D in range(Dmax + 1):
|
||||
if self.verbose:
|
||||
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
|
||||
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
|
||||
if not P:
|
||||
if self.verbose:
|
||||
self.logger.info("No paths to expand.")
|
||||
break
|
||||
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
|
||||
if D == Dmax:
|
||||
if self.verbose:
|
||||
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
|
||||
break
|
||||
if self.reasoning(query, P):
|
||||
if self.verbose:
|
||||
self.logger.info("Paths sufficient, stopping expansion.")
|
||||
break
|
||||
|
||||
# Extract final triples
|
||||
triples = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for path in P:
|
||||
for i in range(0, len(path) - 2, 2):
|
||||
node1_name = path[i]
|
||||
rel = path[i + 1]
|
||||
node2_name = path[i + 2]
|
||||
triples.append(f"({node1_name}, {rel}, {node2_name})")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Final triples: {triples}")
|
||||
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages
|
||||
51
atlas_rag/retriever/simple_retriever.py
Normal file
51
atlas_rag/retriever/simple_retriever.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
class SimpleGraphRetriever(BaseEdgeRetriever):
|
||||
|
||||
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
data:dict):
|
||||
|
||||
self.KG = data["KG"]
|
||||
self.node_list = data["node_list"]
|
||||
self.edge_list = data["edge_list"]
|
||||
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
self.node_faiss_index = data["node_faiss_index"]
|
||||
self.edge_faiss_index = data["edge_faiss_index"]
|
||||
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
# retrieve the top k edges
|
||||
topk_edges = []
|
||||
query_embedding = self.sentence_encoder.encode([query], query_type='edge')
|
||||
D, I = self.edge_faiss_index.search(query_embedding, topN)
|
||||
|
||||
topk_edges += [self.edge_list[i] for i in I[0]]
|
||||
|
||||
topk_edges_with_data = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in topk_edges]
|
||||
string_edge_edges = [f"{self.KG.nodes[edge[0]]['id']} {edge[1]} {self.KG.nodes[edge[2]]['id']}" for edge in topk_edges_with_data]
|
||||
|
||||
return string_edge_edges, ["N/A" for _ in range(len(string_edge_edges))]
|
||||
|
||||
class SimpleTextRetriever(BasePassageRetriever):
|
||||
def __init__(self, passage_dict:Dict[str,str], sentence_encoder:BaseEmbeddingModel, data:dict):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.passage_dict = passage_dict
|
||||
self.passage_list = list(passage_dict.values())
|
||||
self.passage_keys = list(passage_dict.keys())
|
||||
self.text_embeddings = data["text_embeddings"]
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="passage")
|
||||
sim_scores = self.text_embeddings @ query_emb[0].T
|
||||
topk_indices = np.argsort(sim_scores)[-topN:][::-1] # Get indices of top-k scores
|
||||
|
||||
# Retrieve top-k passages
|
||||
topk_passages = [self.passage_list[i] for i in topk_indices]
|
||||
topk_passages_ids = [self.passage_keys[i] for i in topk_indices]
|
||||
return topk_passages, topk_passages_ids
|
||||
195
atlas_rag/retriever/tog.py
Normal file
195
atlas_rag/retriever/tog.py
Normal file
@ -0,0 +1,195 @@
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from typing import Optional
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever
|
||||
from atlas_rag.retriever.inference_config import InferenceConfig
|
||||
|
||||
class TogRetriever(BaseEdgeRetriever):
|
||||
def __init__(self, llm_generator, sentence_encoder, data, inference_config: Optional[InferenceConfig] = None):
|
||||
self.KG = data["KG"]
|
||||
|
||||
self.node_list = list(self.KG.nodes)
|
||||
self.edge_list = list(self.KG.edges)
|
||||
self.edge_list_with_relation = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in self.edge_list]
|
||||
self.edge_list_string = [f"{edge[0]} {self.KG.edges[edge]['relation']} {edge[1]}" for edge in self.edge_list]
|
||||
|
||||
self.llm_generator:LLMGenerator = llm_generator
|
||||
self.sentence_encoder:BaseEmbeddingModel = sentence_encoder
|
||||
|
||||
self.node_embeddings = data["node_embeddings"]
|
||||
self.edge_embeddings = data["edge_embeddings"]
|
||||
|
||||
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
|
||||
|
||||
def ner(self, text):
|
||||
messages = [
|
||||
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: Are Portland International Airport and Gerald R. Ford International Airport both located in Oregon?"},
|
||||
{"role": "system", "content": "Portland International Airport, Gerald R. Ford International Airport, Oregon"},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"},
|
||||
]
|
||||
|
||||
|
||||
response = self.llm_generator.generate_response(messages)
|
||||
generated_text = response
|
||||
# print(generated_text)
|
||||
return generated_text
|
||||
|
||||
|
||||
|
||||
def retrieve_topk_nodes(self, query, topN=5, **kwargs):
|
||||
# extract entities from the query
|
||||
entities = self.ner(query)
|
||||
entities = entities.split(", ")
|
||||
|
||||
if len(entities) == 0:
|
||||
# If the NER cannot extract any entities, we
|
||||
# use the query as the entity to do approximate search
|
||||
entities = [query]
|
||||
|
||||
# evenly distribute the topk for each entity
|
||||
topk_for_each_entity = topN//len(entities)
|
||||
|
||||
# retrieve the top k nodes
|
||||
topk_nodes = []
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
if entity in self.node_list:
|
||||
topk_nodes.append(entity)
|
||||
|
||||
for entity_index, entity in enumerate(entities):
|
||||
topk_for_this_entity = topk_for_each_entity + 1
|
||||
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
# Calculate similarity scores using dot product
|
||||
scores = self.node_embeddings @ entity_embedding[0].T
|
||||
# Get top-k indices
|
||||
top_indices = np.argsort(scores)[-topk_for_this_entity:][::-1]
|
||||
topk_nodes += [self.node_list[i] for i in top_indices]
|
||||
|
||||
topk_nodes = list(set(topk_nodes))
|
||||
|
||||
if len(topk_nodes) > 2*topN:
|
||||
topk_nodes = topk_nodes[:2*topN]
|
||||
return topk_nodes
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
"""
|
||||
Retrieve the top N paths that connect the entities in the query.
|
||||
Dmax is the maximum depth of the search.
|
||||
"""
|
||||
Dmax = self.inference_config.Dmax
|
||||
# in the first step, we retrieve the top k nodes
|
||||
initial_nodes = self.retrieve_topk_nodes(query, topN=topN)
|
||||
E = initial_nodes
|
||||
P = [ [e] for e in E]
|
||||
D = 0
|
||||
|
||||
while D <= Dmax:
|
||||
P = self.search(query, P)
|
||||
P = self.prune(query, P, topN)
|
||||
|
||||
if self.reasoning(query, P):
|
||||
generated_text = self.generate(query, P)
|
||||
break
|
||||
|
||||
D += 1
|
||||
|
||||
if D > Dmax:
|
||||
generated_text = self.generate(query, P)
|
||||
|
||||
# print(generated_text)
|
||||
return generated_text
|
||||
|
||||
def search(self, query, P):
|
||||
new_paths = []
|
||||
for path in P:
|
||||
tail_entity = path[-1]
|
||||
sucessors = list(self.KG.successors(tail_entity))
|
||||
predecessors = list(self.KG.predecessors(tail_entity))
|
||||
|
||||
# print(f"tail_entity: {tail_entity}")
|
||||
# print(f"sucessors: {sucessors}")
|
||||
# print(f"predecessors: {predecessors}")
|
||||
|
||||
# # print the attributes of the tail_entity
|
||||
# print(f"attributes of the tail_entity: {self.KG.nodes[tail_entity]}")
|
||||
|
||||
# remove the entity that is already in the path
|
||||
sucessors = [neighbour for neighbour in sucessors if neighbour not in path]
|
||||
predecessors = [neighbour for neighbour in predecessors if neighbour not in path]
|
||||
|
||||
if len(sucessors) == 0 and len(predecessors) == 0:
|
||||
new_paths.append(path)
|
||||
continue
|
||||
for neighbour in sucessors:
|
||||
relation = self.KG.edges[(tail_entity, neighbour)]["relation"]
|
||||
new_path = path + [relation, neighbour]
|
||||
new_paths.append(new_path)
|
||||
|
||||
for neighbour in predecessors:
|
||||
relation = self.KG.edges[(neighbour, tail_entity)]["relation"]
|
||||
new_path = path + [relation, neighbour]
|
||||
new_paths.append(new_path)
|
||||
|
||||
return new_paths
|
||||
|
||||
def prune(self, query, P, topN=3):
|
||||
ratings = []
|
||||
|
||||
for path in P:
|
||||
path_string = ""
|
||||
for index, node_or_relation in enumerate(path):
|
||||
if index % 2 == 0:
|
||||
id_path = self.KG.nodes[node_or_relation]["id"]
|
||||
else:
|
||||
id_path = node_or_relation
|
||||
path_string += f"{id_path} --->"
|
||||
path_string = path_string[:-5]
|
||||
|
||||
prompt = f"Please rating the following path based on the relevance to the question. The ratings should be in the range of 1 to 5. 1 for least relevant and 5 for most relevant. Only provide the rating, do not provide any other information. The output should be a single integer number. If you think the path is not relevant, please provide 0. If you think the path is relevant, please provide a rating between 1 and 5. \n Query: {query} \n path: {path_string}"
|
||||
|
||||
messages = [{"role": "system", "content": "Answer the question following the prompt."},
|
||||
{"role": "user", "content": f"{prompt}"}]
|
||||
|
||||
response = self.llm_generator.generate_response(messages)
|
||||
# print(response)
|
||||
rating = int(response)
|
||||
ratings.append(rating)
|
||||
|
||||
# sort the paths based on the ratings
|
||||
sorted_paths = [path for _, path in sorted(zip(ratings, P), reverse=True)]
|
||||
|
||||
return sorted_paths[:topN]
|
||||
|
||||
def reasoning(self, query, P):
|
||||
triples = []
|
||||
for path in P:
|
||||
for i in range(0, len(path)-2, 2):
|
||||
|
||||
# triples.append((path[i], path[i+1], path[i+2]))
|
||||
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
|
||||
|
||||
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
|
||||
triples_string = ". ".join(triples_string)
|
||||
|
||||
prompt = f"Given a question and the associated retrieved knowledge graph triples (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triples and your knowledge (Yes or No). Query: {query} \n Knowledge triples: {triples_string}"
|
||||
|
||||
messages = [{"role": "system", "content": "Answer the question following the prompt."},
|
||||
{"role": "user", "content": f"{prompt}"}]
|
||||
|
||||
response = self.llm_generator.generate_response(messages)
|
||||
return "yes" in response.lower()
|
||||
|
||||
def generate(self, query, P):
|
||||
triples = []
|
||||
for path in P:
|
||||
for i in range(0, len(path)-2, 2):
|
||||
# triples.append((path[i], path[i+1], path[i+2]))
|
||||
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
|
||||
|
||||
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
|
||||
|
||||
# response = self.llm_generator.generate_with_context_kg(query, triples_string)
|
||||
return triples_string, ["N/A" for _ in range(len(triples_string))]
|
||||
Reference in New Issue
Block a user