first commit

This commit is contained in:
闫旭隆
2025-09-24 09:29:12 +08:00
parent 6339cdebb9
commit 2308536f66
360 changed files with 136381 additions and 0 deletions

View File

View File

@ -0,0 +1,282 @@
from tqdm import tqdm
import random
import logging
import csv
import os
import hashlib
import re
from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.kg_construction.triple_config import ProcessingConfig
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import get_node_id
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import CONCEPT_INSTRUCTIONS
import pickle
# Increase the field size limit
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
def build_batch_data(sessions, batch_size):
batched_sessions = []
for i in range(0, len(sessions), batch_size):
batched_sessions.append(sessions[i:i+batch_size])
return batched_sessions
# Function to compute a hash ID from text
def compute_hash_id(text):
# Use SHA-256 to generate a hash
hash_object = hashlib.sha256(text.encode('utf-8'))
return hash_object.hexdigest() # Return hash as a hex string
def convert_attribute(value):
""" Convert attributes to GDS-compatible types. """
if isinstance(value, list):
return [str(v) for v in value]
elif isinstance(value, (int, float)):
return value
else:
return str(value)
def clean_text(text):
# remove NUL as well
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
new_text = new_text.replace("\x00", "")
new_text = re.sub(r'\s+', ' ', new_text).strip()
return new_text
def remove_NUL(text):
return text.replace("\x00", "")
def build_batched_events(all_node_list, batch_size):
"The types are in Entity Event Relation"
event_nodes = [node[0] for node in all_node_list if node[1].lower() == "event"]
batched_events = []
for i in range(0, len(event_nodes), batch_size):
batched_events.append(event_nodes[i:i+batch_size])
return batched_events
def build_batched_entities(all_node_list, batch_size):
entity_nodes = [node[0] for node in all_node_list if node[1].lower() == "entity"]
batched_entities = []
for i in range(0, len(entity_nodes), batch_size):
batched_entities.append(entity_nodes[i:i+batch_size])
return batched_entities
def build_batched_relations(all_node_list, batch_size):
relations = [node[0] for node in all_node_list if node[1].lower() == "relation"]
# relations = list(set(relations))
batched_relations = []
for i in range(0, len(relations), batch_size):
batched_relations.append(relations[i:i+batch_size])
return batched_relations
def batched_inference(model:LLMGenerator, inputs, record=False, **kwargs):
responses = model.generate_response(inputs, return_text_only = not record, **kwargs)
answers = []
if record:
text_responses = [response[0] for response in responses]
usages = [response[1] for response in responses]
else:
text_responses = responses
for i in range(len(text_responses)):
answer = text_responses[i]
answers.append([x.strip().lower() for x in answer.split(",")])
if record:
return answers, usages
else:
return answers
def load_data_with_shard(input_file, shard_idx, num_shards):
with open(input_file, "r") as f:
csv_reader = list(csv.reader(f))
# data = csv_reader
data = csv_reader[1:]
# Random shuffle the data before splitting into shards
random.shuffle(data)
total_lines = len(data)
lines_per_shard = (total_lines + num_shards - 1) // num_shards
start_idx = shard_idx * lines_per_shard
end_idx = min((shard_idx + 1) * lines_per_shard, total_lines)
return data[start_idx:end_idx]
def generate_concept(model: LLMGenerator,
input_file = 'processed_data/triples_csv',
output_folder = 'processed_data/triples_conceptualized',
output_file = 'output.json',
logging_file = 'processed_data/logging.txt',
config:ProcessingConfig=None,
batch_size=32,
shard=0,
num_shards=1,
**kwargs):
log_dir = os.path.dirname(logging_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir)
# Create the log file if it doesn't exist
if not os.path.exists(logging_file):
open(logging_file, 'w').close()
language = kwargs.get('language', 'en')
record = kwargs.get('record', False)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(logging_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logging.getLogger().addHandler(file_handler)
with open(f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl", "rb") as f:
temp_kg = pickle.load(f)
# read data
if not os.path.exists(output_folder):
os.makedirs(output_folder)
all_missing_nodes = load_data_with_shard(
input_file,
shard_idx=shard,
num_shards=num_shards
)
batched_events = build_batched_events(all_missing_nodes, batch_size)
batched_entities = build_batched_entities(all_missing_nodes, batch_size)
batched_relations = build_batched_relations(all_missing_nodes, batch_size)
all_batches = []
all_batches.extend(('event', batch) for batch in batched_events)
all_batches.extend(('entity', batch) for batch in batched_entities)
all_batches.extend(('relation', batch) for batch in batched_relations)
print("all_batches", len(all_batches))
output_file = output_folder + f"/{output_file.rsplit('.', 1)[0]}_shard_{shard}.csv"
with open(output_file, "w", newline='') as file:
csv_writer = csv.writer(file)
csv_writer.writerow(["node", "conceptualized_node", "node_type"])
# for batch_type, batch in tqdm(all_batches, total=total_batches, desc="Generating concepts"):
# don't use tqdm for now
for batch_type, batch in tqdm(all_batches, desc="Shard_{}".format(shard)):
# print("batch_type", batch_type)
# print("batch", batch)
replace_context_token = None
if batch_type == 'event':
template = CONCEPT_INSTRUCTIONS[language]['event']
node_type = 'event'
replace_token = '[EVENT]'
elif batch_type == 'entity':
template = CONCEPT_INSTRUCTIONS[language]['entity']
node_type = 'entity'
replace_token = '[ENTITY]'
replace_context_token = '[CONTEXT]'
elif batch_type == 'relation':
template = CONCEPT_INSTRUCTIONS[language]['relation']
node_type = 'relation'
replace_token = '[RELATION]'
inputs = []
for node in batch:
# sample node from given node and replace context token.
if replace_context_token:
node_id = get_node_id(node)
entity_predecessors = list(temp_kg.predecessors(node_id))
entity_successors = list(temp_kg.successors(node_id))
context = ""
if len(entity_predecessors) > 0:
random_two_neighbors = random.sample(entity_predecessors, min(1, len(entity_predecessors)))
context += ", ".join([f"{temp_kg.nodes[neighbor]['id']} {temp_kg[neighbor][node_id]['relation']}" for neighbor in random_two_neighbors])
if len(entity_successors) > 0:
random_two_neighbors = random.sample(entity_successors, min(1, len(entity_successors)))
context += ", ".join([f"{temp_kg[node_id][neighbor]['relation']} {temp_kg.nodes[neighbor]['id']}" for neighbor in random_two_neighbors])
prompt = template.replace(replace_token, node).replace(replace_context_token, context)
else:
prompt = template.replace(replace_token, node)
constructed_input = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": f"{prompt}"},
]
inputs.append(constructed_input)
try:
# print("inputs", inputs)
if record:
# If recording, we will get both answers and responses
answers, usages = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
else:
answers = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
usages = None
# print("answers", answers)
except Exception as e:
logging.error(f"Error processing {batch_type} batch: {e}")
raise e
# try:
# answers = batched_inference(llm, sampling_params, inputs)
# except Exception as e:
# logging.error(f"Error processing {batch_type} batch: {e}")
# continue
for i,(node, answer) in enumerate(zip(batch, answers)):
# print(node, answer, node_type)
if usages is not None:
logging.info(f"Usage log: Node {node}, completion_usage: {usages[i]}")
csv_writer.writerow([node, ", ".join(answer), node_type])
file.flush()
# count unique conceptualized nodes
conceptualized_nodes = []
conceptualized_events = []
conceptualized_entities = []
conceptualized_relations = []
with open(output_file, "r") as file:
reader = csv.reader(file)
next(reader)
for row in reader:
conceptualized_nodes.extend(row[1].split(","))
if row[2] == "event":
conceptualized_events.extend(row[1].split(","))
elif row[2] == "entity":
conceptualized_entities.extend(row[1].split(","))
elif row[2] == "relation":
conceptualized_relations.extend(row[1].split(","))
conceptualized_nodes = [x.strip() for x in conceptualized_nodes]
conceptualized_events = [x.strip() for x in conceptualized_events]
conceptualized_entities = [x.strip() for x in conceptualized_entities]
conceptualized_relations = [x.strip() for x in conceptualized_relations]
unique_conceptualized_nodes = list(set(conceptualized_nodes))
unique_conceptualized_events = list(set(conceptualized_events))
unique_conceptualized_entities = list(set(conceptualized_entities))
unique_conceptualized_relations = list(set(conceptualized_relations))
print(f"Number of unique conceptualized nodes: {len(unique_conceptualized_nodes)}")
print(f"Number of unique conceptualized events: {len(unique_conceptualized_events)}")
print(f"Number of unique conceptualized entities: {len(unique_conceptualized_entities)}")
print(f"Number of unique conceptualized relations: {len(unique_conceptualized_relations)}")
return

View File

@ -0,0 +1,153 @@
import ast
import uuid
import csv
from tqdm import tqdm
import hashlib
import os
def generate_uuid():
"""Generate a random UUID"""
return str(uuid.uuid4())
def parse_concepts(s):
"""Parse concepts field and filter empty values"""
try:
parsed = ast.literal_eval(s) if s and s != '[]' else []
return [c.strip() for c in parsed if c.strip()]
except:
return []
# Function to compute a hash ID from text
def compute_hash_id(text):
# Use SHA-256 to generate a hash
text = text + '_concept'
hash_object = hashlib.sha256(text.encode('utf-8'))
return hash_object.hexdigest() # Return hash as a hex string
def all_concept_triples_csv_to_csv(node_file, edge_file, concepts_file, output_node_file, output_edge_file, output_full_concept_triple_edges):
# to deal add output the concepts nodes, edges, and new full_triple_edges,
# we need to read the concepts maps to the memory, as it is usually not too large.
# Then we need to iterate over the triple nodes to create concept edges
# Finally we iterate over the triple edges to create the full_triple_edges
# Read missing concept
# relation_concepts_mapping = {}
# all_missing_concepts = []
# check if all output directories exist
output_dir = os.path.dirname(output_node_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_dir = os.path.dirname(output_edge_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_dir = os.path.dirname(output_full_concept_triple_edges)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
node_to_concepts = {}
relation_to_concepts = {}
all_concepts = set()
with open(concepts_file, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
# Load missing concepts list
print("Loading concepts...")
for row in tqdm(reader):
if row['node_type'] == 'relation':
relation = row['node']
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
if relation not in relation_to_concepts:
relation_to_concepts[relation] = concepts
else:
relation_to_concepts[relation].extend(concepts)
relation_to_concepts[relation] = list(set(relation_to_concepts[relation]))
else:
node = row['node']
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
if node not in node_to_concepts:
node_to_concepts[node] = concepts
else:
node_to_concepts[node].extend(concepts)
node_to_concepts[node] = list(set(node_to_concepts[node]))
print("Loading concepts done.")
print(f"Relation to concepts: {len(relation_to_concepts)}")
print(f"Node to concepts: {len(node_to_concepts)}")
# Read triple nodes and write to output concept edges files
print("Processing triple nodes...")
with open(node_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
# name:ID,type,concepts,synsets,:LABEL
header = next(reader)
with open (output_edge_file, 'w', newline='', encoding='utf-8') as f_out:
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
writer.writerow([':START_ID', ':END_ID', 'relation', ':TYPE'])
for row in tqdm(reader):
node_name = row[0]
if node_name in node_to_concepts:
for concept in node_to_concepts[node_name]:
concept_id = compute_hash_id(concept)
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
all_concepts.add(concept)
for concept in parse_concepts(row[2]):
concept_id = compute_hash_id(concept)
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
all_concepts.add(concept)
# Read the concept nodes and write to output concept nodes file
print("Processing concept nodes...")
with open (output_node_file, 'w', newline='', encoding='utf-8') as f_out:
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
writer.writerow(['concept_id:ID', 'name', ':LABEL'])
for concept in tqdm(all_concepts):
concept_id = compute_hash_id(concept)
writer.writerow([concept_id, concept, 'Concept'])
# Read triple edges and write to output full concept triple edges file
print("Processing triple edges...")
with open(edge_file, 'r', encoding='utf-8') as f:
with open(output_full_concept_triple_edges, 'w', newline='', encoding='utf-8') as f_out:
reader = csv.reader(f)
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
header = next(reader)
writer.writerow([':START_ID', ':END_ID', 'relation', 'concepts', 'synsets', ':TYPE'])
for row in tqdm(reader):
src_id = row[0]
end_id = row[1]
relation = row[2]
concepts = row[3]
synsets = row[4]
original_concepts = parse_concepts(concepts)
if relation in relation_to_concepts:
for concept in relation_to_concepts[relation]:
if concept not in original_concepts:
original_concepts.append(concept)
original_concepts = list(set(original_concepts))
writer.writerow([src_id, end_id, relation, original_concepts, synsets, 'Relation'])
return

View File

@ -0,0 +1,267 @@
import time
import uvicorn
from fastapi import FastAPI, HTTPException, Response
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from logging import Logger
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever, BaseLargeKGEdgeRetriever
from atlas_rag.kg_construction.neo4j.utils import start_up_large_kg_index_graph
from atlas_rag.llm_generator import LLMGenerator
from neo4j import Driver
from dataclasses import dataclass
import traceback
@dataclass
class LargeKGConfig:
largekg_retriever: BaseLargeKGRetriever | BaseLargeKGEdgeRetriever = None
reader_llm_generator : LLMGenerator = None
driver: Driver = None
logger: Logger = None
is_felm: bool = False
is_mmlu: bool = False
rag_exemption_list = [
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to assess whether each text segment contains errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: 8923164*7236571?\nSegments: \n1. The product of 8923164 and 7236571 is: 6,461,216,222,844\n2. So, 8923164 multiplied by 7236571 is equal to 6,461,216,222,844.\n\nBelow are your outputs:\nAnswer: 1,2\nIt means segment 1,2 contain errors.""",
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to determine whether each text segment contains factual errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: A company offers a 10% discount on all purchases over $100. A customer purchases three items, each costing $80. Does the customer qualify for the discount?\nSegments: \n1. To solve this problem, we need to use deductive reasoning. We know that the company offers a 10% discount on purchases over $100, so we need to calculate the total cost of the customer's purchase.\n2. The customer purchased three items, each costing $80, so the total cost of the purchase is: 3 x $80 = $200.\n3. Since the total cost of the purchase is greater than $100, the customer qualifies for the discount. \n4. To calculate the discounted price, we can multiply the total cost by 0.1 (which represents the 10% discount): $200 x 0.1 = $20.\n5. So the customer is eligible for a discount of $20, and the final cost of the purchase would be: $200 - $20 = $180.\n6. Therefore, the customer would pay a total of $216 for the three items with the discount applied.\n\nBelow are your outputs:\nAnswer: 2,3,4,5,6\nIt means segment 2,3,4,5,6 contains errors.""",
]
mmlu_check_list = [
"""Given the following question and four candidate answers (A, B, C and D), choose the answer."""
]
app = FastAPI()
@app.on_event("startup")
async def startup():
global large_kg_config
start_up_large_kg_index_graph(large_kg_config.driver)
@app.on_event("shutdown")
async def shutdown():
global large_kg_config
print("Shutting down the model...")
del large_kg_config
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "test"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: str = None
name: Optional[str] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[Union[dict, List[dict]]] = None
repetition_penalty: Optional[float] = 1.1
retriever_config: Optional[dict] = {
"topN": 5,
"number_of_source_nodes_per_ner": 10,
"sampling_area": 250,
"Dmax": 2,
"Wmax": 3
}
class Config:
extra = "allow"
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call"]
class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
index: int
class ChatCompletionResponse(BaseModel):
model: str
id: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
@app.get("/health")
async def health_check():
return Response(status_code=200)
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global large_kg_config
try:
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
print(request)
# raise HTTPException(status_code=400, detail="Invalid request")
if large_kg_config.logger is not None:
large_kg_config.logger.info(f"Request: {request}")
gen_params = dict(
messages=request.messages,
temperature=0.8,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
last_message = request.messages[-1]
system_prompt = 'You are a helpful assistant.'
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
is_exemption = any(exemption in question for exemption in LargeKGConfig.rag_exemption_list)
is_mmlu = any(exemption in question for exemption in LargeKGConfig.mmlu_check_list)
print(f"Is exemption: {is_exemption}, Is MMLU: {is_mmlu}")
if is_mmlu:
rag_text = question
else:
parts = question.rsplit("Question:", 1)
rag_text = parts[-1] if len(parts) > 1 else None
print(f"RAG text: {rag_text}")
if not is_exemption:
passages, passages_score = large_kg_config.largekg_retriever.retrieve_passages(rag_text)
context = "No retrieved context, Please answer the question with your own knowledge." if not passages else "\n".join([f"Passage {i+1}: {text}" for i, text in enumerate(reversed(passages))])
if is_mmlu:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""Here is the context: {context} \n\n
If the context is not useful, you can answer the question with your own knowledge. \n {question}\nThink step by step. Your response should end with 'The answer is ([the_answer_letter])' where the [the_answer_letter] is one of A, B, C and D."""
}
]
elif not is_exemption:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} Reference doc: {context}"""
}
]
else:
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} """
}
]
if large_kg_config.logger is not None:
large_kg_config.logger.info(rag_chat_content)
response = large_kg_config.reader_llm_generator.generate_response(
batch_messages=rag_chat_content,
max_new_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
frequency_penalty = 1.1
)
message = ChatMessage(
role="assistant",
content=response.strip()
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop"
)
return ChatCompletionResponse(
model=request.model,
id="",
object="chat.completion",
choices=[choice_data]
)
except Exception as e:
print("ERROR: ", e)
print("Catched error")
traceback.print_exc()
system_prompt = 'You are a helpful assistant.'
gen_params = dict(
messages=request.messages,
temperature=0.8,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
echo=False,
stream=False,
repetition_penalty=request.repetition_penalty,
tools=request.tools,
)
last_message = request.messages[-1]
system_prompt = 'You are a helpful assistant.'
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
rag_chat_content = [
{
"role": "system",
"content": f"{system_prompt}"
},
{
"role": "user",
"content": f"""{question} """
}
]
response = large_kg_config.reader_llm_generator.generate_response(
batch_messages=rag_chat_content,
max_new_tokens=gen_params["max_tokens"],
temperature=gen_params["temperature"],
frequency_penalty = 1.1
)
message = ChatMessage(
role="assistant",
content=response.strip()
)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason="stop"
)
return ChatCompletionResponse(
model=request.model,
id="",
object="chat.completion",
choices=[choice_data]
)
def start_app(user_config:LargeKGConfig, host="0.0.0.0", port=10090, reload=False):
"""Function to start the FastAPI application."""
global large_kg_config
large_kg_config = user_config # Use the passed context if provided
uvicorn.run(f"atlas_rag.kg_construction.neo4j.neo4j_api:app", host=host, port=port, reload=reload)

View File

@ -0,0 +1,73 @@
import faiss
from neo4j import Driver
import time
from graphdatascience import GraphDataScience
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
def build_projection_graph(driver: GraphDataScience):
project_graph_1 = "largekgrag_graph"
is_project_graph_1_exist = False
# is_project_graph_2_exist = False
result = driver.graph.list()
for index, row in result.iterrows():
if row['graphName'] == project_graph_1:
is_project_graph_1_exist = True
# if row['graphName'] == project_graph_2:
# is_project_graph_2_exist = True
if not is_project_graph_1_exist:
start_time = time.time()
node_properties = ["Node"]
relation_projection = [ "Relation"]
result = driver.graph.project(
project_graph_1,
node_properties,
relation_projection
)
graph = driver.graph.get(project_graph_1)
print(f"Projection graph {project_graph_1} created in {time.time() - start_time:.2f} seconds")
def build_neo4j_label_index(driver: GraphDataScience):
with driver.session() as session:
index_name = f"NodeNumericIDIndex"
# Check if the index already exists
existing_indexes = session.run("SHOW INDEXES").data()
index_exists = any(index['name'] == index_name for index in existing_indexes)
# Drop the index if it exists
if not index_exists:
start_time = time.time()
session.run(f"CREATE INDEX {index_name} FOR (n:Node) ON (n.numeric_id)")
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
index_name = f"TextNumericIDIndex"
index_exists = any(index['name'] == index_name for index in existing_indexes)
if not index_exists:
start_time = time.time()
session.run(f"CREATE INDEX {index_name} FOR (t:Text) ON (t.numeric_id)")
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
index_name = f"EntityEventEdgeNumericIDIndex"
index_exists = any(index['name'] == index_name for index in existing_indexes)
if not index_exists:
start_time = time.time()
session.run(f"CREATE INDEX {index_name} FOR ()-[r:Relation]-() on (r.numeric_id)")
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
def load_indexes(path_dict):
for key, value in path_dict.items():
if key == 'node':
node_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
print(f"Node index loaded from {value}")
elif key == 'edge':
edge_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
print(f"Edge index loaded from {value}")
elif key == 'text':
passage_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
print(f"Passage index loaded from {value}")
return node_index, edge_index, passage_index
def start_up_large_kg_index_graph(neo4j_driver: Driver)->BaseLargeKGRetriever:
gds_driver = GraphDataScience(neo4j_driver)
# build label index and projection graph
build_neo4j_label_index(neo4j_driver)
build_projection_graph(gds_driver)

View File

@ -0,0 +1,22 @@
from dataclasses import dataclass
@dataclass
class ProcessingConfig:
"""Configuration for text processing pipeline."""
model_path: str
data_directory: str
filename_pattern: str
batch_size_triple: int = 16
batch_size_concept: int = 64
output_directory: str = "./generation_result_debug"
total_shards_triple: int = 1
current_shard_triple: int = 0
total_shards_concept: int = 1
current_shard_concept: int = 0
use_8bit: bool = False
debug_mode: bool = False
resume_from: int = 0
record : bool = False
max_new_tokens: int = 8192
max_workers: int = 8
remove_doc_spaces: bool = False

View File

@ -0,0 +1,497 @@
#!/usr/bin/env python3
"""
Knowledge Graph Extraction Pipeline
Extracts entities, relations, and events from text data using transformer models.
"""
import re
import json
import os
import argparse
from datetime import datetime
from typing import List, Dict, Any, Tuple
from pathlib import Path
import torch
from datasets import load_dataset
from tqdm import tqdm
import json_repair
from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.kg_construction.utils.json_processing.json_to_csv import json2csv
from atlas_rag.kg_construction.concept_generation import generate_concept
from atlas_rag.kg_construction.utils.csv_processing.merge_csv import merge_csv_files
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import csvs_to_graphml, csvs_to_temp_graphml
from atlas_rag.kg_construction.concept_to_csv import all_concept_triples_csv_to_csv
from atlas_rag.kg_construction.utils.csv_processing.csv_add_numeric_id import add_csv_columns
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.vectorstore.create_neo4j_index import create_faiss_index
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import TRIPLE_INSTRUCTIONS
from atlas_rag.kg_construction.triple_config import ProcessingConfig
# Constants
TOKEN_LIMIT = 1024
INSTRUCTION_TOKEN_ESTIMATE = 200
CHAR_TO_TOKEN_RATIO = 3.5
class TextChunker:
"""Handles text chunking based on token limits."""
def __init__(self, max_tokens: int = TOKEN_LIMIT, instruction_tokens: int = INSTRUCTION_TOKEN_ESTIMATE):
self.max_tokens = max_tokens
self.instruction_tokens = instruction_tokens
self.char_ratio = CHAR_TO_TOKEN_RATIO
def calculate_max_chars(self) -> int:
"""Calculate maximum characters per chunk."""
available_tokens = self.max_tokens - self.instruction_tokens
return int(available_tokens * self.char_ratio)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks that fit within token limits."""
max_chars = self.calculate_max_chars()
chunks = []
while len(text) > max_chars:
chunks.append(text[:max_chars])
text = text[max_chars:]
if text: # Add remaining text
chunks.append(text)
return chunks
class DatasetProcessor:
"""Processes and prepares dataset for knowledge graph extraction."""
def __init__(self, config: ProcessingConfig):
self.config = config
self.chunker = TextChunker()
def filter_language_content(self, sample: Dict[str, Any]) -> bool:
"""Check if content is in English."""
metadata = sample.get("metadata", {})
language = metadata.get("lang", "en") # Default to English if not specified
supported_languages = list(TRIPLE_INSTRUCTIONS.keys())
return language in supported_languages
def create_sample_chunks(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Create chunks from a single sample."""
original_text = sample.get("text", "")
if self.config.remove_doc_spaces:
original_text = re.sub(r'\s+', ' ',original_text).strip()
text_chunks = self.chunker.split_text(original_text)
chunks = []
for chunk_idx, chunk_text in enumerate(text_chunks):
chunk_data = {
"id": sample["id"],
"text": chunk_text,
"chunk_id": chunk_idx,
"metadata": sample["metadata"]
}
chunks.append(chunk_data)
return chunks
def prepare_dataset(self, raw_dataset) -> List[Dict[str, Any]]:
"""Process raw dataset into chunks suitable for processing with generalized slicing."""
processed_samples = []
total_texts = len(raw_dataset)
# Handle edge cases
if total_texts == 0:
print(f"No texts found for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
return processed_samples
# Calculate base and remainder for fair distribution
base_texts_per_shard = total_texts // self.config.total_shards_triple
remainder = total_texts % self.config.total_shards_triple
# Calculate start index
if self.config.current_shard_triple < remainder:
start_idx = self.config.current_shard_triple * (base_texts_per_shard + 1)
else:
start_idx = remainder * (base_texts_per_shard + 1) + (self.config.current_shard_triple - remainder) * base_texts_per_shard
# Calculate end index
if self.config.current_shard_triple < remainder:
end_idx = start_idx + (base_texts_per_shard + 1)
else:
end_idx = start_idx + base_texts_per_shard
# Ensure indices are within bounds
start_idx = min(start_idx, total_texts)
end_idx = min(end_idx, total_texts)
print(f"Processing shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple} "
f"(texts {start_idx}-{end_idx-1} of {total_texts}, {end_idx - start_idx} documents)")
# Process documents in assigned shard
for idx in range(start_idx, end_idx):
sample = raw_dataset[idx]
# Filter by language
if not self.filter_language_content(sample):
print(f"Unsupported language in sample {idx}, skipping.")
continue
# Create chunks
chunks = self.create_sample_chunks(sample)
processed_samples.extend(chunks)
# Debug mode early termination
if self.config.debug_mode and len(processed_samples) >= 20:
print("Debug mode: Stopping at 20 chunks")
break
print(f"Generated {len(processed_samples)} chunks for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
return processed_samples
class CustomDataLoader:
"""Custom data loader for knowledge graph extraction."""
def __init__(self, dataset, processor: DatasetProcessor):
self.raw_dataset = dataset
self.processor = processor
self.processed_data = processor.prepare_dataset(dataset)
self.stage_to_prompt_dict = {
"stage_1": "entity_relation",
"stage_2": "event_entity",
"stage_3": "event_relation"
}
def __len__(self) -> int:
return len(self.processed_data)
def create_batch_instructions(self, batch_data: List[Dict[str, Any]]) -> List[str]:
messages_dict = {
'stage_1': [],
'stage_2': [],
'stage_3': []
}
for item in batch_data:
# get language
language = item.get("metadata",{}).get("lang", "en")
system_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['system']
stage_1_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['entity_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n' + item["text"]
stage_2_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_entity'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
stage_3_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
stage_one_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_1_msg}
]
stage_two_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_2_msg}
]
stage_three_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_3_msg}
]
messages_dict['stage_1'].append(stage_one_message)
messages_dict['stage_2'].append(stage_two_message)
messages_dict['stage_3'].append(stage_three_message)
return messages_dict
def __iter__(self):
"""Iterate through batches."""
batch_size = self.processor.config.batch_size_triple
start_idx = self.processor.config.resume_from * batch_size
for i in tqdm(range(start_idx, len(self.processed_data), batch_size)):
batch_data = self.processed_data[i:i + batch_size]
# Prepare instructions
instructions = self.create_batch_instructions(batch_data)
# Extract batch information
batch_ids = [item["id"] for item in batch_data]
batch_metadata = [item["metadata"] for item in batch_data]
batch_texts = [item["text"] for item in batch_data]
yield instructions, batch_ids, batch_texts, batch_metadata
class OutputParser:
"""Parses model outputs and extracts structured data."""
def __init__(self):
pass
def extract_structured_data(self, outputs: List[str]) -> List[List[Dict[str, Any]]]:
"""Extract structured data from model outputs."""
results = []
for output in outputs:
parsed_data = json_repair.loads(output)
results.append(parsed_data)
return results
class KnowledgeGraphExtractor:
"""Main class for knowledge graph extraction pipeline."""
def __init__(self, model:LLMGenerator, config: ProcessingConfig):
self.config = config
self.model = None
self.parser = None
self.model = model
self.model_name = model.model_name
self.parser = OutputParser()
def load_dataset(self) -> Any:
"""Load and prepare dataset."""
data_path = Path(self.config.data_directory)
all_files = os.listdir(data_path)
valid_files = [
filename for filename in all_files
if filename.startswith(self.config.filename_pattern) and
(filename.endswith(".json.gz") or filename.endswith(".json") or filename.endswith(".jsonl") or filename.endswith(".jsonl.gz"))
]
print(f"Found data files: {valid_files}")
data_files = valid_files
dataset_config = {"train": data_files}
return load_dataset(self.config.data_directory, data_files=dataset_config["train"])
def process_stage(self, instructions: Dict[str, str], stage = 1) -> Tuple[List[str], List[List[Dict[str, Any]]]]:
"""Process first stage: entity-relation extraction."""
outputs = self.model.triple_extraction(messages=instructions, max_tokens=self.config.max_new_tokens, stage=stage, record=self.config.record)
if self.config.record:
text_outputs = [output[0] for output in outputs]
else:
text_outputs = outputs
structured_data = self.parser.extract_structured_data(text_outputs)
return outputs, structured_data
def create_output_filename(self) -> str:
"""Create output filename with timestamp and shard info."""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
model_name_safe = self.config.model_path.replace("/", "_")
filename = (f"{model_name_safe}_{self.config.filename_pattern}_output_"
f"{timestamp}_{self.config.current_shard_triple + 1}_in_{self.config.total_shards_triple}.json")
extraction_dir = os.path.join(self.config.output_directory, "kg_extraction")
os.makedirs(extraction_dir, exist_ok=True)
return os.path.join(extraction_dir, filename)
def prepare_result_dict(self, batch_data: Tuple, stage_outputs: Tuple, index: int) -> Dict[str, Any]:
"""Prepare result dictionary for a single sample."""
ids, original_texts, metadata = batch_data
(stage1_results, entity_relations), (stage2_results, event_entities), (stage3_results, event_relations) = stage_outputs
if self.config.record:
stage1_outputs = [output[0] for output in stage1_results]
stage1_usage = [output[1] for output in stage1_results]
stage2_outputs = [output[0] for output in stage2_results]
stage2_usage = [output[1] for output in stage2_results]
stage3_outputs = [output[0] for output in stage3_results]
stage3_usage = [output[1] for output in stage3_results]
else:
stage1_outputs = stage1_results
stage2_outputs = stage2_results
stage3_outputs = stage3_results
result = {
"id": ids[index],
"metadata": metadata[index],
"original_text": original_texts[index],
"entity_relation_dict": entity_relations[index],
"event_entity_relation_dict": event_entities[index],
"event_relation_dict": event_relations[index],
"output_stage_one": stage1_outputs[index],
"output_stage_two": stage2_outputs[index],
"output_stage_three": stage3_outputs[index],
}
if self.config.record:
result['usage_stage_one'] = stage1_usage[index]
result['usage_stage_two'] = stage2_usage[index]
result['usage_stage_three'] = stage3_usage[index]
# Handle date serialization
if 'date_download' in result['metadata']:
result['metadata']['date_download'] = str(result['metadata']['date_download'])
return result
def debug_print_result(self, result: Dict[str, Any]):
"""Print result for debugging."""
for key, value in result.items():
print(f"{key}: {value}")
print("-" * 100)
def run_extraction(self):
"""Run the complete knowledge graph extraction pipeline."""
# Setup
os.makedirs(self.config.output_directory+'/kg_extraction', exist_ok=True)
dataset = self.load_dataset()
if self.config.debug_mode:
print("Debug mode: Processing only 20 samples")
# Create data processor and loader
processor = DatasetProcessor(self.config)
data_loader = CustomDataLoader(dataset["train"], processor)
output_file = self.create_output_filename()
print(f"Model: {self.config.model_path}")
batch_counter = 0
with torch.no_grad():
with open(output_file, "w") as output_stream:
for batch in data_loader:
batch_counter += 1
messages_dict, batch_ids, batch_texts, batch_metadata = batch
# Process all three stages
stage1_results = self.process_stage(messages_dict['stage_1'],1)
stage2_results = self.process_stage(messages_dict['stage_2'],2)
stage3_results = self.process_stage(messages_dict['stage_3'],3)
# Combine results
batch_data = (batch_ids, batch_texts, batch_metadata)
stage_outputs = (stage1_results, stage2_results, stage3_results)
# Write results
print(f"Processed {batch_counter} batches ({batch_counter * self.config.batch_size_triple} chunks)")
for i in range(len(batch_ids)):
result = self.prepare_result_dict(batch_data, stage_outputs, i)
if self.config.debug_mode:
self.debug_print_result(result)
output_stream.write(json.dumps(result, ensure_ascii=False) + "\n")
output_stream.flush()
def convert_json_to_csv(self):
json2csv(
dataset = self.config.filename_pattern,
output_dir=f"{self.config.output_directory}/triples_csv",
data_dir=f"{self.config.output_directory}/kg_extraction"
)
csvs_to_temp_graphml(
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
triple_edge_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv",
config = self.config
)
def generate_concept_csv_temp(self, batch_size: int = None, **kwargs):
generate_concept(
model=self.model,
input_file=f"{self.config.output_directory}/triples_csv/missing_concepts_{self.config.filename_pattern}_from_json.csv",
output_folder=f"{self.config.output_directory}/concepts",
output_file="concept.json",
logging_file=f"{self.config.output_directory}/concepts/logging.txt",
config=self.config,
batch_size=batch_size if batch_size else self.config.batch_size_concept,
shard=self.config.current_shard_concept,
num_shards=self.config.total_shards_concept,
record = self.config.record,
**kwargs
)
def create_concept_csv(self):
merge_csv_files(
output_file=f"{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv",
input_dir=f"{self.config.output_directory}/concepts",
)
all_concept_triples_csv_to_csv(
node_file=f'{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv',
edge_file=f'{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv',
concepts_file=f'{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv',
output_node_file=f'{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv',
output_edge_file=f'{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
output_full_concept_triple_edges=f'{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
)
def convert_to_graphml(self):
csvs_to_graphml(
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
text_node_file=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
concept_node_file=f"{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv",
triple_edge_file=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
text_edge_file=f"{self.config.output_directory}/triples_csv/text_edges_{self.config.filename_pattern}_from_json.csv",
concept_edge_file=f"{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
output_file=f"{self.config.output_directory}/kg_graphml/{self.config.filename_pattern}_graph.graphml",
)
def add_numeric_id(self):
add_csv_columns(
node_csv=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
edge_csv=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
text_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
node_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
edge_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
text_with_numeric_id=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_numeric_id.csv",
)
def compute_kg_embedding(self, encoder_model:BaseEmbeddingModel, batch_size: int = 2048):
encoder_model.compute_kg_embedding(
node_csv_without_emb=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
node_csv_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
edge_csv_without_emb=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
edge_csv_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept_with_emb.csv",
text_node_csv_without_emb=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
text_node_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
batch_size = 2048
)
def create_faiss_index(self, index_type="HNSW,Flat"):
create_faiss_index(self.config.output_directory, self.config.filename_pattern, index_type)
def parse_command_line_arguments() -> ProcessingConfig:
"""Parse command line arguments and return configuration."""
parser = argparse.ArgumentParser(description="Knowledge Graph Extraction Pipeline")
parser.add_argument("-m", "--model", type=str, required=True,
default="meta-llama/Meta-Llama-3-8B-Instruct",
help="Model path for knowledge extraction")
parser.add_argument("--data_dir", type=str, default="your_data_dir",
help="Directory containing input data")
parser.add_argument("--file_name", type=str, default="en_simple_wiki_v0",
help="Filename pattern to match")
parser.add_argument("-b", "--batch_size", type=int, default=16,
help="Batch size for processing")
parser.add_argument("--output_dir", type=str, default="./generation_result_debug",
help="Output directory for results")
parser.add_argument("--total_shards_triple", type=int, default=1,
help="Total number of data shards")
parser.add_argument("--shard", type=int, default=0,
help="Current shard index")
parser.add_argument("--bit8", action="store_true",
help="Use 8-bit quantization")
parser.add_argument("--debug", action="store_true",
help="Enable debug mode")
parser.add_argument("--resume", type=int, default=0,
help="Resume from specific batch")
args = parser.parse_args()
return ProcessingConfig(
model_path=args.model,
data_directory=args.data_dir,
filename_pattern=args.file_name,
batch_size=args.batch_size,
output_directory=args.output_dir,
total_shards_triple=args.total_shards_triple,
current_shard_triple=args.shard,
use_8bit=args.bit8,
debug_mode=args.debug,
resume_from=args.resume
)
def main():
"""Main entry point for the knowledge graph extraction pipeline."""
config = parse_command_line_arguments()
extractor = KnowledgeGraphExtractor(config)
extractor.run_extraction()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,160 @@
import csv
from tqdm import tqdm
def check_created_csv_header(keyword, csv_dir):
keyword_to_paths ={
'cc_en':{
'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv",
'concept_with_numeric_id': f"{csv_dir}/concept_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
},
'pes2o_abstract':{
'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv",
},
'en_simple_wiki_v0':{
'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv",
},
}
for key, path in keyword_to_paths[keyword].items():
with open(path) as infile:
reader = csv.reader(infile)
header = next(reader)
print(f"Header of {key}: {header}")
# print first 5 rows
for i, row in enumerate(reader):
if i < 1:
print(row)
else:
break
def add_csv_columns(node_csv, edge_csv, text_csv, node_with_numeric_id, edge_with_numeric_id, text_with_numeric_id):
with open(node_csv) as infile, open(node_with_numeric_id, 'w', newline='') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
header = next(reader)
print(header)
label_index = header.index(':LABEL')
header.insert(label_index, 'numeric_id') # Add new column name
writer.writerow(header)
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
writer.writerow(row)
with open(edge_csv) as infile, open(edge_with_numeric_id, 'w', newline='') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
header = next(reader)
print(header)
label_index = header.index(':TYPE')
header.insert(label_index, 'numeric_id') # Add new column name
writer.writerow(header)
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
writer.writerow(row)
with open(text_csv) as infile, open(text_with_numeric_id, 'w', newline='') as outfile:
reader = csv.reader(infile)
writer = csv.writer(outfile)
header = next(reader)
print(header)
label_index = header.index(':LABEL')
header.insert(label_index, 'numeric_id') # Add new column name
writer.writerow(header)
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
writer.writerow(row)
# def add_csv_columns(keyword, csv_dir):
# keyword_to_paths ={
# 'cc_en':{
# 'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb.csv",
# 'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json.csv",
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv"
# },
# 'pes2o_abstract':{
# 'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept.csv",
# 'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json.csv",
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv"
# },
# 'en_simple_wiki_v0':{
# 'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept.csv",
# 'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json.csv",
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv"
# },
# }
# # ouput node
# with open(keyword_to_paths[keyword]['node_csv']) as infile, open(keyword_to_paths[keyword]['node_with_numeric_id'], 'w') as outfile:
# reader = csv.reader(infile)
# writer = csv.writer(outfile)
# # Read the header
# header = next(reader)
# print(header)
# # Insert 'numeric_id' before ':LABEL'
# label_index = header.index(':LABEL')
# header.insert(label_index, 'numeric_id') # Add new column name
# writer.writerow(header)
# # Process each row and add a numeric ID
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
# writer.writerow(row)
# # output edge (TYPE instead of LABEL for edge)
# with open(keyword_to_paths[keyword]['edge_csv']) as infile, open(keyword_to_paths[keyword]['edge_with_numeric_id'], 'w') as outfile:
# reader = csv.reader(infile)
# writer = csv.writer(outfile)
# # Read the header
# header = next(reader)
# print(header)
# # Insert 'numeric_id' before ':TYPE'
# label_index = header.index(':TYPE')
# header.insert(label_index, 'numeric_id') # Add new column name
# writer.writerow(header)
# # Process each row and add a numeric ID
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
# writer.writerow(row)
# # output text
# with open(keyword_to_paths[keyword]['text_csv']) as infile, open(keyword_to_paths[keyword]['text_with_numeric_id'], 'w') as outfile:
# reader = csv.reader(infile)
# writer = csv.writer(outfile)
# # Read the header
# header = next(reader)
# print(header)
# # Insert 'numeric_id' before ':LABEL'
# label_index = header.index(':LABEL')
# header.insert(label_index, 'numeric_id') # Add new column name
# writer.writerow(header)
# # Process each row and add a numeric ID
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
# writer.writerow(row)
if __name__ == "__main__":
keyword = "en_simple_wiki_v0"
csv_dir = "./import" # Change this to your CSV directory
add_csv_columns(keyword, csv_dir)
# check_created_csv_header(keyword)

View File

@ -0,0 +1,189 @@
import networkx as nx
import csv
import ast
import hashlib
import os
from atlas_rag.kg_construction.triple_config import ProcessingConfig
import pickle
def get_node_id(entity_name, entity_to_id={}):
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
if entity_name not in entity_to_id:
# Use a hash function to generate a unique ID
hash_object = hashlib.sha256(entity_name.encode('utf-8'))
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
entity_to_id[entity_name] = hash_hex
return entity_to_id[entity_name]
def csvs_to_temp_graphml(triple_node_file, triple_edge_file, config:ProcessingConfig=None):
g = nx.DiGraph()
entity_to_id = {}
# Add triple nodes
with open(triple_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["name:ID"]
mapped_id = get_node_id(node_id, entity_to_id)
if mapped_id not in g.nodes:
g.add_node(mapped_id, id=node_id, type=row["type"])
# Add triple edges
with open(triple_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = get_node_id(row[":END_ID"], entity_to_id)
# Check if edge already exists to prevent duplicates
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
# save graph to
output_name = f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl"
# check if output file directory exists
output_dir = os.path.dirname(output_name)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
# store the graph to a pickle file
with open(output_name, 'wb') as output_file:
pickle.dump(g, output_file)
def csvs_to_graphml(triple_node_file, text_node_file, concept_node_file,
triple_edge_file, text_edge_file, concept_edge_file,
output_file):
'''
Convert multiple CSV files into a single GraphML file.
Types of nodes to be added to the graph:
- Triple nodes: Nodes representing triples, with properties like subject, predicate, object.
- Text nodes: Nodes representing text, with properties like text content.
- Concept nodes: Nodes representing concepts, with properties like concept name and type.
Types of edges to be added to the graph:
- Triple edges: Edges representing relationships between triples, with properties like relation type.
- Text edges: Edges representing relationships between text and nodes, with properties like text type.
- Concept edges: Edges representing relationships between concepts and nodes, with properties like concept type.
DiGraph networkx attributes:
Node:
- type: Type of the node (e.g., entity, event, text, concept).
- file_id: List of text IDs the node is associated with.
- id: Node Name
Edge:
- relation: relation name
- file_id: List of text IDs the edge is associated with.
- type: Type of the edge (e.g., Source, Relation, Concept).
- synsets: List of synsets associated with the edge.
'''
g = nx.DiGraph()
entity_to_id = {}
# Add triple nodes
with open(triple_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["name:ID"]
mapped_id = get_node_id(node_id, entity_to_id)
# Check if node already exists to prevent duplicates
if mapped_id not in g.nodes:
g.add_node(mapped_id, id=node_id, type=row["type"])
# Add text nodes
with open(text_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["text_id:ID"]
# Check if node already exists to prevent duplicates
if node_id not in g.nodes:
g.add_node(node_id, file_id=node_id, id=row["original_text"], type="passage")
# Add concept nodes
with open(concept_node_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
node_id = row["concept_id:ID"]
# Check if node already exists to prevent duplicates
if node_id not in g.nodes:
g.add_node(node_id, file_id="concept_file", id=row["name"], type="concept")
# Add file id for triple nodes and concept nodes when add the edges
# Add triple edges
with open(triple_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = get_node_id(row[":END_ID"], entity_to_id)
# Check if edge already exists to prevent duplicates
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
# Add file_id to start and end nodes if they are triple or concept nodes
for node_id in [start_id, end_id]:
if g.nodes[node_id]['type'] in ['triple', 'concept'] and 'file_id' not in g.nodes[node_id]:
g.nodes[node_id]['file_id'] = row.get("file_id", "triple_file")
# Add concepts to the edge
concepts = ast.literal_eval(row["concepts"])
for concept in concepts:
if "concepts" not in g.edges[start_id, end_id]:
g.edges[start_id, end_id]['concepts'] = str(concept)
else:
# Avoid duplicate concepts by checking if concept is already in the list
current_concepts = g.edges[start_id, end_id]['concepts'].split(",")
if str(concept) not in current_concepts:
g.edges[start_id, end_id]['concepts'] += "," + str(concept)
# Add text edges
with open(text_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = row[":END_ID"]
# Check if edge already exists to prevent duplicates
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation="mention in", type=row[":TYPE"])
# Add file_id to start node if it is a triple or concept node
if 'file_id' in g.nodes[start_id]:
g.nodes[start_id]['file_id'] += "," + str(end_id)
else:
g.nodes[start_id]['file_id'] = str(end_id)
# Add concept edges between triple nodes and concept nodes
with open(concept_edge_file, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
start_id = get_node_id(row[":START_ID"], entity_to_id)
end_id = row[":END_ID"] # end id is concept node id
if not g.has_edge(start_id, end_id):
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
# Write to GraphML
# check if output file directory exists
output_dir = os.path.dirname(output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
nx.write_graphml(g, output_file, infer_numeric_types=True)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Convert CSV files to GraphML format.')
parser.add_argument('--triple_node_file', type=str, required=True, help='Path to the triple node CSV file.')
parser.add_argument('--text_node_file', type=str, required=True, help='Path to the text node CSV file.')
parser.add_argument('--concept_node_file', type=str, required=True, help='Path to the concept node CSV file.')
parser.add_argument('--triple_edge_file', type=str, required=True, help='Path to the triple edge CSV file.')
parser.add_argument('--text_edge_file', type=str, required=True, help='Path to the text edge CSV file.')
parser.add_argument('--concept_edge_file', type=str, required=True, help='Path to the concept edge CSV file.')
parser.add_argument('--output_file', type=str, required=True, help='Path to the output GraphML file.')
args = parser.parse_args()
csvs_to_graphml(args.triple_node_file, args.text_node_file, args.concept_node_file,
args.triple_edge_file, args.text_edge_file, args.concept_edge_file,
args.output_file)

View File

@ -0,0 +1,70 @@
import pandas as pd
import numpy as np
from ast import literal_eval # Safer string-to-list conversion
import os
CHUNKSIZE = 100_000 # Adjust based on your RAM (100K rows per chunk)
EMBEDDING_COL = "embedding:STRING" # Column name with embeddings
# DIMENSION = 32 # Update with your embedding dimension
ENTITY_ONLY = True
def parse_embedding(embed_str):
"""Convert embedding string to numpy array"""
# Remove brackets and convert to list
return np.array(literal_eval(embed_str), dtype=np.float32)
# Create memory-mapped numpy file
def convert_csv_to_npy(csv_path, npy_path):
total_embeddings = 0
# check dir exist, if not then create it
os.makedirs(os.path.dirname(npy_path), exist_ok=True)
with open(npy_path, "wb") as f:
pass # Initialize empty file
# Process CSV in chunks
for chunk_idx, df_chunk in enumerate(
pd.read_csv(csv_path, chunksize=CHUNKSIZE, usecols=[EMBEDDING_COL])
):
# Parse embeddings
embeddings = np.stack(
df_chunk[EMBEDDING_COL].apply(parse_embedding).values
)
# Verify dimensions
# assert embeddings.shape[1] == DIMENSION, \
# f"Dimension mismatch at chunk {chunk_idx}"
total_embeddings += embeddings.shape[0]
# Append to .npy file
with open(npy_path, "ab") as f:
np.save(f, embeddings.astype(np.float32))
print(f"Processed chunk {chunk_idx} ({CHUNKSIZE*(chunk_idx+1)} rows)")
print(f"Total number of embeddings: {total_embeddings}")
print("Conversion complete!")
if __name__ == "__main__":
keyword = 'cc_en' # Change this to your desired keyword
csv_dir="./import" # Change this to your CSV directory
keyword_to_paths ={
'cc_en':{
'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_2.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_2.csv",
'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json_with_emb.csv",
},
'pes2o_abstract':{
'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json.csv",
'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
},
'en_simple_wiki_v0':{
'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
},
}
for key, path in keyword_to_paths[keyword].items():
npy_path = path.replace(".csv", ".npy")
convert_csv_to_npy(path, npy_path)
print(f"Converted {path} to {npy_path}")

View File

@ -0,0 +1,27 @@
import os
import glob
def merge_csv_files(output_file, input_dir):
"""
Merge all CSV files in the input directory into a single output file.
Args:
output_file (str): Path to the output CSV file.
input_dir (str): Directory containing the input CSV files.
"""
# Delete the output file if it exists
if os.path.exists(output_file):
os.remove(output_file)
# Write the header to the output file
with open(output_file, 'w') as outfile:
outfile.write("node,conceptualized_node,node_type\n")
# Append the contents of all CSV files in the input directory
for csv_file in glob.glob(os.path.join(input_dir, '*.csv')):
with open(csv_file, 'r') as infile:
# Skip the header line
next(infile)
# Append the remaining lines to the output file
with open(output_file, 'a') as outfile:
outfile.writelines(infile)

View File

@ -0,0 +1,277 @@
from tqdm import tqdm
import argparse
import os
import csv
import json
import re
import hashlib
# Increase the field size limit
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
# Function to compute a hash ID from text
def compute_hash_id(text):
# Use SHA-256 to generate a hash
hash_object = hashlib.sha256(text.encode('utf-8'))
return hash_object.hexdigest() # Return hash as a hex string
def clean_text(text):
# remove NUL as well
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
new_text = new_text.replace("\x00", "")
new_text = re.sub(r'\s+', ' ', new_text).strip()
return new_text
def remove_NUL(text):
return text.replace("\x00", "")
def json2csv(dataset, data_dir, output_dir, test=False):
"""
Convert JSON files to CSV files for nodes, edges, and missing concepts.
Args:
dataset (str): Name of the dataset.
data_dir (str): Directory containing the JSON files.
output_dir (str): Directory to save the output CSV files.
test (bool): If True, run in test mode (process only 3 files).
"""
visited_nodes = set()
visited_hashes = set()
all_entities = set()
all_events = set()
all_relations = set()
file_dir_list = [f for f in os.listdir(data_dir) if dataset in f]
file_dir_list = sorted(file_dir_list)
if test:
file_dir_list = file_dir_list[:3]
print("Loading data from the json files")
print("Number of files: ", len(file_dir_list))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Define output file paths
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb.csv")
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb.csv")
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json.csv")
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json.csv")
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json.csv")
if test:
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json_test.csv")
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json_test.csv")
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb_test.csv")
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb_test.csv")
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json_test.csv")
# Open CSV files for writing
with open(node_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node_text, \
open(edge_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge_text, \
open(node_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node, \
open(edge_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge:
csv_writer_node_text = csv.writer(csvfile_node_text)
csv_writer_edge_text = csv.writer(csvfile_edge_text)
writer_node = csv.writer(csvfile_node)
writer_edge = csv.writer(csvfile_edge)
# Write headers
csv_writer_node_text.writerow(["text_id:ID", "original_text", ":LABEL"])
csv_writer_edge_text.writerow([":START_ID", ":END_ID", ":TYPE"])
writer_node.writerow(["name:ID", "type", "concepts", "synsets", ":LABEL"])
writer_edge.writerow([":START_ID", ":END_ID", "relation", "concepts", "synsets", ":TYPE"])
# Process each file
for file_dir in tqdm(file_dir_list):
print("Processing file for file ids: ", file_dir)
with open(os.path.join(data_dir, file_dir), "r") as jsonfile:
for line in jsonfile:
data = json.loads(line.strip())
original_text = data["original_text"]
original_text = remove_NUL(original_text)
if "Here is the passage." in original_text:
original_text = original_text.split("Here is the passage.")[-1]
eot_token = "<|eot_id|>"
original_text = original_text.split(eot_token)[0]
text_hash_id = compute_hash_id(original_text)
# Write the original text as nodes
if text_hash_id not in visited_hashes:
visited_hashes.add(text_hash_id)
csv_writer_node_text.writerow([text_hash_id, original_text, "Text"])
file_id = str(data["id"])
entity_relation_dict = data["entity_relation_dict"]
event_entity_relation_dict = data["event_entity_relation_dict"]
event_relation_dict = data["event_relation_dict"]
# Process entity triples
entity_triples = []
for entity_triple in entity_relation_dict:
try:
assert isinstance(entity_triple["Head"], str)
assert isinstance(entity_triple["Relation"], str)
assert isinstance(entity_triple["Tail"], str)
head_entity = entity_triple["Head"]
relation = entity_triple["Relation"]
tail_entity = entity_triple["Tail"]
# Clean the text
head_entity = clean_text(head_entity)
relation = clean_text(relation)
tail_entity = clean_text(tail_entity)
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
continue
entity_triples.append((head_entity, relation, tail_entity))
except:
print(f"Error processing entity triple: {entity_triple}")
continue
# Process event triples
event_triples = []
for event_triple in event_relation_dict:
try:
assert isinstance(event_triple["Head"], str)
assert isinstance(event_triple["Relation"], str)
assert isinstance(event_triple["Tail"], str)
head_event = event_triple["Head"]
relation = event_triple["Relation"]
tail_event = event_triple["Tail"]
# Clean the text
head_event = clean_text(head_event)
relation = clean_text(relation)
tail_event = clean_text(tail_event)
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
continue
event_triples.append((head_event, relation, tail_event))
except:
print(f"Error processing event triple: {event_triple}")
# Process event-entity triples
event_entity_triples = []
for event_entity_participations in event_entity_relation_dict:
if "Event" not in event_entity_participations or "Entity" not in event_entity_participations:
continue
if not isinstance(event_entity_participations["Event"], str) or not isinstance(event_entity_participations["Entity"], list):
continue
for entity in event_entity_participations["Entity"]:
if not isinstance(entity, str):
continue
entity = clean_text(entity)
event = clean_text(event_entity_participations["Event"])
if event.isspace() or len(event) == 0 or entity.isspace() or len(entity) == 0:
continue
event_entity_triples.append((event, "is participated by", entity))
# Write nodes and edges to CSV files
for entity_triple in entity_triples:
head_entity, relation, tail_entity = entity_triple
if head_entity is None or tail_entity is None or relation is None:
continue
if head_entity.isspace() or tail_entity.isspace() or relation.isspace():
continue
if len(head_entity) == 0 or len(tail_entity) == 0 or len(relation) == 0:
continue
# Add nodes to files
if head_entity not in visited_nodes:
visited_nodes.add(head_entity)
all_entities.add(head_entity)
writer_node.writerow([head_entity, "entity", [], [], "Node"])
csv_writer_edge_text.writerow([head_entity, text_hash_id, "Source"])
if tail_entity not in visited_nodes:
visited_nodes.add(tail_entity)
all_entities.add(tail_entity)
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
all_relations.add(relation)
writer_edge.writerow([head_entity, tail_entity, relation, [], [], "Relation"])
for event_triple in event_triples:
head_event, relation, tail_event = event_triple
if head_event is None or tail_event is None or relation is None:
continue
if head_event.isspace() or tail_event.isspace() or relation.isspace():
continue
if len(head_event) == 0 or len(tail_event) == 0 or len(relation) == 0:
continue
# Add nodes to files
if head_event not in visited_nodes:
visited_nodes.add(head_event)
all_events.add(head_event)
writer_node.writerow([head_event, "event", [], [], "Node"])
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
if tail_event not in visited_nodes:
visited_nodes.add(tail_event)
all_events.add(tail_event)
writer_node.writerow([tail_event, "event", [], [], "Node"])
csv_writer_edge_text.writerow([tail_event, text_hash_id, "Source"])
all_relations.add(relation)
writer_edge.writerow([head_event, tail_event, relation, [], [], "Relation"])
for event_entity_triple in event_entity_triples:
head_event, relation, tail_entity = event_entity_triple
if head_event is None or tail_entity is None or relation is None:
continue
if head_event.isspace() or tail_entity.isspace() or relation.isspace():
continue
if len(head_event) == 0 or len(tail_entity) == 0 or len(relation) == 0:
continue
# Add nodes to files
if head_event not in visited_nodes:
visited_nodes.add(head_event)
all_events.add(head_event)
writer_node.writerow([head_event, "event", [], [], "Node"])
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
if tail_entity not in visited_nodes:
visited_nodes.add(tail_entity)
all_entities.add(tail_entity)
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
all_relations.add(relation)
writer_edge.writerow([head_event, tail_entity, relation, [], [], "Relation"])
# Write missing concepts to CSV
with open(missing_concepts_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Name", "Type"])
for entity in all_entities:
writer.writerow([entity, "Entity"])
for event in all_events:
writer.writerow([event, "Event"])
for relation in all_relations:
writer.writerow([relation, "Relation"])
print("Data to CSV completed successfully.")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="[pes2o_abstract, en_simple_wiki_v0, cc_en]")
parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the graph raw JSON files")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the output CSV files")
parser.add_argument("--test", action="store_true", help="Test the script")
args = parser.parse_args()
json2csv(dataset=args.dataset, data_dir=args.data_dir, output_dir=args.output_dir, test=args.test)

View File

@ -0,0 +1,169 @@
import networkx as nx
import json
from tqdm import tqdm
import os
import hashlib
def get_node_id(entity_name, entity_to_id):
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
if entity_name not in entity_to_id:
# Use a hash function to generate a unique ID
hash_object = hashlib.md5(entity_name.encode()) # Use MD5 or another hashing algorithm
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
entity_to_id[entity_name] = f'n{hash_hex[:16]}'
return entity_to_id[entity_name]
def clean_text(text):
# remove NUL as well
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
new_text = new_text.replace("\x00", "")
return new_text
def process_kg_data(input_passage_dir, input_triple_dir, output_dir, keyword):
# Get file names containing the keyword
file_names = [file for file in list(os.listdir(input_triple_dir)) if keyword in file]
print(f"Keyword: {keyword}")
print(f"Number of files: {len(file_names)}")
print(file_names)
passage_file_names = [file for file in list(os.listdir(input_passage_dir)) if keyword in file]
print(f'Passage file names: {passage_file_names}')
g = nx.DiGraph()
print("Graph created.")
entity_to_id = {}
# check if output directory exists, if not create it
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Output directory {output_dir} created.")
output_path = f"{output_dir}/{keyword}_kg_from_corpus.graphml"
# Create the original_text to node_id dictionary and add passage node to the graph
with open(f"{input_passage_dir}/{passage_file_names[0]}") as f:
data = json.load(f)
for item in tqdm(data, desc="Processing passages"):
passage_id = item["id"]
passage_text = item["text"]
node_id = get_node_id(passage_text, entity_to_id)
if passage_text.isspace() or len(passage_text) == 0:
continue
# Add the passage node to the graph
g.add_node(node_id, type="passage", id=passage_text, file_id=passage_id)
for file_name in tqdm(file_names):
print(f"Processing {file_name}")
input_file_path = f"{input_triple_dir}/{file_name}"
with open(input_file_path) as f:
for line in tqdm(f):
data = json.loads(line)
metadata = data["metadata"]
file_id = data["id"]
original_text = data["original_text"]
entity_relation_dict = data["entity_relation_dict"]
event_entity_relation_dict = data["event_entity_relation_dict"]
event_relation_dict = data["event_relation_dict"]
# Process entity triples
entity_triples = []
for entity_triple in entity_relation_dict:
if not all(key in entity_triple for key in ["Head", "Relation", "Tail"]):
continue
head_entity = clean_text(entity_triple["Head"])
relation = clean_text(entity_triple["Relation"])
tail_entity = clean_text(entity_triple["Tail"])
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
continue
entity_triples.append((head_entity, relation, tail_entity))
# Add entity triples to the graph
for triple in entity_triples:
head_id = get_node_id(triple[0], entity_to_id)
tail_id = get_node_id(triple[2], entity_to_id)
g.add_node(head_id, type="entity", id=triple[0])
g.add_node(tail_id, type="entity", id=triple[2])
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(head_id, tail_id, relation=triple[1])
for node_id in [head_id, tail_id]:
if "file_id" not in g.nodes[node_id]:
g.nodes[node_id]["file_id"] = str(file_id)
else:
g.nodes[node_id]["file_id"] += "," + str(file_id)
edge = g.edges[head_id, tail_id]
if "file_id" not in edge:
edge["file_id"] = str(file_id)
else:
edge["file_id"] += "," + str(file_id)
# Process event triples
event_triples = []
for event_triple in event_relation_dict:
if not all(key in event_triple for key in ["Head", "Relation", "Tail"]):
continue
head_event = clean_text(event_triple["Head"])
relation = clean_text(event_triple["Relation"])
tail_event = clean_text(event_triple["Tail"])
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
continue
event_triples.append((head_event, relation, tail_event))
# Add event triples to the graph
for triple in event_triples:
head_id = get_node_id(triple[0], entity_to_id)
tail_id = get_node_id(triple[2], entity_to_id)
g.add_node(head_id, type="event", id=triple[0])
g.add_node(tail_id, type="event", id=triple[2])
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
g.add_edge(head_id, tail_id, relation=triple[1])
for node_id in [head_id, tail_id]:
if "file_id" not in g.nodes[node_id]:
g.nodes[node_id]["file_id"] = str(file_id)
else:
g.nodes[node_id]["file_id"] += "," + str(file_id)
edge = g.edges[head_id, tail_id]
if "file_id" not in edge:
edge["file_id"] = str(file_id)
else:
edge["file_id"] += "," + str(file_id)
# Process event-entity triples
event_entity_triples = []
for event_entity_participations in event_entity_relation_dict:
if not all(key in event_entity_participations for key in ["Event", "Entity"]):
continue
event = clean_text(event_entity_participations["Event"])
if event.isspace() or len(event) == 0:
continue
for entity in event_entity_participations["Entity"]:
if not isinstance(entity, str) or entity.isspace() or len(entity) == 0:
continue
entity = clean_text(entity)
event_entity_triples.append((event, "is participated by", entity))
# Add event-entity triples to the graph
for triple in event_entity_triples:
head_id = get_node_id(triple[0], entity_to_id)
tail_id = get_node_id(triple[2], entity_to_id)
g.add_node(head_id, type="event", id=triple[0])
g.add_node(tail_id, type="entity", id=triple[2])
g.add_edge(head_id, tail_id, relation=triple[1])
for node_id in [head_id, tail_id]:
if "file_id" not in g.nodes[node_id]:
g.nodes[node_id]["file_id"] = str(file_id)
edge = g.edges[head_id, tail_id]
if "file_id" not in edge:
edge["file_id"] = str(file_id)
else:
edge["file_id"] += "," + str(file_id)
print(f"Number of nodes: {g.number_of_nodes()}")
print(f"Number of edges: {g.number_of_edges()}")
print(f"Graph density: {nx.density(g)}")
with open(output_path, 'wb') as f:
nx.write_graphml(g, f, infer_numeric_types=True)

View File

@ -0,0 +1,63 @@
import argparse
import json
import os
import sys
from pathlib import Path
# Set up argument parser
parser = argparse.ArgumentParser(description="Convert all Markdown files in a folder to separate JSON files.")
parser.add_argument(
"--input", required=True, help="Path to the folder containing Markdown files"
)
parser.add_argument(
"--output", default=None, help="Output folder for JSON files (defaults to input folder if not specified)"
)
# Parse arguments
args = parser.parse_args()
# Resolve input folder path
input_folder = Path(args.input)
if not input_folder.is_dir():
print(f"Error: '{args.input}' is not a directory.", file=sys.stderr)
sys.exit(1)
# Set output folder (use input folder if not specified)
output_folder = Path(args.output) if args.output else input_folder
output_folder.mkdir(parents=True, exist_ok=True)
# Find all .md files in the input folder
markdown_files = [f for f in input_folder.iterdir() if f.suffix.lower() == ".md"]
if not markdown_files:
print(f"Error: No Markdown files found in '{args.input}'.", file=sys.stderr)
sys.exit(1)
# Process each Markdown file
for file in markdown_files:
try:
# Read the content of the file
with open(file, "r", encoding="utf-8") as f:
content = f.read()
# Create the JSON object
obj = {
"id": "1",
"text": content,
"metadata": {
"lang": "en"
}
}
# Create output JSON filename (e.g., file1.md -> file1.json)
output_file = output_folder / f"{file.stem}.json"
# Write JSON to file
with open(output_file, "w", encoding="utf-8") as f:
json.dump([obj], f, indent=4)
print(f"Successfully converted '{file}' to '{output_file}'")
except FileNotFoundError:
print(f"Error: File '{file}' not found.", file=sys.stderr)
except Exception as e:
print(f"Error processing file '{file}': {e}", file=sys.stderr)