first commit
This commit is contained in:
1
AIEC-RAG/atlas_rag/vectorstore/__init__.py
Normal file
1
AIEC-RAG/atlas_rag/vectorstore/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .create_graph_index import create_embeddings_and_index
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
177
AIEC-RAG/atlas_rag/vectorstore/create_graph_index.py
Normal file
177
AIEC-RAG/atlas_rag/vectorstore/create_graph_index.py
Normal file
@ -0,0 +1,177 @@
|
||||
import os
|
||||
import pickle
|
||||
import networkx as nx
|
||||
from tqdm import tqdm
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def compute_graph_embeddings(node_list, edge_list_string, sentence_encoder: BaseEmbeddingModel, batch_size=40, normalize_embeddings: bool = False):
|
||||
# Encode in batches
|
||||
node_embeddings = []
|
||||
for i in tqdm(range(0, len(node_list), batch_size), desc="Encoding nodes"):
|
||||
batch = node_list[i:i + batch_size]
|
||||
node_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
|
||||
|
||||
edge_embeddings = []
|
||||
for i in tqdm(range(0, len(edge_list_string), batch_size), desc="Encoding edges"):
|
||||
batch = edge_list_string[i:i + batch_size]
|
||||
edge_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
|
||||
|
||||
return node_embeddings, edge_embeddings
|
||||
|
||||
def build_faiss_index(embeddings):
|
||||
dimension = len(embeddings[0])
|
||||
|
||||
faiss_index = faiss.IndexHNSWFlat(dimension, 64, faiss.METRIC_INNER_PRODUCT)
|
||||
X = np.array(embeddings).astype('float32')
|
||||
|
||||
# normalize the vectors
|
||||
faiss.normalize_L2(X)
|
||||
|
||||
# batched add
|
||||
for i in tqdm(range(0, X.shape[0], 32)):
|
||||
faiss_index.add(X[i:i+32])
|
||||
return faiss_index
|
||||
|
||||
def compute_text_embeddings(text_list, sentence_encoder: BaseEmbeddingModel, batch_size = 40, normalize_embeddings: bool = False):
|
||||
"""Separated text embedding computation"""
|
||||
text_embeddings = []
|
||||
|
||||
for i in tqdm(range(0, len(text_list), batch_size), desc="Encoding texts"):
|
||||
batch = text_list[i:i + batch_size]
|
||||
embeddings = sentence_encoder.encode(batch, normalize_embeddings=normalize_embeddings)
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
embeddings = embeddings.cpu().numpy()
|
||||
text_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings))
|
||||
return text_embeddings
|
||||
|
||||
def create_embeddings_and_index(sentence_encoder, model_name: str, working_directory: str, keyword: str, include_events: bool, include_concept: bool,
|
||||
normalize_embeddings: bool = True,
|
||||
text_batch_size = 40,
|
||||
node_and_edge_batch_size = 256):
|
||||
# Extract the last part of the encoder_model_name for simplified reference
|
||||
encoder_model_name = model_name.split('/')[-1]
|
||||
|
||||
print(f"Using encoder model: {encoder_model_name}")
|
||||
graph_dir = f"{working_directory}/kg_graphml/{keyword}_graph.graphml"
|
||||
if not os.path.exists(graph_dir):
|
||||
raise FileNotFoundError(f"Graph file {graph_dir} does not exist. Please check the path or generate the graph first.")
|
||||
|
||||
node_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_faiss.index"
|
||||
node_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_node_list.pkl"
|
||||
edge_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_faiss.index"
|
||||
edge_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_edge_list.pkl"
|
||||
node_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_embeddings.pkl"
|
||||
edge_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_embeddings.pkl"
|
||||
text_embeddings_path = f"{working_directory}/precompute/{keyword}_{encoder_model_name}_text_embeddings.pkl"
|
||||
text_index_path = f"{working_directory}/precompute/{keyword}_text_faiss.index"
|
||||
original_text_list_path = f"{working_directory}/precompute/{keyword}_text_list.pkl"
|
||||
original_text_dict_with_node_id_path = f"{working_directory}/precompute/{keyword}_original_text_dict_with_node_id.pkl"
|
||||
|
||||
if not os.path.exists(f"{working_directory}/precompute"):
|
||||
os.makedirs(f"{working_directory}/precompute", exist_ok=True)
|
||||
|
||||
print(f"Loading graph from {graph_dir}")
|
||||
with open(graph_dir, "rb") as f:
|
||||
KG: nx.DiGraph = nx.read_graphml(f)
|
||||
|
||||
node_list = list(KG.nodes)
|
||||
text_list = [node for node in tqdm(node_list) if "passage" in KG.nodes[node]["type"]]
|
||||
|
||||
if not include_events and not include_concept:
|
||||
node_list = [node for node in tqdm(node_list) if "entity" in KG.nodes[node]["type"]]
|
||||
elif include_events and not include_concept:
|
||||
node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]]
|
||||
elif include_events and include_concept:
|
||||
node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "concept" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]]
|
||||
else:
|
||||
raise ValueError("Invalid combination of include_events and include_concept")
|
||||
|
||||
edge_list = list(KG.edges)
|
||||
node_set = set(node_list)
|
||||
node_list_string = [KG.nodes[node]["id"] for node in node_list]
|
||||
|
||||
# Filter edges based on node list
|
||||
edge_list_index = [i for i, edge in tqdm(enumerate(edge_list)) if edge[0] in node_set and edge[1] in node_set]
|
||||
edge_list = [edge_list[i] for i in edge_list_index]
|
||||
edge_list_string = [f"{KG.nodes[edge[0]]['id']} {KG.edges[edge]['relation']} {KG.nodes[edge[1]]['id']}" for edge in edge_list]
|
||||
|
||||
original_text_list = []
|
||||
original_text_dict_with_node_id = {}
|
||||
for text_node in text_list:
|
||||
text = KG.nodes[text_node]["id"].strip()
|
||||
original_text_list.append(text)
|
||||
original_text_dict_with_node_id[text_node] = text
|
||||
|
||||
assert len(original_text_list) == len(original_text_dict_with_node_id)
|
||||
|
||||
with open(original_text_list_path, "wb") as f:
|
||||
pickle.dump(original_text_list, f)
|
||||
with open(original_text_dict_with_node_id_path, "wb") as f:
|
||||
pickle.dump(original_text_dict_with_node_id, f)
|
||||
|
||||
if not os.path.exists(text_index_path) or not os.path.exists(text_embeddings_path):
|
||||
print("Computing text embeddings...")
|
||||
text_embeddings = compute_text_embeddings(original_text_list, sentence_encoder, text_batch_size, normalize_embeddings)
|
||||
text_faiss_index = build_faiss_index(text_embeddings)
|
||||
faiss.write_index(text_faiss_index, text_index_path)
|
||||
with open(text_embeddings_path, "wb") as f:
|
||||
pickle.dump(text_embeddings, f)
|
||||
else:
|
||||
print("Text embeddings already computed.")
|
||||
with open(text_embeddings_path, "rb") as f:
|
||||
text_embeddings = pickle.load(f)
|
||||
text_faiss_index = faiss.read_index(text_index_path)
|
||||
|
||||
if not os.path.exists(node_embeddings_path) or not os.path.exists(edge_embeddings_path):
|
||||
print("Node and edge embeddings not found, computing...")
|
||||
node_embeddings, edge_embeddings = compute_graph_embeddings(node_list_string, edge_list_string, sentence_encoder, node_and_edge_batch_size, normalize_embeddings=normalize_embeddings) # Assumes this function is defined
|
||||
else:
|
||||
with open(node_embeddings_path, "rb") as f:
|
||||
node_embeddings = pickle.load(f)
|
||||
with open(edge_embeddings_path, "rb") as f:
|
||||
edge_embeddings = pickle.load(f)
|
||||
print("Graph embeddings already computed")
|
||||
|
||||
if not os.path.exists(node_index_path):
|
||||
node_faiss_index = build_faiss_index(node_embeddings)
|
||||
faiss.write_index(node_faiss_index, node_index_path)
|
||||
else:
|
||||
node_faiss_index = faiss.read_index(node_index_path)
|
||||
|
||||
if not os.path.exists(edge_index_path):
|
||||
edge_faiss_index = build_faiss_index(edge_embeddings)
|
||||
faiss.write_index(edge_faiss_index, edge_index_path)
|
||||
else:
|
||||
edge_faiss_index = faiss.read_index(edge_index_path)
|
||||
|
||||
if not os.path.exists(node_embeddings_path):
|
||||
with open(node_embeddings_path, "wb") as f:
|
||||
pickle.dump(node_embeddings, f)
|
||||
|
||||
if not os.path.exists(edge_embeddings_path):
|
||||
with open(edge_embeddings_path, "wb") as f:
|
||||
pickle.dump(edge_embeddings, f)
|
||||
|
||||
with open(node_list_path, "wb") as f:
|
||||
pickle.dump(node_list, f)
|
||||
|
||||
with open(edge_list_path, "wb") as f:
|
||||
pickle.dump(edge_list, f)
|
||||
|
||||
print("Node and edge embeddings already computed.")
|
||||
# Return all required indices, embeddings, and lists
|
||||
return {
|
||||
"KG": KG,
|
||||
"node_faiss_index": node_faiss_index,
|
||||
"edge_faiss_index": edge_faiss_index,
|
||||
"text_faiss_index": text_faiss_index,
|
||||
"node_embeddings": node_embeddings,
|
||||
"edge_embeddings": edge_embeddings,
|
||||
"text_embeddings": text_embeddings,
|
||||
"node_list": node_list,
|
||||
"edge_list": edge_list,
|
||||
"text_dict": original_text_dict_with_node_id,
|
||||
}
|
||||
259
AIEC-RAG/atlas_rag/vectorstore/create_neo4j_index.py
Normal file
259
AIEC-RAG/atlas_rag/vectorstore/create_neo4j_index.py
Normal file
@ -0,0 +1,259 @@
|
||||
import faiss
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_to_npy import convert_csv_to_npy
|
||||
def create_faiss_index(output_directory, filename_pattern, index_type="HNSW,Flat"):
|
||||
"""
|
||||
Create faiss index for the graph, for index type, see https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
|
||||
|
||||
"IVF65536_HNSW32,Flat" for 1M to 10M nodes
|
||||
|
||||
"HNSW,Flat" for toy dataset
|
||||
|
||||
"""
|
||||
# Convert csv to npy
|
||||
convert_csv_to_npy(
|
||||
csv_path=f"{output_directory}/triples_csv/triple_nodes_{filename_pattern}_from_json_with_emb.csv",
|
||||
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
convert_csv_to_npy(
|
||||
csv_path=f"{output_directory}/triples_csv/text_nodes_{filename_pattern}_from_json_with_emb.csv",
|
||||
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
convert_csv_to_npy(
|
||||
csv_path=f"{output_directory}/triples_csv/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.csv",
|
||||
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
|
||||
)
|
||||
|
||||
# cannot avoid loading into memory when training
|
||||
# simply try load all to train
|
||||
def build_faiss_from_npy(index_type, index_path, npy_path):
|
||||
# check npy size.
|
||||
# shapes = []
|
||||
start_time = time.time()
|
||||
# with open(npy_path, "rb") as f:
|
||||
# while True:
|
||||
# try:
|
||||
# array = np.load(f)
|
||||
# shapes.append(array.shape)
|
||||
# except Exception as e:
|
||||
# print(f"Stopped loading due to: {str(e)}")
|
||||
# break
|
||||
# if shapes:
|
||||
# total_rows = sum(shape[0] for shape in shapes)
|
||||
# dimension = shapes[0][1]
|
||||
# print(f"Total embeddings in {npy_path}\n {total_rows}, Dimension: {dimension}")
|
||||
# minilm is 32
|
||||
# get the dimension from the npy file
|
||||
with open(npy_path, "rb") as f:
|
||||
array = np.load(f)
|
||||
dimension = array.shape[1]
|
||||
print(f"Dimension: {dimension}")
|
||||
index = faiss.index_factory(dimension, index_type, faiss.METRIC_INNER_PRODUCT)
|
||||
|
||||
if index_type.startswith("IVF"):
|
||||
index_ivf = faiss.extract_index_ivf(index)
|
||||
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
|
||||
index_ivf.clustering_index = clustering_index
|
||||
|
||||
# Load data to match the training samples size.
|
||||
# Done by random picking indexes from shapes and check if the sum of the indexes is over the sample size or not.
|
||||
# If yes then read them and start training, skip the np.load part for non chosen indexes
|
||||
|
||||
# selected_indices = set()
|
||||
# possible_indices = list(range(len(shapes)))
|
||||
# selected_training_samples = 0
|
||||
# while selected_training_samples < max_training_samples and possible_indices:
|
||||
# idx = random.choice(possible_indices)
|
||||
# selected_indices.add(idx)
|
||||
# selected_training_samples += shapes[idx][0]
|
||||
# possible_indices.remove(idx)
|
||||
# print(f"Selected total: {selected_training_samples} samples for training")
|
||||
|
||||
xt = []
|
||||
current_index = 0
|
||||
with open(npy_path, "rb") as f:
|
||||
while True:
|
||||
try:
|
||||
# array = np.load(f)
|
||||
# if current_index in selected_indices:
|
||||
array = np.load(f)
|
||||
# faiss.normalize_L2(array)
|
||||
xt.append(array)
|
||||
# current_index += 1
|
||||
except Exception as e:
|
||||
logging.info(f"Stopped loading due to: {str(e)}")
|
||||
break
|
||||
if xt:
|
||||
xt = np.vstack(xt)
|
||||
logging.info(f"Loading time: {time.time() - start_time:.2f} seconds")
|
||||
start_time = time.time()
|
||||
index.train(xt)
|
||||
end_time = time.time()
|
||||
logging.info(f"Training time: {end_time - start_time:.2f} seconds")
|
||||
del xt
|
||||
start_time = time.time()
|
||||
with open(npy_path, "rb") as f:
|
||||
while True:
|
||||
try:
|
||||
array = np.load(f)
|
||||
# faiss.normalize_L2(array)
|
||||
index.add(array)
|
||||
except Exception as e:
|
||||
logging.info(f"Stopped loading due to: {str(e)}")
|
||||
break
|
||||
logging.info(f"Adding time: {time.time() - start_time:.2f} seconds")
|
||||
|
||||
# Convert the GPU index to a CPU index for saving
|
||||
index = faiss.index_gpu_to_cpu(index)
|
||||
# Save the CPU index to a file
|
||||
faiss.write_index(index, index_path)
|
||||
|
||||
def train_and_write_indexes(keyword, npy_dir="./import"):
|
||||
keyword_to_paths = {
|
||||
'cc_en': {
|
||||
'npy':{
|
||||
'node': f"{npy_dir}/triple_nodes_cc_en_from_json_2.npy",
|
||||
# 'edge': f"{npy_dir}/triple_edges_cc_en_from_json_2.npy",
|
||||
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb_2.npy",
|
||||
},
|
||||
'index':{
|
||||
'node': f"{npy_dir}/triple_nodes_cc_en_from_json_non_norm.index",
|
||||
# 'edge': f"{npy_dir}/triple_edges_cc_en_from_json_non_norm.index",
|
||||
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb_non_norm.index",
|
||||
},
|
||||
'index_type':{
|
||||
'node': "IVF1048576_HNSW32,Flat",
|
||||
# 'edge': "IVF1048576_HNSW32,Flat",
|
||||
'text': "IVF262144_HNSW32,Flat",
|
||||
},
|
||||
'csv':{
|
||||
'node': f"{npy_dir}/triple_nodes_cc_en_from_json.csv",
|
||||
# 'edge': ff"{npy_dir}/triple_edges_cc_en_from_json.csv",
|
||||
'text': f"{npy_dir}/text_nodes_cc_en_from_json_with_emb.csv",
|
||||
}
|
||||
},
|
||||
'pes2o_abstract': {
|
||||
'npy':{
|
||||
'node': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json.npy",
|
||||
# 'edge': f"{npy_dir}/triple_edges_pes2o_abstract_from_json.npy",
|
||||
'text': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb.npy",
|
||||
},
|
||||
'index':{
|
||||
'node': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json_non_norm.index",
|
||||
# 'edge': f"{npy_dir}/triple_edges_pes2o_abstract_from_json_non_norm.index",
|
||||
'text': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb_non_norm.index",
|
||||
},
|
||||
'index_type':{
|
||||
'node': "IVF1048576_HNSW32,Flat",
|
||||
# 'edge': "IVF1048576_HNSW32,Flat",
|
||||
'text': "IVF65536_HNSW32,Flat",
|
||||
},
|
||||
'csv':{
|
||||
'node_csv': f"{npy_dir}/triple_nodes_pes2o_abstract_from_json.csv",
|
||||
# 'edge_csv': ff"{npy_dir}/triple_edges_pes2o_abstract_from_json.csv",
|
||||
'text_csv': f"{npy_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
|
||||
}
|
||||
},
|
||||
'en_simple_wiki_v0': {
|
||||
'npy':{
|
||||
'node': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json.npy",
|
||||
# 'edge': f"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json.npy",
|
||||
'text': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.npy",
|
||||
},
|
||||
'index':{
|
||||
'node': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json_non_norm.index",
|
||||
# 'edge': f"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json_non_norm.index",
|
||||
'text': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb_non_norm.index",
|
||||
},
|
||||
'index_type':{
|
||||
'node': "IVF1048576_HNSW32,Flat",
|
||||
# 'edge': "IVF1048576_HNSW32,Flat",
|
||||
'text': "IVF65536_HNSW32,Flat",
|
||||
},
|
||||
'csv':{
|
||||
'node_csv': f"{npy_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
|
||||
# 'edge_csv': ff"{npy_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
|
||||
'text_csv': f"{npy_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
|
||||
}
|
||||
}
|
||||
}
|
||||
emb_list = ['node', 'text'] # Add 'edge' if needed and uncomment the related path lines
|
||||
for emb in emb_list:
|
||||
npy_path = keyword_to_paths[keyword]['npy'][emb]
|
||||
index_path = keyword_to_paths[keyword]['index'][emb]
|
||||
index_type = keyword_to_paths[keyword]['index_type'][emb]
|
||||
logging.info(f"Index {index_path}, Building...")
|
||||
# For cc-en the recommended training samples is 600_000_000, for the rest we can afford to training them using all data.
|
||||
build_faiss_from_npy(index_type, index_path, npy_path)
|
||||
|
||||
# # Test the index
|
||||
# for emb in emb_list:
|
||||
# index_path = keyword_to_paths[keyword]['index'][emb]
|
||||
# print(f"Index {index_path}, Testing...")
|
||||
# test_and_search_faiss_index(index_path, keyword_to_paths[keyword]['csv'][emb])
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
x = 1
|
||||
|
||||
# keyword = "cc_en" # Replace with your actual keyword
|
||||
# logging.basicConfig(
|
||||
# filename=f'{keyword}_faiss_creation.log', # Log file
|
||||
# level=logging.INFO, # Set the logging level
|
||||
# format='%(asctime)s - %(levelname)s - %(message)s' # Log format
|
||||
# )
|
||||
|
||||
# argparser = argparse.ArgumentParser(description="Train and write FAISS indexes for LKG construction.")
|
||||
# argparser.add_argument("--npy_dir", type=str, default="./import", help="Directory containing the .npy files.")
|
||||
# argparser.add_argument("--keyword", type=str, default=keyword, help="Keyword to select the dataset.")
|
||||
|
||||
# args = argparser.parse_args()
|
||||
# keyword = args.keyword
|
||||
# npy_dir = args.npy_dir
|
||||
|
||||
# train_and_write_indexes(keyword,npy_dir)
|
||||
# index_type = "IVF65536_HNSW32,Flat"
|
||||
index_type = "HNSW,Flat"
|
||||
|
||||
output_directory = "/home/jbai/AutoSchemaKG/import/Dulce"
|
||||
filename_pattern = "Dulce"
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/text_nodes_{filename_pattern}_from_json_with_emb.npy",
|
||||
)
|
||||
|
||||
build_faiss_from_npy(
|
||||
index_type=index_type,
|
||||
index_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb_non_norm.index",
|
||||
npy_path=f"{output_directory}/vector_index/triple_edges_{filename_pattern}_from_json_with_concept_with_emb.npy",
|
||||
)
|
||||
197
AIEC-RAG/atlas_rag/vectorstore/embedding_model.py
Normal file
197
AIEC-RAG/atlas_rag/vectorstore/embedding_model.py
Normal file
@ -0,0 +1,197 @@
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModel
|
||||
import torch.nn.functional as F
|
||||
from abc import ABC, abstractmethod
|
||||
import csv
|
||||
|
||||
class BaseEmbeddingModel(ABC):
|
||||
def __init__(self, sentence_encoder):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, query, **kwargs):
|
||||
"""Abstract method to encode queries."""
|
||||
pass
|
||||
|
||||
def compute_kg_embedding(self, node_csv_without_emb, node_csv_file, edge_csv_without_emb, edge_csv_file, text_node_csv_without_emb, text_node_csv, **kwargs):
|
||||
with open(node_csv_without_emb, "r") as csvfile_node:
|
||||
with open(node_csv_file, "w", newline='') as csvfile_node_emb:
|
||||
reader_node = csv.reader(csvfile_node)
|
||||
|
||||
# the reader has [name:ID,type,concepts,synsets,:LABEL]
|
||||
writer_node = csv.writer(csvfile_node_emb)
|
||||
writer_node.writerow(["name:ID", "type", "file_id", "concepts", "synsets", "embedding:STRING", ":LABEL"])
|
||||
|
||||
# the encoding will be processed in batch of 2048
|
||||
batch_size = kwargs.get('batch_size', 2048)
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
for row in reader_node:
|
||||
if row[0] == "name:ID":
|
||||
continue
|
||||
batch_nodes.append(row[0])
|
||||
batch_rows.append(row)
|
||||
|
||||
if len(batch_nodes) == batch_size:
|
||||
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||||
for row in batch_rows:
|
||||
|
||||
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||||
writer_node.writerow(new_row)
|
||||
|
||||
|
||||
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_nodes) > 0:
|
||||
node_embeddings = self.encode(batch_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
node_embedding_dict = dict(zip(batch_nodes, node_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], "", row[2], row[3], node_embedding_dict[row[0]].tolist(), row[4]]
|
||||
writer_node.writerow(new_row)
|
||||
batch_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
|
||||
with open(edge_csv_without_emb, "r") as csvfile_edge:
|
||||
with open(edge_csv_file, "w", newline='') as csvfile_edge_emb:
|
||||
reader_edge = csv.reader(csvfile_edge)
|
||||
# [":START_ID",":END_ID","relation","concepts","synsets",":TYPE"]
|
||||
writer_edge = csv.writer(csvfile_edge_emb)
|
||||
writer_edge.writerow([":START_ID", ":END_ID", "relation", "file_id", "concepts", "synsets", "embedding:STRING", ":TYPE"])
|
||||
|
||||
# the encoding will be processed in batch of 4096
|
||||
batch_size = 2048
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
for row in reader_edge:
|
||||
if row[0] == ":START_ID":
|
||||
continue
|
||||
batch_edges.append(" ".join([row[0], row[2], row[1]]))
|
||||
batch_rows.append(row)
|
||||
|
||||
if len(batch_edges) == batch_size:
|
||||
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||||
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||||
writer_edge.writerow(new_row)
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_edges) > 0:
|
||||
edge_embeddings = self.encode(batch_edges, batch_size=batch_size, show_progress_bar=False)
|
||||
edge_embedding_dict = dict(zip(batch_edges, edge_embeddings))
|
||||
for row in batch_rows:
|
||||
new_row = [row[0], row[1], row[2], "", row[3], row[4], edge_embedding_dict[" ".join([row[0], row[2], row[1]])].tolist(), row[5]]
|
||||
writer_edge.writerow(new_row)
|
||||
batch_edges = []
|
||||
batch_rows = []
|
||||
|
||||
|
||||
with open(text_node_csv_without_emb, "r") as csvfile_text_node:
|
||||
with open(text_node_csv, "w", newline='') as csvfile_text_node_emb:
|
||||
reader_text_node = csv.reader(csvfile_text_node)
|
||||
# [text_id:ID,original_text,:LABEL]
|
||||
writer_text_node = csv.writer(csvfile_text_node_emb)
|
||||
|
||||
writer_text_node.writerow(["text_id:ID", "original_text", ":LABEL", "embedding:STRING"])
|
||||
|
||||
# the encoding will be processed in batch of 2048
|
||||
batch_size = 2048
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
for row in reader_text_node:
|
||||
if row[0] == "text_id:ID":
|
||||
continue
|
||||
|
||||
batch_text_nodes.append(row[1])
|
||||
batch_rows.append(row)
|
||||
if len(batch_text_nodes) == batch_size:
|
||||
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||||
for row in batch_rows:
|
||||
embedding = text_node_embedding_dict[row[1]].tolist()
|
||||
new_row = [row[0], row[1], row[2], embedding]
|
||||
writer_text_node.writerow(new_row)
|
||||
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
if len(batch_text_nodes) > 0:
|
||||
text_node_embeddings = self.encode(batch_text_nodes, batch_size=batch_size, show_progress_bar=False)
|
||||
text_node_embedding_dict = dict(zip(batch_text_nodes, text_node_embeddings))
|
||||
for row in batch_rows:
|
||||
embedding = text_node_embedding_dict[row[1]].tolist()
|
||||
new_row = [row[0], row[1], row[2], embedding]
|
||||
|
||||
writer_text_node.writerow(new_row)
|
||||
batch_text_nodes = []
|
||||
batch_rows = []
|
||||
|
||||
class NvEmbed(BaseEmbeddingModel):
|
||||
def __init__(self, sentence_encoder: SentenceTransformer | AutoModel):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
def add_eos(self, input_examples):
|
||||
"""Add EOS token to input examples."""
|
||||
if self.sentence_encoder.tokenizer.eos_token is not None:
|
||||
return [input_example + self.sentence_encoder.tokenizer.eos_token for input_example in input_examples]
|
||||
else:
|
||||
return input_examples
|
||||
|
||||
def encode(self, query, query_type=None, **kwargs):
|
||||
"""
|
||||
Encode the query into embeddings.
|
||||
|
||||
Args:
|
||||
query: Input text or list of texts.
|
||||
query_type: Type of query (e.g., 'passage', 'entity', 'edge', 'fill_in_edge', 'search').
|
||||
**kwargs: Additional arguments (e.g., normalize_embeddings).
|
||||
|
||||
Returns:
|
||||
Embeddings as a NumPy array.
|
||||
"""
|
||||
normalize_embeddings = kwargs.get('normalize_embeddings', True)
|
||||
|
||||
# Define prompt prefixes based on query type
|
||||
prompt_prefixes = {
|
||||
'passage': 'Given a question, retrieve relevant documents that best answer the question.',
|
||||
'entity': 'Given a question, retrieve relevant phrases that are mentioned in this question.',
|
||||
'edge': 'Given a question, retrieve relevant triplet facts that matches this question.',
|
||||
'fill_in_edge': 'Given a triples with only head and relation, retrieve relevant triplet facts that best fill the atomic query.'
|
||||
}
|
||||
|
||||
if query_type in prompt_prefixes:
|
||||
prompt_prefix = prompt_prefixes[query_type]
|
||||
query_prefix = f"Instruct: {prompt_prefix}\nQuery: "
|
||||
else:
|
||||
query_prefix = None
|
||||
|
||||
# Encode the query
|
||||
if isinstance(self.sentence_encoder, SentenceTransformer):
|
||||
if query_prefix:
|
||||
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), prompt=query_prefix, **kwargs)
|
||||
else:
|
||||
query_embeddings = self.sentence_encoder.encode(self.add_eos(query), **kwargs)
|
||||
elif isinstance(self.sentence_encoder, AutoModel):
|
||||
if query_prefix:
|
||||
query_embeddings = self.sentence_encoder.encode(query, instruction=query_prefix, max_length = 32768, **kwargs)
|
||||
else:
|
||||
query_embeddings = self.sentence_encoder.encode(query, max_length = 32768, **kwargs)
|
||||
|
||||
# Normalize embeddings if required
|
||||
if normalize_embeddings:
|
||||
query_embeddings = F.normalize(query_embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
|
||||
# Move to CPU and convert to NumPy
|
||||
return query_embeddings
|
||||
|
||||
class SentenceEmbedding(BaseEmbeddingModel):
|
||||
def __init__(self,sentence_encoder:SentenceTransformer):
|
||||
self.sentence_encoder = sentence_encoder
|
||||
|
||||
def encode(self, query, **kwargs):
|
||||
return self.sentence_encoder.encode(query, **kwargs)
|
||||
Reference in New Issue
Block a user