Files
2025-09-24 09:29:12 +08:00

313 lines
15 KiB
Python

from difflib import get_close_matches
from logging import Logger
import faiss
from neo4j import GraphDatabase
import time
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from graphdatascience import GraphDataScience
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
import string
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
class LargeKGRetriever(BaseLargeKGRetriever):
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
node_index:faiss.Index, passage_index:faiss.Index,
topN: int = 5,
number_of_source_nodes_per_ner: int = 10,
sampling_area : int = 250,logger:Logger = None):
# istantiate one kg resources
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.gds_driver = GraphDataScience(self.neo4j_driver)
self.topN = topN
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
self.sampling_area = sampling_area
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = node_index
# self.edge_faiss_index = self.edge_indexes[keyword]
self.passage_faiss_index = passage_index
self.verbose = False if logger is None else True
self.logger = logger
self.ppr_weight_threshold = 0.00005
def set_model(self, model):
if self.llm_generator.inference_type == 'openai':
self.llm_generator.model_name = model
else:
raise ValueError("Model can only be set for OpenAI inference type.")
def ner(self, text):
return self.llm_generator.large_kg_ner(text)
def convert_numeric_id_to_name(self, numeric_id):
if numeric_id.isdigit():
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
else:
return numeric_id
def has_intersection(self, word_set, input_string):
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
if self.keyword == 'cc_en':
# Check if any phrase in word_set is a substring of cleaned_string
for phrase in word_set:
if phrase in cleaned_string:
return True
return False
else:
# Check if any word in word_set is present in the cleaned_string's words
words_in_string = set(cleaned_string.split())
return not word_set.isdisjoint(words_in_string)
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
graph = self.gds_driver.graph.get('largekgrag_graph')
node_count = graph.node_count()
sampling_ratio = sampling_area / node_count
aggregation_node_dict = []
ppr_weight_threshold = self.ppr_weight_threshold
start_time = time.time()
# Pre-filter nodes based on ppr_weight threshold
# Precompute word sets based on keyword
if self.keyword == 'cc_en':
filtered_personalization = {
node_id: ppr_weight
for node_id, ppr_weight in personalization_dict.items()
if ppr_weight >= ppr_weight_threshold
}
stop_words = set(stopwords.words('english'))
word_set_phrases = set()
word_set_words = set()
for node_id, ppr_weight in filtered_personalization.items():
name = self.gds_driver.util.asNode(node_id)['name']
if name:
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
if cleaned_phrase:
# Process for 'cc_en': remove stop words and add cleaned phrase
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
if filtered_words:
cleaned_phrase_filtered = ' '.join(filtered_words)
word_set_phrases.add(cleaned_phrase_filtered)
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
if self.verbose:
self.logger.info(f"Optimized word set: {word_set}")
else:
filtered_personalization = personalization_dict
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
# Process each node in the filtered personalization dict
for node_id, ppr_weight in filtered_personalization.items():
try:
self.gds_driver.graph.drop('rwr_sample')
start_time = time.time()
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
restartProbability = 0.4, logProgress = False)
if self.verbose:
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
result = self.gds_driver.pageRank.stream(
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
).sort_values("score", ascending=False)
if self.verbose:
self.logger.info(f"pagerank type: {type(result)}")
self.logger.info(f"pagerank result: {result}")
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
# if self.keyword == 'cc_en':
if self.keyword != 'cc_en':
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
else:
result = result.to_dict('records')
if self.verbose:
self.logger.info(f"largekgRAG :result: {result}")
for entry in result:
if self.keyword == 'cc_en':
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
if not self.has_intersection(word_set, node_name):
continue
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
aggregation_node_dict.append({
'nodeId': numeric_id,
'score': entry['score'] * ppr_weight
})
except Exception as e:
if self.verbose:
self.logger.error(f"Error processing node {node_id}: {e}")
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
else:
continue
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
if self.verbose:
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
if self.verbose:
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
start_time = time.time()
with self.neo4j_driver.session() as session:
intermediate_time = time.time()
# Step 1: Distribute entity scores to connected text nodes and find the top 5
query_scores = """
UNWIND $entries AS entry
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
ORDER BY total_score DESC
LIMIT $topN
RETURN textId, total_score
"""
# Execute query to aggregate scores
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
top_numeric_ids = []
top_scores = []
# Extract the top text node IDs and scores
for record in result_scores:
top_numeric_ids.append(record["textId"])
top_scores.append(record["total_score"])
# Step 2: Use top numeric IDs to retrieve the original text
if self.verbose:
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
intermediate_time = time.time()
query_text = """
UNWIND $textIds AS textId
MATCH (t:Text {numeric_id: textId})
RETURN t.original_text AS text, t.numeric_id AS textId
"""
result_texts = session.run(query_text, textIds=top_numeric_ids)
topN_passages = []
score_dict = dict(zip(top_numeric_ids, top_scores))
# Combine original text with scores
for record in result_texts:
original_text = record["text"]
text_id = record["textId"]
score = score_dict.get(text_id, 0)
topN_passages.append((original_text, score))
if self.verbose:
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
# Sort passages by score
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
top_texts = [item[0] for item in topN_passages][:topN]
top_scores = [item[1] for item in topN_passages][:topN]
if self.verbose:
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
self.logger.info(f"Top passages: {top_texts}")
self.logger.info(f"Top scores: {top_scores}")
if self.verbose:
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
return top_texts, top_scores
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
# extract entities from the query
entities = self.ner(query)
if self.verbose:
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
if len(entities) == 0:
entities = [query]
num_entities = len(entities)
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
if self.verbose:
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
initial_nodes += [str(i)for i in I[0]]
if self.verbose:
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
name_id_map = {}
for node_id in initial_nodes:
name = self.convert_numeric_id_to_name(node_id)
name_id_map[name] = node_id
topk_nodes = list(set(initial_nodes))
# convert the numeric id to string and filter again then return numeric id
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
# Second pass: Add filtered keywords
filtered_top_k_nodes = []
filter_log_dict = {}
match_threshold = 0.8
if self.verbose:
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
for keyword in filtered_keywords:
# Check for an exact match first
if keyword in name_id_map:
filtered_top_k_nodes.append(name_id_map[keyword])
filter_log_dict[keyword] = name_id_map[keyword]
else:
# Look for close matches using difflib's get_close_matches
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
if close_matches:
# If a close match is found, add the corresponding node
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
if self.verbose:
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
topk_nodes = list(set(filtered_top_k_nodes))
if len(topk_nodes) > 2 * num_entities:
topk_nodes = topk_nodes[:2 * num_entities]
return topk_nodes
def _process_text(self, text):
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
text = text.lower()
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
return set(text.split())
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
if topk_nodes == []:
if self.verbose:
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
return {}
if self.verbose:
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
freq_dict_for_nodes = {}
query = """
UNWIND $nodes AS node
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
"""
with self.neo4j_driver.session() as session:
result = session.run(query, nodes=topk_nodes)
for record in result:
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
# Create the personalization dictionary
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
return personalization_dict
def retrieve_passages(self, query):
if self.verbose:
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
topN = self.topN
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
sampling_area = self.sampling_area
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
if personalization_dict == {}:
return [], [0]
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
return topN_passages, topN_scores