first commit
This commit is contained in:
0
AIEC-RAG/atlas_rag/kg_construction/__init__.py
Normal file
0
AIEC-RAG/atlas_rag/kg_construction/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
282
AIEC-RAG/atlas_rag/kg_construction/concept_generation.py
Normal file
282
AIEC-RAG/atlas_rag/kg_construction/concept_generation.py
Normal file
@ -0,0 +1,282 @@
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import logging
|
||||
import csv
|
||||
import os
|
||||
import hashlib
|
||||
import re
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import get_node_id
|
||||
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import CONCEPT_INSTRUCTIONS
|
||||
import pickle
|
||||
# Increase the field size limit
|
||||
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
|
||||
|
||||
|
||||
|
||||
def build_batch_data(sessions, batch_size):
|
||||
batched_sessions = []
|
||||
for i in range(0, len(sessions), batch_size):
|
||||
batched_sessions.append(sessions[i:i+batch_size])
|
||||
return batched_sessions
|
||||
|
||||
# Function to compute a hash ID from text
|
||||
def compute_hash_id(text):
|
||||
# Use SHA-256 to generate a hash
|
||||
hash_object = hashlib.sha256(text.encode('utf-8'))
|
||||
return hash_object.hexdigest() # Return hash as a hex string
|
||||
|
||||
def convert_attribute(value):
|
||||
""" Convert attributes to GDS-compatible types. """
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def clean_text(text):
|
||||
# remove NUL as well
|
||||
|
||||
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
|
||||
new_text = new_text.replace("\x00", "")
|
||||
new_text = re.sub(r'\s+', ' ', new_text).strip()
|
||||
|
||||
return new_text
|
||||
|
||||
def remove_NUL(text):
|
||||
return text.replace("\x00", "")
|
||||
|
||||
def build_batched_events(all_node_list, batch_size):
|
||||
"The types are in Entity Event Relation"
|
||||
event_nodes = [node[0] for node in all_node_list if node[1].lower() == "event"]
|
||||
batched_events = []
|
||||
|
||||
for i in range(0, len(event_nodes), batch_size):
|
||||
batched_events.append(event_nodes[i:i+batch_size])
|
||||
|
||||
return batched_events
|
||||
|
||||
def build_batched_entities(all_node_list, batch_size):
|
||||
|
||||
entity_nodes = [node[0] for node in all_node_list if node[1].lower() == "entity"]
|
||||
batched_entities = []
|
||||
|
||||
for i in range(0, len(entity_nodes), batch_size):
|
||||
batched_entities.append(entity_nodes[i:i+batch_size])
|
||||
|
||||
return batched_entities
|
||||
|
||||
def build_batched_relations(all_node_list, batch_size):
|
||||
|
||||
relations = [node[0] for node in all_node_list if node[1].lower() == "relation"]
|
||||
# relations = list(set(relations))
|
||||
batched_relations = []
|
||||
|
||||
for i in range(0, len(relations), batch_size):
|
||||
batched_relations.append(relations[i:i+batch_size])
|
||||
|
||||
return batched_relations
|
||||
|
||||
def batched_inference(model:LLMGenerator, inputs, record=False, **kwargs):
|
||||
responses = model.generate_response(inputs, return_text_only = not record, **kwargs)
|
||||
answers = []
|
||||
if record:
|
||||
text_responses = [response[0] for response in responses]
|
||||
usages = [response[1] for response in responses]
|
||||
else:
|
||||
text_responses = responses
|
||||
for i in range(len(text_responses)):
|
||||
answer = text_responses[i]
|
||||
answers.append([x.strip().lower() for x in answer.split(",")])
|
||||
|
||||
if record:
|
||||
return answers, usages
|
||||
else:
|
||||
return answers
|
||||
|
||||
def load_data_with_shard(input_file, shard_idx, num_shards):
|
||||
|
||||
with open(input_file, "r") as f:
|
||||
csv_reader = list(csv.reader(f))
|
||||
|
||||
# data = csv_reader
|
||||
data = csv_reader[1:]
|
||||
# Random shuffle the data before splitting into shards
|
||||
random.shuffle(data)
|
||||
|
||||
total_lines = len(data)
|
||||
lines_per_shard = (total_lines + num_shards - 1) // num_shards
|
||||
start_idx = shard_idx * lines_per_shard
|
||||
end_idx = min((shard_idx + 1) * lines_per_shard, total_lines)
|
||||
|
||||
return data[start_idx:end_idx]
|
||||
|
||||
def generate_concept(model: LLMGenerator,
|
||||
input_file = 'processed_data/triples_csv',
|
||||
output_folder = 'processed_data/triples_conceptualized',
|
||||
output_file = 'output.json',
|
||||
logging_file = 'processed_data/logging.txt',
|
||||
config:ProcessingConfig=None,
|
||||
batch_size=32,
|
||||
shard=0,
|
||||
num_shards=1,
|
||||
**kwargs):
|
||||
log_dir = os.path.dirname(logging_file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
# Create the log file if it doesn't exist
|
||||
if not os.path.exists(logging_file):
|
||||
open(logging_file, 'w').close()
|
||||
|
||||
language = kwargs.get('language', 'en')
|
||||
record = kwargs.get('record', False)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
file_handler = logging.FileHandler(logging_file)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
with open(f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl", "rb") as f:
|
||||
temp_kg = pickle.load(f)
|
||||
|
||||
# read data
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
all_missing_nodes = load_data_with_shard(
|
||||
input_file,
|
||||
shard_idx=shard,
|
||||
num_shards=num_shards
|
||||
)
|
||||
|
||||
batched_events = build_batched_events(all_missing_nodes, batch_size)
|
||||
batched_entities = build_batched_entities(all_missing_nodes, batch_size)
|
||||
batched_relations = build_batched_relations(all_missing_nodes, batch_size)
|
||||
|
||||
all_batches = []
|
||||
all_batches.extend(('event', batch) for batch in batched_events)
|
||||
all_batches.extend(('entity', batch) for batch in batched_entities)
|
||||
all_batches.extend(('relation', batch) for batch in batched_relations)
|
||||
|
||||
print("all_batches", len(all_batches))
|
||||
|
||||
|
||||
|
||||
|
||||
output_file = output_folder + f"/{output_file.rsplit('.', 1)[0]}_shard_{shard}.csv"
|
||||
with open(output_file, "w", newline='') as file:
|
||||
csv_writer = csv.writer(file)
|
||||
csv_writer.writerow(["node", "conceptualized_node", "node_type"])
|
||||
|
||||
# for batch_type, batch in tqdm(all_batches, total=total_batches, desc="Generating concepts"):
|
||||
# don't use tqdm for now
|
||||
for batch_type, batch in tqdm(all_batches, desc="Shard_{}".format(shard)):
|
||||
# print("batch_type", batch_type)
|
||||
# print("batch", batch)
|
||||
replace_context_token = None
|
||||
if batch_type == 'event':
|
||||
template = CONCEPT_INSTRUCTIONS[language]['event']
|
||||
node_type = 'event'
|
||||
replace_token = '[EVENT]'
|
||||
elif batch_type == 'entity':
|
||||
template = CONCEPT_INSTRUCTIONS[language]['entity']
|
||||
node_type = 'entity'
|
||||
replace_token = '[ENTITY]'
|
||||
replace_context_token = '[CONTEXT]'
|
||||
elif batch_type == 'relation':
|
||||
template = CONCEPT_INSTRUCTIONS[language]['relation']
|
||||
node_type = 'relation'
|
||||
replace_token = '[RELATION]'
|
||||
|
||||
inputs = []
|
||||
for node in batch:
|
||||
# sample node from given node and replace context token.
|
||||
if replace_context_token:
|
||||
node_id = get_node_id(node)
|
||||
entity_predecessors = list(temp_kg.predecessors(node_id))
|
||||
entity_successors = list(temp_kg.successors(node_id))
|
||||
|
||||
context = ""
|
||||
|
||||
if len(entity_predecessors) > 0:
|
||||
random_two_neighbors = random.sample(entity_predecessors, min(1, len(entity_predecessors)))
|
||||
context += ", ".join([f"{temp_kg.nodes[neighbor]['id']} {temp_kg[neighbor][node_id]['relation']}" for neighbor in random_two_neighbors])
|
||||
|
||||
if len(entity_successors) > 0:
|
||||
random_two_neighbors = random.sample(entity_successors, min(1, len(entity_successors)))
|
||||
context += ", ".join([f"{temp_kg[node_id][neighbor]['relation']} {temp_kg.nodes[neighbor]['id']}" for neighbor in random_two_neighbors])
|
||||
|
||||
prompt = template.replace(replace_token, node).replace(replace_context_token, context)
|
||||
else:
|
||||
prompt = template.replace(replace_token, node)
|
||||
constructed_input = [
|
||||
{"role": "system", "content": "You are a helpful AI assistant."},
|
||||
{"role": "user", "content": f"{prompt}"},
|
||||
]
|
||||
inputs.append(constructed_input)
|
||||
|
||||
try:
|
||||
# print("inputs", inputs)
|
||||
if record:
|
||||
# If recording, we will get both answers and responses
|
||||
answers, usages = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
|
||||
else:
|
||||
answers = batched_inference(model, inputs, record=record, max_workers = config.max_workers)
|
||||
usages = None
|
||||
# print("answers", answers)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing {batch_type} batch: {e}")
|
||||
raise e
|
||||
# try:
|
||||
# answers = batched_inference(llm, sampling_params, inputs)
|
||||
# except Exception as e:
|
||||
# logging.error(f"Error processing {batch_type} batch: {e}")
|
||||
# continue
|
||||
|
||||
for i,(node, answer) in enumerate(zip(batch, answers)):
|
||||
# print(node, answer, node_type)
|
||||
if usages is not None:
|
||||
logging.info(f"Usage log: Node {node}, completion_usage: {usages[i]}")
|
||||
csv_writer.writerow([node, ", ".join(answer), node_type])
|
||||
file.flush()
|
||||
# count unique conceptualized nodes
|
||||
conceptualized_nodes = []
|
||||
|
||||
conceptualized_events = []
|
||||
conceptualized_entities = []
|
||||
conceptualized_relations = []
|
||||
|
||||
with open(output_file, "r") as file:
|
||||
reader = csv.reader(file)
|
||||
next(reader)
|
||||
for row in reader:
|
||||
conceptualized_nodes.extend(row[1].split(","))
|
||||
if row[2] == "event":
|
||||
conceptualized_events.extend(row[1].split(","))
|
||||
elif row[2] == "entity":
|
||||
conceptualized_entities.extend(row[1].split(","))
|
||||
elif row[2] == "relation":
|
||||
conceptualized_relations.extend(row[1].split(","))
|
||||
|
||||
conceptualized_nodes = [x.strip() for x in conceptualized_nodes]
|
||||
conceptualized_events = [x.strip() for x in conceptualized_events]
|
||||
conceptualized_entities = [x.strip() for x in conceptualized_entities]
|
||||
conceptualized_relations = [x.strip() for x in conceptualized_relations]
|
||||
|
||||
unique_conceptualized_nodes = list(set(conceptualized_nodes))
|
||||
unique_conceptualized_events = list(set(conceptualized_events))
|
||||
unique_conceptualized_entities = list(set(conceptualized_entities))
|
||||
unique_conceptualized_relations = list(set(conceptualized_relations))
|
||||
|
||||
print(f"Number of unique conceptualized nodes: {len(unique_conceptualized_nodes)}")
|
||||
print(f"Number of unique conceptualized events: {len(unique_conceptualized_events)}")
|
||||
print(f"Number of unique conceptualized entities: {len(unique_conceptualized_entities)}")
|
||||
print(f"Number of unique conceptualized relations: {len(unique_conceptualized_relations)}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
153
AIEC-RAG/atlas_rag/kg_construction/concept_to_csv.py
Normal file
153
AIEC-RAG/atlas_rag/kg_construction/concept_to_csv.py
Normal file
@ -0,0 +1,153 @@
|
||||
import ast
|
||||
import uuid
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
def generate_uuid():
|
||||
"""Generate a random UUID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def parse_concepts(s):
|
||||
"""Parse concepts field and filter empty values"""
|
||||
try:
|
||||
parsed = ast.literal_eval(s) if s and s != '[]' else []
|
||||
return [c.strip() for c in parsed if c.strip()]
|
||||
except:
|
||||
return []
|
||||
|
||||
|
||||
# Function to compute a hash ID from text
|
||||
def compute_hash_id(text):
|
||||
# Use SHA-256 to generate a hash
|
||||
text = text + '_concept'
|
||||
hash_object = hashlib.sha256(text.encode('utf-8'))
|
||||
return hash_object.hexdigest() # Return hash as a hex string
|
||||
|
||||
|
||||
def all_concept_triples_csv_to_csv(node_file, edge_file, concepts_file, output_node_file, output_edge_file, output_full_concept_triple_edges):
|
||||
|
||||
# to deal add output the concepts nodes, edges, and new full_triple_edges,
|
||||
# we need to read the concepts maps to the memory, as it is usually not too large.
|
||||
# Then we need to iterate over the triple nodes to create concept edges
|
||||
# Finally we iterate over the triple edges to create the full_triple_edges
|
||||
|
||||
# Read missing concept
|
||||
# relation_concepts_mapping = {}
|
||||
# all_missing_concepts = []
|
||||
|
||||
# check if all output directories exist
|
||||
output_dir = os.path.dirname(output_node_file)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
output_dir = os.path.dirname(output_edge_file)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
output_dir = os.path.dirname(output_full_concept_triple_edges)
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
node_to_concepts = {}
|
||||
relation_to_concepts = {}
|
||||
|
||||
all_concepts = set()
|
||||
with open(concepts_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
|
||||
# Load missing concepts list
|
||||
print("Loading concepts...")
|
||||
for row in tqdm(reader):
|
||||
|
||||
|
||||
if row['node_type'] == 'relation':
|
||||
relation = row['node']
|
||||
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
|
||||
|
||||
if relation not in relation_to_concepts:
|
||||
relation_to_concepts[relation] = concepts
|
||||
else:
|
||||
relation_to_concepts[relation].extend(concepts)
|
||||
relation_to_concepts[relation] = list(set(relation_to_concepts[relation]))
|
||||
|
||||
else:
|
||||
node = row['node']
|
||||
concepts = [c.strip() for c in row['conceptualized_node'].split(',') if c.strip()]
|
||||
|
||||
if node not in node_to_concepts:
|
||||
node_to_concepts[node] = concepts
|
||||
else:
|
||||
node_to_concepts[node].extend(concepts)
|
||||
node_to_concepts[node] = list(set(node_to_concepts[node]))
|
||||
|
||||
print("Loading concepts done.")
|
||||
print(f"Relation to concepts: {len(relation_to_concepts)}")
|
||||
print(f"Node to concepts: {len(node_to_concepts)}")
|
||||
|
||||
# Read triple nodes and write to output concept edges files
|
||||
print("Processing triple nodes...")
|
||||
with open(node_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
# name:ID,type,concepts,synsets,:LABEL
|
||||
header = next(reader)
|
||||
|
||||
with open (output_edge_file, 'w', newline='', encoding='utf-8') as f_out:
|
||||
|
||||
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
||||
writer.writerow([':START_ID', ':END_ID', 'relation', ':TYPE'])
|
||||
|
||||
for row in tqdm(reader):
|
||||
node_name = row[0]
|
||||
if node_name in node_to_concepts:
|
||||
|
||||
for concept in node_to_concepts[node_name]:
|
||||
concept_id = compute_hash_id(concept)
|
||||
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
|
||||
all_concepts.add(concept)
|
||||
|
||||
|
||||
for concept in parse_concepts(row[2]):
|
||||
concept_id = compute_hash_id(concept)
|
||||
writer.writerow([row[0], concept_id, 'has_concept', 'Concept'])
|
||||
all_concepts.add(concept)
|
||||
|
||||
# Read the concept nodes and write to output concept nodes file
|
||||
print("Processing concept nodes...")
|
||||
with open (output_node_file, 'w', newline='', encoding='utf-8') as f_out:
|
||||
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
||||
writer.writerow(['concept_id:ID', 'name', ':LABEL'])
|
||||
|
||||
for concept in tqdm(all_concepts):
|
||||
concept_id = compute_hash_id(concept)
|
||||
writer.writerow([concept_id, concept, 'Concept'])
|
||||
|
||||
|
||||
# Read triple edges and write to output full concept triple edges file
|
||||
print("Processing triple edges...")
|
||||
with open(edge_file, 'r', encoding='utf-8') as f:
|
||||
with open(output_full_concept_triple_edges, 'w', newline='', encoding='utf-8') as f_out:
|
||||
reader = csv.reader(f)
|
||||
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
||||
|
||||
|
||||
header = next(reader)
|
||||
writer.writerow([':START_ID', ':END_ID', 'relation', 'concepts', 'synsets', ':TYPE'])
|
||||
|
||||
for row in tqdm(reader):
|
||||
src_id = row[0]
|
||||
end_id = row[1]
|
||||
relation = row[2]
|
||||
concepts = row[3]
|
||||
synsets = row[4]
|
||||
|
||||
original_concepts = parse_concepts(concepts)
|
||||
|
||||
|
||||
if relation in relation_to_concepts:
|
||||
for concept in relation_to_concepts[relation]:
|
||||
if concept not in original_concepts:
|
||||
original_concepts.append(concept)
|
||||
original_concepts = list(set(original_concepts))
|
||||
|
||||
writer.writerow([src_id, end_id, relation, original_concepts, synsets, 'Relation'])
|
||||
return
|
||||
267
AIEC-RAG/atlas_rag/kg_construction/neo4j/neo4j_api.py
Normal file
267
AIEC-RAG/atlas_rag/kg_construction/neo4j/neo4j_api.py
Normal file
@ -0,0 +1,267 @@
|
||||
import time
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
from typing import List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from logging import Logger
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever, BaseLargeKGEdgeRetriever
|
||||
from atlas_rag.kg_construction.neo4j.utils import start_up_large_kg_index_graph
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from neo4j import Driver
|
||||
from dataclasses import dataclass
|
||||
import traceback
|
||||
|
||||
@dataclass
|
||||
class LargeKGConfig:
|
||||
largekg_retriever: BaseLargeKGRetriever | BaseLargeKGEdgeRetriever = None
|
||||
reader_llm_generator : LLMGenerator = None
|
||||
driver: Driver = None
|
||||
logger: Logger = None
|
||||
is_felm: bool = False
|
||||
is_mmlu: bool = False
|
||||
|
||||
rag_exemption_list = [
|
||||
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to assess whether each text segment contains errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: 8923164*7236571?\nSegments: \n1. The product of 8923164 and 7236571 is: 6,461,216,222,844\n2. So, 8923164 multiplied by 7236571 is equal to 6,461,216,222,844.\n\nBelow are your outputs:\nAnswer: 1,2\nIt means segment 1,2 contain errors.""",
|
||||
"""I will show you a question and a list of text segments. All the segments can be concatenated to form a complete answer to the question. Your task is to determine whether each text segment contains factual errors or not. \nPlease generate using the following format:\nAnswer: List the ids of the segments with errors (separated by commas). Please only output the ids, no more details. If all the segments are correct, output \"ALL_CORRECT\".\n\nHere is one example:\nQuestion: A company offers a 10% discount on all purchases over $100. A customer purchases three items, each costing $80. Does the customer qualify for the discount?\nSegments: \n1. To solve this problem, we need to use deductive reasoning. We know that the company offers a 10% discount on purchases over $100, so we need to calculate the total cost of the customer's purchase.\n2. The customer purchased three items, each costing $80, so the total cost of the purchase is: 3 x $80 = $200.\n3. Since the total cost of the purchase is greater than $100, the customer qualifies for the discount. \n4. To calculate the discounted price, we can multiply the total cost by 0.1 (which represents the 10% discount): $200 x 0.1 = $20.\n5. So the customer is eligible for a discount of $20, and the final cost of the purchase would be: $200 - $20 = $180.\n6. Therefore, the customer would pay a total of $216 for the three items with the discount applied.\n\nBelow are your outputs:\nAnswer: 2,3,4,5,6\nIt means segment 2,3,4,5,6 contains errors.""",
|
||||
]
|
||||
|
||||
mmlu_check_list = [
|
||||
"""Given the following question and four candidate answers (A, B, C and D), choose the answer."""
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global large_kg_config
|
||||
start_up_large_kg_index_graph(large_kg_config.driver)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
global large_kg_config
|
||||
print("Shutting down the model...")
|
||||
del large_kg_config
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "test"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "system", "assistant"]
|
||||
content: str = None
|
||||
name: Optional[str] = None
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.8
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
tools: Optional[Union[dict, List[dict]]] = None
|
||||
repetition_penalty: Optional[float] = 1.1
|
||||
retriever_config: Optional[dict] = {
|
||||
"topN": 5,
|
||||
"number_of_source_nodes_per_ner": 10,
|
||||
"sampling_area": 250,
|
||||
"Dmax": 2,
|
||||
"Wmax": 3
|
||||
}
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: Literal["stop", "length", "function_call"]
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: Optional[Literal["stop", "length"]]
|
||||
index: int
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: str
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global large_kg_config
|
||||
try:
|
||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
print(request)
|
||||
# raise HTTPException(status_code=400, detail="Invalid request")
|
||||
if large_kg_config.logger is not None:
|
||||
large_kg_config.logger.info(f"Request: {request}")
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=0.8,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=False,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
)
|
||||
|
||||
last_message = request.messages[-1]
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
|
||||
|
||||
is_exemption = any(exemption in question for exemption in LargeKGConfig.rag_exemption_list)
|
||||
is_mmlu = any(exemption in question for exemption in LargeKGConfig.mmlu_check_list)
|
||||
print(f"Is exemption: {is_exemption}, Is MMLU: {is_mmlu}")
|
||||
if is_mmlu:
|
||||
rag_text = question
|
||||
else:
|
||||
parts = question.rsplit("Question:", 1)
|
||||
rag_text = parts[-1] if len(parts) > 1 else None
|
||||
print(f"RAG text: {rag_text}")
|
||||
if not is_exemption:
|
||||
passages, passages_score = large_kg_config.largekg_retriever.retrieve_passages(rag_text)
|
||||
context = "No retrieved context, Please answer the question with your own knowledge." if not passages else "\n".join([f"Passage {i+1}: {text}" for i, text in enumerate(reversed(passages))])
|
||||
|
||||
if is_mmlu:
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Here is the context: {context} \n\n
|
||||
If the context is not useful, you can answer the question with your own knowledge. \n {question}\nThink step by step. Your response should end with 'The answer is ([the_answer_letter])' where the [the_answer_letter] is one of A, B, C and D."""
|
||||
}
|
||||
]
|
||||
elif not is_exemption:
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{question} Reference doc: {context}"""
|
||||
}
|
||||
]
|
||||
else:
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{question} """
|
||||
}
|
||||
]
|
||||
if large_kg_config.logger is not None:
|
||||
large_kg_config.logger.info(rag_chat_content)
|
||||
|
||||
response = large_kg_config.reader_llm_generator.generate_response(
|
||||
batch_messages=rag_chat_content,
|
||||
max_new_tokens=gen_params["max_tokens"],
|
||||
temperature=gen_params["temperature"],
|
||||
frequency_penalty = 1.1
|
||||
)
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=response.strip()
|
||||
)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop"
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="",
|
||||
object="chat.completion",
|
||||
choices=[choice_data]
|
||||
)
|
||||
except Exception as e:
|
||||
print("ERROR: ", e)
|
||||
print("Catched error")
|
||||
traceback.print_exc()
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=0.8,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens or 1024,
|
||||
echo=False,
|
||||
stream=False,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
tools=request.tools,
|
||||
)
|
||||
last_message = request.messages[-1]
|
||||
system_prompt = 'You are a helpful assistant.'
|
||||
question = last_message.content if last_message.role == 'user' else request.messages[-2].content
|
||||
rag_chat_content = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"{system_prompt}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""{question} """
|
||||
}
|
||||
]
|
||||
response = large_kg_config.reader_llm_generator.generate_response(
|
||||
batch_messages=rag_chat_content,
|
||||
max_new_tokens=gen_params["max_tokens"],
|
||||
temperature=gen_params["temperature"],
|
||||
frequency_penalty = 1.1
|
||||
)
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=response.strip()
|
||||
)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason="stop"
|
||||
)
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="",
|
||||
object="chat.completion",
|
||||
choices=[choice_data]
|
||||
)
|
||||
|
||||
|
||||
def start_app(user_config:LargeKGConfig, host="0.0.0.0", port=10090, reload=False):
|
||||
"""Function to start the FastAPI application."""
|
||||
global large_kg_config
|
||||
large_kg_config = user_config # Use the passed context if provided
|
||||
|
||||
uvicorn.run(f"atlas_rag.kg_construction.neo4j.neo4j_api:app", host=host, port=port, reload=reload)
|
||||
|
||||
73
AIEC-RAG/atlas_rag/kg_construction/neo4j/utils.py
Normal file
73
AIEC-RAG/atlas_rag/kg_construction/neo4j/utils.py
Normal file
@ -0,0 +1,73 @@
|
||||
import faiss
|
||||
from neo4j import Driver
|
||||
import time
|
||||
from graphdatascience import GraphDataScience
|
||||
from atlas_rag.retriever.lkg_retriever.base import BaseLargeKGRetriever
|
||||
|
||||
def build_projection_graph(driver: GraphDataScience):
|
||||
project_graph_1 = "largekgrag_graph"
|
||||
is_project_graph_1_exist = False
|
||||
# is_project_graph_2_exist = False
|
||||
result = driver.graph.list()
|
||||
for index, row in result.iterrows():
|
||||
if row['graphName'] == project_graph_1:
|
||||
is_project_graph_1_exist = True
|
||||
# if row['graphName'] == project_graph_2:
|
||||
# is_project_graph_2_exist = True
|
||||
|
||||
if not is_project_graph_1_exist:
|
||||
start_time = time.time()
|
||||
node_properties = ["Node"]
|
||||
relation_projection = [ "Relation"]
|
||||
result = driver.graph.project(
|
||||
project_graph_1,
|
||||
node_properties,
|
||||
relation_projection
|
||||
)
|
||||
graph = driver.graph.get(project_graph_1)
|
||||
print(f"Projection graph {project_graph_1} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
def build_neo4j_label_index(driver: GraphDataScience):
|
||||
with driver.session() as session:
|
||||
index_name = f"NodeNumericIDIndex"
|
||||
# Check if the index already exists
|
||||
existing_indexes = session.run("SHOW INDEXES").data()
|
||||
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
||||
# Drop the index if it exists
|
||||
if not index_exists:
|
||||
start_time = time.time()
|
||||
session.run(f"CREATE INDEX {index_name} FOR (n:Node) ON (n.numeric_id)")
|
||||
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
index_name = f"TextNumericIDIndex"
|
||||
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
||||
if not index_exists:
|
||||
start_time = time.time()
|
||||
session.run(f"CREATE INDEX {index_name} FOR (t:Text) ON (t.numeric_id)")
|
||||
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
index_name = f"EntityEventEdgeNumericIDIndex"
|
||||
index_exists = any(index['name'] == index_name for index in existing_indexes)
|
||||
if not index_exists:
|
||||
start_time = time.time()
|
||||
session.run(f"CREATE INDEX {index_name} FOR ()-[r:Relation]-() on (r.numeric_id)")
|
||||
print(f"Index {index_name} created in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
def load_indexes(path_dict):
|
||||
for key, value in path_dict.items():
|
||||
if key == 'node':
|
||||
node_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
||||
print(f"Node index loaded from {value}")
|
||||
elif key == 'edge':
|
||||
edge_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
||||
print(f"Edge index loaded from {value}")
|
||||
elif key == 'text':
|
||||
passage_index = faiss.read_index(value, faiss.IO_FLAG_MMAP)
|
||||
print(f"Passage index loaded from {value}")
|
||||
return node_index, edge_index, passage_index
|
||||
|
||||
def start_up_large_kg_index_graph(neo4j_driver: Driver)->BaseLargeKGRetriever:
|
||||
gds_driver = GraphDataScience(neo4j_driver)
|
||||
# build label index and projection graph
|
||||
build_neo4j_label_index(neo4j_driver)
|
||||
build_projection_graph(gds_driver)
|
||||
22
AIEC-RAG/atlas_rag/kg_construction/triple_config.py
Normal file
22
AIEC-RAG/atlas_rag/kg_construction/triple_config.py
Normal file
@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
"""Configuration for text processing pipeline."""
|
||||
model_path: str
|
||||
data_directory: str
|
||||
filename_pattern: str
|
||||
batch_size_triple: int = 16
|
||||
batch_size_concept: int = 64
|
||||
output_directory: str = "./generation_result_debug"
|
||||
total_shards_triple: int = 1
|
||||
current_shard_triple: int = 0
|
||||
total_shards_concept: int = 1
|
||||
current_shard_concept: int = 0
|
||||
use_8bit: bool = False
|
||||
debug_mode: bool = False
|
||||
resume_from: int = 0
|
||||
record : bool = False
|
||||
max_new_tokens: int = 8192
|
||||
max_workers: int = 8
|
||||
remove_doc_spaces: bool = False
|
||||
497
AIEC-RAG/atlas_rag/kg_construction/triple_extraction.py
Normal file
497
AIEC-RAG/atlas_rag/kg_construction/triple_extraction.py
Normal file
@ -0,0 +1,497 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Knowledge Graph Extraction Pipeline
|
||||
Extracts entities, relations, and events from text data using transformer models.
|
||||
"""
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
import json_repair
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.kg_construction.utils.json_processing.json_to_csv import json2csv
|
||||
from atlas_rag.kg_construction.concept_generation import generate_concept
|
||||
from atlas_rag.kg_construction.utils.csv_processing.merge_csv import merge_csv_files
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import csvs_to_graphml, csvs_to_temp_graphml
|
||||
from atlas_rag.kg_construction.concept_to_csv import all_concept_triples_csv_to_csv
|
||||
from atlas_rag.kg_construction.utils.csv_processing.csv_add_numeric_id import add_csv_columns
|
||||
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||||
from atlas_rag.vectorstore.create_neo4j_index import create_faiss_index
|
||||
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import TRIPLE_INSTRUCTIONS
|
||||
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||||
# Constants
|
||||
TOKEN_LIMIT = 1024
|
||||
INSTRUCTION_TOKEN_ESTIMATE = 200
|
||||
CHAR_TO_TOKEN_RATIO = 3.5
|
||||
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Handles text chunking based on token limits."""
|
||||
|
||||
def __init__(self, max_tokens: int = TOKEN_LIMIT, instruction_tokens: int = INSTRUCTION_TOKEN_ESTIMATE):
|
||||
self.max_tokens = max_tokens
|
||||
self.instruction_tokens = instruction_tokens
|
||||
self.char_ratio = CHAR_TO_TOKEN_RATIO
|
||||
|
||||
def calculate_max_chars(self) -> int:
|
||||
"""Calculate maximum characters per chunk."""
|
||||
available_tokens = self.max_tokens - self.instruction_tokens
|
||||
return int(available_tokens * self.char_ratio)
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split text into chunks that fit within token limits."""
|
||||
max_chars = self.calculate_max_chars()
|
||||
chunks = []
|
||||
|
||||
while len(text) > max_chars:
|
||||
chunks.append(text[:max_chars])
|
||||
text = text[max_chars:]
|
||||
|
||||
if text: # Add remaining text
|
||||
chunks.append(text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
class DatasetProcessor:
|
||||
"""Processes and prepares dataset for knowledge graph extraction."""
|
||||
|
||||
def __init__(self, config: ProcessingConfig):
|
||||
self.config = config
|
||||
self.chunker = TextChunker()
|
||||
|
||||
def filter_language_content(self, sample: Dict[str, Any]) -> bool:
|
||||
"""Check if content is in English."""
|
||||
metadata = sample.get("metadata", {})
|
||||
language = metadata.get("lang", "en") # Default to English if not specified
|
||||
supported_languages = list(TRIPLE_INSTRUCTIONS.keys())
|
||||
return language in supported_languages
|
||||
|
||||
|
||||
def create_sample_chunks(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Create chunks from a single sample."""
|
||||
original_text = sample.get("text", "")
|
||||
if self.config.remove_doc_spaces:
|
||||
original_text = re.sub(r'\s+', ' ',original_text).strip()
|
||||
text_chunks = self.chunker.split_text(original_text)
|
||||
chunks = []
|
||||
|
||||
for chunk_idx, chunk_text in enumerate(text_chunks):
|
||||
chunk_data = {
|
||||
"id": sample["id"],
|
||||
"text": chunk_text,
|
||||
"chunk_id": chunk_idx,
|
||||
"metadata": sample["metadata"]
|
||||
}
|
||||
chunks.append(chunk_data)
|
||||
|
||||
return chunks
|
||||
|
||||
def prepare_dataset(self, raw_dataset) -> List[Dict[str, Any]]:
|
||||
"""Process raw dataset into chunks suitable for processing with generalized slicing."""
|
||||
|
||||
processed_samples = []
|
||||
total_texts = len(raw_dataset)
|
||||
|
||||
# Handle edge cases
|
||||
if total_texts == 0:
|
||||
print(f"No texts found for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
|
||||
return processed_samples
|
||||
|
||||
# Calculate base and remainder for fair distribution
|
||||
base_texts_per_shard = total_texts // self.config.total_shards_triple
|
||||
remainder = total_texts % self.config.total_shards_triple
|
||||
|
||||
# Calculate start index
|
||||
if self.config.current_shard_triple < remainder:
|
||||
start_idx = self.config.current_shard_triple * (base_texts_per_shard + 1)
|
||||
else:
|
||||
start_idx = remainder * (base_texts_per_shard + 1) + (self.config.current_shard_triple - remainder) * base_texts_per_shard
|
||||
|
||||
# Calculate end index
|
||||
if self.config.current_shard_triple < remainder:
|
||||
end_idx = start_idx + (base_texts_per_shard + 1)
|
||||
else:
|
||||
end_idx = start_idx + base_texts_per_shard
|
||||
|
||||
# Ensure indices are within bounds
|
||||
start_idx = min(start_idx, total_texts)
|
||||
end_idx = min(end_idx, total_texts)
|
||||
|
||||
print(f"Processing shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple} "
|
||||
f"(texts {start_idx}-{end_idx-1} of {total_texts}, {end_idx - start_idx} documents)")
|
||||
|
||||
# Process documents in assigned shard
|
||||
for idx in range(start_idx, end_idx):
|
||||
sample = raw_dataset[idx]
|
||||
|
||||
# Filter by language
|
||||
if not self.filter_language_content(sample):
|
||||
print(f"Unsupported language in sample {idx}, skipping.")
|
||||
continue
|
||||
|
||||
# Create chunks
|
||||
chunks = self.create_sample_chunks(sample)
|
||||
processed_samples.extend(chunks)
|
||||
|
||||
# Debug mode early termination
|
||||
if self.config.debug_mode and len(processed_samples) >= 20:
|
||||
print("Debug mode: Stopping at 20 chunks")
|
||||
break
|
||||
|
||||
print(f"Generated {len(processed_samples)} chunks for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
|
||||
return processed_samples
|
||||
|
||||
|
||||
class CustomDataLoader:
|
||||
"""Custom data loader for knowledge graph extraction."""
|
||||
|
||||
def __init__(self, dataset, processor: DatasetProcessor):
|
||||
self.raw_dataset = dataset
|
||||
self.processor = processor
|
||||
self.processed_data = processor.prepare_dataset(dataset)
|
||||
self.stage_to_prompt_dict = {
|
||||
"stage_1": "entity_relation",
|
||||
"stage_2": "event_entity",
|
||||
"stage_3": "event_relation"
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.processed_data)
|
||||
|
||||
def create_batch_instructions(self, batch_data: List[Dict[str, Any]]) -> List[str]:
|
||||
messages_dict = {
|
||||
'stage_1': [],
|
||||
'stage_2': [],
|
||||
'stage_3': []
|
||||
}
|
||||
for item in batch_data:
|
||||
# get language
|
||||
language = item.get("metadata",{}).get("lang", "en")
|
||||
system_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['system']
|
||||
stage_1_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['entity_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n' + item["text"]
|
||||
stage_2_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_entity'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
|
||||
stage_3_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
|
||||
stage_one_message = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": stage_1_msg}
|
||||
]
|
||||
stage_two_message = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": stage_2_msg}
|
||||
]
|
||||
stage_three_message = [
|
||||
{"role": "system", "content": system_msg},
|
||||
{"role": "user", "content": stage_3_msg}
|
||||
]
|
||||
messages_dict['stage_1'].append(stage_one_message)
|
||||
messages_dict['stage_2'].append(stage_two_message)
|
||||
messages_dict['stage_3'].append(stage_three_message)
|
||||
|
||||
return messages_dict
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate through batches."""
|
||||
batch_size = self.processor.config.batch_size_triple
|
||||
start_idx = self.processor.config.resume_from * batch_size
|
||||
|
||||
for i in tqdm(range(start_idx, len(self.processed_data), batch_size)):
|
||||
batch_data = self.processed_data[i:i + batch_size]
|
||||
|
||||
# Prepare instructions
|
||||
instructions = self.create_batch_instructions(batch_data)
|
||||
|
||||
# Extract batch information
|
||||
batch_ids = [item["id"] for item in batch_data]
|
||||
batch_metadata = [item["metadata"] for item in batch_data]
|
||||
batch_texts = [item["text"] for item in batch_data]
|
||||
|
||||
yield instructions, batch_ids, batch_texts, batch_metadata
|
||||
|
||||
|
||||
class OutputParser:
|
||||
"""Parses model outputs and extracts structured data."""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def extract_structured_data(self, outputs: List[str]) -> List[List[Dict[str, Any]]]:
|
||||
"""Extract structured data from model outputs."""
|
||||
results = []
|
||||
for output in outputs:
|
||||
parsed_data = json_repair.loads(output)
|
||||
results.append(parsed_data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class KnowledgeGraphExtractor:
|
||||
"""Main class for knowledge graph extraction pipeline."""
|
||||
|
||||
def __init__(self, model:LLMGenerator, config: ProcessingConfig):
|
||||
self.config = config
|
||||
self.model = None
|
||||
self.parser = None
|
||||
self.model = model
|
||||
self.model_name = model.model_name
|
||||
self.parser = OutputParser()
|
||||
|
||||
def load_dataset(self) -> Any:
|
||||
"""Load and prepare dataset."""
|
||||
data_path = Path(self.config.data_directory)
|
||||
all_files = os.listdir(data_path)
|
||||
|
||||
valid_files = [
|
||||
filename for filename in all_files
|
||||
if filename.startswith(self.config.filename_pattern) and
|
||||
(filename.endswith(".json.gz") or filename.endswith(".json") or filename.endswith(".jsonl") or filename.endswith(".jsonl.gz"))
|
||||
]
|
||||
|
||||
print(f"Found data files: {valid_files}")
|
||||
data_files = valid_files
|
||||
dataset_config = {"train": data_files}
|
||||
return load_dataset(self.config.data_directory, data_files=dataset_config["train"])
|
||||
|
||||
def process_stage(self, instructions: Dict[str, str], stage = 1) -> Tuple[List[str], List[List[Dict[str, Any]]]]:
|
||||
"""Process first stage: entity-relation extraction."""
|
||||
outputs = self.model.triple_extraction(messages=instructions, max_tokens=self.config.max_new_tokens, stage=stage, record=self.config.record)
|
||||
if self.config.record:
|
||||
text_outputs = [output[0] for output in outputs]
|
||||
else:
|
||||
text_outputs = outputs
|
||||
structured_data = self.parser.extract_structured_data(text_outputs)
|
||||
return outputs, structured_data
|
||||
|
||||
def create_output_filename(self) -> str:
|
||||
"""Create output filename with timestamp and shard info."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
model_name_safe = self.config.model_path.replace("/", "_")
|
||||
|
||||
filename = (f"{model_name_safe}_{self.config.filename_pattern}_output_"
|
||||
f"{timestamp}_{self.config.current_shard_triple + 1}_in_{self.config.total_shards_triple}.json")
|
||||
|
||||
extraction_dir = os.path.join(self.config.output_directory, "kg_extraction")
|
||||
os.makedirs(extraction_dir, exist_ok=True)
|
||||
|
||||
return os.path.join(extraction_dir, filename)
|
||||
|
||||
def prepare_result_dict(self, batch_data: Tuple, stage_outputs: Tuple, index: int) -> Dict[str, Any]:
|
||||
"""Prepare result dictionary for a single sample."""
|
||||
ids, original_texts, metadata = batch_data
|
||||
(stage1_results, entity_relations), (stage2_results, event_entities), (stage3_results, event_relations) = stage_outputs
|
||||
if self.config.record:
|
||||
stage1_outputs = [output[0] for output in stage1_results]
|
||||
stage1_usage = [output[1] for output in stage1_results]
|
||||
stage2_outputs = [output[0] for output in stage2_results]
|
||||
stage2_usage = [output[1] for output in stage2_results]
|
||||
stage3_outputs = [output[0] for output in stage3_results]
|
||||
stage3_usage = [output[1] for output in stage3_results]
|
||||
else:
|
||||
stage1_outputs = stage1_results
|
||||
stage2_outputs = stage2_results
|
||||
stage3_outputs = stage3_results
|
||||
result = {
|
||||
"id": ids[index],
|
||||
"metadata": metadata[index],
|
||||
"original_text": original_texts[index],
|
||||
"entity_relation_dict": entity_relations[index],
|
||||
"event_entity_relation_dict": event_entities[index],
|
||||
"event_relation_dict": event_relations[index],
|
||||
"output_stage_one": stage1_outputs[index],
|
||||
"output_stage_two": stage2_outputs[index],
|
||||
"output_stage_three": stage3_outputs[index],
|
||||
}
|
||||
if self.config.record:
|
||||
result['usage_stage_one'] = stage1_usage[index]
|
||||
result['usage_stage_two'] = stage2_usage[index]
|
||||
result['usage_stage_three'] = stage3_usage[index]
|
||||
|
||||
# Handle date serialization
|
||||
if 'date_download' in result['metadata']:
|
||||
result['metadata']['date_download'] = str(result['metadata']['date_download'])
|
||||
|
||||
return result
|
||||
|
||||
def debug_print_result(self, result: Dict[str, Any]):
|
||||
"""Print result for debugging."""
|
||||
for key, value in result.items():
|
||||
print(f"{key}: {value}")
|
||||
print("-" * 100)
|
||||
|
||||
def run_extraction(self):
|
||||
"""Run the complete knowledge graph extraction pipeline."""
|
||||
# Setup
|
||||
os.makedirs(self.config.output_directory+'/kg_extraction', exist_ok=True)
|
||||
dataset = self.load_dataset()
|
||||
|
||||
if self.config.debug_mode:
|
||||
print("Debug mode: Processing only 20 samples")
|
||||
|
||||
# Create data processor and loader
|
||||
processor = DatasetProcessor(self.config)
|
||||
data_loader = CustomDataLoader(dataset["train"], processor)
|
||||
|
||||
output_file = self.create_output_filename()
|
||||
print(f"Model: {self.config.model_path}")
|
||||
|
||||
batch_counter = 0
|
||||
|
||||
with torch.no_grad():
|
||||
with open(output_file, "w") as output_stream:
|
||||
for batch in data_loader:
|
||||
batch_counter += 1
|
||||
messages_dict, batch_ids, batch_texts, batch_metadata = batch
|
||||
|
||||
# Process all three stages
|
||||
stage1_results = self.process_stage(messages_dict['stage_1'],1)
|
||||
stage2_results = self.process_stage(messages_dict['stage_2'],2)
|
||||
stage3_results = self.process_stage(messages_dict['stage_3'],3)
|
||||
|
||||
# Combine results
|
||||
batch_data = (batch_ids, batch_texts, batch_metadata)
|
||||
stage_outputs = (stage1_results, stage2_results, stage3_results)
|
||||
|
||||
# Write results
|
||||
print(f"Processed {batch_counter} batches ({batch_counter * self.config.batch_size_triple} chunks)")
|
||||
for i in range(len(batch_ids)):
|
||||
result = self.prepare_result_dict(batch_data, stage_outputs, i)
|
||||
|
||||
if self.config.debug_mode:
|
||||
self.debug_print_result(result)
|
||||
|
||||
output_stream.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
output_stream.flush()
|
||||
|
||||
def convert_json_to_csv(self):
|
||||
json2csv(
|
||||
dataset = self.config.filename_pattern,
|
||||
output_dir=f"{self.config.output_directory}/triples_csv",
|
||||
data_dir=f"{self.config.output_directory}/kg_extraction"
|
||||
)
|
||||
csvs_to_temp_graphml(
|
||||
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
triple_edge_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
config = self.config
|
||||
)
|
||||
|
||||
def generate_concept_csv_temp(self, batch_size: int = None, **kwargs):
|
||||
generate_concept(
|
||||
model=self.model,
|
||||
input_file=f"{self.config.output_directory}/triples_csv/missing_concepts_{self.config.filename_pattern}_from_json.csv",
|
||||
output_folder=f"{self.config.output_directory}/concepts",
|
||||
output_file="concept.json",
|
||||
logging_file=f"{self.config.output_directory}/concepts/logging.txt",
|
||||
config=self.config,
|
||||
batch_size=batch_size if batch_size else self.config.batch_size_concept,
|
||||
shard=self.config.current_shard_concept,
|
||||
num_shards=self.config.total_shards_concept,
|
||||
record = self.config.record,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def create_concept_csv(self):
|
||||
merge_csv_files(
|
||||
output_file=f"{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
input_dir=f"{self.config.output_directory}/concepts",
|
||||
)
|
||||
all_concept_triples_csv_to_csv(
|
||||
node_file=f'{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv',
|
||||
edge_file=f'{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv',
|
||||
concepts_file=f'{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
output_node_file=f'{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
output_edge_file=f'{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
output_full_concept_triple_edges=f'{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||||
)
|
||||
|
||||
def convert_to_graphml(self):
|
||||
csvs_to_graphml(
|
||||
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
text_node_file=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||||
concept_node_file=f"{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
triple_edge_file=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
text_edge_file=f"{self.config.output_directory}/triples_csv/text_edges_{self.config.filename_pattern}_from_json.csv",
|
||||
concept_edge_file=f"{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
output_file=f"{self.config.output_directory}/kg_graphml/{self.config.filename_pattern}_graph.graphml",
|
||||
)
|
||||
|
||||
def add_numeric_id(self):
|
||||
add_csv_columns(
|
||||
node_csv=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
edge_csv=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
text_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||||
node_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
|
||||
edge_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
|
||||
text_with_numeric_id=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_numeric_id.csv",
|
||||
)
|
||||
|
||||
def compute_kg_embedding(self, encoder_model:BaseEmbeddingModel, batch_size: int = 2048):
|
||||
encoder_model.compute_kg_embedding(
|
||||
node_csv_without_emb=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||||
node_csv_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
|
||||
edge_csv_without_emb=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||||
edge_csv_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept_with_emb.csv",
|
||||
text_node_csv_without_emb=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||||
text_node_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
|
||||
batch_size = 2048
|
||||
)
|
||||
|
||||
def create_faiss_index(self, index_type="HNSW,Flat"):
|
||||
create_faiss_index(self.config.output_directory, self.config.filename_pattern, index_type)
|
||||
|
||||
def parse_command_line_arguments() -> ProcessingConfig:
|
||||
"""Parse command line arguments and return configuration."""
|
||||
parser = argparse.ArgumentParser(description="Knowledge Graph Extraction Pipeline")
|
||||
|
||||
parser.add_argument("-m", "--model", type=str, required=True,
|
||||
default="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
help="Model path for knowledge extraction")
|
||||
parser.add_argument("--data_dir", type=str, default="your_data_dir",
|
||||
help="Directory containing input data")
|
||||
parser.add_argument("--file_name", type=str, default="en_simple_wiki_v0",
|
||||
help="Filename pattern to match")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=16,
|
||||
help="Batch size for processing")
|
||||
parser.add_argument("--output_dir", type=str, default="./generation_result_debug",
|
||||
help="Output directory for results")
|
||||
parser.add_argument("--total_shards_triple", type=int, default=1,
|
||||
help="Total number of data shards")
|
||||
parser.add_argument("--shard", type=int, default=0,
|
||||
help="Current shard index")
|
||||
parser.add_argument("--bit8", action="store_true",
|
||||
help="Use 8-bit quantization")
|
||||
parser.add_argument("--debug", action="store_true",
|
||||
help="Enable debug mode")
|
||||
parser.add_argument("--resume", type=int, default=0,
|
||||
help="Resume from specific batch")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return ProcessingConfig(
|
||||
model_path=args.model,
|
||||
data_directory=args.data_dir,
|
||||
filename_pattern=args.file_name,
|
||||
batch_size=args.batch_size,
|
||||
output_directory=args.output_dir,
|
||||
total_shards_triple=args.total_shards_triple,
|
||||
current_shard_triple=args.shard,
|
||||
use_8bit=args.bit8,
|
||||
debug_mode=args.debug,
|
||||
resume_from=args.resume
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the knowledge graph extraction pipeline."""
|
||||
config = parse_command_line_arguments()
|
||||
extractor = KnowledgeGraphExtractor(config)
|
||||
extractor.run_extraction()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,160 @@
|
||||
import csv
|
||||
from tqdm import tqdm
|
||||
|
||||
def check_created_csv_header(keyword, csv_dir):
|
||||
keyword_to_paths ={
|
||||
'cc_en':{
|
||||
'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv",
|
||||
'concept_with_numeric_id': f"{csv_dir}/concept_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
|
||||
},
|
||||
'pes2o_abstract':{
|
||||
'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
|
||||
'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv",
|
||||
},
|
||||
'en_simple_wiki_v0':{
|
||||
'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
|
||||
'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv",
|
||||
},
|
||||
}
|
||||
for key, path in keyword_to_paths[keyword].items():
|
||||
with open(path) as infile:
|
||||
reader = csv.reader(infile)
|
||||
header = next(reader)
|
||||
print(f"Header of {key}: {header}")
|
||||
|
||||
# print first 5 rows
|
||||
for i, row in enumerate(reader):
|
||||
if i < 1:
|
||||
print(row)
|
||||
else:
|
||||
break
|
||||
|
||||
def add_csv_columns(node_csv, edge_csv, text_csv, node_with_numeric_id, edge_with_numeric_id, text_with_numeric_id):
|
||||
with open(node_csv) as infile, open(node_with_numeric_id, 'w', newline='') as outfile:
|
||||
reader = csv.reader(infile)
|
||||
writer = csv.writer(outfile)
|
||||
header = next(reader)
|
||||
print(header)
|
||||
label_index = header.index(':LABEL')
|
||||
header.insert(label_index, 'numeric_id') # Add new column name
|
||||
writer.writerow(header)
|
||||
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
writer.writerow(row)
|
||||
with open(edge_csv) as infile, open(edge_with_numeric_id, 'w', newline='') as outfile:
|
||||
reader = csv.reader(infile)
|
||||
writer = csv.writer(outfile)
|
||||
header = next(reader)
|
||||
print(header)
|
||||
label_index = header.index(':TYPE')
|
||||
header.insert(label_index, 'numeric_id') # Add new column name
|
||||
writer.writerow(header)
|
||||
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
writer.writerow(row)
|
||||
with open(text_csv) as infile, open(text_with_numeric_id, 'w', newline='') as outfile:
|
||||
reader = csv.reader(infile)
|
||||
writer = csv.writer(outfile)
|
||||
header = next(reader)
|
||||
print(header)
|
||||
label_index = header.index(':LABEL')
|
||||
header.insert(label_index, 'numeric_id') # Add new column name
|
||||
writer.writerow(header)
|
||||
for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
# def add_csv_columns(keyword, csv_dir):
|
||||
# keyword_to_paths ={
|
||||
# 'cc_en':{
|
||||
# 'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb.csv",
|
||||
# 'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json.csv",
|
||||
|
||||
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_cc_en_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_cc_en_from_json_with_numeric_id.csv"
|
||||
# },
|
||||
# 'pes2o_abstract':{
|
||||
# 'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept.csv",
|
||||
# 'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json.csv",
|
||||
|
||||
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_pes2o_abstract_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_numeric_id.csv"
|
||||
# },
|
||||
# 'en_simple_wiki_v0':{
|
||||
# 'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept.csv",
|
||||
# 'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json.csv",
|
||||
|
||||
# 'node_with_numeric_id': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json_without_emb_with_numeric_id.csv",
|
||||
# 'edge_with_numeric_id': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json_without_emb_full_concept_with_numeric_id.csv",
|
||||
# 'text_with_numeric_id': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_numeric_id.csv"
|
||||
# },
|
||||
# }
|
||||
# # ouput node
|
||||
# with open(keyword_to_paths[keyword]['node_csv']) as infile, open(keyword_to_paths[keyword]['node_with_numeric_id'], 'w') as outfile:
|
||||
# reader = csv.reader(infile)
|
||||
# writer = csv.writer(outfile)
|
||||
|
||||
# # Read the header
|
||||
# header = next(reader)
|
||||
# print(header)
|
||||
# # Insert 'numeric_id' before ':LABEL'
|
||||
# label_index = header.index(':LABEL')
|
||||
# header.insert(label_index, 'numeric_id') # Add new column name
|
||||
# writer.writerow(header)
|
||||
|
||||
# # Process each row and add a numeric ID
|
||||
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
# writer.writerow(row)
|
||||
|
||||
# # output edge (TYPE instead of LABEL for edge)
|
||||
# with open(keyword_to_paths[keyword]['edge_csv']) as infile, open(keyword_to_paths[keyword]['edge_with_numeric_id'], 'w') as outfile:
|
||||
# reader = csv.reader(infile)
|
||||
# writer = csv.writer(outfile)
|
||||
|
||||
# # Read the header
|
||||
# header = next(reader)
|
||||
# print(header)
|
||||
# # Insert 'numeric_id' before ':TYPE'
|
||||
# label_index = header.index(':TYPE')
|
||||
# header.insert(label_index, 'numeric_id') # Add new column name
|
||||
# writer.writerow(header)
|
||||
|
||||
# # Process each row and add a numeric ID
|
||||
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
# writer.writerow(row)
|
||||
|
||||
# # output text
|
||||
# with open(keyword_to_paths[keyword]['text_csv']) as infile, open(keyword_to_paths[keyword]['text_with_numeric_id'], 'w') as outfile:
|
||||
# reader = csv.reader(infile)
|
||||
# writer = csv.writer(outfile)
|
||||
|
||||
# # Read the header
|
||||
# header = next(reader)
|
||||
# print(header)
|
||||
# # Insert 'numeric_id' before ':LABEL'
|
||||
# label_index = header.index(':LABEL')
|
||||
# header.insert(label_index, 'numeric_id') # Add new column name
|
||||
# writer.writerow(header)
|
||||
|
||||
# # Process each row and add a numeric ID
|
||||
# for row_number, row in tqdm(enumerate(reader), desc="Adding numeric ID"):
|
||||
# row.insert(label_index, row_number) # Add numeric ID before ':LABEL'
|
||||
# writer.writerow(row)
|
||||
|
||||
if __name__ == "__main__":
|
||||
keyword = "en_simple_wiki_v0"
|
||||
csv_dir = "./import" # Change this to your CSV directory
|
||||
add_csv_columns(keyword, csv_dir)
|
||||
# check_created_csv_header(keyword)
|
||||
@ -0,0 +1,189 @@
|
||||
import networkx as nx
|
||||
import csv
|
||||
import ast
|
||||
import hashlib
|
||||
import os
|
||||
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||||
import pickle
|
||||
|
||||
def get_node_id(entity_name, entity_to_id={}):
|
||||
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
|
||||
if entity_name not in entity_to_id:
|
||||
# Use a hash function to generate a unique ID
|
||||
hash_object = hashlib.sha256(entity_name.encode('utf-8'))
|
||||
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
|
||||
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
|
||||
entity_to_id[entity_name] = hash_hex
|
||||
return entity_to_id[entity_name]
|
||||
|
||||
def csvs_to_temp_graphml(triple_node_file, triple_edge_file, config:ProcessingConfig=None):
|
||||
g = nx.DiGraph()
|
||||
entity_to_id = {}
|
||||
|
||||
# Add triple nodes
|
||||
with open(triple_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["name:ID"]
|
||||
mapped_id = get_node_id(node_id, entity_to_id)
|
||||
if mapped_id not in g.nodes:
|
||||
g.add_node(mapped_id, id=node_id, type=row["type"])
|
||||
|
||||
|
||||
# Add triple edges
|
||||
with open(triple_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = get_node_id(row[":END_ID"], entity_to_id)
|
||||
# Check if edge already exists to prevent duplicates
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
|
||||
|
||||
# save graph to
|
||||
output_name = f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl"
|
||||
# check if output file directory exists
|
||||
output_dir = os.path.dirname(output_name)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# store the graph to a pickle file
|
||||
with open(output_name, 'wb') as output_file:
|
||||
pickle.dump(g, output_file)
|
||||
|
||||
|
||||
|
||||
def csvs_to_graphml(triple_node_file, text_node_file, concept_node_file,
|
||||
triple_edge_file, text_edge_file, concept_edge_file,
|
||||
output_file):
|
||||
'''
|
||||
Convert multiple CSV files into a single GraphML file.
|
||||
|
||||
Types of nodes to be added to the graph:
|
||||
- Triple nodes: Nodes representing triples, with properties like subject, predicate, object.
|
||||
- Text nodes: Nodes representing text, with properties like text content.
|
||||
- Concept nodes: Nodes representing concepts, with properties like concept name and type.
|
||||
|
||||
Types of edges to be added to the graph:
|
||||
- Triple edges: Edges representing relationships between triples, with properties like relation type.
|
||||
- Text edges: Edges representing relationships between text and nodes, with properties like text type.
|
||||
- Concept edges: Edges representing relationships between concepts and nodes, with properties like concept type.
|
||||
|
||||
DiGraph networkx attributes:
|
||||
Node:
|
||||
- type: Type of the node (e.g., entity, event, text, concept).
|
||||
- file_id: List of text IDs the node is associated with.
|
||||
- id: Node Name
|
||||
Edge:
|
||||
- relation: relation name
|
||||
- file_id: List of text IDs the edge is associated with.
|
||||
- type: Type of the edge (e.g., Source, Relation, Concept).
|
||||
- synsets: List of synsets associated with the edge.
|
||||
|
||||
'''
|
||||
g = nx.DiGraph()
|
||||
entity_to_id = {}
|
||||
|
||||
# Add triple nodes
|
||||
with open(triple_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["name:ID"]
|
||||
mapped_id = get_node_id(node_id, entity_to_id)
|
||||
# Check if node already exists to prevent duplicates
|
||||
if mapped_id not in g.nodes:
|
||||
g.add_node(mapped_id, id=node_id, type=row["type"])
|
||||
|
||||
# Add text nodes
|
||||
with open(text_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["text_id:ID"]
|
||||
# Check if node already exists to prevent duplicates
|
||||
if node_id not in g.nodes:
|
||||
g.add_node(node_id, file_id=node_id, id=row["original_text"], type="passage")
|
||||
|
||||
# Add concept nodes
|
||||
with open(concept_node_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
node_id = row["concept_id:ID"]
|
||||
# Check if node already exists to prevent duplicates
|
||||
if node_id not in g.nodes:
|
||||
g.add_node(node_id, file_id="concept_file", id=row["name"], type="concept")
|
||||
|
||||
# Add file id for triple nodes and concept nodes when add the edges
|
||||
|
||||
# Add triple edges
|
||||
with open(triple_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = get_node_id(row[":END_ID"], entity_to_id)
|
||||
# Check if edge already exists to prevent duplicates
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
|
||||
# Add file_id to start and end nodes if they are triple or concept nodes
|
||||
for node_id in [start_id, end_id]:
|
||||
if g.nodes[node_id]['type'] in ['triple', 'concept'] and 'file_id' not in g.nodes[node_id]:
|
||||
g.nodes[node_id]['file_id'] = row.get("file_id", "triple_file")
|
||||
|
||||
# Add concepts to the edge
|
||||
concepts = ast.literal_eval(row["concepts"])
|
||||
for concept in concepts:
|
||||
if "concepts" not in g.edges[start_id, end_id]:
|
||||
g.edges[start_id, end_id]['concepts'] = str(concept)
|
||||
else:
|
||||
# Avoid duplicate concepts by checking if concept is already in the list
|
||||
current_concepts = g.edges[start_id, end_id]['concepts'].split(",")
|
||||
if str(concept) not in current_concepts:
|
||||
g.edges[start_id, end_id]['concepts'] += "," + str(concept)
|
||||
|
||||
|
||||
# Add text edges
|
||||
with open(text_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = row[":END_ID"]
|
||||
# Check if edge already exists to prevent duplicates
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation="mention in", type=row[":TYPE"])
|
||||
# Add file_id to start node if it is a triple or concept node
|
||||
if 'file_id' in g.nodes[start_id]:
|
||||
g.nodes[start_id]['file_id'] += "," + str(end_id)
|
||||
else:
|
||||
g.nodes[start_id]['file_id'] = str(end_id)
|
||||
|
||||
# Add concept edges between triple nodes and concept nodes
|
||||
with open(concept_edge_file, 'r') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
start_id = get_node_id(row[":START_ID"], entity_to_id)
|
||||
end_id = row[":END_ID"] # end id is concept node id
|
||||
if not g.has_edge(start_id, end_id):
|
||||
g.add_edge(start_id, end_id, relation=row["relation"], type=row[":TYPE"])
|
||||
|
||||
# Write to GraphML
|
||||
# check if output file directory exists
|
||||
output_dir = os.path.dirname(output_file)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
nx.write_graphml(g, output_file, infer_numeric_types=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert CSV files to GraphML format.')
|
||||
parser.add_argument('--triple_node_file', type=str, required=True, help='Path to the triple node CSV file.')
|
||||
parser.add_argument('--text_node_file', type=str, required=True, help='Path to the text node CSV file.')
|
||||
parser.add_argument('--concept_node_file', type=str, required=True, help='Path to the concept node CSV file.')
|
||||
parser.add_argument('--triple_edge_file', type=str, required=True, help='Path to the triple edge CSV file.')
|
||||
parser.add_argument('--text_edge_file', type=str, required=True, help='Path to the text edge CSV file.')
|
||||
parser.add_argument('--concept_edge_file', type=str, required=True, help='Path to the concept edge CSV file.')
|
||||
parser.add_argument('--output_file', type=str, required=True, help='Path to the output GraphML file.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
csvs_to_graphml(args.triple_node_file, args.text_node_file, args.concept_node_file,
|
||||
args.triple_edge_file, args.text_edge_file, args.concept_edge_file,
|
||||
args.output_file)
|
||||
@ -0,0 +1,70 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from ast import literal_eval # Safer string-to-list conversion
|
||||
import os
|
||||
|
||||
CHUNKSIZE = 100_000 # Adjust based on your RAM (100K rows per chunk)
|
||||
EMBEDDING_COL = "embedding:STRING" # Column name with embeddings
|
||||
# DIMENSION = 32 # Update with your embedding dimension
|
||||
ENTITY_ONLY = True
|
||||
def parse_embedding(embed_str):
|
||||
"""Convert embedding string to numpy array"""
|
||||
# Remove brackets and convert to list
|
||||
return np.array(literal_eval(embed_str), dtype=np.float32)
|
||||
|
||||
# Create memory-mapped numpy file
|
||||
def convert_csv_to_npy(csv_path, npy_path):
|
||||
total_embeddings = 0
|
||||
# check dir exist, if not then create it
|
||||
os.makedirs(os.path.dirname(npy_path), exist_ok=True)
|
||||
|
||||
with open(npy_path, "wb") as f:
|
||||
pass # Initialize empty file
|
||||
|
||||
# Process CSV in chunks
|
||||
for chunk_idx, df_chunk in enumerate(
|
||||
pd.read_csv(csv_path, chunksize=CHUNKSIZE, usecols=[EMBEDDING_COL])
|
||||
):
|
||||
|
||||
|
||||
# Parse embeddings
|
||||
embeddings = np.stack(
|
||||
df_chunk[EMBEDDING_COL].apply(parse_embedding).values
|
||||
)
|
||||
|
||||
# Verify dimensions
|
||||
# assert embeddings.shape[1] == DIMENSION, \
|
||||
# f"Dimension mismatch at chunk {chunk_idx}"
|
||||
total_embeddings += embeddings.shape[0]
|
||||
# Append to .npy file
|
||||
with open(npy_path, "ab") as f:
|
||||
np.save(f, embeddings.astype(np.float32))
|
||||
|
||||
print(f"Processed chunk {chunk_idx} ({CHUNKSIZE*(chunk_idx+1)} rows)")
|
||||
print(f"Total number of embeddings: {total_embeddings}")
|
||||
print("Conversion complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
keyword = 'cc_en' # Change this to your desired keyword
|
||||
csv_dir="./import" # Change this to your CSV directory
|
||||
keyword_to_paths ={
|
||||
'cc_en':{
|
||||
'node_csv': f"{csv_dir}/triple_nodes_cc_en_from_json_2.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_cc_en_from_json_2.csv",
|
||||
'text_csv': f"{csv_dir}/text_nodes_cc_en_from_json_with_emb.csv",
|
||||
},
|
||||
'pes2o_abstract':{
|
||||
'node_csv': f"{csv_dir}/triple_nodes_pes2o_abstract_from_json.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_pes2o_abstract_from_json.csv",
|
||||
'text_csv': f"{csv_dir}/text_nodes_pes2o_abstract_from_json_with_emb.csv",
|
||||
},
|
||||
'en_simple_wiki_v0':{
|
||||
'node_csv': f"{csv_dir}/triple_nodes_en_simple_wiki_v0_from_json.csv",
|
||||
# 'edge_csv': f"{csv_dir}/triple_edges_en_simple_wiki_v0_from_json.csv",
|
||||
'text_csv': f"{csv_dir}/text_nodes_en_simple_wiki_v0_from_json_with_emb.csv",
|
||||
},
|
||||
}
|
||||
for key, path in keyword_to_paths[keyword].items():
|
||||
npy_path = path.replace(".csv", ".npy")
|
||||
convert_csv_to_npy(path, npy_path)
|
||||
print(f"Converted {path} to {npy_path}")
|
||||
@ -0,0 +1,27 @@
|
||||
import os
|
||||
import glob
|
||||
|
||||
def merge_csv_files(output_file, input_dir):
|
||||
"""
|
||||
Merge all CSV files in the input directory into a single output file.
|
||||
|
||||
Args:
|
||||
output_file (str): Path to the output CSV file.
|
||||
input_dir (str): Directory containing the input CSV files.
|
||||
"""
|
||||
# Delete the output file if it exists
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
|
||||
# Write the header to the output file
|
||||
with open(output_file, 'w') as outfile:
|
||||
outfile.write("node,conceptualized_node,node_type\n")
|
||||
|
||||
# Append the contents of all CSV files in the input directory
|
||||
for csv_file in glob.glob(os.path.join(input_dir, '*.csv')):
|
||||
with open(csv_file, 'r') as infile:
|
||||
# Skip the header line
|
||||
next(infile)
|
||||
# Append the remaining lines to the output file
|
||||
with open(output_file, 'a') as outfile:
|
||||
outfile.writelines(infile)
|
||||
@ -0,0 +1,277 @@
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import re
|
||||
import hashlib
|
||||
|
||||
# Increase the field size limit
|
||||
csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit
|
||||
|
||||
|
||||
# Function to compute a hash ID from text
|
||||
def compute_hash_id(text):
|
||||
# Use SHA-256 to generate a hash
|
||||
hash_object = hashlib.sha256(text.encode('utf-8'))
|
||||
return hash_object.hexdigest() # Return hash as a hex string
|
||||
|
||||
def clean_text(text):
|
||||
# remove NUL as well
|
||||
|
||||
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
|
||||
new_text = new_text.replace("\x00", "")
|
||||
new_text = re.sub(r'\s+', ' ', new_text).strip()
|
||||
|
||||
return new_text
|
||||
|
||||
def remove_NUL(text):
|
||||
return text.replace("\x00", "")
|
||||
|
||||
def json2csv(dataset, data_dir, output_dir, test=False):
|
||||
"""
|
||||
Convert JSON files to CSV files for nodes, edges, and missing concepts.
|
||||
|
||||
Args:
|
||||
dataset (str): Name of the dataset.
|
||||
data_dir (str): Directory containing the JSON files.
|
||||
output_dir (str): Directory to save the output CSV files.
|
||||
test (bool): If True, run in test mode (process only 3 files).
|
||||
"""
|
||||
visited_nodes = set()
|
||||
visited_hashes = set()
|
||||
|
||||
all_entities = set()
|
||||
all_events = set()
|
||||
all_relations = set()
|
||||
|
||||
file_dir_list = [f for f in os.listdir(data_dir) if dataset in f]
|
||||
file_dir_list = sorted(file_dir_list)
|
||||
if test:
|
||||
file_dir_list = file_dir_list[:3]
|
||||
print("Loading data from the json files")
|
||||
print("Number of files: ", len(file_dir_list))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
# Define output file paths
|
||||
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb.csv")
|
||||
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb.csv")
|
||||
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json.csv")
|
||||
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json.csv")
|
||||
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json.csv")
|
||||
|
||||
if test:
|
||||
node_text_file = os.path.join(output_dir, f"text_nodes_{dataset}_from_json_test.csv")
|
||||
edge_text_file = os.path.join(output_dir, f"text_edges_{dataset}_from_json_test.csv")
|
||||
node_csv_without_emb = os.path.join(output_dir, f"triple_nodes_{dataset}_from_json_without_emb_test.csv")
|
||||
edge_csv_without_emb = os.path.join(output_dir, f"triple_edges_{dataset}_from_json_without_emb_test.csv")
|
||||
missing_concepts_file = os.path.join(output_dir, f"missing_concepts_{dataset}_from_json_test.csv")
|
||||
|
||||
# Open CSV files for writing
|
||||
with open(node_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node_text, \
|
||||
open(edge_text_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge_text, \
|
||||
open(node_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_node, \
|
||||
open(edge_csv_without_emb, "w", newline='', encoding='utf-8', errors='ignore') as csvfile_edge:
|
||||
|
||||
csv_writer_node_text = csv.writer(csvfile_node_text)
|
||||
csv_writer_edge_text = csv.writer(csvfile_edge_text)
|
||||
writer_node = csv.writer(csvfile_node)
|
||||
writer_edge = csv.writer(csvfile_edge)
|
||||
|
||||
# Write headers
|
||||
csv_writer_node_text.writerow(["text_id:ID", "original_text", ":LABEL"])
|
||||
csv_writer_edge_text.writerow([":START_ID", ":END_ID", ":TYPE"])
|
||||
writer_node.writerow(["name:ID", "type", "concepts", "synsets", ":LABEL"])
|
||||
writer_edge.writerow([":START_ID", ":END_ID", "relation", "concepts", "synsets", ":TYPE"])
|
||||
|
||||
# Process each file
|
||||
for file_dir in tqdm(file_dir_list):
|
||||
print("Processing file for file ids: ", file_dir)
|
||||
with open(os.path.join(data_dir, file_dir), "r") as jsonfile:
|
||||
for line in jsonfile:
|
||||
data = json.loads(line.strip())
|
||||
original_text = data["original_text"]
|
||||
original_text = remove_NUL(original_text)
|
||||
if "Here is the passage." in original_text:
|
||||
original_text = original_text.split("Here is the passage.")[-1]
|
||||
eot_token = "<|eot_id|>"
|
||||
original_text = original_text.split(eot_token)[0]
|
||||
|
||||
text_hash_id = compute_hash_id(original_text)
|
||||
|
||||
# Write the original text as nodes
|
||||
if text_hash_id not in visited_hashes:
|
||||
visited_hashes.add(text_hash_id)
|
||||
csv_writer_node_text.writerow([text_hash_id, original_text, "Text"])
|
||||
|
||||
file_id = str(data["id"])
|
||||
entity_relation_dict = data["entity_relation_dict"]
|
||||
event_entity_relation_dict = data["event_entity_relation_dict"]
|
||||
event_relation_dict = data["event_relation_dict"]
|
||||
|
||||
# Process entity triples
|
||||
entity_triples = []
|
||||
for entity_triple in entity_relation_dict:
|
||||
try:
|
||||
assert isinstance(entity_triple["Head"], str)
|
||||
assert isinstance(entity_triple["Relation"], str)
|
||||
assert isinstance(entity_triple["Tail"], str)
|
||||
|
||||
head_entity = entity_triple["Head"]
|
||||
relation = entity_triple["Relation"]
|
||||
tail_entity = entity_triple["Tail"]
|
||||
|
||||
# Clean the text
|
||||
head_entity = clean_text(head_entity)
|
||||
relation = clean_text(relation)
|
||||
tail_entity = clean_text(tail_entity)
|
||||
|
||||
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
|
||||
continue
|
||||
|
||||
entity_triples.append((head_entity, relation, tail_entity))
|
||||
except:
|
||||
print(f"Error processing entity triple: {entity_triple}")
|
||||
continue
|
||||
|
||||
# Process event triples
|
||||
event_triples = []
|
||||
for event_triple in event_relation_dict:
|
||||
try:
|
||||
assert isinstance(event_triple["Head"], str)
|
||||
assert isinstance(event_triple["Relation"], str)
|
||||
assert isinstance(event_triple["Tail"], str)
|
||||
head_event = event_triple["Head"]
|
||||
relation = event_triple["Relation"]
|
||||
tail_event = event_triple["Tail"]
|
||||
|
||||
# Clean the text
|
||||
head_event = clean_text(head_event)
|
||||
relation = clean_text(relation)
|
||||
tail_event = clean_text(tail_event)
|
||||
|
||||
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
|
||||
continue
|
||||
|
||||
event_triples.append((head_event, relation, tail_event))
|
||||
except:
|
||||
print(f"Error processing event triple: {event_triple}")
|
||||
|
||||
# Process event-entity triples
|
||||
event_entity_triples = []
|
||||
for event_entity_participations in event_entity_relation_dict:
|
||||
if "Event" not in event_entity_participations or "Entity" not in event_entity_participations:
|
||||
continue
|
||||
if not isinstance(event_entity_participations["Event"], str) or not isinstance(event_entity_participations["Entity"], list):
|
||||
continue
|
||||
|
||||
for entity in event_entity_participations["Entity"]:
|
||||
if not isinstance(entity, str):
|
||||
continue
|
||||
|
||||
entity = clean_text(entity)
|
||||
event = clean_text(event_entity_participations["Event"])
|
||||
|
||||
if event.isspace() or len(event) == 0 or entity.isspace() or len(entity) == 0:
|
||||
continue
|
||||
|
||||
event_entity_triples.append((event, "is participated by", entity))
|
||||
|
||||
# Write nodes and edges to CSV files
|
||||
for entity_triple in entity_triples:
|
||||
head_entity, relation, tail_entity = entity_triple
|
||||
if head_entity is None or tail_entity is None or relation is None:
|
||||
continue
|
||||
if head_entity.isspace() or tail_entity.isspace() or relation.isspace():
|
||||
continue
|
||||
if len(head_entity) == 0 or len(tail_entity) == 0 or len(relation) == 0:
|
||||
continue
|
||||
|
||||
# Add nodes to files
|
||||
if head_entity not in visited_nodes:
|
||||
visited_nodes.add(head_entity)
|
||||
all_entities.add(head_entity)
|
||||
writer_node.writerow([head_entity, "entity", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([head_entity, text_hash_id, "Source"])
|
||||
|
||||
if tail_entity not in visited_nodes:
|
||||
visited_nodes.add(tail_entity)
|
||||
all_entities.add(tail_entity)
|
||||
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
|
||||
|
||||
all_relations.add(relation)
|
||||
writer_edge.writerow([head_entity, tail_entity, relation, [], [], "Relation"])
|
||||
|
||||
for event_triple in event_triples:
|
||||
head_event, relation, tail_event = event_triple
|
||||
if head_event is None or tail_event is None or relation is None:
|
||||
continue
|
||||
if head_event.isspace() or tail_event.isspace() or relation.isspace():
|
||||
continue
|
||||
if len(head_event) == 0 or len(tail_event) == 0 or len(relation) == 0:
|
||||
continue
|
||||
|
||||
# Add nodes to files
|
||||
if head_event not in visited_nodes:
|
||||
visited_nodes.add(head_event)
|
||||
all_events.add(head_event)
|
||||
writer_node.writerow([head_event, "event", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
|
||||
|
||||
if tail_event not in visited_nodes:
|
||||
visited_nodes.add(tail_event)
|
||||
all_events.add(tail_event)
|
||||
writer_node.writerow([tail_event, "event", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([tail_event, text_hash_id, "Source"])
|
||||
|
||||
all_relations.add(relation)
|
||||
writer_edge.writerow([head_event, tail_event, relation, [], [], "Relation"])
|
||||
|
||||
for event_entity_triple in event_entity_triples:
|
||||
head_event, relation, tail_entity = event_entity_triple
|
||||
if head_event is None or tail_entity is None or relation is None:
|
||||
continue
|
||||
if head_event.isspace() or tail_entity.isspace() or relation.isspace():
|
||||
continue
|
||||
if len(head_event) == 0 or len(tail_entity) == 0 or len(relation) == 0:
|
||||
continue
|
||||
|
||||
# Add nodes to files
|
||||
if head_event not in visited_nodes:
|
||||
visited_nodes.add(head_event)
|
||||
all_events.add(head_event)
|
||||
writer_node.writerow([head_event, "event", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([head_event, text_hash_id, "Source"])
|
||||
|
||||
if tail_entity not in visited_nodes:
|
||||
visited_nodes.add(tail_entity)
|
||||
all_entities.add(tail_entity)
|
||||
writer_node.writerow([tail_entity, "entity", [], [], "Node"])
|
||||
csv_writer_edge_text.writerow([tail_entity, text_hash_id, "Source"])
|
||||
|
||||
all_relations.add(relation)
|
||||
writer_edge.writerow([head_event, tail_entity, relation, [], [], "Relation"])
|
||||
|
||||
# Write missing concepts to CSV
|
||||
with open(missing_concepts_file, "w", newline='', encoding='utf-8', errors='ignore') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(["Name", "Type"])
|
||||
for entity in all_entities:
|
||||
writer.writerow([entity, "Entity"])
|
||||
for event in all_events:
|
||||
writer.writerow([event, "Event"])
|
||||
for relation in all_relations:
|
||||
writer.writerow([relation, "Relation"])
|
||||
|
||||
print("Data to CSV completed successfully.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", type=str, required=True, help="[pes2o_abstract, en_simple_wiki_v0, cc_en]")
|
||||
parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the graph raw JSON files")
|
||||
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the output CSV files")
|
||||
parser.add_argument("--test", action="store_true", help="Test the script")
|
||||
args = parser.parse_args()
|
||||
json2csv(dataset=args.dataset, data_dir=args.data_dir, output_dir=args.output_dir, test=args.test)
|
||||
@ -0,0 +1,169 @@
|
||||
import networkx as nx
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
def get_node_id(entity_name, entity_to_id):
|
||||
"""Returns existing or creates new nX ID for an entity using a hash-based approach."""
|
||||
if entity_name not in entity_to_id:
|
||||
# Use a hash function to generate a unique ID
|
||||
hash_object = hashlib.md5(entity_name.encode()) # Use MD5 or another hashing algorithm
|
||||
hash_hex = hash_object.hexdigest() # Get the hexadecimal representation of the hash
|
||||
# Use the first 8 characters of the hash as the ID (you can adjust the length as needed)
|
||||
entity_to_id[entity_name] = f'n{hash_hex[:16]}'
|
||||
return entity_to_id[entity_name]
|
||||
|
||||
def clean_text(text):
|
||||
# remove NUL as well
|
||||
new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",")
|
||||
new_text = new_text.replace("\x00", "")
|
||||
return new_text
|
||||
|
||||
|
||||
def process_kg_data(input_passage_dir, input_triple_dir, output_dir, keyword):
|
||||
# Get file names containing the keyword
|
||||
file_names = [file for file in list(os.listdir(input_triple_dir)) if keyword in file]
|
||||
print(f"Keyword: {keyword}")
|
||||
print(f"Number of files: {len(file_names)}")
|
||||
print(file_names)
|
||||
|
||||
passage_file_names = [file for file in list(os.listdir(input_passage_dir)) if keyword in file]
|
||||
print(f'Passage file names: {passage_file_names}')
|
||||
|
||||
g = nx.DiGraph()
|
||||
print("Graph created.")
|
||||
entity_to_id = {}
|
||||
# check if output directory exists, if not create it
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
print(f"Output directory {output_dir} created.")
|
||||
|
||||
output_path = f"{output_dir}/{keyword}_kg_from_corpus.graphml"
|
||||
|
||||
# Create the original_text to node_id dictionary and add passage node to the graph
|
||||
with open(f"{input_passage_dir}/{passage_file_names[0]}") as f:
|
||||
data = json.load(f)
|
||||
for item in tqdm(data, desc="Processing passages"):
|
||||
passage_id = item["id"]
|
||||
passage_text = item["text"]
|
||||
node_id = get_node_id(passage_text, entity_to_id)
|
||||
if passage_text.isspace() or len(passage_text) == 0:
|
||||
continue
|
||||
# Add the passage node to the graph
|
||||
g.add_node(node_id, type="passage", id=passage_text, file_id=passage_id)
|
||||
|
||||
for file_name in tqdm(file_names):
|
||||
print(f"Processing {file_name}")
|
||||
input_file_path = f"{input_triple_dir}/{file_name}"
|
||||
with open(input_file_path) as f:
|
||||
for line in tqdm(f):
|
||||
data = json.loads(line)
|
||||
metadata = data["metadata"]
|
||||
file_id = data["id"]
|
||||
original_text = data["original_text"]
|
||||
entity_relation_dict = data["entity_relation_dict"]
|
||||
event_entity_relation_dict = data["event_entity_relation_dict"]
|
||||
event_relation_dict = data["event_relation_dict"]
|
||||
|
||||
# Process entity triples
|
||||
entity_triples = []
|
||||
for entity_triple in entity_relation_dict:
|
||||
if not all(key in entity_triple for key in ["Head", "Relation", "Tail"]):
|
||||
continue
|
||||
head_entity = clean_text(entity_triple["Head"])
|
||||
relation = clean_text(entity_triple["Relation"])
|
||||
tail_entity = clean_text(entity_triple["Tail"])
|
||||
if head_entity.isspace() or len(head_entity) == 0 or tail_entity.isspace() or len(tail_entity) == 0:
|
||||
continue
|
||||
entity_triples.append((head_entity, relation, tail_entity))
|
||||
|
||||
# Add entity triples to the graph
|
||||
for triple in entity_triples:
|
||||
head_id = get_node_id(triple[0], entity_to_id)
|
||||
tail_id = get_node_id(triple[2], entity_to_id)
|
||||
g.add_node(head_id, type="entity", id=triple[0])
|
||||
g.add_node(tail_id, type="entity", id=triple[2])
|
||||
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(head_id, tail_id, relation=triple[1])
|
||||
for node_id in [head_id, tail_id]:
|
||||
if "file_id" not in g.nodes[node_id]:
|
||||
g.nodes[node_id]["file_id"] = str(file_id)
|
||||
else:
|
||||
g.nodes[node_id]["file_id"] += "," + str(file_id)
|
||||
edge = g.edges[head_id, tail_id]
|
||||
if "file_id" not in edge:
|
||||
edge["file_id"] = str(file_id)
|
||||
else:
|
||||
edge["file_id"] += "," + str(file_id)
|
||||
|
||||
# Process event triples
|
||||
event_triples = []
|
||||
for event_triple in event_relation_dict:
|
||||
if not all(key in event_triple for key in ["Head", "Relation", "Tail"]):
|
||||
continue
|
||||
head_event = clean_text(event_triple["Head"])
|
||||
relation = clean_text(event_triple["Relation"])
|
||||
tail_event = clean_text(event_triple["Tail"])
|
||||
if head_event.isspace() or len(head_event) == 0 or tail_event.isspace() or len(tail_event) == 0:
|
||||
continue
|
||||
event_triples.append((head_event, relation, tail_event))
|
||||
|
||||
# Add event triples to the graph
|
||||
for triple in event_triples:
|
||||
head_id = get_node_id(triple[0], entity_to_id)
|
||||
tail_id = get_node_id(triple[2], entity_to_id)
|
||||
g.add_node(head_id, type="event", id=triple[0])
|
||||
g.add_node(tail_id, type="event", id=triple[2])
|
||||
g.add_edge(head_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(tail_id, get_node_id(original_text, entity_to_id), relation='mention in')
|
||||
g.add_edge(head_id, tail_id, relation=triple[1])
|
||||
for node_id in [head_id, tail_id]:
|
||||
if "file_id" not in g.nodes[node_id]:
|
||||
g.nodes[node_id]["file_id"] = str(file_id)
|
||||
else:
|
||||
g.nodes[node_id]["file_id"] += "," + str(file_id)
|
||||
edge = g.edges[head_id, tail_id]
|
||||
if "file_id" not in edge:
|
||||
edge["file_id"] = str(file_id)
|
||||
else:
|
||||
edge["file_id"] += "," + str(file_id)
|
||||
|
||||
# Process event-entity triples
|
||||
event_entity_triples = []
|
||||
for event_entity_participations in event_entity_relation_dict:
|
||||
if not all(key in event_entity_participations for key in ["Event", "Entity"]):
|
||||
continue
|
||||
event = clean_text(event_entity_participations["Event"])
|
||||
if event.isspace() or len(event) == 0:
|
||||
continue
|
||||
for entity in event_entity_participations["Entity"]:
|
||||
if not isinstance(entity, str) or entity.isspace() or len(entity) == 0:
|
||||
continue
|
||||
entity = clean_text(entity)
|
||||
event_entity_triples.append((event, "is participated by", entity))
|
||||
|
||||
# Add event-entity triples to the graph
|
||||
for triple in event_entity_triples:
|
||||
head_id = get_node_id(triple[0], entity_to_id)
|
||||
tail_id = get_node_id(triple[2], entity_to_id)
|
||||
g.add_node(head_id, type="event", id=triple[0])
|
||||
g.add_node(tail_id, type="entity", id=triple[2])
|
||||
g.add_edge(head_id, tail_id, relation=triple[1])
|
||||
for node_id in [head_id, tail_id]:
|
||||
if "file_id" not in g.nodes[node_id]:
|
||||
g.nodes[node_id]["file_id"] = str(file_id)
|
||||
edge = g.edges[head_id, tail_id]
|
||||
if "file_id" not in edge:
|
||||
edge["file_id"] = str(file_id)
|
||||
else:
|
||||
edge["file_id"] += "," + str(file_id)
|
||||
|
||||
print(f"Number of nodes: {g.number_of_nodes()}")
|
||||
print(f"Number of edges: {g.number_of_edges()}")
|
||||
print(f"Graph density: {nx.density(g)}")
|
||||
with open(output_path, 'wb') as f:
|
||||
nx.write_graphml(g, f, infer_numeric_types=True)
|
||||
|
||||
|
||||
@ -0,0 +1,63 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Set up argument parser
|
||||
parser = argparse.ArgumentParser(description="Convert all Markdown files in a folder to separate JSON files.")
|
||||
parser.add_argument(
|
||||
"--input", required=True, help="Path to the folder containing Markdown files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", default=None, help="Output folder for JSON files (defaults to input folder if not specified)"
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve input folder path
|
||||
input_folder = Path(args.input)
|
||||
if not input_folder.is_dir():
|
||||
print(f"Error: '{args.input}' is not a directory.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Set output folder (use input folder if not specified)
|
||||
output_folder = Path(args.output) if args.output else input_folder
|
||||
output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Find all .md files in the input folder
|
||||
markdown_files = [f for f in input_folder.iterdir() if f.suffix.lower() == ".md"]
|
||||
|
||||
if not markdown_files:
|
||||
print(f"Error: No Markdown files found in '{args.input}'.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Process each Markdown file
|
||||
for file in markdown_files:
|
||||
try:
|
||||
# Read the content of the file
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Create the JSON object
|
||||
obj = {
|
||||
"id": "1",
|
||||
"text": content,
|
||||
"metadata": {
|
||||
"lang": "en"
|
||||
}
|
||||
}
|
||||
|
||||
# Create output JSON filename (e.g., file1.md -> file1.json)
|
||||
output_file = output_folder / f"{file.stem}.json"
|
||||
|
||||
# Write JSON to file
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump([obj], f, indent=4)
|
||||
|
||||
print(f"Successfully converted '{file}' to '{output_file}'")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File '{file}' not found.", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"Error processing file '{file}': {e}", file=sys.stderr)
|
||||
Reference in New Issue
Block a user