first commit

This commit is contained in:
闫旭隆
2025-09-24 09:29:12 +08:00
parent 6339cdebb9
commit 2308536f66
360 changed files with 136381 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from .hipporag import HippoRAGRetriever
from .hipporag2 import HippoRAG2Retriever
from .simple_retriever import SimpleGraphRetriever, SimpleTextRetriever
from .tog import TogRetriever

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import List, Tuple
class BaseRetriever(ABC):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@abstractmethod
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
raise NotImplementedError("This method should be overridden by subclasses.")
class BaseEdgeRetriever(BaseRetriever):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@abstractmethod
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
raise NotImplementedError("This method should be overridden by subclasses.")
class BasePassageRetriever(BaseRetriever):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@abstractmethod
def retrieve(self, query, topk=5, **kwargs) -> Tuple[List[str], List[str]]:
raise NotImplementedError("This method should be overridden by subclasses.")

View File

@ -0,0 +1,140 @@
from tqdm import tqdm
import networkx as nx
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from logging import Logger
from typing import Optional
from atlas_rag.retriever.base import BasePassageRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
class HippoRAGRetriever(BasePassageRetriever):
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
data:dict, inference_config: Optional[InferenceConfig] = None, logger = None, **kwargs):
self.passage_dict = data["text_dict"]
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.node_list = data["node_list"]
file_id_to_node_id = {}
self.KG = data["KG"]
for node_id in tqdm(list(self.KG.nodes)):
if self.KG.nodes[node_id]['type'] == "passage":
if self.KG.nodes[node_id]['file_id'] not in file_id_to_node_id:
file_id_to_node_id[self.KG.nodes[node_id]['file_id']] = []
file_id_to_node_id[self.KG.nodes[node_id]['file_id']].append(node_id)
self.file_id_to_node_id = file_id_to_node_id
self.KG:nx.DiGraph = self.KG.subgraph(self.node_list)
self.node_name_list = [self.KG.nodes[node]["id"] for node in self.node_list]
self.logger :Logger = logger
if self.logger is None:
self.logging = False
else:
self.logging = True
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
def retrieve_personalization_dict(self, query, topN=10):
# extract entities from the query
entities = self.llm_generator.ner(query)
entities = entities.split(", ")
if self.logging:
self.logger.info(f"HippoRAG NER Entities: {entities}")
# print("Entities:", entities)
if len(entities) == 0:
# If the NER cannot extract any entities, we
# use the query as the entity to do approximate search
entities = [query]
# evenly distribute the topk for each entity
topk_for_each_entity = topN//len(entities)
# retrieve the top k nodes
topk_nodes = []
for entity_index, entity in enumerate(entities):
if entity in self.node_name_list:
# get the index of the entity in the node list
index = self.node_name_list.index(entity)
topk_nodes.append(self.node_list[index])
else:
topk_for_this_entity = 1
# print("Topk for this entity:", topk_for_this_entity)
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
scores = self.node_embeddings@entity_embedding[0].T
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
topk_nodes += [self.node_list[i] for i in index_matrix]
if self.logging:
self.logger.info(f"HippoRAG Topk Nodes: {[self.KG.nodes[node]['id'] for node in topk_nodes]}")
topk_nodes = list(set(topk_nodes))
# assert len(topk_nodes) <= topN
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
# print("Topk nodes:", topk_nodes)
# find the number of docs that one work appears in
freq_dict_for_nodes = {}
for node in topk_nodes:
node_data = self.KG.nodes[node]
# print(node_data)
file_ids = node_data["file_id"]
file_ids_list = file_ids.split(",")
#uniq this list
file_ids_list = list(set(file_ids_list))
freq_dict_for_nodes[node] = len(file_ids_list)
personalization_dict = {node: 1 / freq_dict_for_nodes[node] for node in topk_nodes}
# print("personalization dict: ")
return personalization_dict
def retrieve(self, query, topN=5, **kwargs):
topN_nodes = self.inference_config.topk_nodes
personaliation_dict = self.retrieve_personalization_dict(query, topN=topN_nodes)
# retrieve the top N passages
pr = nx.pagerank(self.KG, personalization=personaliation_dict)
for node in pr:
pr[node] = round(pr[node], 4)
if pr[node] < 0.001:
pr[node] = 0
passage_probabilities_sum = {}
for node in pr:
node_data = self.KG.nodes[node]
file_ids = node_data["file_id"]
# for each file id check through each text_id
file_ids_list = file_ids.split(",")
#uniq this list
file_ids_list = list(set(file_ids_list))
# file id to node id
for file_id in file_ids_list:
if file_id == 'concept_file':
continue
for node_id in self.file_id_to_node_id[file_id]:
if node_id not in passage_probabilities_sum:
passage_probabilities_sum[node_id] = 0
passage_probabilities_sum[node_id] += pr[node]
sorted_passages = sorted(passage_probabilities_sum.items(), key=lambda x: x[1], reverse=True)
top_passages = sorted_passages[:topN]
top_passages, scores = zip(*top_passages)
passag_contents = [self.passage_dict[passage_id] for passage_id in top_passages]
return passag_contents, top_passages

View File

