import os import pickle import networkx as nx from tqdm import tqdm from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel import faiss import numpy as np import torch def compute_graph_embeddings(node_list, edge_list_string, sentence_encoder: BaseEmbeddingModel, batch_size=40, normalize_embeddings: bool = False): # Encode in batches node_embeddings = [] for i in tqdm(range(0, len(node_list), batch_size), desc="Encoding nodes"): batch = node_list[i:i + batch_size] node_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings)) edge_embeddings = [] for i in tqdm(range(0, len(edge_list_string), batch_size), desc="Encoding edges"): batch = edge_list_string[i:i + batch_size] edge_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings)) return node_embeddings, edge_embeddings def build_faiss_index(embeddings): dimension = len(embeddings[0]) faiss_index = faiss.IndexHNSWFlat(dimension, 64, faiss.METRIC_INNER_PRODUCT) X = np.array(embeddings).astype('float32') # normalize the vectors faiss.normalize_L2(X) # batched add for i in tqdm(range(0, X.shape[0], 32)): faiss_index.add(X[i:i+32]) return faiss_index def compute_text_embeddings(text_list, sentence_encoder: BaseEmbeddingModel, batch_size = 40, normalize_embeddings: bool = False): """Separated text embedding computation""" text_embeddings = [] for i in tqdm(range(0, len(text_list), batch_size), desc="Encoding texts"): batch = text_list[i:i + batch_size] embeddings = sentence_encoder.encode(batch, normalize_embeddings=normalize_embeddings) if isinstance(embeddings, torch.Tensor): embeddings = embeddings.cpu().numpy() text_embeddings.extend(sentence_encoder.encode(batch, normalize_embeddings = normalize_embeddings)) return text_embeddings def create_embeddings_and_index(sentence_encoder, model_name: str, working_directory: str, keyword: str, include_events: bool, include_concept: bool, normalize_embeddings: bool = True, text_batch_size = 40, node_and_edge_batch_size = 256): # Extract the last part of the encoder_model_name for simplified reference encoder_model_name = model_name.split('/')[-1] print(f"Using encoder model: {encoder_model_name}") graph_dir = f"{working_directory}/kg_graphml/{keyword}_graph.graphml" if not os.path.exists(graph_dir): raise FileNotFoundError(f"Graph file {graph_dir} does not exist. Please check the path or generate the graph first.") node_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_faiss.index" node_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_node_list.pkl" edge_index_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_faiss.index" edge_list_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_edge_list.pkl" node_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_node_embeddings.pkl" edge_embeddings_path = f"{working_directory}/precompute/{keyword}_event{include_events}_concept{include_concept}_{encoder_model_name}_edge_embeddings.pkl" text_embeddings_path = f"{working_directory}/precompute/{keyword}_{encoder_model_name}_text_embeddings.pkl" text_index_path = f"{working_directory}/precompute/{keyword}_text_faiss.index" original_text_list_path = f"{working_directory}/precompute/{keyword}_text_list.pkl" original_text_dict_with_node_id_path = f"{working_directory}/precompute/{keyword}_original_text_dict_with_node_id.pkl" if not os.path.exists(f"{working_directory}/precompute"): os.makedirs(f"{working_directory}/precompute", exist_ok=True) print(f"Loading graph from {graph_dir}") with open(graph_dir, "rb") as f: KG: nx.DiGraph = nx.read_graphml(f) node_list = list(KG.nodes) text_list = [node for node in tqdm(node_list) if "passage" in KG.nodes[node]["type"]] if not include_events and not include_concept: node_list = [node for node in tqdm(node_list) if "entity" in KG.nodes[node]["type"]] elif include_events and not include_concept: node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]] elif include_events and include_concept: node_list = [node for node in tqdm(node_list) if "event" in KG.nodes[node]["type"] or "concept" in KG.nodes[node]["type"] or "entity" in KG.nodes[node]["type"]] else: raise ValueError("Invalid combination of include_events and include_concept") edge_list = list(KG.edges) node_set = set(node_list) node_list_string = [KG.nodes[node]["id"] for node in node_list] # Filter edges based on node list edge_list_index = [i for i, edge in tqdm(enumerate(edge_list)) if edge[0] in node_set and edge[1] in node_set] edge_list = [edge_list[i] for i in edge_list_index] edge_list_string = [f"{KG.nodes[edge[0]]['id']} {KG.edges[edge]['relation']} {KG.nodes[edge[1]]['id']}" for edge in edge_list] original_text_list = [] original_text_dict_with_node_id = {} for text_node in text_list: text = KG.nodes[text_node]["id"].strip() original_text_list.append(text) original_text_dict_with_node_id[text_node] = text assert len(original_text_list) == len(original_text_dict_with_node_id) with open(original_text_list_path, "wb") as f: pickle.dump(original_text_list, f) with open(original_text_dict_with_node_id_path, "wb") as f: pickle.dump(original_text_dict_with_node_id, f) if not os.path.exists(text_index_path) or not os.path.exists(text_embeddings_path): print("Computing text embeddings...") text_embeddings = compute_text_embeddings(original_text_list, sentence_encoder, text_batch_size, normalize_embeddings) text_faiss_index = build_faiss_index(text_embeddings) faiss.write_index(text_faiss_index, text_index_path) with open(text_embeddings_path, "wb") as f: pickle.dump(text_embeddings, f) else: print("Text embeddings already computed.") with open(text_embeddings_path, "rb") as f: text_embeddings = pickle.load(f) text_faiss_index = faiss.read_index(text_index_path) if not os.path.exists(node_embeddings_path) or not os.path.exists(edge_embeddings_path): print("Node and edge embeddings not found, computing...") node_embeddings, edge_embeddings = compute_graph_embeddings(node_list_string, edge_list_string, sentence_encoder, node_and_edge_batch_size, normalize_embeddings=normalize_embeddings) # Assumes this function is defined else: with open(node_embeddings_path, "rb") as f: node_embeddings = pickle.load(f) with open(edge_embeddings_path, "rb") as f: edge_embeddings = pickle.load(f) print("Graph embeddings already computed") if not os.path.exists(node_index_path): node_faiss_index = build_faiss_index(node_embeddings) faiss.write_index(node_faiss_index, node_index_path) else: node_faiss_index = faiss.read_index(node_index_path) if not os.path.exists(edge_index_path): edge_faiss_index = build_faiss_index(edge_embeddings) faiss.write_index(edge_faiss_index, edge_index_path) else: edge_faiss_index = faiss.read_index(edge_index_path) if not os.path.exists(node_embeddings_path): with open(node_embeddings_path, "wb") as f: pickle.dump(node_embeddings, f) if not os.path.exists(edge_embeddings_path): with open(edge_embeddings_path, "wb") as f: pickle.dump(edge_embeddings, f) with open(node_list_path, "wb") as f: pickle.dump(node_list, f) with open(edge_list_path, "wb") as f: pickle.dump(edge_list, f) print("Node and edge embeddings already computed.") # Return all required indices, embeddings, and lists return { "KG": KG, "node_faiss_index": node_faiss_index, "edge_faiss_index": edge_faiss_index, "text_faiss_index": text_faiss_index, "node_embeddings": node_embeddings, "edge_embeddings": edge_embeddings, "text_embeddings": text_embeddings, "node_list": node_list, "edge_list": edge_list, "text_dict": original_text_dict_with_node_id, }