from neo4j import GraphDatabase import faiss import numpy as np import random from collections import defaultdict from typing import List import time import logging from atlas_rag.llm_generator.llm_generator import LLMGenerator from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever import nltk from nltk.corpus import stopwords nltk.download('stopwords') class LargeKGToGRetriever(BaseLargeKGEdgeRetriever): def __init__(self, keyword: str, neo4j_driver: GraphDatabase, llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel, node_index: faiss.Index, topN : int = 5, Dmax : int = 3, Wmax : int = 3, prune_size: int = 10, logger: logging.Logger = None): """ Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j. Args: keyword (str): Identifier for the KG dataset (e.g., 'cc_en'). neo4j_driver (GraphDatabase): Neo4j driver for database access. llm_generator (LLMGenerator): LLM for NER, rating, and reasoning. sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings. node_index (faiss.Index): FAISS index for node embeddings. logger (Logger, optional): Logger for verbose output. """ self.keyword = keyword self.neo4j_driver = neo4j_driver self.llm_generator = llm_generator self.sentence_encoder = sentence_encoder self.filter_encoder = filter_encoder self.node_faiss_index = node_index self.verbose = logger is not None self.logger = logger self.topN = topN self.Dmax = Dmax self.Wmax = Wmax self.prune_size = prune_size def ner(self, text: str) -> List[str]: """ Extract named entities from the query text using the LLM. Args: text (str): The query text. Returns: List[str]: List of extracted entities. """ entities = self.llm_generator.large_kg_ner(text) if self.verbose: self.logger.info(f"Extracted entities: {entities}") return entities def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]: """ Retrieve top-k nodes similar to entities in the query. Args: query (str): The user query. top_k_nodes (int): Number of nodes to retrieve per entity. Returns: List[str]: List of node numeric_ids. """ start_time = time.time() entities = self.ner(query) if self.verbose: ner_time = time.time() - start_time self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}") if not entities: entities = [query] initial_nodes = [] for entity in entities: entity_embedding = self.sentence_encoder.encode([entity]) D, I = self.node_faiss_index.search(entity_embedding, 3) if self.verbose: self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}") if len(I[0]) > 0: # Check if results exist initial_nodes.extend([str(i) for i in I[0]]) # no need filtering as ToG pruning will handle it. topk_nodes_ids = list(set(initial_nodes)) with self.neo4j_driver.session() as session: start_time = time.time() query = """ MATCH (n:Node) WHERE n.numeric_id IN $topk_nodes_ids RETURN n.numeric_id AS id, n.name AS name """ result = session.run(query, topk_nodes_ids=topk_nodes_ids) topk_nodes_dict = {} for record in result: topk_nodes_dict[record["id"]] = record["name"] if self.verbose: neo4j_time = time.time() - start_time self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}") if self.verbose: self.logger.info(f"Top-k nodes: {topk_nodes_dict}") return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]: """ Expand each path by adding neighbors of the last node. Args: P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types. Returns: List[List[str]]: List of expanded paths. """ last_nodes = [] last_node_ids = [] last_node_types = [] paths_end_with_text = [] paths_end_with_text_id = [] paths_end_with_text_type = [] paths_end_with_node = [] paths_end_with_node_id = [] paths_end_with_node_type = [] if self.verbose: self.logger.info(f"Expanding paths, current paths: {P}") for p, pid, ptype in zip(P, PIDS, PTYPES): if not p or not pid or not ptype: # Skip empty paths continue t = ptype[-1] if t == "Text": paths_end_with_text.append(p) paths_end_with_text_id.append(pid) paths_end_with_text_type.append(ptype) continue last_node = p[-1] # Last node in the path last_node_id = pid[-1] # Last node numeric_id last_nodes.append(last_node) last_node_ids.append(last_node_id) last_node_types.append(t) paths_end_with_node.append(p) paths_end_with_node_id.append(pid) paths_end_with_node_type.append(ptype) assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths" if not last_node_ids: return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type with self.neo4j_driver.session() as session: # Query Node relationships start_time = time.time() outgoing_query = """ CALL apoc.cypher.runTimeboxed( "MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids WITH n, r, m ORDER BY rand() LIMIT 60000 RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type", {last_node_ids: $last_node_ids}, 60000 ) YIELD value RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type """ outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids) outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"]) for record in outgoing_result] if self.verbose: outgoing_time = time.time() - start_time self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}") # # Query outgoing Node -> Text relationships # start_time = time.time() # outgoing_text_query = """ # MATCH (n:Node)-[r:Source]->(t:Text) # WHERE n.numeric_id IN $last_node_ids # RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type # """ # outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids) # outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"]) # for record in outgoing_text_result] # if self.verbose: # outgoing_text_time = time.time() - start_time # self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}") last_node_to_new_paths = defaultdict(list) last_node_to_new_paths_ids = defaultdict(list) last_node_to_new_paths_types = defaultdict(list) for p, pid, ptype in zip(P, PIDS, PTYPES): last_node = p[-1] last_node_id = pid[-1] # Outgoing Node -> Node for source, source_name, rel_type, target, target_name, target_type in outgoing: if source == last_node_id and target_name not in p: new_path = p + [rel_type, target_name] if target_name.lower() in stopwords.words('english'): continue last_node_to_new_paths[last_node].append(new_path) last_node_to_new_paths_ids[last_node].append(pid + [target]) last_node_to_new_paths_types[last_node].append(ptype + [target_type]) # # Outgoing Node -> Text # for source, source_name, rel_type, target, target_name, target_type in outgoing_text: # if source == last_node_id and target_name not in p: # new_path = p + [rel_type, target_name] # last_node_to_new_paths_text[last_node].append(new_path) # last_node_to_new_paths_text_ids[last_node].append(pid + [target]) # last_node_to_new_paths_text_types[last_node].append(ptype + [target_type]) # # Incoming Node -> Node # for source, rel_type, target, source_name, source_type in incoming: # if target == last_node_id and source not in p: # new_path = p + [rel_type, source_name] num_paths = 0 for last_node, new_paths in last_node_to_new_paths.items(): num_paths += len(new_paths) # for last_node, new_paths in last_node_to_new_paths_text.items(): # num_paths += len(new_paths) new_paths = [] new_pids = [] new_ptypes = [] if self.verbose: self.logger.info(f"Number of new paths before filtering: {num_paths}") self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}") if num_paths > len(last_node_ids) * width: # Apply filtering when total paths exceed threshold for last_node, new_ps in last_node_to_new_paths.items(): if len(new_ps) > width: path_embeddings = self.filter_encoder.encode(new_ps) query_embeddings = self.filter_encoder.encode([query]) scores = np.dot(path_embeddings, query_embeddings.T).flatten() top_indices = np.argsort(scores)[-width:] new_paths.extend([new_ps[i] for i in top_indices]) new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices]) new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices]) else: new_paths.extend(new_ps) new_pids.extend(last_node_to_new_paths_ids[last_node]) new_ptypes.extend(last_node_to_new_paths_types[last_node]) else: # Collect all paths without filtering when total is at or below threshold for last_node, new_ps in last_node_to_new_paths.items(): new_paths.extend(new_ps) new_pids.extend(last_node_to_new_paths_ids[last_node]) new_ptypes.extend(last_node_to_new_paths_types[last_node]) if self.verbose: self.logger.info(f"Expanded paths count: {len(new_paths)}") self.logger.info(f"Expanded paths: {new_paths}") return new_paths, new_pids, new_ptypes def path_to_string(self, path: List[str]) -> str: """ Convert a path to a human-readable string for LLM rating. Args: path (List[str]): Path as a list of node_ids and relation types. Returns: str: String representation of the path. """ if len(path) < 1: return "" path_str = [] with self.neo4j_driver.session() as session: for i in range(0, len(path), 2): node_id = path[i] result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id) node_name = result.single()["n.name"] if result.single() else node_id if i + 1 < len(path): rel_type = path[i + 1] path_str.append(f"{node_name} ---> {rel_type} --->") else: path_str.append(node_name) return " ".join(path_str).strip() def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]: """ Prune paths to keep the top N based on LLM relevance ratings. Args: query (str): The user query. P (List[List[str]]): List of paths to prune. topN (int): Number of paths to retain. Returns: List[List[str]]: Top N paths. """ ratings = [] path_strings = P # Process paths in chunks of 10 for i in range(0, len(path_strings), self.prune_size): chunk = path_strings[i:i + self.prune_size] # Construct user prompt with the current chunk of paths listed user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n" for j, path_str in enumerate(chunk, 1): user_prompt += f"{j + i}. {path_str}\n" user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query." # Define system prompt to expect a list of integers system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query." # Send the prompt to the language model messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0) if self.verbose: self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}") # Parse the response into a list of ratings rating_str = response.strip() chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()] if len(chunk_ratings) > len(chunk): chunk_ratings = chunk_ratings[:len(chunk)] if self.verbose: self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.") ratings.extend(chunk_ratings) # Concatenate ratings # Ensure ratings length matches number of paths, padding with 0s if necessary if len(ratings) < len(path_strings): # self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.") # ratings += [0] * (len(path_strings) - len(ratings)) # fall back to use filter encoder to get topN self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.") path_embeddings = self.filter_encoder.encode(path_strings) query_embedding = self.filter_encoder.encode([query]) scores = np.dot(path_embeddings, query_embedding.T).flatten() top_indices = np.argsort(scores)[-topN:] top_paths = [path_strings[i] for i in top_indices] return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices] elif len(ratings) > len(path_strings): self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.") ratings = ratings[:len(path_strings)] # Sort indices based on ratings in descending order sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True) # Filter out indices where the rating is 0 filtered_indices = [i for i in sorted_indices if ratings[i] > 0] # Take the top N indices from the filtered list top_indices = filtered_indices[:topN] # Use the filtered indices to get the top paths, PIDS, and PTYPES if self.verbose: self.logger.info(f"Top indices after pruning: {top_indices}") self.logger.info(f"length of path_strings: {len(path_strings)}") top_paths = [path_strings[i] for i in top_indices] top_pids = [PIDS[i] for i in top_indices] top_ptypes = [PTYPES[i] for i in top_indices] # Log top paths if verbose mode is enabled if self.verbose: self.logger.info(f"Pruned to top {topN} paths: {top_paths}") return top_paths, top_pids, top_ptypes def reasoning(self, query: str, P: List[List[str]]) -> bool: """ Check if the current paths are sufficient to answer the query. Args: query (str): The user query. P (List[List[str]]): Current list of paths. Returns: bool: True if sufficient, False otherwise. """ triples = [] with self.neo4j_driver.session() as session: for path in P: if len(path) < 3: continue for i in range(0, len(path) - 2, 2): node1_name = path[i] rel = path[i + 1] node2_name = path[i + 2] triples.append(f"({node1_name}, {rel}, {node2_name})") triples_str = ". ".join(triples) prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}" messages = [ {"role": "system", "content": "Answer Yes or No only."}, {"role": "user", "content": prompt} ] response = self.llm_generator.generate_response(messages,max_new_tokens=512) if self.verbose: self.logger.info(f"Reasoning result: {response}") return "yes" in response.lower() def retrieve_passages(self, query: str) -> List[str]: """ Retrieve the top N paths to answer the query. Args: query (str): The user query. topN (int): Number of paths to return. Dmax (int): Maximum depth of path expansion. Wmax (int): Maximum width of path expansion. Returns: List[str]: List of triples as strings. """ topN = self.topN Dmax = self.Dmax Wmax = self.Wmax if self.verbose: self.logger.info(f"Retrieving paths for query: {query}") initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN) if not initial_nodes: if self.verbose: self.logger.info("No initial nodes found.") return [] P = [[node] for node in initial_nodes] PIDS = [[node_id] for node_id in initial_nodes_ids] PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node' for D in range(Dmax + 1): if self.verbose: self.logger.info(f"Depth {D}, Current paths: {len(P)}") P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query) if not P: if self.verbose: self.logger.info("No paths to expand.") break P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN) if D == Dmax: if self.verbose: self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.") break if self.reasoning(query, P): if self.verbose: self.logger.info("Paths sufficient, stopping expansion.") break # Extract final triples triples = [] with self.neo4j_driver.session() as session: for path in P: for i in range(0, len(path) - 2, 2): node1_name = path[i] rel = path[i + 1] node2_name = path[i + 2] triples.append(f"({node1_name}, {rel}, {node2_name})") if self.verbose: self.logger.info(f"Final triples: {triples}") return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages