Files
AIEC-RAG/atlas_rag/retriever/simple_retriever.py

52 lines
2.3 KiB
Python
Raw Normal View History

2025-09-24 09:29:12 +08:00
from typing import Dict
import numpy as np
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
class SimpleGraphRetriever(BaseEdgeRetriever):
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
data:dict):
self.KG = data["KG"]
self.node_list = data["node_list"]
self.edge_list = data["edge_list"]
self.llm_generator = llm_generator
self.sentence_encoder = sentence_encoder
self.node_faiss_index = data["node_faiss_index"]
self.edge_faiss_index = data["edge_faiss_index"]
def retrieve(self, query, topN=5, **kwargs):
# retrieve the top k edges
topk_edges = []
query_embedding = self.sentence_encoder.encode([query], query_type='edge')
D, I = self.edge_faiss_index.search(query_embedding, topN)
topk_edges += [self.edge_list[i] for i in I[0]]
topk_edges_with_data = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in topk_edges]
string_edge_edges = [f"{self.KG.nodes[edge[0]]['id']} {edge[1]} {self.KG.nodes[edge[2]]['id']}" for edge in topk_edges_with_data]
return string_edge_edges, ["N/A" for _ in range(len(string_edge_edges))]
class SimpleTextRetriever(BasePassageRetriever):
def __init__(self, passage_dict:Dict[str,str], sentence_encoder:BaseEmbeddingModel, data:dict):
self.sentence_encoder = sentence_encoder
self.passage_dict = passage_dict
self.passage_list = list(passage_dict.values())
self.passage_keys = list(passage_dict.keys())
self.text_embeddings = data["text_embeddings"]
def retrieve(self, query, topN=5, **kwargs):
query_emb = self.sentence_encoder.encode([query], query_type="passage")
sim_scores = self.text_embeddings @ query_emb[0].T
topk_indices = np.argsort(sim_scores)[-topN:][::-1] # Get indices of top-k scores
# Retrieve top-k passages
topk_passages = [self.passage_list[i] for i in topk_indices]
topk_passages_ids = [self.passage_keys[i] for i in topk_indices]
return topk_passages, topk_passages_ids