198 lines
9.5 KiB
Python
198 lines
9.5 KiB
Python
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
from transformers import AutoModel
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
import csv
|
||
|
|
|
||
|
|
class BaseEmbeddingModel(ABC):
|
||
|
|
def __init__(self, sentence_encoder):
|
||
|
|
self.sentence_encoder = sentence_encoder
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def encode(self, query, **kwargs):
|
||
|
|
"""Abstract method to encode queries."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def compute_kg_embedding(self, node_csv_without_emb, node_csv_file, edge_csv_without_emb, edge_csv_file, text_node_csv_without_emb, text_node_csv, **kwargs):
|
||
|
|
with open(node_csv_without_emb, "r") as csvfile_node:
|
||
|
|
with open(node_csv_file, "w", newline='') as csvfile_node_emb:
|
||
|
|
reader_node = csv.reader(csvfile_node)
|
||
|
|
|
||
|
|
# the reader has [name:ID,type,concepts,synsets,:LABEL]
|
||
|
|
writer_node = csv.writer(csvfile_node_emb)
|
||
|
|
writer_node.writerow(["name:ID", "type", "file_id", "concepts", "synsets", "embedding:STRING", ":LABEL"])
|
||
|
|
|
||
|
|
# the encoding will be processed in batch of 2048
|
||
|
|
batch_size = kwargs.get('batch_size', 2048)
|
||
|
|
batch_nodes = []
|
||
|
|
batch_rows = []
|
||
|
|
for row in reader_node:
|
||
|
|
if row[0] == "name:ID":
|
||
|
|
continue
|
||
|
|
batch_nodes.append(row[0])
|
||
|
|
batch_rows.append(row)
|
||
|
|
|
||
|
|
if len(batch_nodes) == batch_size:
|
||
|
|
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||
|
|
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||
|
|
for row in batch_rows:
|
||
|
|
|
||
|
|
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||
|
|
writer_node.writerow(new_row)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
batch_nodes = []
|
||
|
|
batch_rows = []
|
||
|
|
|
||
|
|
if len(batch_nodes) > 0:
|
||
|
|
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||
|
|
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||
|
|
for row in batch_rows:
|
||
|
|
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||
|
|
writer_node.writerow(new_row)
|
||
|
|
batch_nodes = []
|
||
|
|
batch_rows = []
|
||
|
|
|
||
|
|
|
||
|
|
with open(edge_csv_without_emb, "r") as csvfile_edge:
|
||
|
|
with open(edge_csv_file, "w", newline='') as csvfile_edge_emb:
|
||
|
|
reader_edge = csv.reader(csvfile_edge)
|
||
|
|
# [":START_ID",":END_ID","relation","concepts","synsets",":TYPE"]
|
||
|
|
writer_edge = csv.writer(csvfile_edge_emb)
|
||
|
|
writer_edge.writerow([":START_ID", ":END_ID", "relation", "file_id", "concepts", "synsets", "embedding:STRING", ":TYPE"])
|
||
|
|
|
||
|
|
# the encoding will be processed in batch of 4096
|
||
|
|
batch_size = 2048
|
||
|
|
batch_edges = []
|
||
|
|
batch_rows = []
|
||
|
|
for row in reader_edge:
|
||
|
|
if row[0] == ":START_ID":
|
||
|
|
continue
|
||
|
|
batch_edges.append(" ".join([row[0], row[2], row[1]]))
|
||
|
|
batch_rows.append(row)
|
||
|
|
|
||
|
|
if len(batch_edges) == batch_size:
|
||
|
|
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||
|
|
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||
|
|
for row in batch_rows:
|
||
|
|
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||
|
|
writer_edge.writerow(new_row)
|
||
|
|
batch_edges = []
|
||
|
|
batch_rows = []
|
||
|
|
|
||
|
|
if len(batch_edges) > 0:
|
||
|
|
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||
|
|
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||
|
|
for row in batch_rows:
|
||
|
|
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||
|
|
writer_edge.writerow(new_row)
|
||
|
|
batch_edges = []
|
||
|
|
batch_rows = []
|
||
|
|
|
||
|
|
|
||
|
|
with open(text_node_csv_without_emb, "r") as csvfile_text_node:
|
||
|
|
with open(text_node_csv, "w", newline='') as csvfile_text_node_emb:
|
||
|
|
reader_text_node = csv.reader(csvfile_text_node)
|
||
|
|
# [text_id:ID,original_text,:LABEL]
|
||
|
|
writer_text_node = csv.writer(csvfile_text_node_emb)
|
||
|
|
|
||
|
|
writer_text_node.writerow(["text_id:ID", "original_text", ":LABEL", "embedding:STRING"])
|
||
|
|
|
||
|
|
# the encoding will be processed in batch of 2048
|
||
|
|
batch_size = 2048
|
||
|
|
batch_text_nodes = []
|
||
|
|
batch_rows = []
|
||
|
|
for row in reader_text_node:
|
||
|
|
if row[0] == "text_id:ID":
|
||
|
|
continue
|
||
|
|
|
||
|
|
batch_text_nodes.append(row[1])
|
||
|
|
batch_rows.append(row)
|
||
|
|
if len(batch_text_nodes) == batch_size:
|
||
|
|
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||
|
|
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||
|
|
for row in batch_rows:
|
||
|
|
embedding = text_node_embedding_dict[row[1]].tolist()
|
||
|
|
new_row = [row[0], row[1], row[2], embedding]
|
||
|
|
writer_text_node.writerow(new_row)
|
||
|
|
|
||
|
|
batch_text_nodes = []
|
||
|
|
batch_rows = []
|
||
|
|
|
||
|
|
if len(batch_text_nodes) > 0:
|
||
|
|
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||
|
|
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||
|
|
for row in batch_rows:
|
||
|
|
embedding = text_node_embedding_dict[row[1]].tolist()
|
||
|
|
new_row = [row[0], row[1], row[2], embedding]
|
||
|
|
|
||
|
|
writer_text_node.writerow(new_row)
|
||
|
|
batch_text_nodes = []
|
||
|
|
batch_rows = []
|
||
|
|
|
||
|
|
class NvEmbed(BaseEmbeddingModel):
|
||
|
|
def __init__(self, sentence_encoder: SentenceTransformer | AutoModel):
|
||
|
|
self.sentence_encoder = sentence_encoder
|
||
|
|
|
||
|
|
def add_eos(self, input_examples):
|
||
|
|
"""Add EOS token to input examples."""
|
||
|
|
if self.sentence_encoder.tokenizer.eos_token is not None:
|
||
|
|
return [input_example + self.sentence_encoder.tokenizer.eos_token for input_example in input_examples]
|
||
|
|
else:
|
||
|
|
return input_examples
|
||
|
|
|
||
|
|
def encode(self, query, query_type=None, **kwargs):
|
||
|
|
"""
|
||
|
|
Encode the query into embeddings.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: Input text or list of texts.
|
||
|
|
query_type: Type of query (e.g., 'passage', 'entity', 'edge', 'fill_in_edge', 'search').
|
||
|
|
**kwargs: Additional arguments (e.g., normalize_embeddings).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Embeddings as a NumPy array.
|
||
|
|
"""
|
||
|
|
normalize_embeddings = kwargs.get('normalize_embeddings', True)
|
||
|
|
|
||
|
|
# Define prompt prefixes based on query type
|
||
|
|
prompt_prefixes = {
|
||
|
|
'passage': 'Given a question, retrieve relevant documents that best answer the question.',
|
||
|
|
'entity': 'Given a question, retrieve relevant phrases that are mentioned in this question.',
|
||
|
|
'edge': 'Given a question, retrieve relevant triplet facts that matches this question.',
|
||
|
|
'fill_in_edge': 'Given a triples with only head and relation, retrieve relevant triplet facts that best fill the atomic query.'
|
||
|
|
}
|
||
|
|
|
||
|
|
if query_type in prompt_prefixes:
|
||
|
|
prompt_prefix = prompt_prefixes[query_type]
|
||
|
|
query_prefix = f"Instruct: {prompt_prefix}\nQuery: "
|
||
|
|
else:
|
||
|
|
query_prefix = None
|
||
|
|
|
||
|
|
# Encode the query
|
||
|
|
if isinstance(self.sentence_encoder, SentenceTransformer):
|
||
|
|
if query_prefix:
|
||
|
|
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), prompt=query_prefix, **kwargs)
|
||
|
|
else:
|
||
|
|
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), **kwargs)
|
||
|
|
elif isinstance(self.sentence_encoder, AutoModel):
|
||
|
|
if query_prefix:
|
||
|
|
query_embeddings = self.sentence_encoder.encode(query, instruction=query_prefix, max_length = 32768, **kwargs)
|
||
|
|
else:
|
||
|
|
query_embeddings = self.sentence_encoder.encode(query, max_length = 32768, **kwargs)
|
||
|
|
|
||
|
|
# Normalize embeddings if required
|
||
|
|
if normalize_embeddings:
|
||
|
|
query_embeddings = F.normalize(query_embeddings, p=2, dim=1).detach().cpu().numpy()
|
||
|
|
|
||
|
|
# Move to CPU and convert to NumPy
|
||
|
|
return query_embeddings
|
||
|
|
|
||
|
|
class SentenceEmbedding(BaseEmbeddingModel):
|
||
|
|
def __init__(self,sentence_encoder:SentenceTransformer):
|
||
|
|
self.sentence_encoder = sentence_encoder
|
||
|
|
|
||
|
|
def encode(self, query, **kwargs):
|
||
|
|
return self.sentence_encoder.encode(query, **kwargs)
|