@ -0,0 +1,237 @@
import networkx as nx
import json
from tqdm import tqdm
import json
from tqdm import tqdm
from typing import Dict, List, Tuple
import networkx as nx
import numpy as np
import json_repair
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from logging import Logger
from dataclasses import dataclass
from typing import Optional
from atlas_rag.retriever.base import BasePassageRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
def min_max_normalize(x):
min_val = np.min(x)
max_val = np.max(x)
range_val = max_val - min_val
# Handle the case where all values are the same (range is zero)
if range_val == 0:
return np.ones_like(x) # Return an array of ones with the same shape as x
return (x - min_val) / range_val
class HippoRAG2Retriever(BasePassageRetriever):
def __init__(self, llm_generator:LLMGenerator,
sentence_encoder:BaseEmbeddingModel,
data : dict,
inference_config: Optional[InferenceConfig] = None,
logger = None,
**kwargs):
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.node_list = data["node_list"]
self.edge_list = data["edge_list"]
self.edge_embeddings = data["edge_embeddings"]
self.text_embeddings = data["text_embeddings"]
self.edge_faiss_index = data["edge_faiss_index"]
self.passage_dict = data["text_dict"]
self.text_id_list = list(self.passage_dict.keys())
self.KG = data["KG"]
self.KG = self.KG.subgraph(self.node_list + self.text_id_list)
self.logger = logger
if self.logger is None:
self.logging = False
else:
self.logging = True
hipporag2mode = "query2edge"
if hipporag2mode == "query2edge":
self.retrieve_node_fn = self.query2edge
elif hipporag2mode == "query2node":
self.retrieve_node_fn = self.query2node
elif hipporag2mode == "ner2node":
self.retrieve_node_fn = self.ner2node
else:
raise ValueError(f"Invalid mode: {hipporag2mode}. Choose from 'query2edge', 'query2node', or 'query2passage'.")
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
node_id_to_file_id = {}
for node_id in tqdm(list(self.KG.nodes)):
if self.inference_config.keyword == "musique" and self.KG.nodes[node_id]['type']=="passage":
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["id"]
else:
node_id_to_file_id[node_id] = self.KG.nodes[node_id]["file_id"]
self.node_id_to_file_id = node_id_to_file_id
def ner(self, text):
return self.llm_generator.ner(text)
def ner2node(self, query, topN = 10):
entities = self.ner(query)
entities = entities.split(", ")
if len(entities) == 0:
entities = [query]
# retrieve the top k nodes
topk_nodes = []
node_score_dict = {}
for entity_index, entity in enumerate(entities):
topk_for_this_entity = 1
entity_embedding = self.sentence_encoder.encode([entity], query_type="search")
scores = min_max_normalize(self.node_embeddings@entity_embedding[0].T)
index_matrix = np.argsort(scores)[-topk_for_this_entity:][::-1]
similarity_matrix = [scores[i] for i in index_matrix]
for index, sim_score in zip(index_matrix, similarity_matrix):
node = self.node_list[index]
if node not in topk_nodes:
topk_nodes.append(node)
node_score_dict[node] = sim_score
topk_nodes = list(set(topk_nodes))
result_node_score_dict = {}
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
for node in topk_nodes:
if node in node_score_dict:
result_node_score_dict[node] = node_score_dict[node]
return result_node_score_dict
def query2node(self, query, topN = 10):
query_emb = self.sentence_encoder.encode([query], query_type="entity")
scores = min_max_normalize(self.node_embeddings@query_emb[0].T)
index_matrix = np.argsort(scores)[-topN:][::-1]
similarity_matrix = [scores[i] for i in index_matrix]
result_node_score_dict = {}
for index, sim_score in zip(index_matrix, similarity_matrix):
node = self.node_list[index]
result_node_score_dict[node] = sim_score
return result_node_score_dict
def query2edge(self, query, topN = 10):
query_emb = self.sentence_encoder.encode([query], query_type="edge")
scores = min_max_normalize(self.edge_embeddings@query_emb[0].T)
index_matrix = np.argsort(scores)[-topN:][::-1]
log_edge_list = []
for index in index_matrix:
edge = self.edge_list[index]
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
log_edge_list.append(edge_str)
similarity_matrix = [scores[i] for i in index_matrix]
# construct the edge list
before_filter_edge_json = {}
before_filter_edge_json['fact'] = []
for index, sim_score in zip(index_matrix, similarity_matrix):
edge = self.edge_list[index]
edge_str = [self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']]
before_filter_edge_json['fact'].append(edge_str)
if self.logging:
self.logger.info(f"HippoRAG2 Before Filter Edge: {before_filter_edge_json['fact']}")
filtered_facts = self.llm_generator.filter_triples_with_entity_event(query, json.dumps(before_filter_edge_json, ensure_ascii=False))
filtered_facts = json_repair.loads(filtered_facts)['fact']
if len(filtered_facts) == 0:
return {}
# use filtered facts to get the edge id and check if it exists in the original candidate list.
node_score_dict = {}
log_edge_list = []
for edge in filtered_facts:
edge_str = f'{edge[0]} {edge[1]} {edge[2]}'
search_emb = self.sentence_encoder.encode([edge_str], query_type="search")
D, I = self.edge_faiss_index.search(search_emb, 1)
filtered_index = I[0][0]
# get the edge and the original score
edge = self.edge_list[filtered_index]
log_edge_list.append([self.KG.nodes[edge[0]]['id'], self.KG.edges[edge]['relation'], self.KG.nodes[edge[1]]['id']])
head, tail = edge[0], edge[1]
sim_score = scores[filtered_index]
if head not in node_score_dict:
node_score_dict[head] = [sim_score]
else:
node_score_dict[head].append(sim_score)
if tail not in node_score_dict:
node_score_dict[tail] = [sim_score]
else:
node_score_dict[tail].append(sim_score)
# average the scores
if self.logging:
self.logger.info(f"HippoRAG2: Filtered edges: {log_edge_list}")
# take average of the scores
for node in node_score_dict:
node_score_dict[node] = sum(node_score_dict[node]) / len(node_score_dict[node])
return node_score_dict
def query2passage(self, query, weight_adjust = 0.05):
query_emb = self.sentence_encoder.encode([query], query_type="passage")
sim_scores = self.text_embeddings @ query_emb[0].T
sim_scores = min_max_normalize(sim_scores)*weight_adjust # converted to probability
# create dict of passage id and score
return dict(zip(self.text_id_list, sim_scores))
def retrieve_personalization_dict(self, query, topN=30, weight_adjust=0.05):
node_dict = self.retrieve_node_fn(query, topN=topN)
text_dict = self.query2passage(query, weight_adjust=weight_adjust)
return node_dict, text_dict
def retrieve(self, query, topN=5, **kwargs):
topN_edges = self.inference_config.topk_edges
weight_adjust = self.inference_config.weight_adjust
node_dict, text_dict = self.retrieve_personalization_dict(query, topN=topN_edges, weight_adjust=weight_adjust)
personalization_dict = {}
if len(node_dict) == 0:
# return topN text passages
sorted_passages = sorted(text_dict.items(), key=lambda x: x[1], reverse=True)
sorted_passages = sorted_passages[:topN]
sorted_passages_contents = []
sorted_scores = []
sorted_passage_ids = []
for passage_id, score in sorted_passages:
sorted_passages_contents.append(self.passage_dict[passage_id])
sorted_scores.append(float(score))
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
return sorted_passages_contents, sorted_passage_ids
personalization_dict.update(node_dict)
personalization_dict.update(text_dict)
# retrieve the top N passages
pr = nx.pagerank(self.KG, personalization=personalization_dict,
alpha = self.inference_config.ppr_alpha,
max_iter=self.inference_config.ppr_max_iter,
tol=self.inference_config.ppr_tol)
# get the top N passages based on the text_id list and pagerank score
text_dict_score = {}
for node in self.text_id_list:
# filter out nodes that have 0 score
if pr[node] > 0.0:
text_dict_score[node] = pr[node]
# return topN passages
sorted_passages_ids = sorted(text_dict_score.items(), key=lambda x: x[1], reverse=True)
sorted_passages_ids = sorted_passages_ids[:topN]
sorted_passages_contents = []
sorted_scores = []
sorted_passage_ids = []
for passage_id, score in sorted_passages_ids:
sorted_passages_contents.append(self.passage_dict[passage_id])
sorted_scores.append(score)
sorted_passage_ids.append(self.node_id_to_file_id[passage_id])
return sorted_passages_contents, sorted_passage_ids

View File

@ -0,0 +1,23 @@
from dataclasses import dataclass
@dataclass
class InferenceConfig:
"""
Configuration class for inference settings.
Attributes:
topk (int): Number of top results to retrieve. Default is 5.
Dmax (int): Maximum depth for search. Default is 4.
weight_adjust (float): Weight adjustment factor for passage retrieval. Default is 0.05.
topk_edges (int): Number of top edges to retrieve. Default is 50.
topk_nodes (int): Number of top nodes to retrieve. Default is 10.
"""
keyword: str = "musique"
topk: int = 5
Dmax: int = 4
weight_adjust: float = 1.0
topk_edges: int = 50
topk_nodes: int = 10
ppr_alpha: float = 0.99
ppr_max_iter: int = 2000
ppr_tol: float = 1e-7

View File

@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
class BaseLargeKGRetriever(ABC):
def __init__():
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
@abstractmethod
def retrieve_passages(self, query, retriever_config:dict):
"""
Retrieve passages based on the query.
Args:
query (str): The input query.
topN (int): Number of top passages to retrieve.
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
sampling_area (int): Area for sampling in the graph.
Returns:
List of retrieved passages and their scores.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class BaseLargeKGEdgeRetriever(ABC):
def __init__():
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
@abstractmethod
def retrieve_passages(self, query, retriever_config:dict):
"""
Retrieve Edges / Paths based on the query.
Args:
query (str): The input query.
topN (int): Number of top passages to retrieve.
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
sampling_area (int): Area for sampling in the graph.
Returns:
List of retrieved passages and their scores.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

View File

@ -0,0 +1,313 @@
from difflib import get_close_matches
from logging import Logger
import faiss
from neo4j import GraphDatabase
import time
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from graphdatascience import GraphDataScience
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
import string
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
class LargeKGRetriever(BaseLargeKGRetriever):
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
node_index:faiss.Index, passage_index:faiss.Index,
topN: int = 5,
number_of_source_nodes_per_ner: int = 10,
sampling_area : int = 250,logger:Logger = None):
# istantiate one kg resources
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.gds_driver = GraphDataScience(self.neo4j_driver)
self.topN = topN
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
self.sampling_area = sampling_area
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = node_index
# self.edge_faiss_index = self.edge_indexes[keyword]
self.passage_faiss_index = passage_index
self.verbose = False if logger is None else True
self.logger = logger
self.ppr_weight_threshold = 0.00005
def set_model(self, model):
if self.llm_generator.inference_type == 'openai':
self.llm_generator.model_name = model
else:
raise ValueError("Model can only be set for OpenAI inference type.")
def ner(self, text):
return self.llm_generator.large_kg_ner(text)
def convert_numeric_id_to_name(self, numeric_id):
if numeric_id.isdigit():
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
else:
return numeric_id
def has_intersection(self, word_set, input_string):
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
if self.keyword == 'cc_en':
# Check if any phrase in word_set is a substring of cleaned_string
for phrase in word_set:
if phrase in cleaned_string:
return True
return False
else:
# Check if any word in word_set is present in the cleaned_string's words
words_in_string = set(cleaned_string.split())
return not word_set.isdisjoint(words_in_string)
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
graph = self.gds_driver.graph.get('largekgrag_graph')
node_count = graph.node_count()
sampling_ratio = sampling_area / node_count
aggregation_node_dict = []
ppr_weight_threshold = self.ppr_weight_threshold
start_time = time.time()
# Pre-filter nodes based on ppr_weight threshold
# Precompute word sets based on keyword
if self.keyword == 'cc_en':
filtered_personalization = {
node_id: ppr_weight
for node_id, ppr_weight in personalization_dict.items()
if ppr_weight >= ppr_weight_threshold
}
stop_words = set(stopwords.words('english'))
word_set_phrases = set()
word_set_words = set()
for node_id, ppr_weight in filtered_personalization.items():
name = self.gds_driver.util.asNode(node_id)['name']
if name:
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
if cleaned_phrase:
# Process for 'cc_en': remove stop words and add cleaned phrase
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
if filtered_words:
cleaned_phrase_filtered = ' '.join(filtered_words)
word_set_phrases.add(cleaned_phrase_filtered)
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
if self.verbose:
self.logger.info(f"Optimized word set: {word_set}")
else:
filtered_personalization = personalization_dict
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
# Process each node in the filtered personalization dict
for node_id, ppr_weight in filtered_personalization.items():
try:
self.gds_driver.graph.drop('rwr_sample')
start_time = time.time()
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
restartProbability = 0.4, logProgress = False)
if self.verbose:
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
result = self.gds_driver.pageRank.stream(
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
).sort_values("score", ascending=False)
if self.verbose:
self.logger.info(f"pagerank type: {type(result)}")
self.logger.info(f"pagerank result: {result}")
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
# if self.keyword == 'cc_en':
if self.keyword != 'cc_en':
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
else:
result = result.to_dict('records')
if self.verbose:
self.logger.info(f"largekgRAG :result: {result}")
for entry in result:
if self.keyword == 'cc_en':
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
if not self.has_intersection(word_set, node_name):
continue
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
aggregation_node_dict.append({
'nodeId': numeric_id,
'score': entry['score'] * ppr_weight
})
except Exception as e:
if self.verbose:
self.logger.error(f"Error processing node {node_id}: {e}")
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
else:
continue
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
if self.verbose:
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
if self.verbose:
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
start_time = time.time()
with self.neo4j_driver.session() as session:
intermediate_time = time.time()
# Step 1: Distribute entity scores to connected text nodes and find the top 5
query_scores = """
UNWIND $entries AS entry
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
ORDER BY total_score DESC
LIMIT $topN
RETURN textId, total_score
"""
# Execute query to aggregate scores
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
top_numeric_ids = []
top_scores = []
# Extract the top text node IDs and scores
for record in result_scores:
top_numeric_ids.append(record["textId"])
top_scores.append(record["total_score"])
# Step 2: Use top numeric IDs to retrieve the original text
if self.verbose:
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
intermediate_time = time.time()
query_text = """
UNWIND $textIds AS textId
MATCH (t:Text {numeric_id: textId})
RETURN t.original_text AS text, t.numeric_id AS textId
"""
result_texts = session.run(query_text, textIds=top_numeric_ids)
topN_passages = []
score_dict = dict(zip(top_numeric_ids, top_scores))
# Combine original text with scores
for record in result_texts:
original_text = record["text"]
text_id = record["textId"]
score = score_dict.get(text_id, 0)
topN_passages.append((original_text, score))
if self.verbose:
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
# Sort passages by score
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
top_texts = [item[0] for item in topN_passages][:topN]
top_scores = [item[1] for item in topN_passages][:topN]
if self.verbose:
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
self.logger.info(f"Top passages: {top_texts}")
self.logger.info(f"Top scores: {top_scores}")
if self.verbose:
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
return top_texts, top_scores
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
# extract entities from the query
entities = self.ner(query)
if self.verbose:
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
if len(entities) == 0:
entities = [query]
num_entities = len(entities)
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
if self.verbose:
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
initial_nodes += [str(i)for i in I[0]]
if self.verbose:
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
name_id_map = {}
for node_id in initial_nodes:
name = self.convert_numeric_id_to_name(node_id)
name_id_map[name] = node_id
topk_nodes = list(set(initial_nodes))
# convert the numeric id to string and filter again then return numeric id
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
# Second pass: Add filtered keywords
filtered_top_k_nodes = []
filter_log_dict = {}
match_threshold = 0.8
if self.verbose:
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
for keyword in filtered_keywords:
# Check for an exact match first
if keyword in name_id_map:
filtered_top_k_nodes.append(name_id_map[keyword])
filter_log_dict[keyword] = name_id_map[keyword]
else:
# Look for close matches using difflib's get_close_matches
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
if close_matches:
# If a close match is found, add the corresponding node
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
if self.verbose:
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
topk_nodes = list(set(filtered_top_k_nodes))
if len(topk_nodes) > 2 * num_entities:
topk_nodes = topk_nodes[:2 * num_entities]
return topk_nodes
def _process_text(self, text):
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
text = text.lower()
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
return set(text.split())
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
if topk_nodes == []:
if self.verbose:
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
return {}
if self.verbose:
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
freq_dict_for_nodes = {}
query = """
UNWIND $nodes AS node
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
"""
with self.neo4j_driver.session() as session:
result = session.run(query, nodes=topk_nodes)
for record in result:
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
# Create the personalization dictionary
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
return personalization_dict
def retrieve_passages(self, query):
if self.verbose:
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
topN = self.topN
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
sampling_area = self.sampling_area
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
if personalization_dict == {}:
return [], [0]
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
return topN_passages, topN_scores

