first commit
This commit is contained in:
51
AIEC-RAG/atlas_rag/retriever/simple_retriever.py
Normal file
51
AIEC-RAG/atlas_rag/retriever/simple_retriever.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import Dict
|
||||
import numpy as np
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
class SimpleGraphRetriever(BaseEdgeRetriever):
|
||||
|
||||
def __init__(self, llm_generator:LLMGenerator, sentence_encoder:BaseEmbeddingModel,
|
||||
data:dict):
|
||||
|
||||
self.KG = data["KG"]
|
||||
self.node_list = data["node_list"]
|
||||
self.edge_list = data["edge_list"]
|
||||
|
||||
self.llm_generator = llm_generator
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
self.node_faiss_index = data["node_faiss_index"]
|
||||
self.edge_faiss_index = data["edge_faiss_index"]
|
||||
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
# retrieve the top k edges
|
||||
topk_edges = []
|
||||
query_embedding = self.sentence_encoder.encode([query], query_type='edge')
|
||||
D, I = self.edge_faiss_index.search(query_embedding, topN)
|
||||
|
||||
topk_edges += [self.edge_list[i] for i in I[0]]
|
||||
|
||||
topk_edges_with_data = [(edge[0], self.KG.edges[edge]["relation"], edge[1]) for edge in topk_edges]
|
||||
string_edge_edges = [f"{self.KG.nodes[edge[0]]['id']} {edge[1]} {self.KG.nodes[edge[2]]['id']}" for edge in topk_edges_with_data]
|
||||
|
||||
return string_edge_edges, ["N/A" for _ in range(len(string_edge_edges))]
|
||||
|
||||
class SimpleTextRetriever(BasePassageRetriever):
|
||||
def __init__(self, passage_dict:Dict[str,str], sentence_encoder:BaseEmbeddingModel, data:dict):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
self.passage_dict = passage_dict
|
||||
self.passage_list = list(passage_dict.values())
|
||||
self.passage_keys = list(passage_dict.keys())
|
||||
self.text_embeddings = data["text_embeddings"]
|
||||
|
||||
def retrieve(self, query, topN=5, **kwargs):
|
||||
query_emb = self.sentence_encoder.encode([query], query_type="passage")
|
||||
sim_scores = self.text_embeddings @ query_emb[0].T
|
||||
topk_indices = np.argsort(sim_scores)[-topN:][::-1] # Get indices of top-k scores
|
||||
|
||||
# Retrieve top-k passages
|
||||
topk_passages = [self.passage_list[i] for i in topk_indices]
|
||||
topk_passages_ids = [self.passage_keys[i] for i in topk_indices]
|
||||
return topk_passages, topk_passages_ids
|
||||
Reference in New Issue
Block a user