first commit

This commit is contained in:
闫旭隆
2025-09-25 10:33:37 +08:00
commit 34839c2654
387 changed files with 149159 additions and 0 deletions

View File

@ -0,0 +1 @@
from .create_graph_index import create_embeddings_and_index

View File

@ -0,0 +1,177 @@
import os
import pickle
import networkx as nx
from tqdm import tqdm
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
import faiss
import numpy as np
import torch
def compute_graph_embeddings(node_list, edge_list_string, sentence_encoder: BaseEmbeddingModel, batch_size=40, normalize_embeddings: bool = False):
# Encode in batches
node_embeddings = []
for i in tqdm(range(0, len(node_list), batch_size), desc="Encoding nodes"):
batch = node_list[i:i + batch_size]
node_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
edge_embeddings = []
for i in tqdm(range(0, len(edge_list_string), batch_size), desc="Encoding edges"):
batch = edge_list_string[i:i + batch_size]
edge_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
return node_embeddings, edge_embeddings
def build_faiss_index(embeddings):
dimension = len(embeddings[0])
faiss_index = faiss.IndexHNSWFlat(dimension, 64, faiss.METRIC_INNER_PRODUCT)
X = np.array(embeddings).astype('float32')
# normalize the vectors
faiss.normalize_L2(X)
# batched add
for i in tqdm(range(0, X.shape[0], 32)):
faiss_index.add(X[i:i+32])
return faiss_index
def compute_text_embeddings(text_list, sentence_encoder: BaseEmbeddingModel, batch_size = 40, normalize_embeddings: bool = False):
"""Separated text embedding computation"""
text_embeddings = []
for i in tqdm(range(0, len(text_list), batch_size), desc="Encoding texts"):
batch = text_list[i:i + batch_size]
embeddings = sentence_encoder.encode(batch, normalize_embeddings=normalize_embeddings)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
text_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
return text_embeddings
def create_embeddings_and_index(sentence_encoder, model_name: str, working_directory: str, keyword: str, include_events: bool, include_concept: bool,
normalize_embeddings: bool = True,
text_batch_size = 40,
node_and_edge_batch_size = 256):
# Extract the last part of the encoder_model_name for simplified reference
encoder_model_name = model_name.split('/')[-1]
print(f"Using encoder model: {encoder_model_name}")
graph_dir = f"{working_directory}/kg_graphml/{keyword}_graph.graphml"
if not os.path.exists(graph_dir):
raise FileNotFoundError(f"Graph file {graph_dir} does not exist. Please check the path or generate the graph first.")
node_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_faiss.index"
node_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_node_list.pkl"
edge_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_faiss.index"
edge_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_edge_list.pkl"
node_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_embeddings.pkl"
edge_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_embeddings.pkl"
text_embeddings_path = f"{working_directory}/precompute/{keyword}_{encoder_model_name}_text_embeddings.pkl"
text_index_path = f"{working_directory}/precompute/{keyword}_text_faiss.index"
original_text_list_path = f"{working_directory}/precompute/{keyword}_text_list.pkl"
original_text_dict_with_node_id_path = f"{working_directory}/precompute/{keyword}_original_text_dict_with_node_id.pkl"
if not os.path.exists(f"{working_directory}/precompute"):
os.makedirs(f"{working_directory}/precompute", exist_ok=True)
print(f"Loading graph from {graph_dir}")
with open(graph_dir, "rb") as f:
KG: nx.DiGraph = nx.read_graphml(f)
node_list = list(KG.nodes)
text_list = [node for node in tqdm(node_list) if "passage" in KG.nodes[node]["type"]]
if not include_events and not include_concept:
node_list = [node for node in tqdm(node_list) if "entity" in KG.nodes[node]["type"]]
elif include_events and not include_concept:
node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]]
elif include_events and include_concept:
node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "concept" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]]
else:
raise ValueError("Invalid combination of include_events and include_concept")
edge_list = list(KG.edges)
node_set = set(node_list)
node_list_string = [KG.nodes[node]["id"] for node in node_list]
# Filter edges based on node list
edge_list_index = [i for i, edge in tqdm(enumerate(edge_list)) if edge[0] in node_set and edge[1] in node_set]
edge_list = [edge_list[i] for i in edge_list_index]
edge_list_string = [f"{KG.nodes[edge[0]]['id']} {KG.edges[edge]['relation']} {KG.nodes[edge[1]]['id']}" for edge in edge_list]
original_text_list = []
original_text_dict_with_node_id = {}
for text_node in text_list:
text = KG.nodes[text_node]["id"].strip()
original_text_list.append(text)
original_text_dict_with_node_id[text_node] = text
assert len(original_text_list) == len(original_text_dict_with_node_id)
with open(original_text_list_path, "wb") as f:
pickle.dump(original_text_list, f)
with open(original_text_dict_with_node_id_path, "wb") as f:
pickle.dump(original_text_dict_with_node_id, f)
if not os.path.exists(text_index_path) or not os.path.exists(text_embeddings_path):
print("Computing text embeddings...")
text_embeddings = compute_text_embeddings(original_text_list, sentence_encoder, text_batch_size, normalize_embeddings)
text_faiss_index = build_faiss_index(text_embeddings)
faiss.write_index(text_faiss_index, text_index_path)
with open(text_embeddings_path, "wb") as f:
pickle.dump(text_embeddings, f)
else:
print("Text embeddings already computed.")
with open(text_embeddings_path, "rb") as f:
text_embeddings = pickle.load(f)
text_faiss_index = faiss.read_index(text_index_path)
if not os.path.exists(node_embeddings_path) or not os.path.exists(edge_embeddings_path):
print("Node and edge embeddings not found, computing...")
node_embeddings, edge_embeddings = compute_graph_embeddings(node_list_string, edge_list_string, sentence_encoder, node_and_edge_batch_size, normalize_embeddings=normalize_embeddings) # Assumes this function is defined
else:
with open(node_embeddings_path, "rb") as f:
node_embeddings = pickle.load(f)
with open(edge_embeddings_path, "rb") as f:
edge_embeddings = pickle.load(f)
print("Graph embeddings already computed")
if not os.path.exists(node_index_path):
node_faiss_index = build_faiss_index(node_embeddings)
faiss.write_index(node_faiss_index, node_index_path)
else:
node_faiss_index = faiss.read_index(node_index_path)
if not os.path.exists(edge_index_path):
edge_faiss_index = build_faiss_index(edge_embeddings)
faiss.write_index(edge_faiss_index, edge_index_path)
else:
edge_faiss_index = faiss.read_index(edge_index_path)
if not os.path.exists(node_embeddings_path):
with open(node_embeddings_path, "wb") as f:
pickle.dump(node_embeddings, f)
if not os.path.exists(edge_embeddings_path):
with open(edge_embeddings_path, "wb") as f:
pickle.dump(edge_embeddings, f)
with open(node_list_path, "wb") as f:
pickle.dump(node_list, f)
with open(edge_list_path, "wb") as f:
pickle.dump(edge_list, f)
print("Node and edge embeddings already computed.")
# Return all required indices, embeddings, and lists
return {
"KG": KG,
"node_faiss_index": node_faiss_index,
"edge_faiss_index": edge_faiss_index,
"text_faiss_index": text_faiss_index,
"node_embeddings": node_embeddings,
"edge_embeddings": edge_embeddings,
"text_embeddings": text_embeddings,
"node_list": node_list,
"edge_list": edge_list,
"text_dict": original_text_dict_with_node_id,
}

