313 lines
15 KiB
Python
313 lines
15 KiB
Python
from difflib import get_close_matches
|
|
from logging import Logger
|
|
import faiss
|
|
from neo4j import GraphDatabase
|
|
import time
|
|
import nltk
|
|
from nltk.corpus import stopwords
|
|
nltk.download('stopwords')
|
|
from graphdatascience import GraphDataScience
|
|
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
|
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
|
import string
|
|
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
|
|
|
|
|
|
|
class LargeKGRetriever(BaseLargeKGRetriever):
|
|
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
|
|
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
|
node_index:faiss.Index, passage_index:faiss.Index,
|
|
topN: int = 5,
|
|
number_of_source_nodes_per_ner: int = 10,
|
|
sampling_area : int = 250,logger:Logger = None):
|
|
# istantiate one kg resources
|
|
self.keyword = keyword
|
|
self.neo4j_driver = neo4j_driver
|
|
self.gds_driver = GraphDataScience(self.neo4j_driver)
|
|
self.topN = topN
|
|
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
|
|
self.sampling_area = sampling_area
|
|
self.llm_generator = llm_generator
|
|
self.sentence_encoder = sentence_encoder
|
|
self.node_faiss_index = node_index
|
|
# self.edge_faiss_index = self.edge_indexes[keyword]
|
|
self.passage_faiss_index = passage_index
|
|
|
|
self.verbose = False if logger is None else True
|
|
self.logger = logger
|
|
|
|
self.ppr_weight_threshold = 0.00005
|
|
|
|
def set_model(self, model):
|
|
if self.llm_generator.inference_type == 'openai':
|
|
self.llm_generator.model_name = model
|
|
else:
|
|
raise ValueError("Model can only be set for OpenAI inference type.")
|
|
|
|
def ner(self, text):
|
|
return self.llm_generator.large_kg_ner(text)
|
|
|
|
def convert_numeric_id_to_name(self, numeric_id):
|
|
if numeric_id.isdigit():
|
|
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
|
|
else:
|
|
return numeric_id
|
|
|
|
def has_intersection(self, word_set, input_string):
|
|
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
|
|
if self.keyword == 'cc_en':
|
|
# Check if any phrase in word_set is a substring of cleaned_string
|
|
for phrase in word_set:
|
|
if phrase in cleaned_string:
|
|
return True
|
|
return False
|
|
else:
|
|
# Check if any word in word_set is present in the cleaned_string's words
|
|
words_in_string = set(cleaned_string.split())
|
|
return not word_set.isdisjoint(words_in_string)
|
|
|
|
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
|
|
graph = self.gds_driver.graph.get('largekgrag_graph')
|
|
node_count = graph.node_count()
|
|
sampling_ratio = sampling_area / node_count
|
|
aggregation_node_dict = []
|
|
ppr_weight_threshold = self.ppr_weight_threshold
|
|
start_time = time.time()
|
|
|
|
# Pre-filter nodes based on ppr_weight threshold
|
|
# Precompute word sets based on keyword
|
|
if self.keyword == 'cc_en':
|
|
filtered_personalization = {
|
|
node_id: ppr_weight
|
|
for node_id, ppr_weight in personalization_dict.items()
|
|
if ppr_weight >= ppr_weight_threshold
|
|
}
|
|
stop_words = set(stopwords.words('english'))
|
|
word_set_phrases = set()
|
|
word_set_words = set()
|
|
for node_id, ppr_weight in filtered_personalization.items():
|
|
name = self.gds_driver.util.asNode(node_id)['name']
|
|
if name:
|
|
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
|
|
if cleaned_phrase:
|
|
# Process for 'cc_en': remove stop words and add cleaned phrase
|
|
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
|
|
if filtered_words:
|
|
cleaned_phrase_filtered = ' '.join(filtered_words)
|
|
word_set_phrases.add(cleaned_phrase_filtered)
|
|
|
|
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
|
|
if self.verbose:
|
|
self.logger.info(f"Optimized word set: {word_set}")
|
|
else:
|
|
filtered_personalization = personalization_dict
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
|
|
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
|
|
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
|
|
# Process each node in the filtered personalization dict
|
|
for node_id, ppr_weight in filtered_personalization.items():
|
|
try:
|
|
self.gds_driver.graph.drop('rwr_sample')
|
|
start_time = time.time()
|
|
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
|
|
restartProbability = 0.4, logProgress = False)
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
|
|
start_time = time.time()
|
|
result = self.gds_driver.pageRank.stream(
|
|
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
|
|
).sort_values("score", ascending=False)
|
|
|
|
if self.verbose:
|
|
self.logger.info(f"pagerank type: {type(result)}")
|
|
self.logger.info(f"pagerank result: {result}")
|
|
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
|
|
start_time = time.time()
|
|
# if self.keyword == 'cc_en':
|
|
if self.keyword != 'cc_en':
|
|
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
|
|
else:
|
|
result = result.to_dict('records')
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG :result: {result}")
|
|
for entry in result:
|
|
if self.keyword == 'cc_en':
|
|
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
|
|
if not self.has_intersection(word_set, node_name):
|
|
continue
|
|
|
|
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
|
|
aggregation_node_dict.append({
|
|
'nodeId': numeric_id,
|
|
'score': entry['score'] * ppr_weight
|
|
})
|
|
|
|
except Exception as e:
|
|
if self.verbose:
|
|
self.logger.error(f"Error processing node {node_id}: {e}")
|
|
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
|
|
else:
|
|
continue
|
|
|
|
|
|
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
|
|
if self.verbose:
|
|
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
|
|
if self.verbose:
|
|
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
|
|
start_time = time.time()
|
|
with self.neo4j_driver.session() as session:
|
|
intermediate_time = time.time()
|
|
# Step 1: Distribute entity scores to connected text nodes and find the top 5
|
|
query_scores = """
|
|
UNWIND $entries AS entry
|
|
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
|
|
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
|
|
ORDER BY total_score DESC
|
|
LIMIT $topN
|
|
RETURN textId, total_score
|
|
"""
|
|
# Execute query to aggregate scores
|
|
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
|
|
top_numeric_ids = []
|
|
top_scores = []
|
|
|
|
# Extract the top text node IDs and scores
|
|
for record in result_scores:
|
|
top_numeric_ids.append(record["textId"])
|
|
top_scores.append(record["total_score"])
|
|
|
|
# Step 2: Use top numeric IDs to retrieve the original text
|
|
if self.verbose:
|
|
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
|
|
intermediate_time = time.time()
|
|
query_text = """
|
|
UNWIND $textIds AS textId
|
|
MATCH (t:Text {numeric_id: textId})
|
|
RETURN t.original_text AS text, t.numeric_id AS textId
|
|
"""
|
|
result_texts = session.run(query_text, textIds=top_numeric_ids)
|
|
topN_passages = []
|
|
score_dict = dict(zip(top_numeric_ids, top_scores))
|
|
|
|
# Combine original text with scores
|
|
for record in result_texts:
|
|
original_text = record["text"]
|
|
text_id = record["textId"]
|
|
score = score_dict.get(text_id, 0)
|
|
topN_passages.append((original_text, score))
|
|
if self.verbose:
|
|
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
|
|
# Sort passages by score
|
|
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
|
|
top_texts = [item[0] for item in topN_passages][:topN]
|
|
top_scores = [item[1] for item in topN_passages][:topN]
|
|
if self.verbose:
|
|
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
|
|
self.logger.info(f"Top passages: {top_texts}")
|
|
self.logger.info(f"Top scores: {top_scores}")
|
|
if self.verbose:
|
|
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
|
|
return top_texts, top_scores
|
|
|
|
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
|
|
# extract entities from the query
|
|
entities = self.ner(query)
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
|
|
if len(entities) == 0:
|
|
entities = [query]
|
|
num_entities = len(entities)
|
|
initial_nodes = []
|
|
for entity in entities:
|
|
entity_embedding = self.sentence_encoder.encode([entity])
|
|
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
|
|
initial_nodes += [str(i)for i in I[0]]
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
|
|
name_id_map = {}
|
|
for node_id in initial_nodes:
|
|
name = self.convert_numeric_id_to_name(node_id)
|
|
name_id_map[name] = node_id
|
|
|
|
topk_nodes = list(set(initial_nodes))
|
|
# convert the numeric id to string and filter again then return numeric id
|
|
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
|
|
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
|
|
|
|
# Second pass: Add filtered keywords
|
|
filtered_top_k_nodes = []
|
|
filter_log_dict = {}
|
|
match_threshold = 0.8
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
|
|
for keyword in filtered_keywords:
|
|
# Check for an exact match first
|
|
if keyword in name_id_map:
|
|
filtered_top_k_nodes.append(name_id_map[keyword])
|
|
filter_log_dict[keyword] = name_id_map[keyword]
|
|
else:
|
|
# Look for close matches using difflib's get_close_matches
|
|
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
|
|
|
|
if close_matches:
|
|
# If a close match is found, add the corresponding node
|
|
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
|
|
|
|
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
|
|
|
|
topk_nodes = list(set(filtered_top_k_nodes))
|
|
if len(topk_nodes) > 2 * num_entities:
|
|
topk_nodes = topk_nodes[:2 * num_entities]
|
|
return topk_nodes
|
|
|
|
def _process_text(self, text):
|
|
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
|
|
text = text.lower()
|
|
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
|
|
return set(text.split())
|
|
|
|
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
|
|
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
|
|
if topk_nodes == []:
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
|
|
return {}
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
|
|
freq_dict_for_nodes = {}
|
|
query = """
|
|
UNWIND $nodes AS node
|
|
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
|
|
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
|
|
"""
|
|
with self.neo4j_driver.session() as session:
|
|
result = session.run(query, nodes=topk_nodes)
|
|
for record in result:
|
|
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
|
|
# Create the personalization dictionary
|
|
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
|
|
return personalization_dict
|
|
|
|
def retrieve_passages(self, query):
|
|
if self.verbose:
|
|
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
|
|
topN = self.topN
|
|
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
|
|
sampling_area = self.sampling_area
|
|
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
|
|
if personalization_dict == {}:
|
|
return [], [0]
|
|
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
|
|
return topN_passages, topN_scores
|
|
|
|
|
|
|