View File

@ -0,0 +1,469 @@
from neo4j import GraphDatabase
import faiss
import numpy as np
import random
from collections import defaultdict
from typing import List
import time
import logging
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
node_index: faiss.Index,
topN : int = 5,
Dmax : int = 3,
Wmax : int = 3,
prune_size: int = 10,
logger: logging.Logger = None):
"""
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
Args:
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
neo4j_driver (GraphDatabase): Neo4j driver for database access.
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
node_index (faiss.Index): FAISS index for node embeddings.
logger (Logger, optional): Logger for verbose output.
"""
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.filter_encoder = filter_encoder
self.node_faiss_index = node_index
self.verbose = logger is not None
self.logger = logger
self.topN = topN
self.Dmax = Dmax
self.Wmax = Wmax
self.prune_size = prune_size
def ner(self, text: str) -> List[str]:
"""
Extract named entities from the query text using the LLM.
Args:
text (str): The query text.
Returns:
List[str]: List of extracted entities.
"""
entities = self.llm_generator.large_kg_ner(text)
if self.verbose:
self.logger.info(f"Extracted entities: {entities}")
return entities
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
"""
Retrieve top-k nodes similar to entities in the query.
Args:
query (str): The user query.
top_k_nodes (int): Number of nodes to retrieve per entity.
Returns:
List[str]: List of node numeric_ids.
"""
start_time = time.time()
entities = self.ner(query)
if self.verbose:
ner_time = time.time() - start_time
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
if not entities:
entities = [query]
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, 3)
if self.verbose:
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
if len(I[0]) > 0: # Check if results exist
initial_nodes.extend([str(i) for i in I[0]])
# no need filtering as ToG pruning will handle it.
topk_nodes_ids = list(set(initial_nodes))
with self.neo4j_driver.session() as session:
start_time = time.time()
query = """
MATCH (n:Node)
WHERE n.numeric_id IN $topk_nodes_ids
RETURN n.numeric_id AS id, n.name AS name
"""
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
topk_nodes_dict = {}
for record in result:
topk_nodes_dict[record["id"]] = record["name"]
if self.verbose:
neo4j_time = time.time() - start_time
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
if self.verbose:
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
"""
Expand each path by adding neighbors of the last node.
Args:
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
Returns:
List[List[str]]: List of expanded paths.
"""
last_nodes = []
last_node_ids = []
last_node_types = []
paths_end_with_text = []
paths_end_with_text_id = []
paths_end_with_text_type = []
paths_end_with_node = []
paths_end_with_node_id = []
paths_end_with_node_type = []
if self.verbose:
self.logger.info(f"Expanding paths, current paths: {P}")
for p, pid, ptype in zip(P, PIDS, PTYPES):
if not p or not pid or not ptype: # Skip empty paths
continue
t = ptype[-1]
if t == "Text":
paths_end_with_text.append(p)
paths_end_with_text_id.append(pid)
paths_end_with_text_type.append(ptype)
continue
last_node = p[-1] # Last node in the path
last_node_id = pid[-1] # Last node numeric_id
last_nodes.append(last_node)
last_node_ids.append(last_node_id)
last_node_types.append(t)
paths_end_with_node.append(p)
paths_end_with_node_id.append(pid)
paths_end_with_node_type.append(ptype)
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
if not last_node_ids:
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
with self.neo4j_driver.session() as session:
# Query Node relationships
start_time = time.time()
outgoing_query = """
CALL apoc.cypher.runTimeboxed(
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
WITH n, r, m ORDER BY rand() LIMIT 60000
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
{last_node_ids: $last_node_ids},
60000
)
YIELD value
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
"""
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
for record in outgoing_result]
if self.verbose:
outgoing_time = time.time() - start_time
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
# # Query outgoing Node -> Text relationships
# start_time = time.time()
# outgoing_text_query = """
# MATCH (n:Node)-[r:Source]->(t:Text)
# WHERE n.numeric_id IN $last_node_ids
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
# """
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
# for record in outgoing_text_result]
# if self.verbose:
# outgoing_text_time = time.time() - start_time
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
last_node_to_new_paths = defaultdict(list)
last_node_to_new_paths_ids = defaultdict(list)
last_node_to_new_paths_types = defaultdict(list)
for p, pid, ptype in zip(P, PIDS, PTYPES):
last_node = p[-1]
last_node_id = pid[-1]
# Outgoing Node -> Node
for source, source_name, rel_type, target, target_name, target_type in outgoing:
if source == last_node_id and target_name not in p:
new_path = p + [rel_type, target_name]
if target_name.lower() in stopwords.words('english'):
continue
last_node_to_new_paths[last_node].append(new_path)
last_node_to_new_paths_ids[last_node].append(pid + [target])
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
# # Outgoing Node -> Text
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
# if source == last_node_id and target_name not in p:
# new_path = p + [rel_type, target_name]
# last_node_to_new_paths_text[last_node].append(new_path)
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
# # Incoming Node -> Node
# for source, rel_type, target, source_name, source_type in incoming:
# if target == last_node_id and source not in p:
# new_path = p + [rel_type, source_name]
num_paths = 0
for last_node, new_paths in last_node_to_new_paths.items():
num_paths += len(new_paths)
# for last_node, new_paths in last_node_to_new_paths_text.items():
# num_paths += len(new_paths)
new_paths = []
new_pids = []
new_ptypes = []
if self.verbose:
self.logger.info(f"Number of new paths before filtering: {num_paths}")
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
if num_paths > len(last_node_ids) * width:
# Apply filtering when total paths exceed threshold
for last_node, new_ps in last_node_to_new_paths.items():
if len(new_ps) > width:
path_embeddings = self.filter_encoder.encode(new_ps)
query_embeddings = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
top_indices = np.argsort(scores)[-width:]
new_paths.extend([new_ps[i] for i in top_indices])
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
else:
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
else:
# Collect all paths without filtering when total is at or below threshold
for last_node, new_ps in last_node_to_new_paths.items():
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
if self.verbose:
self.logger.info(f"Expanded paths count: {len(new_paths)}")
self.logger.info(f"Expanded paths: {new_paths}")
return new_paths, new_pids, new_ptypes
def path_to_string(self, path: List[str]) -> str:
"""
Convert a path to a human-readable string for LLM rating.
Args:
path (List[str]): Path as a list of node_ids and relation types.
Returns:
str: String representation of the path.
"""
if len(path) < 1:
return ""
path_str = []
with self.neo4j_driver.session() as session:
for i in range(0, len(path), 2):
node_id = path[i]
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
node_name = result.single()["n.name"] if result.single() else node_id
if i + 1 < len(path):
rel_type = path[i + 1]
path_str.append(f"{node_name} ---> {rel_type} --->")
else:
path_str.append(node_name)
return " ".join(path_str).strip()
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
"""
Prune paths to keep the top N based on LLM relevance ratings.
Args:
query (str): The user query.
P (List[List[str]]): List of paths to prune.
topN (int): Number of paths to retain.
Returns:
List[List[str]]: Top N paths.
"""
ratings = []
path_strings = P
# Process paths in chunks of 10
for i in range(0, len(path_strings), self.prune_size):
chunk = path_strings[i:i + self.prune_size]
# Construct user prompt with the current chunk of paths listed
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
for j, path_str in enumerate(chunk, 1):
user_prompt += f"{j + i}. {path_str}\n"
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
# Define system prompt to expect a list of integers
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
# Send the prompt to the language model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
if self.verbose:
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
# Parse the response into a list of ratings
rating_str = response.strip()
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
if len(chunk_ratings) > len(chunk):
chunk_ratings = chunk_ratings[:len(chunk)]
if self.verbose:
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
ratings.extend(chunk_ratings) # Concatenate ratings
# Ensure ratings length matches number of paths, padding with 0s if necessary
if len(ratings) < len(path_strings):
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
# ratings += [0] * (len(path_strings) - len(ratings))
# fall back to use filter encoder to get topN
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
path_embeddings = self.filter_encoder.encode(path_strings)
query_embedding = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embedding.T).flatten()
top_indices = np.argsort(scores)[-topN:]
top_paths = [path_strings[i] for i in top_indices]
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
elif len(ratings) > len(path_strings):
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
ratings = ratings[:len(path_strings)]
# Sort indices based on ratings in descending order
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
# Filter out indices where the rating is 0
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
# Take the top N indices from the filtered list
top_indices = filtered_indices[:topN]
# Use the filtered indices to get the top paths, PIDS, and PTYPES
if self.verbose:
self.logger.info(f"Top indices after pruning: {top_indices}")
self.logger.info(f"length of path_strings: {len(path_strings)}")
top_paths = [path_strings[i] for i in top_indices]
top_pids = [PIDS[i] for i in top_indices]
top_ptypes = [PTYPES[i] for i in top_indices]
# Log top paths if verbose mode is enabled
if self.verbose:
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
return top_paths, top_pids, top_ptypes
def reasoning(self, query: str, P: List[List[str]]) -> bool:
"""
Check if the current paths are sufficient to answer the query.
Args:
query (str): The user query.
P (List[List[str]]): Current list of paths.
Returns:
bool: True if sufficient, False otherwise.
"""
triples = []
with self.neo4j_driver.session() as session:
for path in P:
if len(path) < 3:
continue
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
triples_str = ". ".join(triples)
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
messages = [
{"role": "system", "content": "Answer Yes or No only."},
{"role": "user", "content": prompt}
]
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
if self.verbose:
self.logger.info(f"Reasoning result: {response}")
return "yes" in response.lower()
def retrieve_passages(self, query: str) -> List[str]:
"""
Retrieve the top N paths to answer the query.
Args:
query (str): The user query.
topN (int): Number of paths to return.
Dmax (int): Maximum depth of path expansion.
Wmax (int): Maximum width of path expansion.
Returns:
List[str]: List of triples as strings.
"""
topN = self.topN
Dmax = self.Dmax
Wmax = self.Wmax
if self.verbose:
self.logger.info(f"Retrieving paths for query: {query}")
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
if not initial_nodes:
if self.verbose:
self.logger.info("No initial nodes found.")
return []
P = [[node] for node in initial_nodes]
PIDS = [[node_id] for node_id in initial_nodes_ids]
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
for D in range(Dmax + 1):
if self.verbose:
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
if not P:
if self.verbose:
self.logger.info("No paths to expand.")
break
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
if D == Dmax:
if self.verbose:
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
break
if self.reasoning(query, P):
if self.verbose:
self.logger.info("Paths sufficient, stopping expansion.")
break
# Extract final triples
triples = []
with self.neo4j_driver.session() as session:
for path in P:
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
if self.verbose:
self.logger.info(f"Final triples: {triples}")
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages

View File

@ -0,0 +1,51 @@
from typing import Dict
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
class SimpleGraphRetriever(BaseEdgeRetriever):
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
data:dict):
self.KG = data["KG"]
self.node_list = data["node_list"]
self.edge_list = data["edge_list"]
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = data["node_faiss_index"]
self.edge_faiss_index = data["edge_faiss_index"]
def retrieve(self, query, topN=5, **kwargs):
# retrieve the top k edges
topk_edges = []
query_embedding = self.sentence_encoder.encode([query], query_type='edge')
D, I = self.edge_faiss_index.search(query_embedding, topN)
topk_edges += [self.edge_list[i] for i in I[0]]
topk_edges_with_data = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in topk_edges]
string_edge_edges = [f"{self.KG.nodes[edge[0]]['id']} {edge[1]} {self.KG.nodes[edge[2]]['id']}" for edge in topk_edges_with_data]
return string_edge_edges, ["N/A" for _ in range(len(string_edge_edges))]
class SimpleTextRetriever(BasePassageRetriever):
def __init__(self, passage_dict:Dict[str,str], sentence_encoder:BaseEmbeddingModel, data:dict):
self.sentence_encoder = sentence_encoder
self.passage_dict = passage_dict
self.passage_list = list(passage_dict.values())
self.passage_keys = list(passage_dict.keys())
self.text_embeddings = data["text_embeddings"]
def retrieve(self, query, topN=5, **kwargs):
query_emb = self.sentence_encoder.encode([query], query_type="passage")
sim_scores = self.text_embeddings @ query_emb[0].T
topk_indices = np.argsort(sim_scores)[-topN:][::-1] # Get indices of top-k scores
# Retrieve top-k passages
topk_passages = [self.passage_list[i] for i in topk_indices]
topk_passages_ids = [self.passage_keys[i] for i in topk_indices]
return topk_passages, topk_passages_ids

195
atlas_rag/retriever/tog.py Normal file
View File

@ -0,0 +1,195 @@
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from typing import Optional
from atlas_rag.retriever.base import BaseEdgeRetriever
from atlas_rag.retriever.inference_config import InferenceConfig
class TogRetriever(BaseEdgeRetriever):
def __init__(self, llm_generator, sentence_encoder, data, inference_config: Optional[InferenceConfig] = None):
self.KG = data["KG"]
self.node_list = list(self.KG.nodes)
self.edge_list = list(self.KG.edges)
self.edge_list_with_relation = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in self.edge_list]
self.edge_list_string = [f"{edge[0]} {self.KG.edges[edge]['relation']} {edge[1]}" for edge in self.edge_list]
self.llm_generator:LLMGenerator = llm_generator
self.sentence_encoder:BaseEmbeddingModel = sentence_encoder
self.node_embeddings = data["node_embeddings"]
self.edge_embeddings = data["edge_embeddings"]
self.inference_config = inference_config if inference_config is not None else InferenceConfig()
def ner(self, text):
messages = [
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
{"role": "user", "content": f"Extract the named entities from: Are Portland International Airport and Gerald R. Ford International Airport both located in Oregon?"},
{"role": "system", "content": "Portland International Airport, Gerald R. Ford International Airport, Oregon"},
{"role": "user", "content": f"Extract the named entities from: {text}"},
]
response = self.llm_generator.generate_response(messages)
generated_text = response
# print(generated_text)
return generated_text
def retrieve_topk_nodes(self, query, topN=5, **kwargs):
# extract entities from the query
entities = self.ner(query)
entities = entities.split(", ")
if len(entities) == 0:
# If the NER cannot extract any entities, we
# use the query as the entity to do approximate search
entities = [query]
# evenly distribute the topk for each entity
topk_for_each_entity = topN//len(entities)
# retrieve the top k nodes
topk_nodes = []
for entity_index, entity in enumerate(entities):
if entity in self.node_list:
topk_nodes.append(entity)
for entity_index, entity in enumerate(entities):
topk_for_this_entity = topk_for_each_entity + 1
entity_embedding = self.sentence_encoder.encode([entity])
# Calculate similarity scores using dot product
scores = self.node_embeddings @ entity_embedding[0].T
# Get top-k indices
top_indices = np.argsort(scores)[-topk_for_this_entity:][::-1]
topk_nodes += [self.node_list[i] for i in top_indices]
topk_nodes = list(set(topk_nodes))
if len(topk_nodes) > 2*topN:
topk_nodes = topk_nodes[:2*topN]
return topk_nodes
def retrieve(self, query, topN=5, **kwargs):
"""
Retrieve the top N paths that connect the entities in the query.
Dmax is the maximum depth of the search.
"""
Dmax = self.inference_config.Dmax
# in the first step, we retrieve the top k nodes
initial_nodes = self.retrieve_topk_nodes(query, topN=topN)
E = initial_nodes
P = [ [e] for e in E]
D = 0
while D <= Dmax:
P = self.search(query, P)
P = self.prune(query, P, topN)
if self.reasoning(query, P):
generated_text = self.generate(query, P)
break
D += 1
if D > Dmax:
generated_text = self.generate(query, P)
# print(generated_text)
return generated_text
def search(self, query, P):
new_paths = []
for path in P:
tail_entity = path[-1]
sucessors = list(self.KG.successors(tail_entity))
predecessors = list(self.KG.predecessors(tail_entity))
# print(f"tail_entity: {tail_entity}")
# print(f"sucessors: {sucessors}")
# print(f"predecessors: {predecessors}")
# # print the attributes of the tail_entity
# print(f"attributes of the tail_entity: {self.KG.nodes[tail_entity]}")
# remove the entity that is already in the path
sucessors = [neighbour for neighbour in sucessors if neighbour not in path]
predecessors = [neighbour for neighbour in predecessors if neighbour not in path]
if len(sucessors) == 0 and len(predecessors) == 0:
new_paths.append(path)
continue
for neighbour in sucessors:
relation = self.KG.edges[(tail_entity, neighbour)]["relation"]
new_path = path + [relation, neighbour]
new_paths.append(new_path)
for neighbour in predecessors:
relation = self.KG.edges[(neighbour, tail_entity)]["relation"]
new_path = path + [relation, neighbour]
new_paths.append(new_path)
return new_paths
def prune(self, query, P, topN=3):
ratings = []
for path in P:
path_string = ""
for index, node_or_relation in enumerate(path):
if index % 2 == 0:
id_path = self.KG.nodes[node_or_relation]["id"]
else:
id_path = node_or_relation
path_string += f"{id_path} --->"
path_string = path_string[:-5]
prompt = f"Please rating the following path based on the relevance to the question. The ratings should be in the range of 1 to 5. 1 for least relevant and 5 for most relevant. Only provide the rating, do not provide any other information. The output should be a single integer number. If you think the path is not relevant, please provide 0. If you think the path is relevant, please provide a rating between 1 and 5. \n Query: {query} \n path: {path_string}"
messages = [{"role": "system", "content": "Answer the question following the prompt."},
{"role": "user", "content": f"{prompt}"}]
response = self.llm_generator.generate_response(messages)
# print(response)
rating = int(response)
ratings.append(rating)
# sort the paths based on the ratings
sorted_paths = [path for _, path in sorted(zip(ratings, P), reverse=True)]
return sorted_paths[:topN]
def reasoning(self, query, P):
triples = []
for path in P:
for i in range(0, len(path)-2, 2):
# triples.append((path[i], path[i+1], path[i+2]))
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
triples_string = ". ".join(triples_string)
prompt = f"Given a question and the associated retrieved knowledge graph triples (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triples and your knowledge (Yes or No). Query: {query} \n Knowledge triples: {triples_string}"
messages = [{"role": "system", "content": "Answer the question following the prompt."},
{"role": "user", "content": f"{prompt}"}]
response = self.llm_generator.generate_response(messages)
return "yes" in response.lower()
def generate(self, query, P):
triples = []
for path in P:
for i in range(0, len(path)-2, 2):
# triples.append((path[i], path[i+1], path[i+2]))
triples.append((self.KG.nodes[path[i]]["id"], path[i+1], self.KG.nodes[path[i+2]]["id"]))
triples_string = [f"({triple[0]}, {triple[1]}, {triple[2]})" for triple in triples]
# response = self.llm_generator.generate_with_context_kg(query, triples_string)
return triples_string, ["N/A" for _ in range(len(triples_string))]