View File

@ -0,0 +1,259 @@
import faiss
import numpy as np
import time
import logging
from atlas_rag.kg_construction.utils.csv_processing.csv_to_npy import convert_csv_to_npy
def create_faiss_index(output_directory, filename_pattern, index_type="HNSW,Flat"):
"""
Create faiss index for the graph, for index type, see https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
"IVF65536_HNSW32,Flat" for 1M to 10M nodes
"HNSW,Flat" for toy dataset
"""
# Convert csv to npy
convert_csv_to_npy(
csv_path=f"{output_directory}/triples_csv/triple_nodes_{filename_pattern}_from_json_with_emb.csv",
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
)
convert_csv_to_npy(
csv_path=f"{output_directory}/triples_csv/text_nodes_{filename_pattern}_from_json_with_emb.csv",
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
)
convert_csv_to_npy(
csv_path=f"{output_directory}/triples_csv/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.csv",
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
)
build_faiss_from_npy(
index_type=index_type,
index_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
)
build_faiss_from_npy(
index_type=index_type,
index_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
)
build_faiss_from_npy(
index_type=index_type,
index_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb_non_norm.index",
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
)
# cannot avoid loading into memory when training
# simply try load all to train
def build_faiss_from_npy(index_type, index_path, npy_path):
# check npy size.
# shapes = []
start_time = time.time()
# with open(npy_path, "rb") as f:
# while True:
# try:
# array = np.load(f)
# shapes.append(array.shape)
# except Exception as e:
# print(f"Stopped loading due to: {str(e)}")
# break
# if shapes:
# total_rows = sum(shape[0] for shape in shapes)
# dimension = shapes[0][1]
# print(f"Total embeddings in {npy_path}\n {total_rows}, Dimension: {dimension}")
# minilm is 32
# get the dimension from the npy file
with open(npy_path, "rb") as f:
array = np.load(f)
dimension = array.shape[1]
print(f"Dimension: {dimension}")
index = faiss.index_factory(dimension, index_type, faiss.METRIC_INNER_PRODUCT)
if index_type.startswith("IVF"):
index_ivf = faiss.extract_index_ivf(index)
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
# Load data to match the training samples size.
# Done by random picking indexes from shapes and check if the sum of the indexes is over the sample size or not.
# If yes then read them and start training, skip the np.load part for non chosen indexes
# selected_indices = set()
# possible_indices = list(range(len(shapes)))
# selected_training_samples = 0
# while selected_training_samples < max_training_samples and possible_indices:
# idx = random.choice(possible_indices)
# selected_indices.add(idx)
# selected_training_samples += shapes[idx][0]
# possible_indices.remove(idx)
# print(f"Selected total: {selected_training_samples} samples for training")
xt = []
current_index = 0
with open(npy_path, "rb") as f:
while True:
try:
# array = np.load(f)
# if current_index in selected_indices:
array = np.load(f)
# faiss.normalize_L2(array)
xt.append(array)
# current_index += 1
except Exception as e:
logging.info(f"Stopped loading due to: {str(e)}")
break
if xt:
xt = np.vstack(xt)
logging.info(f"Loading time: {time.time() - start_time:.2f} seconds")
start_time = time.time()
index.train(xt)
end_time = time.time()
logging.info(f"Training time: {end_time - start_time:.2f} seconds")
del xt
start_time = time.time()
with open(npy_path, "rb") as f:
while True:
try:
array = np.load(f)
# faiss.normalize_L2(array)
index.add(array)
except Exception as e:
logging.info(f"Stopped loading due to: {str(e)}")
break
logging.info(f"Adding time: {time.time() - start_time:.2f} seconds")
# Convert the GPU index to a CPU index for saving
index = faiss.index_gpu_to_cpu(index)
# Save the CPU index to a file
faiss.write_index(index, index_path)
def train_and_write_indexes(keyword, npy_dir="./import"):
keyword_to_paths = {
'cc_en': {
'npy':{
'node': f"{npy_dir}/triple_nodes_cc_en_from_json_2.npy",
# 'edge': f"{npy_dir}/triple_edges_cc_en_from_json_2.npy",
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb_2.npy",
},
'index':{
'node': f"{npy_dir}/triple_nodes_cc_en_from_json_non_norm.index",
# 'edge': f"{npy_dir}/triple_edges_cc_en_from_json_non_norm.index",
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb_non_norm.index",
},
'index_type':{
'node': "IVF1048576_HNSW32,Flat",
# 'edge': "IVF1048576_HNSW32,Flat",
'text': "IVF262144_HNSW32,Flat",
},
'csv':{
'node': f"{npy_dir}/triple_nodes_cc_en_from_json.csv",
# 'edge': ff"{npy_dir}/triple_edges_cc_en_from_json.csv",
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb.csv",
}
},
'pes2o_abstract': {
'npy':{
'node': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json.npy",
# 'edge': f"{npy_dir}/triple_edges_pes2o_abstract_from_json.npy",
'text': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb.npy",
},
'index':{
'node': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json_non_norm.index",
# 'edge': f"{npy_dir}/triple_edges_pes2o_abstract_from_json_non_norm.index",
'text': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb_non_norm.index",
},
'index_type':{
'node': "IVF1048576_HNSW32,Flat",
# 'edge': "IVF1048576_HNSW32,Flat",
'text': "IVF65536_HNSW32,Flat",
},
'csv':{
'node_csv': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json.csv",
# 'edge_csv': ff"{npy_dir}/triple_edges_pes2o_abstract_from_json.csv",
'text_csv': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
}
},
'en_simple_wiki_v0': {
'npy':{
'node': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json.npy",
# 'edge': f"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json.npy",
'text': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.npy",
},
'index':{
'node': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json_non_norm.index",
# 'edge': f"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json_non_norm.index",
'text': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb_non_norm.index",
},
'index_type':{
'node': "IVF1048576_HNSW32,Flat",
# 'edge': "IVF1048576_HNSW32,Flat",
'text': "IVF65536_HNSW32,Flat",
},
'csv':{
'node_csv': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
# 'edge_csv': ff"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
'text_csv': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
}
}
}
emb_list = ['node', 'text'] # Add 'edge' if needed and uncomment the related path lines
for emb in emb_list:
npy_path = keyword_to_paths[keyword]['npy'][emb]
index_path = keyword_to_paths[keyword]['index'][emb]
index_type = keyword_to_paths[keyword]['index_type'][emb]
logging.info(f"Index {index_path}, Building...")
# For cc-en the recommended training samples is 600_000_000, for the rest we can afford to training them using all data.
build_faiss_from_npy(index_type, index_path, npy_path)
# # Test the index
# for emb in emb_list:
# index_path = keyword_to_paths[keyword]['index'][emb]
# print(f"Index {index_path}, Testing...")
# test_and_search_faiss_index(index_path, keyword_to_paths[keyword]['csv'][emb])
if __name__ == "__main__":
x = 1
# keyword = "cc_en" # Replace with your actual keyword
# logging.basicConfig(
# filename=f'{keyword}_faiss_creation.log', # Log file
# level=logging.INFO, # Set the logging level
# format='%(asctime)s - %(levelname)s - %(message)s' # Log format
# )
# argparser = argparse.ArgumentParser(description="Train and write FAISS indexes for LKG construction.")
# argparser.add_argument("--npy_dir", type=str, default="./import", help="Directory containing the .npy files.")
# argparser.add_argument("--keyword", type=str, default=keyword, help="Keyword to select the dataset.")
# args = argparser.parse_args()
# keyword = args.keyword
# npy_dir = args.npy_dir
# train_and_write_indexes(keyword,npy_dir)
# index_type = "IVF65536_HNSW32,Flat"
index_type = "HNSW,Flat"
output_directory = "/home/jbai/AutoSchemaKG/import/Dulce"
filename_pattern = "Dulce"
build_faiss_from_npy(
index_type=index_type,
index_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
)
build_faiss_from_npy(
index_type=index_type,
index_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
)
build_faiss_from_npy(
index_type=index_type,
index_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb_non_norm.index",
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
)

