first commit
This commit is contained in:
40
AIEC-RAG/atlas_rag/retriever/lkg_retriever/base.py
Normal file
40
AIEC-RAG/atlas_rag/retriever/lkg_retriever/base.py
Normal file
@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLargeKGRetriever(ABC):
|
||||
def __init__():
|
||||
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
|
||||
@abstractmethod
|
||||
def retrieve_passages(self, query, retriever_config:dict):
|
||||
"""
|
||||
Retrieve passages based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
topN (int): Number of top passages to retrieve.
|
||||
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
|
||||
sampling_area (int): Area for sampling in the graph.
|
||||
|
||||
Returns:
|
||||
List of retrieved passages and their scores.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
|
||||
class BaseLargeKGEdgeRetriever(ABC):
|
||||
def __init__():
|
||||
raise NotImplementedError("This is a base class and cannot be instantiated directly.")
|
||||
@abstractmethod
|
||||
def retrieve_passages(self, query, retriever_config:dict):
|
||||
"""
|
||||
Retrieve Edges / Paths based on the query.
|
||||
|
||||
Args:
|
||||
query (str): The input query.
|
||||
topN (int): Number of top passages to retrieve.
|
||||
number_of_source_nodes_per_ner (int): Number of source nodes per named entity recognition.
|
||||
sampling_area (int): Area for sampling in the graph.
|
||||
|
||||
Returns:
|
||||
List of retrieved passages and their scores.
|
||||
"""
|
||||
raise NotImplementedError("This method should be implemented by subclasses.")
|
||||
313
AIEC-RAG/atlas_rag/retriever/lkg_retriever/lkgr.py
Normal file
313
AIEC-RAG/atlas_rag/retriever/lkg_retriever/lkgr.py
Normal file
@ -0,0 +1,313 @@
|
||||
from difflib import get_close_matches
|
||||
from logging import Logger
|
||||
import faiss
|
||||
from neo4j import GraphDatabase
|
||||
import time
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
nltk.download('stopwords')
|
||||
from graphdatascience import GraphDataScience
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
import string
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
||||
|
||||
|
||||
|
||||
class LargeKGRetriever(BaseLargeKGRetriever):
|
||||
def __init__(self, keyword:str, neo4j_driver: GraphDatabase,
|
||||
llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
node_index:faiss.Index, passage_index:faiss.Index,
|
||||
topN: int = 5,
|
||||
number_of_source_nodes_per_ner: int = 10,
|
||||
sampling_area : int = 250,logger:Logger = None):
|
||||
# istantiate one kg resources
|
||||
self.keyword = keyword
|
||||
self.neo4j_driver = neo4j_driver
|
||||
self.gds_driver = GraphDataScience(self.neo4j_driver)
|
||||
self.topN = topN
|
||||
self.number_of_source_nodes_per_ner = number_of_source_nodes_per_ner
|
||||
self.sampling_area = sampling_area
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.node_faiss_index = node_index
|
||||
# self.edge_faiss_index = self.edge_indexes[keyword]
|
||||
self.passage_faiss_index = passage_index
|
||||
|
||||
self.verbose = False if logger is None else True
|
||||
self.logger = logger
|
||||
|
||||
self.ppr_weight_threshold = 0.00005
|
||||
|
||||
def set_model(self, model):
|
||||
if self.llm_generator.inference_type == 'openai':
|
||||
self.llm_generator.model_name = model
|
||||
else:
|
||||
raise ValueError("Model can only be set for OpenAI inference type.")
|
||||
|
||||
def ner(self, text):
|
||||
return self.llm_generator.large_kg_ner(text)
|
||||
|
||||
def convert_numeric_id_to_name(self, numeric_id):
|
||||
if numeric_id.isdigit():
|
||||
return self.gds_driver.util.asNode(self.gds_driver.find_node_id(["Node"], {"numeric_id": numeric_id})).get('name')
|
||||
else:
|
||||
return numeric_id
|
||||
|
||||
def has_intersection(self, word_set, input_string):
|
||||
cleaned_string = input_string.translate(str.maketrans('', '', string.punctuation)).lower()
|
||||
if self.keyword == 'cc_en':
|
||||
# Check if any phrase in word_set is a substring of cleaned_string
|
||||
for phrase in word_set:
|
||||
if phrase in cleaned_string:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
# Check if any word in word_set is present in the cleaned_string's words
|
||||
words_in_string = set(cleaned_string.split())
|
||||
return not word_set.isdisjoint(words_in_string)
|
||||
|
||||
def pagerank(self, personalization_dict, topN=5, sampling_area=200):
|
||||
graph = self.gds_driver.graph.get('largekgrag_graph')
|
||||
node_count = graph.node_count()
|
||||
sampling_ratio = sampling_area / node_count
|
||||
aggregation_node_dict = []
|
||||
ppr_weight_threshold = self.ppr_weight_threshold
|
||||
start_time = time.time()
|
||||
|
||||
# Pre-filter nodes based on ppr_weight threshold
|
||||
# Precompute word sets based on keyword
|
||||
if self.keyword == 'cc_en':
|
||||
filtered_personalization = {
|
||||
node_id: ppr_weight
|
||||
for node_id, ppr_weight in personalization_dict.items()
|
||||
if ppr_weight >= ppr_weight_threshold
|
||||
}
|
||||
stop_words = set(stopwords.words('english'))
|
||||
word_set_phrases = set()
|
||||
word_set_words = set()
|
||||
for node_id, ppr_weight in filtered_personalization.items():
|
||||
name = self.gds_driver.util.asNode(node_id)['name']
|
||||
if name:
|
||||
cleaned_phrase = name.translate(str.maketrans('', '', string.punctuation)).lower().strip()
|
||||
if cleaned_phrase:
|
||||
# Process for 'cc_en': remove stop words and add cleaned phrase
|
||||
filtered_words = [word for word in cleaned_phrase.split() if word not in stop_words]
|
||||
if filtered_words:
|
||||
cleaned_phrase_filtered = ' '.join(filtered_words)
|
||||
word_set_phrases.add(cleaned_phrase_filtered)
|
||||
|
||||
word_set = word_set_phrases if self.keyword == 'cc_en' else word_set_words
|
||||
if self.verbose:
|
||||
self.logger.info(f"Optimized word set: {word_set}")
|
||||
else:
|
||||
filtered_personalization = personalization_dict
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Personalization dict: {filtered_personalization}")
|
||||
self.logger.info(f"largekgRAG : Sampling ratio: {sampling_ratio}")
|
||||
self.logger.info(f"largekgRAG : PPR weight threshold: {ppr_weight_threshold}")
|
||||
# Process each node in the filtered personalization dict
|
||||
for node_id, ppr_weight in filtered_personalization.items():
|
||||
try:
|
||||
self.gds_driver.graph.drop('rwr_sample')
|
||||
start_time = time.time()
|
||||
G_sample, _ = self.gds_driver.graph.sample.rwr("rwr_sample", graph, concurrency=4, samplingRatio = sampling_ratio, startNodes = [node_id],
|
||||
restartProbability = 0.4, logProgress = False)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Sampled graph for node {node_id} in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
result = self.gds_driver.pageRank.stream(
|
||||
G_sample, maxIterations=30, sourceNodes=[node_id], logProgress=False
|
||||
).sort_values("score", ascending=False)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"pagerank type: {type(result)}")
|
||||
self.logger.info(f"pagerank result: {result}")
|
||||
self.logger.info(f"largekgRAG : PageRank calculated for node {node_id} in {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
# if self.keyword == 'cc_en':
|
||||
if self.keyword != 'cc_en':
|
||||
result = result[result['score'] > 0.0].nlargest(50, 'score').to_dict('records')
|
||||
else:
|
||||
result = result.to_dict('records')
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG :result: {result}")
|
||||
for entry in result:
|
||||
if self.keyword == 'cc_en':
|
||||
node_name = self.gds_driver.util.asNode(entry['nodeId'])['name']
|
||||
if not self.has_intersection(word_set, node_name):
|
||||
continue
|
||||
|
||||
numeric_id = self.gds_driver.util.asNode(entry['nodeId'])['numeric_id']
|
||||
aggregation_node_dict.append({
|
||||
'nodeId': numeric_id,
|
||||
'score': entry['score'] * ppr_weight
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self.logger.error(f"Error processing node {node_id}: {e}")
|
||||
self.logger.error(f"Node is filtered out: {self.gds_driver.util.asNode(node_id)['name']}")
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
aggregation_node_dict = sorted(aggregation_node_dict, key=lambda x: x['score'], reverse=True)[:25]
|
||||
if self.verbose:
|
||||
self.logger.info(f"Aggregation node dict: {aggregation_node_dict}")
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to sample and calculate PageRank: {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
with self.neo4j_driver.session() as session:
|
||||
intermediate_time = time.time()
|
||||
# Step 1: Distribute entity scores to connected text nodes and find the top 5
|
||||
query_scores = """
|
||||
UNWIND $entries AS entry
|
||||
MATCH (n:Node {numeric_id: entry.nodeId})-[:Source]->(t:Text)
|
||||
WITH t.numeric_id AS textId, SUM(entry.score) AS total_score
|
||||
ORDER BY total_score DESC
|
||||
LIMIT $topN
|
||||
RETURN textId, total_score
|
||||
"""
|
||||
# Execute query to aggregate scores
|
||||
result_scores = session.run(query_scores, entries=aggregation_node_dict, topN=topN)
|
||||
top_numeric_ids = []
|
||||
top_scores = []
|
||||
|
||||
# Extract the top text node IDs and scores
|
||||
for record in result_scores:
|
||||
top_numeric_ids.append(record["textId"])
|
||||
top_scores.append(record["total_score"])
|
||||
|
||||
# Step 2: Use top numeric IDs to retrieve the original text
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to prepare query 1 : {time.time() - intermediate_time:.2f} seconds")
|
||||
intermediate_time = time.time()
|
||||
query_text = """
|
||||
UNWIND $textIds AS textId
|
||||
MATCH (t:Text {numeric_id: textId})
|
||||
RETURN t.original_text AS text, t.numeric_id AS textId
|
||||
"""
|
||||
result_texts = session.run(query_text, textIds=top_numeric_ids)
|
||||
topN_passages = []
|
||||
score_dict = dict(zip(top_numeric_ids, top_scores))
|
||||
|
||||
# Combine original text with scores
|
||||
for record in result_texts:
|
||||
original_text = record["text"]
|
||||
text_id = record["textId"]
|
||||
score = score_dict.get(text_id, 0)
|
||||
topN_passages.append((original_text, score))
|
||||
if self.verbose:
|
||||
self.logger.info(f"Time taken to prepare query 2 : {time.time() - intermediate_time:.2f} seconds")
|
||||
# Sort passages by score
|
||||
topN_passages = sorted(topN_passages, key=lambda x: x[1], reverse=True)
|
||||
top_texts = [item[0] for item in topN_passages][:topN]
|
||||
top_scores = [item[1] for item in topN_passages][:topN]
|
||||
if self.verbose:
|
||||
self.logger.info(f"Total passages retrieved: {len(top_texts)}")
|
||||
self.logger.info(f"Top passages: {top_texts}")
|
||||
self.logger.info(f"Top scores: {top_scores}")
|
||||
if self.verbose:
|
||||
self.logger.info(f"Neo4j Query Time: {time.time() - start_time:.2f} seconds")
|
||||
return top_texts, top_scores
|
||||
|
||||
def retrieve_topk_nodes(self, query, top_k_nodes = 2):
|
||||
# extract entities from the query
|
||||
entities = self.ner(query)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : LLM Extracted entities: {entities}")
|
||||
if len(entities) == 0:
|
||||
entities = [query]
|
||||
num_entities = len(entities)
|
||||
initial_nodes = []
|
||||
for entity in entities:
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
D, I = self.node_faiss_index.search(entity_embedding, top_k_nodes)
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Search results - Distances: {D}, Indices: {I}")
|
||||
initial_nodes += [str(i)for i in I[0]]
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Initial nodes: {initial_nodes}")
|
||||
name_id_map = {}
|
||||
for node_id in initial_nodes:
|
||||
name = self.convert_numeric_id_to_name(node_id)
|
||||
name_id_map[name] = node_id
|
||||
|
||||
topk_nodes = list(set(initial_nodes))
|
||||
# convert the numeric id to string and filter again then return numeric id
|
||||
keywords_before_filter = [self.convert_numeric_id_to_name(n) for n in initial_nodes]
|
||||
filtered_keywords = self.llm_generator.large_kg_filter_keywords_with_entity(query, keywords_before_filter)
|
||||
|
||||
# Second pass: Add filtered keywords
|
||||
filtered_top_k_nodes = []
|
||||
filter_log_dict = {}
|
||||
match_threshold = 0.8
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Filtered Before Match Keywords Candidate: {filtered_keywords}")
|
||||
for keyword in filtered_keywords:
|
||||
# Check for an exact match first
|
||||
if keyword in name_id_map:
|
||||
filtered_top_k_nodes.append(name_id_map[keyword])
|
||||
filter_log_dict[keyword] = name_id_map[keyword]
|
||||
else:
|
||||
# Look for close matches using difflib's get_close_matches
|
||||
close_matches = get_close_matches(keyword, name_id_map.keys(), n=1, cutoff=match_threshold)
|
||||
|
||||
if close_matches:
|
||||
# If a close match is found, add the corresponding node
|
||||
filtered_top_k_nodes.append(name_id_map[close_matches[0]])
|
||||
|
||||
filter_log_dict[keyword] = name_id_map[close_matches[0]] if close_matches else None
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Filtered After Match Keywords Candidate: {filter_log_dict}")
|
||||
|
||||
topk_nodes = list(set(filtered_top_k_nodes))
|
||||
if len(topk_nodes) > 2 * num_entities:
|
||||
topk_nodes = topk_nodes[:2 * num_entities]
|
||||
return topk_nodes
|
||||
|
||||
def _process_text(self, text):
|
||||
"""Normalize text for containment checks (lowercase, alphanumeric+spaces)"""
|
||||
text = text.lower()
|
||||
text = ''.join([c for c in text if c.isalnum() or c.isspace()])
|
||||
return set(text.split())
|
||||
|
||||
def retrieve_personalization_dict(self, query, number_of_source_nodes_per_ner=5):
|
||||
topk_nodes = self.retrieve_topk_nodes(query, number_of_source_nodes_per_ner)
|
||||
if topk_nodes == []:
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : No nodes found for query: {query}")
|
||||
return {}
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Topk nodes: {[self.convert_numeric_id_to_name(node_id) for node_id in topk_nodes]}")
|
||||
freq_dict_for_nodes = {}
|
||||
query = """
|
||||
UNWIND $nodes AS node
|
||||
MATCH (n1:Node {numeric_id: node})-[r:Source]-(n2:Text)
|
||||
RETURN n1.numeric_id as numeric_id, COUNT(DISTINCT n2.text_id) AS fileCount
|
||||
"""
|
||||
with self.neo4j_driver.session() as session:
|
||||
result = session.run(query, nodes=topk_nodes)
|
||||
for record in result:
|
||||
freq_dict_for_nodes[record["numeric_id"]] = record["fileCount"]
|
||||
# Create the personalization dictionary
|
||||
personalization_dict = {self.gds_driver.find_node_id(["Node"],{"numeric_id": numeric_id}): 1 / file_count for numeric_id, file_count in freq_dict_for_nodes.items()}
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Personalization dict's number of node: {len(personalization_dict)}")
|
||||
return personalization_dict
|
||||
|
||||
def retrieve_passages(self, query):
|
||||
if self.verbose:
|
||||
self.logger.info(f"largekgRAG : Retrieving passages for query: {query}")
|
||||
topN = self.topN
|
||||
number_of_source_nodes_per_ner = self.number_of_source_nodes_per_ner
|
||||
sampling_area = self.sampling_area
|
||||
personalization_dict = self.retrieve_personalization_dict(query, number_of_source_nodes_per_ner)
|
||||
if personalization_dict == {}:
|
||||
return [], [0]
|
||||
topN_passages, topN_scores = self.pagerank(personalization_dict, topN, sampling_area = sampling_area)
|
||||
return topN_passages, topN_scores
|
||||
|
||||
|
||||
|
||||
469
AIEC-RAG/atlas_rag/retriever/lkg_retriever/tog.py
Normal file
469
AIEC-RAG/atlas_rag/retriever/lkg_retriever/tog.py
Normal file
@ -0,0 +1,469 @@
|
||||
from neo4j import GraphDatabase
|
||||
import faiss
|
||||
import numpy as np
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
import time
|
||||
import logging
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGEdgeRetriever
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
nltk.download('stopwords')
|
||||
|
||||
class LargeKGToGRetriever(BaseLargeKGEdgeRetriever):
|
||||
def __init__(self, keyword: str, neo4j_driver: GraphDatabase,
|
||||
llm_generator: LLMGenerator, sentence_encoder: BaseEmbeddingModel, filter_encoder: BaseEmbeddingModel,
|
||||
node_index: faiss.Index,
|
||||
topN : int = 5,
|
||||
Dmax : int = 3,
|
||||
Wmax : int = 3,
|
||||
prune_size: int = 10,
|
||||
logger: logging.Logger = None):
|
||||
"""
|
||||
Initialize the LargeKGToGRetriever for billion-level KG retrieval using Neo4j.
|
||||
|
||||
Args:
|
||||
keyword (str): Identifier for the KG dataset (e.g., 'cc_en').
|
||||
neo4j_driver (GraphDatabase): Neo4j driver for database access.
|
||||
llm_generator (LLMGenerator): LLM for NER, rating, and reasoning.
|
||||
sentence_encoder (BaseEmbeddingModel): Encoder for generating embeddings.
|
||||
node_index (faiss.Index): FAISS index for node embeddings.
|
||||
logger (Logger, optional): Logger for verbose output.
|
||||
"""
|
||||
self.keyword = keyword
|
||||
self.neo4j_driver = neo4j_driver
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.filter_encoder = filter_encoder
|
||||
self.node_faiss_index = node_index
|
||||
self.verbose = logger is not None
|
||||
self.logger = logger
|
||||
self.topN = topN
|
||||
self.Dmax = Dmax
|
||||
self.Wmax = Wmax
|
||||
self.prune_size = prune_size
|
||||
|
||||
def ner(self, text: str) -> List[str]:
|
||||
"""
|
||||
Extract named entities from the query text using the LLM.
|
||||
|
||||
Args:
|
||||
text (str): The query text.
|
||||
|
||||
Returns:
|
||||
List[str]: List of extracted entities.
|
||||
"""
|
||||
entities = self.llm_generator.large_kg_ner(text)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Extracted entities: {entities}")
|
||||
return entities
|
||||
|
||||
def retrieve_topk_nodes(self, query: str, top_k_nodes: int = 5) -> List[str]:
|
||||
"""
|
||||
Retrieve top-k nodes similar to entities in the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
top_k_nodes (int): Number of nodes to retrieve per entity.
|
||||
|
||||
Returns:
|
||||
List[str]: List of node numeric_ids.
|
||||
"""
|
||||
start_time = time.time()
|
||||
entities = self.ner(query)
|
||||
if self.verbose:
|
||||
ner_time = time.time() - start_time
|
||||
self.logger.info(f"NER took {ner_time:.2f} seconds, entities: {entities}")
|
||||
if not entities:
|
||||
entities = [query]
|
||||
|
||||
initial_nodes = []
|
||||
for entity in entities:
|
||||
entity_embedding = self.sentence_encoder.encode([entity])
|
||||
D, I = self.node_faiss_index.search(entity_embedding, 3)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Entity: {entity}, FAISS Distances: {D}, Indices: {I}")
|
||||
if len(I[0]) > 0: # Check if results exist
|
||||
initial_nodes.extend([str(i) for i in I[0]])
|
||||
# no need filtering as ToG pruning will handle it.
|
||||
topk_nodes_ids = list(set(initial_nodes))
|
||||
|
||||
with self.neo4j_driver.session() as session:
|
||||
start_time = time.time()
|
||||
query = """
|
||||
MATCH (n:Node)
|
||||
WHERE n.numeric_id IN $topk_nodes_ids
|
||||
RETURN n.numeric_id AS id, n.name AS name
|
||||
"""
|
||||
result = session.run(query, topk_nodes_ids=topk_nodes_ids)
|
||||
topk_nodes_dict = {}
|
||||
for record in result:
|
||||
topk_nodes_dict[record["id"]] = record["name"]
|
||||
if self.verbose:
|
||||
neo4j_time = time.time() - start_time
|
||||
self.logger.info(f"Neo4j query took {neo4j_time:.2f} seconds, count: {len(topk_nodes_ids)}")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Top-k nodes: {topk_nodes_dict}")
|
||||
return list(topk_nodes_dict.keys()), list(topk_nodes_dict.values()) # numeric_id of nodes returned
|
||||
|
||||
def expand_paths(self, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], width: int, query: str) -> List[List[str]]:
|
||||
"""
|
||||
Expand each path by adding neighbors of the last node.
|
||||
|
||||
Args:
|
||||
P (List[List[str]]): Current list of paths, where each path is a list of alternating node_ids and relation types.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: List of expanded paths.
|
||||
"""
|
||||
last_nodes = []
|
||||
last_node_ids = []
|
||||
last_node_types = []
|
||||
|
||||
paths_end_with_text = []
|
||||
paths_end_with_text_id = []
|
||||
paths_end_with_text_type = []
|
||||
|
||||
paths_end_with_node = []
|
||||
paths_end_with_node_id = []
|
||||
paths_end_with_node_type = []
|
||||
if self.verbose:
|
||||
self.logger.info(f"Expanding paths, current paths: {P}")
|
||||
for p, pid, ptype in zip(P, PIDS, PTYPES):
|
||||
if not p or not pid or not ptype: # Skip empty paths
|
||||
continue
|
||||
t = ptype[-1]
|
||||
if t == "Text":
|
||||
paths_end_with_text.append(p)
|
||||
paths_end_with_text_id.append(pid)
|
||||
paths_end_with_text_type.append(ptype)
|
||||
continue
|
||||
last_node = p[-1] # Last node in the path
|
||||
last_node_id = pid[-1] # Last node numeric_id
|
||||
|
||||
last_nodes.append(last_node)
|
||||
last_node_ids.append(last_node_id)
|
||||
last_node_types.append(t)
|
||||
|
||||
paths_end_with_node.append(p)
|
||||
paths_end_with_node_id.append(pid)
|
||||
paths_end_with_node_type.append(ptype)
|
||||
|
||||
assert len(last_nodes) == len(last_node_ids) == len(last_node_types), "Mismatch in last nodes, ids, and types lengths"
|
||||
|
||||
if not last_node_ids:
|
||||
return paths_end_with_text, paths_end_with_text_id, paths_end_with_text_type
|
||||
|
||||
with self.neo4j_driver.session() as session:
|
||||
# Query Node relationships
|
||||
start_time = time.time()
|
||||
outgoing_query = """
|
||||
CALL apoc.cypher.runTimeboxed(
|
||||
"MATCH (n:Node)-[r:Relation]-(m:Node) WHERE n.numeric_id IN $last_node_ids
|
||||
WITH n, r, m ORDER BY rand() LIMIT 60000
|
||||
RETURN n.numeric_id AS source, n.name AS source_name, r.relation AS rel_type, m.numeric_id AS target, m.name AS target_name, 'Node' AS target_type",
|
||||
{last_node_ids: $last_node_ids},
|
||||
60000
|
||||
)
|
||||
YIELD value
|
||||
RETURN value.source AS source, value.source_name AS source_name, value.rel_type AS rel_type, value.target AS target, value.target_name AS target_name, value.target_type AS target_type
|
||||
"""
|
||||
outgoing_result = session.run(outgoing_query, last_node_ids=last_node_ids)
|
||||
outgoing = [(record["source"], record['source_name'], record["rel_type"], record["target"], record["target_name"], record["target_type"])
|
||||
for record in outgoing_result]
|
||||
if self.verbose:
|
||||
outgoing_time = time.time() - start_time
|
||||
self.logger.info(f"Outgoing relationships query took {outgoing_time:.2f} seconds, count: {len(outgoing)}")
|
||||
|
||||
# # Query outgoing Node -> Text relationships
|
||||
# start_time = time.time()
|
||||
# outgoing_text_query = """
|
||||
# MATCH (n:Node)-[r:Source]->(t:Text)
|
||||
# WHERE n.numeric_id IN $last_node_ids
|
||||
# RETURN n.numeric_id AS source, n.name AS source_name, 'from Source' AS rel_type, t.numeric_id as target, t.original_text AS target_name, 'Text' AS target_type
|
||||
# """
|
||||
# outgoing_text_result = session.run(outgoing_text_query, last_node_ids=last_node_ids)
|
||||
# outgoing_text = [(record["source"], record["source_name"], record["rel_type"], record["target"], record["target_name"], record["target_type"])
|
||||
# for record in outgoing_text_result]
|
||||
# if self.verbose:
|
||||
# outgoing_text_time = time.time() - start_time
|
||||
# self.logger.info(f"Outgoing Node->Text relationships query took {outgoing_text_time:.2f} seconds, count: {len(outgoing_text)}")
|
||||
|
||||
last_node_to_new_paths = defaultdict(list)
|
||||
last_node_to_new_paths_ids = defaultdict(list)
|
||||
last_node_to_new_paths_types = defaultdict(list)
|
||||
|
||||
|
||||
for p, pid, ptype in zip(P, PIDS, PTYPES):
|
||||
last_node = p[-1]
|
||||
last_node_id = pid[-1]
|
||||
# Outgoing Node -> Node
|
||||
for source, source_name, rel_type, target, target_name, target_type in outgoing:
|
||||
if source == last_node_id and target_name not in p:
|
||||
new_path = p + [rel_type, target_name]
|
||||
if target_name.lower() in stopwords.words('english'):
|
||||
continue
|
||||
last_node_to_new_paths[last_node].append(new_path)
|
||||
last_node_to_new_paths_ids[last_node].append(pid + [target])
|
||||
last_node_to_new_paths_types[last_node].append(ptype + [target_type])
|
||||
|
||||
# # Outgoing Node -> Text
|
||||
# for source, source_name, rel_type, target, target_name, target_type in outgoing_text:
|
||||
# if source == last_node_id and target_name not in p:
|
||||
# new_path = p + [rel_type, target_name]
|
||||
|
||||
# last_node_to_new_paths_text[last_node].append(new_path)
|
||||
# last_node_to_new_paths_text_ids[last_node].append(pid + [target])
|
||||
# last_node_to_new_paths_text_types[last_node].append(ptype + [target_type])
|
||||
|
||||
# # Incoming Node -> Node
|
||||
# for source, rel_type, target, source_name, source_type in incoming:
|
||||
# if target == last_node_id and source not in p:
|
||||
# new_path = p + [rel_type, source_name]
|
||||
|
||||
|
||||
num_paths = 0
|
||||
for last_node, new_paths in last_node_to_new_paths.items():
|
||||
num_paths += len(new_paths)
|
||||
# for last_node, new_paths in last_node_to_new_paths_text.items():
|
||||
# num_paths += len(new_paths)
|
||||
new_paths = []
|
||||
new_pids = []
|
||||
new_ptypes = []
|
||||
if self.verbose:
|
||||
self.logger.info(f"Number of new paths before filtering: {num_paths}")
|
||||
self.logger.info(f"last nodes: {last_node_to_new_paths.keys()}")
|
||||
if num_paths > len(last_node_ids) * width:
|
||||
# Apply filtering when total paths exceed threshold
|
||||
for last_node, new_ps in last_node_to_new_paths.items():
|
||||
if len(new_ps) > width:
|
||||
path_embeddings = self.filter_encoder.encode(new_ps)
|
||||
query_embeddings = self.filter_encoder.encode([query])
|
||||
scores = np.dot(path_embeddings, query_embeddings.T).flatten()
|
||||
top_indices = np.argsort(scores)[-width:]
|
||||
new_paths.extend([new_ps[i] for i in top_indices])
|
||||
new_pids.extend([last_node_to_new_paths_ids[last_node][i] for i in top_indices])
|
||||
new_ptypes.extend([last_node_to_new_paths_types[last_node][i] for i in top_indices])
|
||||
else:
|
||||
new_paths.extend(new_ps)
|
||||
new_pids.extend(last_node_to_new_paths_ids[last_node])
|
||||
new_ptypes.extend(last_node_to_new_paths_types[last_node])
|
||||
else:
|
||||
# Collect all paths without filtering when total is at or below threshold
|
||||
for last_node, new_ps in last_node_to_new_paths.items():
|
||||
new_paths.extend(new_ps)
|
||||
new_pids.extend(last_node_to_new_paths_ids[last_node])
|
||||
new_ptypes.extend(last_node_to_new_paths_types[last_node])
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Expanded paths count: {len(new_paths)}")
|
||||
self.logger.info(f"Expanded paths: {new_paths}")
|
||||
return new_paths, new_pids, new_ptypes
|
||||
|
||||
def path_to_string(self, path: List[str]) -> str:
|
||||
"""
|
||||
Convert a path to a human-readable string for LLM rating.
|
||||
|
||||
Args:
|
||||
path (List[str]): Path as a list of node_ids and relation types.
|
||||
|
||||
Returns:
|
||||
str: String representation of the path.
|
||||
"""
|
||||
if len(path) < 1:
|
||||
return ""
|
||||
path_str = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for i in range(0, len(path), 2):
|
||||
node_id = path[i]
|
||||
result = session.run("MATCH (n:Node {numeric_id: $node_id}) RETURN n.name", node_id=node_id)
|
||||
node_name = result.single()["n.name"] if result.single() else node_id
|
||||
if i + 1 < len(path):
|
||||
rel_type = path[i + 1]
|
||||
path_str.append(f"{node_name} ---> {rel_type} --->")
|
||||
else:
|
||||
path_str.append(node_name)
|
||||
return " ".join(path_str).strip()
|
||||
|
||||
def prune(self, query: str, P: List[List[str]], PIDS: List[List[str]], PTYPES: List[List[str]], topN: int = 5) -> List[List[str]]:
|
||||
"""
|
||||
Prune paths to keep the top N based on LLM relevance ratings.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
P (List[List[str]]): List of paths to prune.
|
||||
topN (int): Number of paths to retain.
|
||||
|
||||
Returns:
|
||||
List[List[str]]: Top N paths.
|
||||
"""
|
||||
ratings = []
|
||||
path_strings = P
|
||||
|
||||
# Process paths in chunks of 10
|
||||
for i in range(0, len(path_strings), self.prune_size):
|
||||
chunk = path_strings[i:i + self.prune_size]
|
||||
|
||||
# Construct user prompt with the current chunk of paths listed
|
||||
user_prompt = f"Please rate the following paths based on how well they help answer the query (1-5, 0 if not relevant).\n\nQuery: {query}\n\nPaths:\n"
|
||||
for j, path_str in enumerate(chunk, 1):
|
||||
user_prompt += f"{j + i}. {path_str}\n"
|
||||
user_prompt += "\nProvide a list of integers, each corresponding to the rating of the path's ability to help answer the query."
|
||||
|
||||
# Define system prompt to expect a list of integers
|
||||
system_prompt = "You are a rating machine that only provides a list of comma-separated integers (0-5) as a response, each rating how well the corresponding path helps answer the query."
|
||||
|
||||
# Send the prompt to the language model
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = self.llm_generator.generate_response(messages, max_new_tokens=1024, temperature=0.0)
|
||||
if self.verbose:
|
||||
self.logger.info(f"LLM response for chunk {i // self.prune_size + 1}: {response}")
|
||||
|
||||
# Parse the response into a list of ratings
|
||||
rating_str = response.strip()
|
||||
chunk_ratings = [int(r) for r in rating_str.split(',') if r.strip().isdigit()]
|
||||
if len(chunk_ratings) > len(chunk):
|
||||
chunk_ratings = chunk_ratings[:len(chunk)]
|
||||
if self.verbose:
|
||||
self.logger.warning(f"Received more ratings ({len(chunk_ratings)}) than paths in chunk ({len(chunk)}). Trimming ratings.")
|
||||
ratings.extend(chunk_ratings) # Concatenate ratings
|
||||
|
||||
# Ensure ratings length matches number of paths, padding with 0s if necessary
|
||||
if len(ratings) < len(path_strings):
|
||||
# self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Padding with 0s.")
|
||||
# ratings += [0] * (len(path_strings) - len(ratings))
|
||||
# fall back to use filter encoder to get topN
|
||||
self.logger.warning(f"Number of ratings ({len(ratings)}) does not match number of paths ({len(path_strings)}). Using filter encoder to get topN paths.")
|
||||
path_embeddings = self.filter_encoder.encode(path_strings)
|
||||
query_embedding = self.filter_encoder.encode([query])
|
||||
scores = np.dot(path_embeddings, query_embedding.T).flatten()
|
||||
top_indices = np.argsort(scores)[-topN:]
|
||||
top_paths = [path_strings[i] for i in top_indices]
|
||||
return top_paths, [PIDS[i] for i in top_indices], [PTYPES[i] for i in top_indices]
|
||||
elif len(ratings) > len(path_strings):
|
||||
self.logger.warning(f"Number of ratings ({len(ratings)}) exceeds number of paths ({len(path_strings)}). Trimming ratings.")
|
||||
ratings = ratings[:len(path_strings)]
|
||||
|
||||
# Sort indices based on ratings in descending order
|
||||
sorted_indices = sorted(range(len(ratings)), key=lambda i: ratings[i], reverse=True)
|
||||
|
||||
# Filter out indices where the rating is 0
|
||||
filtered_indices = [i for i in sorted_indices if ratings[i] > 0]
|
||||
|
||||
# Take the top N indices from the filtered list
|
||||
top_indices = filtered_indices[:topN]
|
||||
|
||||
# Use the filtered indices to get the top paths, PIDS, and PTYPES
|
||||
if self.verbose:
|
||||
self.logger.info(f"Top indices after pruning: {top_indices}")
|
||||
self.logger.info(f"length of path_strings: {len(path_strings)}")
|
||||
top_paths = [path_strings[i] for i in top_indices]
|
||||
top_pids = [PIDS[i] for i in top_indices]
|
||||
top_ptypes = [PTYPES[i] for i in top_indices]
|
||||
|
||||
|
||||
# Log top paths if verbose mode is enabled
|
||||
if self.verbose:
|
||||
self.logger.info(f"Pruned to top {topN} paths: {top_paths}")
|
||||
|
||||
return top_paths, top_pids, top_ptypes
|
||||
|
||||
def reasoning(self, query: str, P: List[List[str]]) -> bool:
|
||||
"""
|
||||
Check if the current paths are sufficient to answer the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
P (List[List[str]]): Current list of paths.
|
||||
|
||||
Returns:
|
||||
bool: True if sufficient, False otherwise.
|
||||
"""
|
||||
triples = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for path in P:
|
||||
if len(path) < 3:
|
||||
continue
|
||||
for i in range(0, len(path) - 2, 2):
|
||||
node1_name = path[i]
|
||||
rel = path[i + 1]
|
||||
node2_name = path[i + 2]
|
||||
triples.append(f"({node1_name}, {rel}, {node2_name})")
|
||||
triples_str = ". ".join(triples)
|
||||
prompt = f"Are these triples, along with your knowledge, sufficient to answer the query?\nQuery: {query}\nTriples: {triples_str}"
|
||||
messages = [
|
||||
{"role": "system", "content": "Answer Yes or No only."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
response = self.llm_generator.generate_response(messages,max_new_tokens=512)
|
||||
if self.verbose:
|
||||
self.logger.info(f"Reasoning result: {response}")
|
||||
return "yes" in response.lower()
|
||||
|
||||
def retrieve_passages(self, query: str) -> List[str]:
|
||||
"""
|
||||
Retrieve the top N paths to answer the query.
|
||||
|
||||
Args:
|
||||
query (str): The user query.
|
||||
topN (int): Number of paths to return.
|
||||
Dmax (int): Maximum depth of path expansion.
|
||||
Wmax (int): Maximum width of path expansion.
|
||||
|
||||
Returns:
|
||||
List[str]: List of triples as strings.
|
||||
"""
|
||||
topN = self.topN
|
||||
Dmax = self.Dmax
|
||||
Wmax = self.Wmax
|
||||
if self.verbose:
|
||||
self.logger.info(f"Retrieving paths for query: {query}")
|
||||
|
||||
initial_nodes_ids, initial_nodes = self.retrieve_topk_nodes(query, top_k_nodes=topN)
|
||||
if not initial_nodes:
|
||||
if self.verbose:
|
||||
self.logger.info("No initial nodes found.")
|
||||
return []
|
||||
|
||||
P = [[node] for node in initial_nodes]
|
||||
PIDS = [[node_id] for node_id in initial_nodes_ids]
|
||||
PTYPES = [["Node"] for _ in initial_nodes_ids] # Assuming all initial nodes are of type 'Node'
|
||||
for D in range(Dmax + 1):
|
||||
if self.verbose:
|
||||
self.logger.info(f"Depth {D}, Current paths: {len(P)}")
|
||||
P, PIDS, PTYPES = self.expand_paths(P, PIDS, PTYPES, Wmax, query)
|
||||
if not P:
|
||||
if self.verbose:
|
||||
self.logger.info("No paths to expand.")
|
||||
break
|
||||
P, PIDS, PTYPES = self.prune(query, P, PIDS, PTYPES, topN)
|
||||
if D == Dmax:
|
||||
if self.verbose:
|
||||
self.logger.info(f"Reached maximum depth {Dmax}, stopping expansion.")
|
||||
break
|
||||
if self.reasoning(query, P):
|
||||
if self.verbose:
|
||||
self.logger.info("Paths sufficient, stopping expansion.")
|
||||
break
|
||||
|
||||
# Extract final triples
|
||||
triples = []
|
||||
with self.neo4j_driver.session() as session:
|
||||
for path in P:
|
||||
for i in range(0, len(path) - 2, 2):
|
||||
node1_name = path[i]
|
||||
rel = path[i + 1]
|
||||
node2_name = path[i + 2]
|
||||
triples.append(f"({node1_name}, {rel}, {node2_name})")
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(f"Final triples: {triples}")
|
||||
return triples, 'N/A' # 'N/A' for passages_score as this retriever does not return passages
|
||||
Reference in New Issue
Block a user