Files
AIEC-RAG/atlas_rag/retriever/lkg_retriever/tog.py
2025-09-24 09:29:12 +08:00

470 lines
21 KiB
Python

from neo4j import GraphDatabase
import faiss
import numpy as np
import random
from collections import defaultdict
from typing import List
import time
import logging
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
node_index: faiss.Index,
topN : int = 5,
Dmax : int = 3,
Wmax : int = 3,
prune_size: int = 10,
logger: logging.Logger = None):
"""
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
Args:
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
neo4j_driver (GraphDatabase): Neo4j driver for database access.
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
node_index (faiss.Index): FAISS index for node embeddings.
logger (Logger, optional): Logger for verbose output.
"""
self.keyword = keyword
self.neo4j_driver = neo4j_driver
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.filter_encoder = filter_encoder
self.node_faiss_index = node_index
self.verbose = logger is not None
self.logger = logger
self.topN = topN
self.Dmax = Dmax
self.Wmax = Wmax
self.prune_size = prune_size
def ner(self, text: str) -> List[str]:
"""
Extract named entities from the query text using the LLM.
Args:
text (str): The query text.
Returns:
List[str]: List of extracted entities.
"""
entities = self.llm_generator.large_kg_ner(text)
if self.verbose:
self.logger.info(f"Extracted entities: {entities}")
return entities
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
"""
Retrieve top-k nodes similar to entities in the query.
Args:
query (str): The user query.
top_k_nodes (int): Number of nodes to retrieve per entity.
Returns:
List[str]: List of node numeric_ids.
"""
start_time = time.time()
entities = self.ner(query)
if self.verbose:
ner_time = time.time() - start_time
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
if not entities:
entities = [query]
initial_nodes = []
for entity in entities:
entity_embedding = self.sentence_encoder.encode([entity])
D, I = self.node_faiss_index.search(entity_embedding, 3)
if self.verbose:
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
if len(I[0]) > 0: # Check if results exist
initial_nodes.extend([str(i) for i in I[0]])
# no need filtering as ToG pruning will handle it.
topk_nodes_ids = list(set(initial_nodes))
with self.neo4j_driver.session() as session:
start_time = time.time()
query = """
MATCH (n:Node)
WHERE n.numeric_id IN $topk_nodes_ids
RETURN n.numeric_id AS id, n.name AS name
"""
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
topk_nodes_dict = {}
for record in result:
topk_nodes_dict[record["id"]] = record["name"]
if self.verbose:
neo4j_time = time.time() - start_time
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
if self.verbose:
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
"""
Expand each path by adding neighbors of the last node.
Args:
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
Returns:
List[List[str]]: List of expanded paths.
"""
last_nodes = []
last_node_ids = []
last_node_types = []
paths_end_with_text = []
paths_end_with_text_id = []
paths_end_with_text_type = []
paths_end_with_node = []
paths_end_with_node_id = []
paths_end_with_node_type = []
if self.verbose:
self.logger.info(f"Expanding paths, current paths: {P}")
for p, pid, ptype in zip(P, PIDS, PTYPES):
if not p or not pid or not ptype: # Skip empty paths
continue
t = ptype[-1]
if t == "Text":
paths_end_with_text.append(p)
paths_end_with_text_id.append(pid)
paths_end_with_text_type.append(ptype)
continue
last_node = p[-1] # Last node in the path
last_node_id = pid[-1] # Last node numeric_id
last_nodes.append(last_node)
last_node_ids.append(last_node_id)
last_node_types.append(t)
paths_end_with_node.append(p)
paths_end_with_node_id.append(pid)
paths_end_with_node_type.append(ptype)
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
if not last_node_ids:
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
with self.neo4j_driver.session() as session:
# Query Node relationships
start_time = time.time()
outgoing_query = """
CALL apoc.cypher.runTimeboxed(
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
WITH n, r, m ORDER BY rand() LIMIT 60000
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
{last_node_ids: $last_node_ids},
60000
)
YIELD value
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
"""
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
for record in outgoing_result]
if self.verbose:
outgoing_time = time.time() - start_time
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
# # Query outgoing Node -> Text relationships
# start_time = time.time()
# outgoing_text_query = """
# MATCH (n:Node)-[r:Source]->(t:Text)
# WHERE n.numeric_id IN $last_node_ids
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
# """
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
# for record in outgoing_text_result]
# if self.verbose:
# outgoing_text_time = time.time() - start_time
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
last_node_to_new_paths = defaultdict(list)
last_node_to_new_paths_ids = defaultdict(list)
last_node_to_new_paths_types = defaultdict(list)
for p, pid, ptype in zip(P, PIDS, PTYPES):
last_node = p[-1]
last_node_id = pid[-1]
# Outgoing Node -> Node
for source, source_name, rel_type, target, target_name, target_type in outgoing:
if source == last_node_id and target_name not in p:
new_path = p + [rel_type, target_name]
if target_name.lower() in stopwords.words('english'):
continue
last_node_to_new_paths[last_node].append(new_path)
last_node_to_new_paths_ids[last_node].append(pid + [target])
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
# # Outgoing Node -> Text
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
# if source == last_node_id and target_name not in p:
# new_path = p + [rel_type, target_name]
# last_node_to_new_paths_text[last_node].append(new_path)
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
# # Incoming Node -> Node
# for source, rel_type, target, source_name, source_type in incoming:
# if target == last_node_id and source not in p:
# new_path = p + [rel_type, source_name]
num_paths = 0
for last_node, new_paths in last_node_to_new_paths.items():
num_paths += len(new_paths)
# for last_node, new_paths in last_node_to_new_paths_text.items():
# num_paths += len(new_paths)
new_paths = []
new_pids = []
new_ptypes = []
if self.verbose:
self.logger.info(f"Number of new paths before filtering: {num_paths}")
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
if num_paths > len(last_node_ids) * width:
# Apply filtering when total paths exceed threshold
for last_node, new_ps in last_node_to_new_paths.items():
if len(new_ps) > width:
path_embeddings = self.filter_encoder.encode(new_ps)
query_embeddings = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
top_indices = np.argsort(scores)[-width:]
new_paths.extend([new_ps[i] for i in top_indices])
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
else:
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
else:
# Collect all paths without filtering when total is at or below threshold
for last_node, new_ps in last_node_to_new_paths.items():
new_paths.extend(new_ps)
new_pids.extend(last_node_to_new_paths_ids[last_node])
new_ptypes.extend(last_node_to_new_paths_types[last_node])
if self.verbose:
self.logger.info(f"Expanded paths count: {len(new_paths)}")
self.logger.info(f"Expanded paths: {new_paths}")
return new_paths, new_pids, new_ptypes
def path_to_string(self, path: List[str]) -> str:
"""
Convert a path to a human-readable string for LLM rating.
Args:
path (List[str]): Path as a list of node_ids and relation types.
Returns:
str: String representation of the path.
"""
if len(path) < 1:
return ""
path_str = []
with self.neo4j_driver.session() as session:
for i in range(0, len(path), 2):
node_id = path[i]
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
node_name = result.single()["n.name"] if result.single() else node_id
if i + 1 < len(path):
rel_type = path[i + 1]
path_str.append(f"{node_name} ---> {rel_type} --->")
else:
path_str.append(node_name)
return " ".join(path_str).strip()
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
"""
Prune paths to keep the top N based on LLM relevance ratings.
Args:
query (str): The user query.
P (List[List[str]]): List of paths to prune.
topN (int): Number of paths to retain.
Returns:
List[List[str]]: Top N paths.
"""
ratings = []
path_strings = P
# Process paths in chunks of 10
for i in range(0, len(path_strings), self.prune_size):
chunk = path_strings[i:i + self.prune_size]
# Construct user prompt with the current chunk of paths listed
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
for j, path_str in enumerate(chunk, 1):
user_prompt += f"{j + i}. {path_str}\n"
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
# Define system prompt to expect a list of integers
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
# Send the prompt to the language model
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
if self.verbose:
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
# Parse the response into a list of ratings
rating_str = response.strip()
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
if len(chunk_ratings) > len(chunk):
chunk_ratings = chunk_ratings[:len(chunk)]
if self.verbose:
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
ratings.extend(chunk_ratings) # Concatenate ratings
# Ensure ratings length matches number of paths, padding with 0s if necessary
if len(ratings) < len(path_strings):
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
# ratings += [0] * (len(path_strings) - len(ratings))
# fall back to use filter encoder to get topN
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
path_embeddings = self.filter_encoder.encode(path_strings)
query_embedding = self.filter_encoder.encode([query])
scores = np.dot(path_embeddings, query_embedding.T).flatten()
top_indices = np.argsort(scores)[-topN:]
top_paths = [path_strings[i] for i in top_indices]
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
elif len(ratings) > len(path_strings):
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
ratings = ratings[:len(path_strings)]
# Sort indices based on ratings in descending order
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
# Filter out indices where the rating is 0
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
# Take the top N indices from the filtered list
top_indices = filtered_indices[:topN]
# Use the filtered indices to get the top paths, PIDS, and PTYPES
if self.verbose:
self.logger.info(f"Top indices after pruning: {top_indices}")
self.logger.info(f"length of path_strings: {len(path_strings)}")
top_paths = [path_strings[i] for i in top_indices]
top_pids = [PIDS[i] for i in top_indices]
top_ptypes = [PTYPES[i] for i in top_indices]
# Log top paths if verbose mode is enabled
if self.verbose:
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
return top_paths, top_pids, top_ptypes
def reasoning(self, query: str, P: List[List[str]]) -> bool:
"""
Check if the current paths are sufficient to answer the query.
Args:
query (str): The user query.
P (List[List[str]]): Current list of paths.
Returns:
bool: True if sufficient, False otherwise.
"""
triples = []
with self.neo4j_driver.session() as session:
for path in P:
if len(path) < 3:
continue
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
triples_str = ". ".join(triples)
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
messages = [
{"role": "system", "content": "Answer Yes or No only."},
{"role": "user", "content": prompt}
]
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
if self.verbose:
self.logger.info(f"Reasoning result: {response}")
return "yes" in response.lower()
def retrieve_passages(self, query: str) -> List[str]:
"""
Retrieve the top N paths to answer the query.
Args:
query (str): The user query.
topN (int): Number of paths to return.
Dmax (int): Maximum depth of path expansion.
Wmax (int): Maximum width of path expansion.
Returns:
List[str]: List of triples as strings.
"""
topN = self.topN
Dmax = self.Dmax
Wmax = self.Wmax
if self.verbose:
self.logger.info(f"Retrieving paths for query: {query}")
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
if not initial_nodes:
if self.verbose:
self.logger.info("No initial nodes found.")
return []
P = [[node] for node in initial_nodes]
PIDS = [[node_id] for node_id in initial_nodes_ids]
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
for D in range(Dmax + 1):
if self.verbose:
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
if not P:
if self.verbose:
self.logger.info("No paths to expand.")
break
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
if D == Dmax:
if self.verbose:
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
break
if self.reasoning(query, P):
if self.verbose:
self.logger.info("Paths sufficient, stopping expansion.")
break
# Extract final triples
triples = []
with self.neo4j_driver.session() as session:
for path in P:
for i in range(0, len(path) - 2, 2):
node1_name = path[i]
rel = path[i + 1]
node2_name = path[i + 2]
triples.append(f"({node1_name}, {rel}, {node2_name})")
if self.verbose:
self.logger.info(f"Final triples: {triples}")
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages