Files
AIEC-RAG/atlas_rag/vectorstore/embedding_model.py
2025-09-24 09:29:12 +08:00

198 lines
9.5 KiB
Python

from sentence_transformers import SentenceTransformer
from transformers import AutoModel
import torch.nn.functional as F
from abc import ABC, abstractmethod
import csv
class BaseEmbeddingModel(ABC):
def __init__(self, sentence_encoder):
self.sentence_encoder = sentence_encoder
@abstractmethod
def encode(self, query, **kwargs):
"""Abstract method to encode queries."""
pass
def compute_kg_embedding(self, node_csv_without_emb, node_csv_file, edge_csv_without_emb, edge_csv_file, text_node_csv_without_emb, text_node_csv, **kwargs):
with open(node_csv_without_emb, "r") as csvfile_node:
with open(node_csv_file, "w", newline='') as csvfile_node_emb:
reader_node = csv.reader(csvfile_node)
# the reader has [name:ID,type,concepts,synsets,:LABEL]
writer_node = csv.writer(csvfile_node_emb)
writer_node.writerow(["name:ID", "type", "file_id", "concepts", "synsets", "embedding:STRING", ":LABEL"])
# the encoding will be processed in batch of 2048
batch_size = kwargs.get('batch_size', 2048)
batch_nodes = []
batch_rows = []
for row in reader_node:
if row[0] == "name:ID":
continue
batch_nodes.append(row[0])
batch_rows.append(row)
if len(batch_nodes) == batch_size:
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
writer_node.writerow(new_row)
batch_nodes = []
batch_rows = []
if len(batch_nodes) > 0:
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
writer_node.writerow(new_row)
batch_nodes = []
batch_rows = []
with open(edge_csv_without_emb, "r") as csvfile_edge:
with open(edge_csv_file, "w", newline='') as csvfile_edge_emb:
reader_edge = csv.reader(csvfile_edge)
# [":START_ID",":END_ID","relation","concepts","synsets",":TYPE"]
writer_edge = csv.writer(csvfile_edge_emb)
writer_edge.writerow([":START_ID", ":END_ID", "relation", "file_id", "concepts", "synsets", "embedding:STRING", ":TYPE"])
# the encoding will be processed in batch of 4096
batch_size = 2048
batch_edges = []
batch_rows = []
for row in reader_edge:
if row[0] == ":START_ID":
continue
batch_edges.append(" ".join([row[0], row[2], row[1]]))
batch_rows.append(row)
if len(batch_edges) == batch_size:
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
writer_edge.writerow(new_row)
batch_edges = []
batch_rows = []
if len(batch_edges) > 0:
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
writer_edge.writerow(new_row)
batch_edges = []
batch_rows = []
with open(text_node_csv_without_emb, "r") as csvfile_text_node:
with open(text_node_csv, "w", newline='') as csvfile_text_node_emb:
reader_text_node = csv.reader(csvfile_text_node)
# [text_id:ID,original_text,:LABEL]
writer_text_node = csv.writer(csvfile_text_node_emb)
writer_text_node.writerow(["text_id:ID", "original_text", ":LABEL", "embedding:STRING"])
# the encoding will be processed in batch of 2048
batch_size = 2048
batch_text_nodes = []
batch_rows = []
for row in reader_text_node:
if row[0] == "text_id:ID":
continue
batch_text_nodes.append(row[1])
batch_rows.append(row)
if len(batch_text_nodes) == batch_size:
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
for row in batch_rows:
embedding = text_node_embedding_dict[row[1]].tolist()
new_row = [row[0], row[1], row[2], embedding]
writer_text_node.writerow(new_row)
batch_text_nodes = []
batch_rows = []
if len(batch_text_nodes) > 0:
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
for row in batch_rows:
embedding = text_node_embedding_dict[row[1]].tolist()
new_row = [row[0], row[1], row[2], embedding]
writer_text_node.writerow(new_row)
batch_text_nodes = []
batch_rows = []
class NvEmbed(BaseEmbeddingModel):
def __init__(self, sentence_encoder: SentenceTransformer | AutoModel):
self.sentence_encoder = sentence_encoder
def add_eos(self, input_examples):
"""Add EOS token to input examples."""
if self.sentence_encoder.tokenizer.eos_token is not None:
return [input_example + self.sentence_encoder.tokenizer.eos_token for input_example in input_examples]
else:
return input_examples
def encode(self, query, query_type=None, **kwargs):
"""
Encode the query into embeddings.
Args:
query: Input text or list of texts.
query_type: Type of query (e.g., 'passage', 'entity', 'edge', 'fill_in_edge', 'search').
**kwargs: Additional arguments (e.g., normalize_embeddings).
Returns:
Embeddings as a NumPy array.
"""
normalize_embeddings = kwargs.get('normalize_embeddings', True)
# Define prompt prefixes based on query type
prompt_prefixes = {
'passage': 'Given a question, retrieve relevant documents that best answer the question.',
'entity': 'Given a question, retrieve relevant phrases that are mentioned in this question.',
'edge': 'Given a question, retrieve relevant triplet facts that matches this question.',
'fill_in_edge': 'Given a triples with only head and relation, retrieve relevant triplet facts that best fill the atomic query.'
}
if query_type in prompt_prefixes:
prompt_prefix = prompt_prefixes[query_type]
query_prefix = f"Instruct: {prompt_prefix}\nQuery: "
else:
query_prefix = None
# Encode the query
if isinstance(self.sentence_encoder, SentenceTransformer):
if query_prefix:
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), prompt=query_prefix, **kwargs)
else:
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), **kwargs)
elif isinstance(self.sentence_encoder, AutoModel):
if query_prefix:
query_embeddings = self.sentence_encoder.encode(query, instruction=query_prefix, max_length = 32768, **kwargs)
else:
query_embeddings = self.sentence_encoder.encode(query, max_length = 32768, **kwargs)
# Normalize embeddings if required
if normalize_embeddings:
query_embeddings = F.normalize(query_embeddings, p=2, dim=1).detach().cpu().numpy()
# Move to CPU and convert to NumPy
return query_embeddings
class SentenceEmbedding(BaseEmbeddingModel):
def __init__(self,sentence_encoder:SentenceTransformer):
self.sentence_encoder = sentence_encoder
def encode(self, query, **kwargs):
return self.sentence_encoder.encode(query, **kwargs)