first commit
This commit is contained in:
197
atlas_rag/vectorstore/embedding_model.py
Normal file
197
atlas_rag/vectorstore/embedding_model.py
Normal file
@ -0,0 +1,197 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModel
|
||||
import torch.nn.functional as F
|
||||
from abc import ABC, abstractmethod
|
||||
import csv
|
||||
|
||||
class BaseEmbeddingModel(ABC):
|
||||
def __init__(self, sentence_encoder):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, query, **kwargs):
|
||||
"""Abstract method to encode queries."""
|
||||
pass
|
||||
|
||||
def compute_kg_embedding(self, node_csv_without_emb, node_csv_file, edge_csv_without_emb, edge_csv_file, text_node_csv_without_emb, text_node_csv, **kwargs):
|
||||
with open(node_csv_without_emb, "r") as csvfile_node:
|
||||
with open(node_csv_file, "w", newline='') as csvfile_node_emb:
|
||||
reader_node = csv.reader(csvfile_node)
|
||||
|
||||
# the reader has [name:ID,type,concepts,synsets,:LABEL]
|
||||
writer_node = csv.writer(csvfile_node_emb)
|
||||
writer_node.writerow(["name:ID", "type", "file_id", "concepts", "synsets", "embedding:STRING", ":LABEL"])
|
||||
|
||||
# the encoding will be processed in batch of 2048
|
||||
batch_size = kwargs.get('batch_size', 2048)
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
for row in reader_node:
|
||||
if row[0] == "name:ID":
|
||||
continue
|
||||
batch_nodes.append(row[0])
|
||||
batch_rows.append(row)
|
||||
|
||||
if len(batch_nodes) == batch_size:
|
||||
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||||
for row in batch_rows:
|
||||
|
||||
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||||
writer_node.writerow(new_row)
|
||||
|
||||
|
||||
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_nodes) > 0:
|
||||
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||||
writer_node.writerow(new_row)
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
|
||||
with open(edge_csv_without_emb, "r") as csvfile_edge:
|
||||
with open(edge_csv_file, "w", newline='') as csvfile_edge_emb:
|
||||
reader_edge = csv.reader(csvfile_edge)
|
||||
# [":START_ID",":END_ID","relation","concepts","synsets",":TYPE"]
|
||||
writer_edge = csv.writer(csvfile_edge_emb)
|
||||
writer_edge.writerow([":START_ID", ":END_ID", "relation", "file_id", "concepts", "synsets", "embedding:STRING", ":TYPE"])
|
||||
|
||||
# the encoding will be processed in batch of 4096
|
||||
batch_size = 2048
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
for row in reader_edge:
|
||||
if row[0] == ":START_ID":
|
||||
continue
|
||||
batch_edges.append(" ".join([row[0], row[2], row[1]]))
|
||||
batch_rows.append(row)
|
||||
|
||||
if len(batch_edges) == batch_size:
|
||||
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||||
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||||
writer_edge.writerow(new_row)
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_edges) > 0:
|
||||
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||||
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||||
writer_edge.writerow(new_row)
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
|
||||
|
||||
with open(text_node_csv_without_emb, "r") as csvfile_text_node:
|
||||
with open(text_node_csv, "w", newline='') as csvfile_text_node_emb:
|
||||
reader_text_node = csv.reader(csvfile_text_node)
|
||||
# [text_id:ID,original_text,:LABEL]
|
||||
writer_text_node = csv.writer(csvfile_text_node_emb)
|
||||
|
||||
writer_text_node.writerow(["text_id:ID", "original_text", ":LABEL", "embedding:STRING"])
|
||||
|
||||
# the encoding will be processed in batch of 2048
|
||||
batch_size = 2048
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
for row in reader_text_node:
|
||||
if row[0] == "text_id:ID":
|
||||
continue
|
||||
|
||||
batch_text_nodes.append(row[1])
|
||||
batch_rows.append(row)
|
||||
if len(batch_text_nodes) == batch_size:
|
||||
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||||
for row in batch_rows:
|
||||
embedding = text_node_embedding_dict[row[1]].tolist()
|
||||
new_row = [row[0], row[1], row[2], embedding]
|
||||
writer_text_node.writerow(new_row)
|
||||
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_text_nodes) > 0:
|
||||
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||||
for row in batch_rows:
|
||||
embedding = text_node_embedding_dict[row[1]].tolist()
|
||||
new_row = [row[0], row[1], row[2], embedding]
|
||||
|
||||
writer_text_node.writerow(new_row)
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
class NvEmbed(BaseEmbeddingModel):
|
||||
def __init__(self, sentence_encoder: SentenceTransformer | AutoModel):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
def add_eos(self, input_examples):
|
||||
"""Add EOS token to input examples."""
|
||||
if self.sentence_encoder.tokenizer.eos_token is not None:
|
||||
return [input_example + self.sentence_encoder.tokenizer.eos_token for input_example in input_examples]
|
||||
else:
|
||||
return input_examples
|
||||
|
||||
def encode(self, query, query_type=None, **kwargs):
|
||||
"""
|
||||
Encode the query into embeddings.
|
||||
|
||||
Args:
|
||||
query: Input text or list of texts.
|
||||
query_type: Type of query (e.g., 'passage', 'entity', 'edge', 'fill_in_edge', 'search').
|
||||
**kwargs: Additional arguments (e.g., normalize_embeddings).
|
||||
|
||||
Returns:
|
||||
Embeddings as a NumPy array.
|
||||
"""
|
||||
normalize_embeddings = kwargs.get('normalize_embeddings', True)
|
||||
|
||||
# Define prompt prefixes based on query type
|
||||
prompt_prefixes = {
|
||||
'passage': 'Given a question, retrieve relevant documents that best answer the question.',
|
||||
'entity': 'Given a question, retrieve relevant phrases that are mentioned in this question.',
|
||||
'edge': 'Given a question, retrieve relevant triplet facts that matches this question.',
|
||||
'fill_in_edge': 'Given a triples with only head and relation, retrieve relevant triplet facts that best fill the atomic query.'
|
||||
}
|
||||
|
||||
if query_type in prompt_prefixes:
|
||||
prompt_prefix = prompt_prefixes[query_type]
|
||||
query_prefix = f"Instruct: {prompt_prefix}\nQuery: "
|
||||
else:
|
||||
query_prefix = None
|
||||
|
||||
# Encode the query
|
||||
if isinstance(self.sentence_encoder, SentenceTransformer):
|
||||
if query_prefix:
|
||||
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), prompt=query_prefix, **kwargs)
|
||||
else:
|
||||
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), **kwargs)
|
||||
elif isinstance(self.sentence_encoder, AutoModel):
|
||||
if query_prefix:
|
||||
query_embeddings = self.sentence_encoder.encode(query, instruction=query_prefix, max_length = 32768, **kwargs)
|
||||
else:
|
||||
query_embeddings = self.sentence_encoder.encode(query, max_length = 32768, **kwargs)
|
||||
|
||||
# Normalize embeddings if required
|
||||
if normalize_embeddings:
|
||||
query_embeddings = F.normalize(query_embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
|
||||
# Move to CPU and convert to NumPy
|
||||
return query_embeddings
|
||||
|
||||
class SentenceEmbedding(BaseEmbeddingModel):
|
||||
def __init__(self,sentence_encoder:SentenceTransformer):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
def encode(self, query, **kwargs):
|
||||
return self.sentence_encoder.encode(query, **kwargs)
|
||||
Reference in New Issue
Block a user