first commit

This commit is contained in:
闫旭隆
2025-09-24 09:29:12 +08:00
parent 6339cdebb9
commit 2308536f66
360 changed files with 136381 additions and 0 deletions

View File

@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
class BaseLargeKGRetriever(ABC):
def __init__():
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
@abstractmethod
def retrieve_passages(self, query, retriever_config:dict):
"""
Retrieve passages based on the query.
Args:
query (str): The input query.
topN (int): Number of top passages to retrieve.
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
sampling_area (int): Area for sampling in the graph.
Returns:
List of retrieved passages and their scores.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
class BaseLargeKGEdgeRetriever(ABC):
def __init__():
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
@abstractmethod
def retrieve_passages(self, query, retriever_config:dict):
"""
Retrieve Edges / Paths based on the query.
Args:
query (str): The input query.
topN (int): Number of top passages to retrieve.
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
sampling_area (int): Area for sampling in the graph.
Returns:
List of retrieved passages and their scores.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

View File

@ -0,0 +1,313 @@
from difflib import get_close_matches
from logging import Logger
import faiss
from neo4j import GraphDatabase
import time
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from graphdatascience import GraphDataScience
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
import string
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
class LargeKGRetriever(BaseLargeKGRetriever):
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
node_index:faiss.Index, passage_index:faiss.Index,
topN: int = 5,
number_of_source_nodes_per_ner: int = 10,
sampling_area : int = 250,logger:Logger = None):
# istantiate one kg resources
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.gds_driver = GraphDataScience(self.neo4j_driver)
self.topN = topN
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
self.sampling_area = sampling_area
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = node_index
# self.edge_faiss_index = self.edge_indexes[keyword]
self.passage_faiss_index = passage_index
self.verbose = False if logger is None else True
self.logger = logger
self.ppr_weight_threshold = 0.00005
def set_model(self, model):
if self.llm_generator.inference_type == 'openai':
self.llm_generator.model_name = model
else:
raise ValueError("Model can only be set for OpenAI inference type.")
def ner(self, text):
return self.llm_generator.large_kg_ner(text)
def convert_numeric_id_to_name(self, numeric_id):
if numeric_id.isdigit():
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
else:
return numeric_id
def has_intersection(self, word_set, input_string):
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
if self.keyword == 'cc_en':
# Check if any phrase in word_set is a substring of cleaned_string
for phrase in word_set:
if phrase in cleaned_string:
return True
return False
else:
# Check if any word in word_set is present in the cleaned_string's words
words_in_string = set(cleaned_string.split())
return not word_set.isdisjoint(words_in_string)
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
graph = self.gds_driver.graph.get('largekgrag_graph')
node_count = graph.node_count()
sampling_ratio = sampling_area / node_count
aggregation_node_dict = []
ppr_weight_threshold = self.ppr_weight_threshold
start_time = time.time()
# Pre-filter nodes based on ppr_weight threshold
# Precompute word sets based on keyword
if self.keyword == 'cc_en':
filtered_personalization = {
node_id: ppr_weight
for node_id, ppr_weight in personalization_dict.items()
if ppr_weight >= ppr_weight_threshold
}
stop_words = set(stopwords.words('english'))
word_set_phrases = set()
word_set_words = set()
for node_id, ppr_weight in filtered_personalization.items():
name = self.gds_driver.util.asNode(node_id)['name']
if name:
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
if cleaned_phrase:
# Process for 'cc_en': remove stop words and add cleaned phrase
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
if filtered_words:
cleaned_phrase_filtered = ' '.join(filtered_words)
word_set_phrases.add(cleaned_phrase_filtered)
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
if self.verbose:
self.logger.info(f"Optimized word set: {word_set}")
else:
filtered_personalization = personalization_dict
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
# Process each node in the filtered personalization dict
for node_id, ppr_weight in filtered_personalization.items():
try:
self.gds_driver.graph.drop('rwr_sample')
start_time = time.time()
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
restartProbability = 0.4, logProgress = False)
if self.verbose:
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
result = self.gds_driver.pageRank.stream(
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
).sort_values("score", ascending=False)
if self.verbose:
self.logger.info(f"pagerank type: {type(result)}")
self.logger.info(f"pagerank result: {result}")
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
start_time = time.time()
# if self.keyword == 'cc_en':
if self.keyword != 'cc_en':
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
else:
result = result.to_dict('records')
if self.verbose:
self.logger.info(f"largekgRAG :result: {result}")
for entry in result:
if self.keyword == 'cc_en':
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
if not self.has_intersection(word_set, node_name):
continue
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
aggregation_node_dict.append({
'nodeId': numeric_id,
'score': entry['score'] * ppr_weight
})
except Exception as e:
if self.verbose:
self.logger.error(f"Error processing node {node_id}: {e}")
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
else:
continue
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
if self.verbose:
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
if self.verbose:
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
start_time = time.time()
with self.neo4j_driver.session() as session:
intermediate_time = time.time()
# Step 1: Distribute entity scores to connected text nodes and find the top 5
query_scores = """
UNWIND $entries AS entry
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
ORDER BY total_score DESC
LIMIT $topN
RETURN textId, total_score
"""
# Execute query to aggregate scores
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
top_numeric_ids = []
top_scores = []
# Extract the top text node IDs and scores
for record in result_scores:
top_numeric_ids.append(record["textId"])
top_scores.append(record["total_score"])
# Step 2: Use top numeric IDs to retrieve the original text
if self.verbose:
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
intermediate_time = time.time()
query_text = """
UNWIND $textIds AS textId
MATCH (t:Text {numeric_id: textId})
RETURN t.original_text AS text, t.numeric_id AS textId
"""
result_texts = session.run(query_text, textIds=top_numeric_ids)
topN_passages = []
score_dict = dict(zip(top_numeric_ids, top_scores))
# Combine original text with scores
for record in result_texts:
original_text = record["text"]
text_id = record["textId"]
score = score_dict.get(text_id, 0)
topN_passages.append((original_text, score))
if self.verbose:
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
# Sort passages by score
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
top_texts = [item[0] for item in topN_passages][:topN]
top_scores = [item[1] for item in topN_passages][:topN]
if self.verbose:
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
self.logger.info(f"Top passages: {top_texts}")
self.logger.info(f"Top scores: {top_scores}")
if self.verbose:
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
return top_texts, top_scores
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
# extract entities from the query
entities = self.ner(query)
if self.verbose:
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
if len(entities) == 0:
entities = [query]
num_entities = len(entities)
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
if self.verbose:
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
initial_nodes += [str(i)for i in I[0]]
if self.verbose:
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
name_id_map = {}
for node_id in initial_nodes:
name = self.convert_numeric_id_to_name(node_id)
name_id_map[name] = node_id
topk_nodes = list(set(initial_nodes))
# convert the numeric id to string and filter again then return numeric id
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
# Second pass: Add filtered keywords
filtered_top_k_nodes = []
filter_log_dict = {}
match_threshold = 0.8
if self.verbose:
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
for keyword in filtered_keywords:
# Check for an exact match first
if keyword in name_id_map:
filtered_top_k_nodes.append(name_id_map[keyword])
filter_log_dict[keyword] = name_id_map[keyword]
else:
# Look for close matches using difflib's get_close_matches
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
if close_matches:
# If a close match is found, add the corresponding node
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
if self.verbose:
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
topk_nodes = list(set(filtered_top_k_nodes))
if len(topk_nodes) > 2 * num_entities:
topk_nodes = topk_nodes[:2 * num_entities]
return topk_nodes
def _process_text(self, text):
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
text = text.lower()
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
return set(text.split())
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
if topk_nodes == []:
if self.verbose:
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
return {}
if self.verbose:
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
freq_dict_for_nodes = {}
query = """
UNWIND $nodes AS node
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
"""
with self.neo4j_driver.session() as session:
result = session.run(query, nodes=topk_nodes)
for record in result:
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
# Create the personalization dictionary
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
if self.verbose:
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
return personalization_dict
def retrieve_passages(self, query):
if self.verbose:
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
topN = self.topN
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
sampling_area = self.sampling_area
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
if personalization_dict == {}:
return [], [0]
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
return topN_passages, topN_scores

View File

@ -0,0 +1,469 @@
from neo4j import GraphDatabase
import faiss
import numpy as np
import random
from collections import defaultdict
from typing import List
import time
import logging
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
node_index: faiss.Index,
topN : int = 5,
Dmax : int = 3,
Wmax : int = 3,
prune_size: int = 10,
logger: logging.Logger = None):
"""
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
Args:
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
neo4j_driver (GraphDatabase): Neo4j driver for database access.
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
node_index (faiss.Index): FAISS index for node embeddings.
logger (Logger, optional): Logger for verbose output.
"""
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.filter_encoder = filter_encoder
self.node_faiss_index = node_index
self.verbose = logger is not None
self.logger = logger
self.topN = topN
self.Dmax = Dmax
self.Wmax = Wmax
self.prune_size = prune_size
def ner(self, text: str) -> List[str]:
"""
Extract named entities from the query text using the LLM.
Args:
text (str): The query text.
Returns:
List[str]: List of extracted entities.
"""
entities = self.llm_generator.large_kg_ner(text)
if self.verbose:
self.logger.info(f"Extracted entities: {entities}")
return entities
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
"""
Retrieve top-k nodes similar to entities in the query.
Args:
query (str): The user query.
top_k_nodes (int): Number of nodes to retrieve per entity.
Returns:
List[str]: List of node numeric_ids.
"""
start_time = time.time()
entities = self.ner(query)
if self.verbose:
ner_time = time.time() - start_time
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
if not entities:
entities = [query]
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, 3)
if self.verbose:
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
if len(I[0]) > 0: # Check if results exist
initial_nodes.extend([str(i) for i in I[0]])
# no need filtering as ToG pruning will handle it.
topk_nodes_ids = list(set(initial_nodes))
with self.neo4j_driver.session() as session:
start_time = time.time()
query = """
MATCH (n:Node)
WHERE n.numeric_id IN $topk_nodes_ids
RETURN n.numeric_id AS id, n.name AS name
"""
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
topk_nodes_dict = {}
for record in result:
topk_nodes_dict[record["id"]] = record["name"]
if self.verbose:
neo4j_time = time.time() - start_time
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
if self.verbose:
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
"""
Expand each path by adding neighbors of the last node.
Args:
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
Returns:
List[List[str]]: List of expanded paths.
"""
last_nodes = []
last_node_ids = []
last_node_types = []
paths_end_with_text = []
paths_end_with_text_id = []
paths_end_with_text_type = []
paths_end_with_node = []
paths_end_with_node_id = []
paths_end_with_node_type = []
if self.verbose:
self.logger.info(f"Expanding paths, current paths: {P}")
for p, pid, ptype in zip(P, PIDS, PTYPES):
if not p or not pid or not ptype: # Skip empty paths
continue
t = ptype[-1]
if t == "Text":
paths_end_with_text.append(p)
paths_end_with_text_id.append(pid)
paths_end_with_text_type.append(ptype)
continue
last_node = p[-1] # Last node in the path
last_node_id = pid[-1] # Last node numeric_id
last_nodes.append(last_node)
last_node_ids.append(last_node_id)
last_node_types.append(t)
paths_end_with_node.append(p)
paths_end_with_node_id.append(pid)
paths_end_with_node_type.append(ptype)
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
if not last_node_ids:
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
with self.neo4j_driver.session() as session:
# Query Node relationships
start_time = time.time()
outgoing_query = """
CALL apoc.cypher.runTimeboxed(
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
WITH n, r, m ORDER BY rand() LIMIT 60000
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
{last_node_ids: $last_node_ids},
60000
)
YIELD value
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
"""
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
for record in outgoing_result]
if self.verbose:
outgoing_time = time.time() - start_time
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
# # Query outgoing Node -> Text relationships
# start_time = time.time()
# outgoing_text_query = """
# MATCH (n:Node)-[r:Source]->(t:Text)
# WHERE n.numeric_id IN $last_node_ids
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
# """
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
# for record in outgoing_text_result]
# if self.verbose:
# outgoing_text_time = time.time() - start_time
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
last_node_to_new_paths = defaultdict(list)
last_node_to_new_paths_ids = defaultdict(list)
last_node_to_new_paths_types = defaultdict(list)
for p, pid, ptype in zip(P, PIDS, PTYPES):
last_node = p[-1]
last_node_id = pid[-1]
# Outgoing Node -> Node
for source, source_name, rel_type, target, target_name, target_type in outgoing:
if source == last_node_id and target_name not in p:
new_path = p + [rel_type, target_name]
if target_name.lower() in stopwords.words('english'):
continue
last_node_to_new_paths[last_node].append(new_path)
last_node_to_new_paths_ids[last_node].append(pid + [target])
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
# # Outgoing Node -> Text
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
# if source == last_node_id and target_name not in p:
# new_path = p + [rel_type, target_name]
# last_node_to_new_paths_text[last_node].append(new_path)
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
# # Incoming Node -> Node
# for source, rel_type, target, source_name, source_type in incoming:
# if target == last_node_id and source not in p:
# new_path = p + [rel_type, source_name]
num_paths = 0
for last_node, new_paths in last_node_to_new_paths.items():
num_paths += len(new_paths)
# for last_node, new_paths in last_node_to_new_paths_text.items():
# num_paths += len(new_paths)
new_paths = []
new_pids = []
new_ptypes = []
if self.verbose:
self.logger.info(f"Number of new paths before filtering: {num_paths}")
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
if num_paths > len(last_node_ids) * width:
# Apply filtering when total paths exceed threshold
for last_node, new_ps in last_node_to_new_paths.items():
if len(new_ps) > width:
path_embeddings = self.filter_encoder.encode(new_ps)
query_embeddings = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
top_indices = np.argsort(scores)[-width:]
new_paths.extend([new_ps[i] for i in top_indices])
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
else:
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
else:
# Collect all paths without filtering when total is at or below threshold
for last_node, new_ps in last_node_to_new_paths.items():
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
if self.verbose:
self.logger.info(f"Expanded paths count: {len(new_paths)}")
self.logger.info(f"Expanded paths: {new_paths}")
return new_paths, new_pids, new_ptypes
def path_to_string(self, path: List[str]) -> str:
"""
Convert a path to a human-readable string for LLM rating.
Args:
path (List[str]): Path as a list of node_ids and relation types.
Returns:
str: String representation of the path.
"""
if len(path) < 1:
return ""
path_str = []
with self.neo4j_driver.session() as session:
for i in range(0, len(path), 2):
node_id = path[i]
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
node_name = result.single()["n.name"] if result.single() else node_id
if i + 1 < len(path):
rel_type = path[i + 1]
path_str.append(f"{node_name} ---> {rel_type} --->")
else:
path_str.append(node_name)
return " ".join(path_str).strip()
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
"""
Prune paths to keep the top N based on LLM relevance ratings.
Args:
query (str): The user query.
P (List[List[str]]): List of paths to prune.
topN (int): Number of paths to retain.
Returns:
List[List[str]]: Top N paths.
"""
ratings = []
path_strings = P
# Process paths in chunks of 10
for i in range(0, len(path_strings), self.prune_size):
chunk = path_strings[i:i + self.prune_size]
# Construct user prompt with the current chunk of paths listed
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
for j, path_str in enumerate(chunk, 1):
user_prompt += f"{j + i}. {path_str}\n"
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
# Define system prompt to expect a list of integers
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
# Send the prompt to the language model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
if self.verbose:
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
# Parse the response into a list of ratings
rating_str = response.strip()
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
if len(chunk_ratings) > len(chunk):
chunk_ratings = chunk_ratings[:len(chunk)]
if self.verbose:
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
ratings.extend(chunk_ratings) # Concatenate ratings
# Ensure ratings length matches number of paths, padding with 0s if necessary
if len(ratings) < len(path_strings):
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
# ratings += [0] * (len(path_strings) - len(ratings))
# fall back to use filter encoder to get topN
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
path_embeddings = self.filter_encoder.encode(path_strings)
query_embedding = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embedding.T).flatten()
top_indices = np.argsort(scores)[-topN:]
top_paths = [path_strings[i] for i in top_indices]
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
elif len(ratings) > len(path_strings):
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
ratings = ratings[:len(path_strings)]
# Sort indices based on ratings in descending order
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
# Filter out indices where the rating is 0
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
# Take the top N indices from the filtered list
top_indices = filtered_indices[:topN]
# Use the filtered indices to get the top paths, PIDS, and PTYPES
if self.verbose:
self.logger.info(f"Top indices after pruning: {top_indices}")
self.logger.info(f"length of path_strings: {len(path_strings)}")
top_paths = [path_strings[i] for i in top_indices]
top_pids = [PIDS[i] for i in top_indices]
top_ptypes = [PTYPES[i] for i in top_indices]
# Log top paths if verbose mode is enabled
if self.verbose:
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
return top_paths, top_pids, top_ptypes
def reasoning(self, query: str, P: List[List[str]]) -> bool:
"""
Check if the current paths are sufficient to answer the query.
Args:
query (str): The user query.
P (List[List[str]]): Current list of paths.
Returns:
bool: True if sufficient, False otherwise.
"""
triples = []
with self.neo4j_driver.session() as session:
for path in P:
if len(path) < 3:
continue
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
triples_str = ". ".join(triples)
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
messages = [
{"role": "system", "content": "Answer Yes or No only."},
{"role": "user", "content": prompt}
]
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
if self.verbose:
self.logger.info(f"Reasoning result: {response}")
return "yes" in response.lower()
def retrieve_passages(self, query: str) -> List[str]:
"""
Retrieve the top N paths to answer the query.
Args:
query (str): The user query.
topN (int): Number of paths to return.
Dmax (int): Maximum depth of path expansion.
Wmax (int): Maximum width of path expansion.
Returns:
List[str]: List of triples as strings.
"""
topN = self.topN
Dmax = self.Dmax
Wmax = self.Wmax
if self.verbose:
self.logger.info(f"Retrieving paths for query: {query}")
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
if not initial_nodes:
if self.verbose:
self.logger.info("No initial nodes found.")
return []
P = [[node] for node in initial_nodes]
PIDS = [[node_id] for node_id in initial_nodes_ids]
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
for D in range(Dmax + 1):
if self.verbose:
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
if not P:
if self.verbose:
self.logger.info("No paths to expand.")
break
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
if D == Dmax:
if self.verbose:
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
break
if self.reasoning(query, P):
if self.verbose:
self.logger.info("Paths sufficient, stopping expansion.")
break
# Extract final triples
triples = []
with self.neo4j_driver.session() as session:
for path in P:
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
if self.verbose:
self.logger.info(f"Final triples: {triples}")
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages