74 lines
3.2 KiB
Python
74 lines
3.2 KiB
Python
import faiss
|
|
from neo4j import Driver
|
|
import time
|
|
from graphdatascience import GraphDataScience
|
|
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
|
|
|
def build_projection_graph(driver: GraphDataScience):
|
|
project_graph_1 = "largekgrag_graph"
|
|
is_project_graph_1_exist = False
|
|
# is_project_graph_2_exist = False
|
|
result = driver.graph.list()
|
|
for index, row in result.iterrows():
|
|
if row['graphName'] == project_graph_1:
|
|
is_project_graph_1_exist = True
|
|
# if row['graphName'] == project_graph_2:
|
|
# is_project_graph_2_exist = True
|
|
|
|
if not is_project_graph_1_exist:
|
|
start_time = time.time()
|
|
node_properties = ["Node"]
|
|
relation_projection = [ "Relation"]
|
|
result = driver.graph.project(
|
|
project_graph_1,
|
|
node_properties,
|
|
relation_projection
|
|
)
|
|
graph = driver.graph.get(project_graph_1)
|
|
print(f"Projection graph {project_graph_1} created in {time.time() - start_time:.2f} seconds")
|
|
|
|
def build_neo4j_label_index(driver: GraphDataScience):
|
|
with driver.session() as session:
|
|
index_name = f"NodeNumericIDIndex"
|
|
# Check if the index already exists
|
|
existing_indexes = session.run("SHOW INDEXES").data()
|
|
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
|
# Drop the index if it exists
|
|
if not index_exists:
|
|
start_time = time.time()
|
|
session.run(f"CREATE INDEX {index_name} FOR (n:Node) ON (n.numeric_id)")
|
|
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
|
|
|
index_name = f"TextNumericIDIndex"
|
|
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
|
if not index_exists:
|
|
start_time = time.time()
|
|
session.run(f"CREATE INDEX {index_name} FOR (t:Text) ON (t.numeric_id)")
|
|
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
|
|
|
index_name = f"EntityEventEdgeNumericIDIndex"
|
|
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
|
if not index_exists:
|
|
start_time = time.time()
|
|
session.run(f"CREATE INDEX {index_name} FOR ()-[r:Relation]-() on (r.numeric_id)")
|
|
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
|
|
|
def load_indexes(path_dict):
|
|
for key, value in path_dict.items():
|
|
if key == 'node':
|
|
node_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
|
print(f"Node index loaded from {value}")
|
|
elif key == 'edge':
|
|
edge_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
|
print(f"Edge index loaded from {value}")
|
|
elif key == 'text':
|
|
passage_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
|
print(f"Passage index loaded from {value}")
|
|
return node_index, edge_index, passage_index
|
|
|
|
def start_up_large_kg_index_graph(neo4j_driver: Driver)->BaseLargeKGRetriever:
|
|
gds_driver = GraphDataScience(neo4j_driver)
|
|
# build label index and projection graph
|
|
build_neo4j_label_index(neo4j_driver)
|
|
build_projection_graph(gds_driver)
|