first commit
This commit is contained in:
1
atlas_rag/vectorstore/__init__.py
Normal file
1
atlas_rag/vectorstore/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .create_graph_index import create_embeddings_and_index
|
||||
BIN
atlas_rag/vectorstore/__pycache__/__init__.cpython-311.pyc
Normal file
BIN
atlas_rag/vectorstore/__pycache__/__init__.cpython-311.pyc
Normal file
Binary file not shown.
BIN
atlas_rag/vectorstore/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
atlas_rag/vectorstore/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
atlas_rag/vectorstore/__pycache__/embedding_model.cpython-39.pyc
Normal file
BIN
atlas_rag/vectorstore/__pycache__/embedding_model.cpython-39.pyc
Normal file
Binary file not shown.
177
atlas_rag/vectorstore/create_graph_index.py
Normal file
177
atlas_rag/vectorstore/create_graph_index.py
Normal file
@ -0,0 +1,177 @@
|
||||
import os
|
||||
import pickle
|
||||
import networkx as nx
|
||||
from tqdm import tqdm
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def compute_graph_embeddings(node_list, edge_list_string, sentence_encoder: BaseEmbeddingModel, batch_size=40, normalize_embeddings: bool = False):
|
||||
# Encode in batches
|
||||
node_embeddings = []
|
||||
for i in tqdm(range(0, len(node_list), batch_size), desc="Encoding nodes"):
|
||||
batch = node_list[i:i + batch_size]
|
||||
node_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
|
||||
|
||||
edge_embeddings = []
|
||||
for i in tqdm(range(0, len(edge_list_string), batch_size), desc="Encoding edges"):
|
||||
batch = edge_list_string[i:i + batch_size]
|
||||
edge_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
|
||||
|
||||
return node_embeddings, edge_embeddings
|
||||
|
||||
def build_faiss_index(embeddings):
|
||||
dimension = len(embeddings[0])
|
||||
|
||||
faiss_index = faiss.IndexHNSWFlat(dimension, 64, faiss.METRIC_INNER_PRODUCT)
|
||||
X = np.array(embeddings).astype('float32')
|
||||
|
||||
# normalize the vectors
|
||||
faiss.normalize_L2(X)
|
||||
|
||||
# batched add
|
||||
for i in tqdm(range(0, X.shape[0], 32)):
|
||||
faiss_index.add(X[i:i+32])
|
||||
return faiss_index
|
||||
|
||||
def compute_text_embeddings(text_list, sentence_encoder: BaseEmbeddingModel, batch_size = 40, normalize_embeddings: bool = False):
|
||||
"""Separated text embedding computation"""
|
||||
text_embeddings = []
|
||||
|
||||
for i in tqdm(range(0, len(text_list), batch_size), desc="Encoding texts"):
|
||||
batch = text_list[i:i + batch_size]
|
||||
embeddings = sentence_encoder.encode(batch, normalize_embeddings=normalize_embeddings)
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
text_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
|
||||
return text_embeddings
|
||||
|
||||
def create_embeddings_and_index(sentence_encoder, model_name: str, working_directory: str, keyword: str, include_events: bool, include_concept: bool,
|
||||
normalize_embeddings: bool = True,
|
||||
text_batch_size = 40,
|
||||
node_and_edge_batch_size = 256):
|
||||
# Extract the last part of the encoder_model_name for simplified reference
|
||||
encoder_model_name = model_name.split('/')[-1]
|
||||
|
||||
print(f"Using encoder model: {encoder_model_name}")
|
||||
graph_dir = f"{working_directory}/kg_graphml/{keyword}_graph.graphml"
|
||||
if not os.path.exists(graph_dir):
|
||||
raise FileNotFoundError(f"Graph file {graph_dir} does not exist. Please check the path or generate the graph first.")
|
||||
|
||||
node_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_faiss.index"
|
||||
node_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_node_list.pkl"
|
||||
edge_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_faiss.index"
|
||||
edge_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_edge_list.pkl"
|
||||
node_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_embeddings.pkl"
|
||||
edge_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_embeddings.pkl"
|
||||
text_embeddings_path = f"{working_directory}/precompute/{keyword}_{encoder_model_name}_text_embeddings.pkl"
|
||||
text_index_path = f"{working_directory}/precompute/{keyword}_text_faiss.index"
|
||||
original_text_list_path = f"{working_directory}/precompute/{keyword}_text_list.pkl"
|
||||
original_text_dict_with_node_id_path = f"{working_directory}/precompute/{keyword}_original_text_dict_with_node_id.pkl"
|
||||
|
||||
if not os.path.exists(f"{working_directory}/precompute"):
|
||||
os.makedirs(f"{working_directory}/precompute", exist_ok=True)
|
||||
|
||||
print(f"Loading graph from {graph_dir}")
|
||||
with open(graph_dir, "rb") as f:
|
||||
KG: nx.DiGraph = nx.read_graphml(f)
|
||||
|
||||
node_list = list(KG.nodes)
|
||||
text_list = [node for node in tqdm(node_list) if "passage" in KG.nodes[node]["type"]]
|
||||
|
||||
if not include_events and not include_concept:
|
||||
node_list = [node for node in tqdm(node_list) if "entity" in KG.nodes[node]["type"]]
|
||||
elif include_events and not include_concept:
|
||||
node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]]
|
||||
elif include_events and include_concept:
|
||||
node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "concept" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]]
|
||||
else:
|
||||
raise ValueError("Invalid combination of include_events and include_concept")
|
||||
|
||||
edge_list = list(KG.edges)
|
||||
node_set = set(node_list)
|
||||
node_list_string = [KG.nodes[node]["id"] for node in node_list]
|
||||
|
||||
# Filter edges based on node list
|
||||
edge_list_index = [i for i, edge in tqdm(enumerate(edge_list)) if edge[0] in node_set and edge[1] in node_set]
|
||||
edge_list = [edge_list[i] for i in edge_list_index]
|
||||
edge_list_string = [f"{KG.nodes[edge[0]]['id']} {KG.edges[edge]['relation']} {KG.nodes[edge[1]]['id']}" for edge in edge_list]
|
||||
|
||||
original_text_list = []
|
||||
original_text_dict_with_node_id = {}
|
||||
for text_node in text_list:
|
||||
text = KG.nodes[text_node]["id"].strip()
|
||||
original_text_list.append(text)
|
||||
original_text_dict_with_node_id[text_node] = text
|
||||
|
||||
assert len(original_text_list) == len(original_text_dict_with_node_id)
|
||||
|
||||
with open(original_text_list_path, "wb") as f:
|
||||
pickle.dump(original_text_list, f)
|
||||
with open(original_text_dict_with_node_id_path, "wb") as f:
|
||||
pickle.dump(original_text_dict_with_node_id, f)
|
||||
|
||||
if not os.path.exists(text_index_path) or not os.path.exists(text_embeddings_path):
|
||||
print("Computing text embeddings...")
|
||||
text_embeddings = compute_text_embeddings(original_text_list, sentence_encoder, text_batch_size, normalize_embeddings)
|
||||
text_faiss_index = build_faiss_index(text_embeddings)
|
||||
faiss.write_index(text_faiss_index, text_index_path)
|
||||
with open(text_embeddings_path, "wb") as f:
|
||||
pickle.dump(text_embeddings, f)
|
||||
else:
|
||||
print("Text embeddings already computed.")
|
||||
with open(text_embeddings_path, "rb") as f:
|
||||
text_embeddings = pickle.load(f)
|
||||
text_faiss_index = faiss.read_index(text_index_path)
|
||||
|
||||
if not os.path.exists(node_embeddings_path) or not os.path.exists(edge_embeddings_path):
|
||||
print("Node and edge embeddings not found, computing...")
|
||||
node_embeddings, edge_embeddings = compute_graph_embeddings(node_list_string, edge_list_string, sentence_encoder, node_and_edge_batch_size, normalize_embeddings=normalize_embeddings) # Assumes this function is defined
|
||||
else:
|
||||
with open(node_embeddings_path, "rb") as f:
|
||||
node_embeddings = pickle.load(f)
|
||||
with open(edge_embeddings_path, "rb") as f:
|
||||
edge_embeddings = pickle.load(f)
|
||||
print("Graph embeddings already computed")
|
||||
|
||||
if not os.path.exists(node_index_path):
|
||||
node_faiss_index = build_faiss_index(node_embeddings)
|
||||
faiss.write_index(node_faiss_index, node_index_path)
|
||||
else:
|
||||
node_faiss_index = faiss.read_index(node_index_path)
|
||||
|
||||
if not os.path.exists(edge_index_path):
|
||||
edge_faiss_index = build_faiss_index(edge_embeddings)
|
||||
faiss.write_index(edge_faiss_index, edge_index_path)
|
||||
else:
|
||||
edge_faiss_index = faiss.read_index(edge_index_path)
|
||||
|
||||
if not os.path.exists(node_embeddings_path):
|
||||
with open(node_embeddings_path, "wb") as f:
|
||||
pickle.dump(node_embeddings, f)
|
||||
|
||||
if not os.path.exists(edge_embeddings_path):
|
||||
with open(edge_embeddings_path, "wb") as f:
|
||||
pickle.dump(edge_embeddings, f)
|
||||
|
||||
with open(node_list_path, "wb") as f:
|
||||
pickle.dump(node_list, f)
|
||||
|
||||
with open(edge_list_path, "wb") as f:
|
||||
pickle.dump(edge_list, f)
|
||||
|
||||
print("Node and edge embeddings already computed.")
|
||||
# Return all required indices, embeddings, and lists
|
||||
return {
|
||||
"KG": KG,
|
||||
"node_faiss_index": node_faiss_index,
|
||||
"edge_faiss_index": edge_faiss_index,
|
||||
"text_faiss_index": text_faiss_index,
|
||||
"node_embeddings": node_embeddings,
|
||||
"edge_embeddings": edge_embeddings,
|
||||
"text_embeddings": text_embeddings,
|
||||
"node_list": node_list,
|
||||
"edge_list": edge_list,
|
||||
"text_dict": original_text_dict_with_node_id,
|
||||
}
|
||||
259
atlas_rag/vectorstore/create_neo4j_index.py
Normal file
259
atlas_rag/vectorstore/create_neo4j_index.py
Normal file
@ -0,0 +1,259 @@
|
||||
import faiss
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_to_npy import convert_csv_to_npy
|
||||
def create_faiss_index(output_directory, filename_pattern, index_type="HNSW,Flat"):
|
||||
"""
|
||||
Create faiss index for the graph, for index type, see https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
|
||||
|
||||
"IVF65536_HNSW32,Flat" for 1M to 10M nodes
|
||||
|
||||
"HNSW,Flat" for toy dataset
|
||||
|
||||
"""
|
||||
# Convert csv to npy
|
||||
convert_csv_to_npy(
|
||||
csv_path=f"{output_directory}/triples_csv/triple_nodes_{filename_pattern}_from_json_with_emb.csv",
|
||||
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
convert_csv_to_npy(
|
||||
csv_path=f"{output_directory}/triples_csv/text_nodes_{filename_pattern}_from_json_with_emb.csv",
|
||||
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
convert_csv_to_npy(
|
||||
csv_path=f"{output_directory}/triples_csv/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.csv",
|
||||
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
|
||||
)
|
||||
|
||||
# cannot avoid loading into memory when training
|
||||
# simply try load all to train
|
||||
def build_faiss_from_npy(index_type, index_path, npy_path):
|
||||
# check npy size.
|
||||
# shapes = []
|
||||
start_time = time.time()
|
||||
# with open(npy_path, "rb") as f:
|
||||
# while True:
|
||||
# try:
|
||||
# array = np.load(f)
|
||||
# shapes.append(array.shape)
|
||||
# except Exception as e:
|
||||
# print(f"Stopped loading due to: {str(e)}")
|
||||
# break
|
||||
# if shapes:
|
||||
# total_rows = sum(shape[0] for shape in shapes)
|
||||
# dimension = shapes[0][1]
|
||||
# print(f"Total embeddings in {npy_path}\n {total_rows}, Dimension: {dimension}")
|
||||
# minilm is 32
|
||||
# get the dimension from the npy file
|
||||
with open(npy_path, "rb") as f:
|
||||
array = np.load(f)
|
||||
dimension = array.shape[1]
|
||||
print(f"Dimension: {dimension}")
|
||||
index = faiss.index_factory(dimension, index_type, faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
if index_type.startswith("IVF"):
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
|
||||
index_ivf.clustering_index = clustering_index
|
||||
|
||||
# Load data to match the training samples size.
|
||||
# Done by random picking indexes from shapes and check if the sum of the indexes is over the sample size or not.
|
||||
# If yes then read them and start training, skip the np.load part for non chosen indexes
|
||||
|
||||
# selected_indices = set()
|
||||
# possible_indices = list(range(len(shapes)))
|
||||
# selected_training_samples = 0
|
||||
# while selected_training_samples < max_training_samples and possible_indices:
|
||||
# idx = random.choice(possible_indices)
|
||||
# selected_indices.add(idx)
|
||||
# selected_training_samples += shapes[idx][0]
|
||||
# possible_indices.remove(idx)
|
||||
# print(f"Selected total: {selected_training_samples} samples for training")
|
||||
|
||||
xt = []
|
||||
current_index = 0
|
||||
with open(npy_path, "rb") as f:
|
||||
while True:
|
||||
try:
|
||||
# array = np.load(f)
|
||||
# if current_index in selected_indices:
|
||||
array = np.load(f)
|
||||
# faiss.normalize_L2(array)
|
||||
xt.append(array)
|
||||
# current_index += 1
|
||||
except Exception as e:
|
||||
logging.info(f"Stopped loading due to: {str(e)}")
|
||||
break
|
||||
if xt:
|
||||
xt = np.vstack(xt)
|
||||
logging.info(f"Loading time: {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
index.train(xt)
|
||||
end_time = time.time()
|
||||
logging.info(f"Training time: {end_time - start_time:.2f} seconds")
|
||||
del xt
|
||||
start_time = time.time()
|
||||
with open(npy_path, "rb") as f:
|
||||
while True:
|
||||
try:
|
||||
array = np.load(f)
|
||||
# faiss.normalize_L2(array)
|
||||
index.add(array)
|
||||
except Exception as e:
|
||||
logging.info(f"Stopped loading due to: {str(e)}")
|
||||
break
|
||||
logging.info(f"Adding time: {time.time() - start_time:.2f} seconds")
|
||||
|
||||
# Convert the GPU index to a CPU index for saving
|
||||
index = faiss.index_gpu_to_cpu(index)
|
||||
# Save the CPU index to a file
|
||||
faiss.write_index(index, index_path)
|
||||
|
||||
def train_and_write_indexes(keyword, npy_dir="./import"):
|
||||
keyword_to_paths = {
|
||||
'cc_en': {
|
||||
'npy':{
|
||||
'node': f"{npy_dir}/triple_nodes_cc_en_from_json_2.npy",
|
||||
# 'edge': f"{npy_dir}/triple_edges_cc_en_from_json_2.npy",
|
||||
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb_2.npy",
|
||||
},
|
||||
'index':{
|
||||
'node': f"{npy_dir}/triple_nodes_cc_en_from_json_non_norm.index",
|
||||
# 'edge': f"{npy_dir}/triple_edges_cc_en_from_json_non_norm.index",
|
||||
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb_non_norm.index",
|
||||
},
|
||||
'index_type':{
|
||||
'node': "IVF1048576_HNSW32,Flat",
|
||||
# 'edge': "IVF1048576_HNSW32,Flat",
|
||||
'text': "IVF262144_HNSW32,Flat",
|
||||
},
|
||||
'csv':{
|
||||
'node': f"{npy_dir}/triple_nodes_cc_en_from_json.csv",
|
||||
# 'edge': ff"{npy_dir}/triple_edges_cc_en_from_json.csv",
|
||||
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb.csv",
|
||||
}
|
||||
},
|
||||
'pes2o_abstract': {
|
||||
'npy':{
|
||||
'node': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json.npy",
|
||||
# 'edge': f"{npy_dir}/triple_edges_pes2o_abstract_from_json.npy",
|
||||
'text': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb.npy",
|
||||
},
|
||||
'index':{
|
||||
'node': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json_non_norm.index",
|
||||
# 'edge': f"{npy_dir}/triple_edges_pes2o_abstract_from_json_non_norm.index",
|
||||
'text': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb_non_norm.index",
|
||||
},
|
||||
'index_type':{
|
||||
'node': "IVF1048576_HNSW32,Flat",
|
||||
# 'edge': "IVF1048576_HNSW32,Flat",
|
||||
'text': "IVF65536_HNSW32,Flat",
|
||||
},
|
||||
'csv':{
|
||||
'node_csv': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json.csv",
|
||||
# 'edge_csv': ff"{npy_dir}/triple_edges_pes2o_abstract_from_json.csv",
|
||||
'text_csv': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
|
||||
}
|
||||
},
|
||||
'en_simple_wiki_v0': {
|
||||
'npy':{
|
||||
'node': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json.npy",
|
||||
# 'edge': f"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json.npy",
|
||||
'text': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.npy",
|
||||
},
|
||||
'index':{
|
||||
'node': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json_non_norm.index",
|
||||
# 'edge': f"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json_non_norm.index",
|
||||
'text': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb_non_norm.index",
|
||||
},
|
||||
'index_type':{
|
||||
'node': "IVF1048576_HNSW32,Flat",
|
||||
# 'edge': "IVF1048576_HNSW32,Flat",
|
||||
'text': "IVF65536_HNSW32,Flat",
|
||||
},
|
||||
'csv':{
|
||||
'node_csv': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
|
||||
# 'edge_csv': ff"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
|
||||
'text_csv': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
|
||||
}
|
||||
}
|
||||
}
|
||||
emb_list = ['node', 'text'] # Add 'edge' if needed and uncomment the related path lines
|
||||
for emb in emb_list:
|
||||
npy_path = keyword_to_paths[keyword]['npy'][emb]
|
||||
index_path = keyword_to_paths[keyword]['index'][emb]
|
||||
index_type = keyword_to_paths[keyword]['index_type'][emb]
|
||||
logging.info(f"Index {index_path}, Building...")
|
||||
# For cc-en the recommended training samples is 600_000_000, for the rest we can afford to training them using all data.
|
||||
build_faiss_from_npy(index_type, index_path, npy_path)
|
||||
|
||||
# # Test the index
|
||||
# for emb in emb_list:
|
||||
# index_path = keyword_to_paths[keyword]['index'][emb]
|
||||
# print(f"Index {index_path}, Testing...")
|
||||
# test_and_search_faiss_index(index_path, keyword_to_paths[keyword]['csv'][emb])
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
x = 1
|
||||
|
||||
# keyword = "cc_en" # Replace with your actual keyword
|
||||
# logging.basicConfig(
|
||||
# filename=f'{keyword}_faiss_creation.log', # Log file
|
||||
# level=logging.INFO, # Set the logging level
|
||||
# format='%(asctime)s - %(levelname)s - %(message)s' # Log format
|
||||
# )
|
||||
|
||||
# argparser = argparse.ArgumentParser(description="Train and write FAISS indexes for LKG construction.")
|
||||
# argparser.add_argument("--npy_dir", type=str, default="./import", help="Directory containing the .npy files.")
|
||||
# argparser.add_argument("--keyword", type=str, default=keyword, help="Keyword to select the dataset.")
|
||||
|
||||
# args = argparser.parse_args()
|
||||
# keyword = args.keyword
|
||||
# npy_dir = args.npy_dir
|
||||
|
||||
# train_and_write_indexes(keyword,npy_dir)
|
||||
# index_type = "IVF65536_HNSW32,Flat"
|
||||
index_type = "HNSW,Flat"
|
||||
|
||||
output_directory = "/home/jbai/AutoSchemaKG/import/Dulce"
|
||||
filename_pattern = "Dulce"
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
|
||||
)
|
||||
197
atlas_rag/vectorstore/embedding_model.py
Normal file
197
atlas_rag/vectorstore/embedding_model.py
Normal file
@ -0,0 +1,197 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModel
|
||||
import torch.nn.functional as F
|
||||
from abc import ABC, abstractmethod
|
||||
import csv
|
||||
|
||||
class BaseEmbeddingModel(ABC):
|
||||
def __init__(self, sentence_encoder):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, query, **kwargs):
|
||||
"""Abstract method to encode queries."""
|
||||
pass
|
||||
|
||||
def compute_kg_embedding(self, node_csv_without_emb, node_csv_file, edge_csv_without_emb, edge_csv_file, text_node_csv_without_emb, text_node_csv, **kwargs):
|
||||
with open(node_csv_without_emb, "r") as csvfile_node:
|
||||
with open(node_csv_file, "w", newline='') as csvfile_node_emb:
|
||||
reader_node = csv.reader(csvfile_node)
|
||||
|
||||
# the reader has [name:ID,type,concepts,synsets,:LABEL]
|
||||
writer_node = csv.writer(csvfile_node_emb)
|
||||
writer_node.writerow(["name:ID", "type", "file_id", "concepts", "synsets", "embedding:STRING", ":LABEL"])
|
||||
|
||||
# the encoding will be processed in batch of 2048
|
||||
batch_size = kwargs.get('batch_size', 2048)
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
for row in reader_node:
|
||||
if row[0] == "name:ID":
|
||||
continue
|
||||
batch_nodes.append(row[0])
|
||||
batch_rows.append(row)
|
||||
|
||||
if len(batch_nodes) == batch_size:
|
||||
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||||
for row in batch_rows:
|
||||
|
||||
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||||
writer_node.writerow(new_row)
|
||||
|
||||
|
||||
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_nodes) > 0:
|
||||
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||||
writer_node.writerow(new_row)
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
|
||||
with open(edge_csv_without_emb, "r") as csvfile_edge:
|
||||
with open(edge_csv_file, "w", newline='') as csvfile_edge_emb:
|
||||
reader_edge = csv.reader(csvfile_edge)
|
||||
# [":START_ID",":END_ID","relation","concepts","synsets",":TYPE"]
|
||||
writer_edge = csv.writer(csvfile_edge_emb)
|
||||
writer_edge.writerow([":START_ID", ":END_ID", "relation", "file_id", "concepts", "synsets", "embedding:STRING", ":TYPE"])
|
||||
|
||||
# the encoding will be processed in batch of 4096
|
||||
batch_size = 2048
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
for row in reader_edge:
|
||||
if row[0] == ":START_ID":
|
||||
continue
|
||||
batch_edges.append(" ".join([row[0], row[2], row[1]]))
|
||||
batch_rows.append(row)
|
||||
|
||||
if len(batch_edges) == batch_size:
|
||||
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||||
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||||
writer_edge.writerow(new_row)
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_edges) > 0:
|
||||
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||||
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||||
writer_edge.writerow(new_row)
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
|
||||
|
||||
with open(text_node_csv_without_emb, "r") as csvfile_text_node:
|
||||
with open(text_node_csv, "w", newline='') as csvfile_text_node_emb:
|
||||
reader_text_node = csv.reader(csvfile_text_node)
|
||||
# [text_id:ID,original_text,:LABEL]
|
||||
writer_text_node = csv.writer(csvfile_text_node_emb)
|
||||
|
||||
writer_text_node.writerow(["text_id:ID", "original_text", ":LABEL", "embedding:STRING"])
|
||||
|
||||
# the encoding will be processed in batch of 2048
|
||||
batch_size = 2048
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
for row in reader_text_node:
|
||||
if row[0] == "text_id:ID":
|
||||
continue
|
||||
|
||||
batch_text_nodes.append(row[1])
|
||||
batch_rows.append(row)
|
||||
if len(batch_text_nodes) == batch_size:
|
||||
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||||
for row in batch_rows:
|
||||
embedding = text_node_embedding_dict[row[1]].tolist()
|
||||
new_row = [row[0], row[1], row[2], embedding]
|
||||
writer_text_node.writerow(new_row)
|
||||
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_text_nodes) > 0:
|
||||
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||||
for row in batch_rows:
|
||||
embedding = text_node_embedding_dict[row[1]].tolist()
|
||||
new_row = [row[0], row[1], row[2], embedding]
|
||||
|
||||
writer_text_node.writerow(new_row)
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
class NvEmbed(BaseEmbeddingModel):
|
||||
def __init__(self, sentence_encoder: SentenceTransformer | AutoModel):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
def add_eos(self, input_examples):
|
||||
"""Add EOS token to input examples."""
|
||||
if self.sentence_encoder.tokenizer.eos_token is not None:
|
||||
return [input_example + self.sentence_encoder.tokenizer.eos_token for input_example in input_examples]
|
||||
else:
|
||||
return input_examples
|
||||
|
||||
def encode(self, query, query_type=None, **kwargs):
|
||||
"""
|
||||
Encode the query into embeddings.
|
||||
|
||||
Args:
|
||||
query: Input text or list of texts.
|
||||
query_type: Type of query (e.g., 'passage', 'entity', 'edge', 'fill_in_edge', 'search').
|
||||
**kwargs: Additional arguments (e.g., normalize_embeddings).
|
||||
|
||||
Returns:
|
||||
Embeddings as a NumPy array.
|
||||
"""
|
||||
normalize_embeddings = kwargs.get('normalize_embeddings', True)
|
||||
|
||||
# Define prompt prefixes based on query type
|
||||
prompt_prefixes = {
|
||||
'passage': 'Given a question, retrieve relevant documents that best answer the question.',
|
||||
'entity': 'Given a question, retrieve relevant phrases that are mentioned in this question.',
|
||||
'edge': 'Given a question, retrieve relevant triplet facts that matches this question.',
|
||||
'fill_in_edge': 'Given a triples with only head and relation, retrieve relevant triplet facts that best fill the atomic query.'
|
||||
}
|
||||
|
||||
if query_type in prompt_prefixes:
|
||||
prompt_prefix = prompt_prefixes[query_type]
|
||||
query_prefix = f"Instruct: {prompt_prefix}\nQuery: "
|
||||
else:
|
||||
query_prefix = None
|
||||
|
||||
# Encode the query
|
||||
if isinstance(self.sentence_encoder, SentenceTransformer):
|
||||
if query_prefix:
|
||||
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), prompt=query_prefix, **kwargs)
|
||||
else:
|
||||
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), **kwargs)
|
||||
elif isinstance(self.sentence_encoder, AutoModel):
|
||||
if query_prefix:
|
||||
query_embeddings = self.sentence_encoder.encode(query, instruction=query_prefix, max_length = 32768, **kwargs)
|
||||
else:
|
||||
query_embeddings = self.sentence_encoder.encode(query, max_length = 32768, **kwargs)
|
||||
|
||||
# Normalize embeddings if required
|
||||
if normalize_embeddings:
|
||||
query_embeddings = F.normalize(query_embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
|
||||
# Move to CPU and convert to NumPy
|
||||
return query_embeddings
|
||||
|
||||
class SentenceEmbedding(BaseEmbeddingModel):
|
||||
def __init__(self,sentence_encoder:SentenceTransformer):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
def encode(self, query, **kwargs):
|
||||
return self.sentence_encoder.encode(query, **kwargs)
|
||||
Reference in New Issue
Block a user