View File

@ -0,0 +1,197 @@
from sentence_transformers import SentenceTransformer
from transformers import AutoModel
import torch.nn.functional as F
from abc import ABC, abstractmethod
import csv
class BaseEmbeddingModel(ABC):
def __init__(self, sentence_encoder):
self.sentence_encoder = sentence_encoder
@abstractmethod
def encode(self, query, **kwargs):
"""Abstract method to encode queries."""
pass
def compute_kg_embedding(self, node_csv_without_emb, node_csv_file, edge_csv_without_emb, edge_csv_file, text_node_csv_without_emb, text_node_csv, **kwargs):
with open(node_csv_without_emb, "r") as csvfile_node:
with open(node_csv_file, "w", newline='') as csvfile_node_emb:
reader_node = csv.reader(csvfile_node)
# the reader has [name:ID,type,concepts,synsets,:LABEL]
writer_node = csv.writer(csvfile_node_emb)
writer_node.writerow(["name:ID", "type", "file_id", "concepts", "synsets", "embedding:STRING", ":LABEL"])
# the encoding will be processed in batch of 2048
batch_size = kwargs.get('batch_size', 2048)
batch_nodes = []
batch_rows = []
for row in reader_node:
if row[0] == "name:ID":
continue
batch_nodes.append(row[0])
batch_rows.append(row)
if len(batch_nodes) == batch_size:
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
writer_node.writerow(new_row)
batch_nodes = []
batch_rows = []
if len(batch_nodes) > 0:
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
writer_node.writerow(new_row)
batch_nodes = []
batch_rows = []
with open(edge_csv_without_emb, "r") as csvfile_edge:
with open(edge_csv_file, "w", newline='') as csvfile_edge_emb:
reader_edge = csv.reader(csvfile_edge)
# [":START_ID",":END_ID","relation","concepts","synsets",":TYPE"]
writer_edge = csv.writer(csvfile_edge_emb)
writer_edge.writerow([":START_ID", ":END_ID", "relation", "file_id", "concepts", "synsets", "embedding:STRING", ":TYPE"])
# the encoding will be processed in batch of 4096
batch_size = 2048
batch_edges = []
batch_rows = []
for row in reader_edge:
if row[0] == ":START_ID":
continue
batch_edges.append(" ".join([row[0], row[2], row[1]]))
batch_rows.append(row)
if len(batch_edges) == batch_size:
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
writer_edge.writerow(new_row)
batch_edges = []
batch_rows = []
if len(batch_edges) > 0:
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
for row in batch_rows:
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
writer_edge.writerow(new_row)
batch_edges = []
batch_rows = []
with open(text_node_csv_without_emb, "r") as csvfile_text_node:
with open(text_node_csv, "w", newline='') as csvfile_text_node_emb:
reader_text_node = csv.reader(csvfile_text_node)
# [text_id:ID,original_text,:LABEL]
writer_text_node = csv.writer(csvfile_text_node_emb)
writer_text_node.writerow(["text_id:ID", "original_text", ":LABEL", "embedding:STRING"])
# the encoding will be processed in batch of 2048
batch_size = 2048
batch_text_nodes = []
batch_rows = []
for row in reader_text_node:
if row[0] == "text_id:ID":
continue
batch_text_nodes.append(row[1])
batch_rows.append(row)
if len(batch_text_nodes) == batch_size:
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
for row in batch_rows:
embedding = text_node_embedding_dict[row[1]].tolist()
new_row = [row[0], row[1], row[2], embedding]
writer_text_node.writerow(new_row)
batch_text_nodes = []
batch_rows = []
if len(batch_text_nodes) > 0:
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
for row in batch_rows:
embedding = text_node_embedding_dict[row[1]].tolist()
new_row = [row[0], row[1], row[2], embedding]
writer_text_node.writerow(new_row)
batch_text_nodes = []
batch_rows = []
class NvEmbed(BaseEmbeddingModel):
def __init__(self, sentence_encoder: SentenceTransformer | AutoModel):
self.sentence_encoder = sentence_encoder
def add_eos(self, input_examples):
"""Add EOS token to input examples."""
if self.sentence_encoder.tokenizer.eos_token is not None:
return [input_example + self.sentence_encoder.tokenizer.eos_token for input_example in input_examples]
else:
return input_examples
def encode(self, query, query_type=None, **kwargs):
"""
Encode the query into embeddings.
Args:
query: Input text or list of texts.
query_type: Type of query (e.g., 'passage', 'entity', 'edge', 'fill_in_edge', 'search').
**kwargs: Additional arguments (e.g., normalize_embeddings).
Returns:
Embeddings as a NumPy array.
"""
normalize_embeddings = kwargs.get('normalize_embeddings', True)
# Define prompt prefixes based on query type
prompt_prefixes = {
'passage': 'Given a question, retrieve relevant documents that best answer the question.',
'entity': 'Given a question, retrieve relevant phrases that are mentioned in this question.',
'edge': 'Given a question, retrieve relevant triplet facts that matches this question.',
'fill_in_edge': 'Given a triples with only head and relation, retrieve relevant triplet facts that best fill the atomic query.'
}
if query_type in prompt_prefixes:
prompt_prefix = prompt_prefixes[query_type]
query_prefix = f"Instruct: {prompt_prefix}\nQuery: "
else:
query_prefix = None
# Encode the query
if isinstance(self.sentence_encoder, SentenceTransformer):
if query_prefix:
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), prompt=query_prefix, **kwargs)
else:
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), **kwargs)
elif isinstance(self.sentence_encoder, AutoModel):
if query_prefix:
query_embeddings = self.sentence_encoder.encode(query, instruction=query_prefix, max_length = 32768, **kwargs)
else:
query_embeddings = self.sentence_encoder.encode(query, max_length = 32768, **kwargs)
# Normalize embeddings if required
if normalize_embeddings:
query_embeddings = F.normalize(query_embeddings, p=2, dim=1).detach().cpu().numpy()
# Move to CPU and convert to NumPy
return query_embeddings
class SentenceEmbedding(BaseEmbeddingModel):
def __init__(self,sentence_encoder:SentenceTransformer):
self.sentence_encoder = sentence_encoder
def encode(self, query, **kwargs):
return self.sentence_encoder.encode(query, **kwargs)