497 lines
24 KiB
Python
497 lines
24 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Knowledge Graph Extraction Pipeline
|
||
|
|
Extracts entities, relations, and events from text data using transformer models.
|
||
|
|
"""
|
||
|
|
import re
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import argparse
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import List, Dict, Any, Tuple
|
||
|
|
from pathlib import Path
|
||
|
|
import torch
|
||
|
|
from datasets import load_dataset
|
||
|
|
from tqdm import tqdm
|
||
|
|
import json_repair
|
||
|
|
from atlas_rag.llm_generator import LLMGenerator
|
||
|
|
from atlas_rag.kg_construction.utils.json_processing.json_to_csv import json2csv
|
||
|
|
from atlas_rag.kg_construction.concept_generation import generate_concept
|
||
|
|
from atlas_rag.kg_construction.utils.csv_processing.merge_csv import merge_csv_files
|
||
|
|
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import csvs_to_graphml, csvs_to_temp_graphml
|
||
|
|
from atlas_rag.kg_construction.concept_to_csv import all_concept_triples_csv_to_csv
|
||
|
|
from atlas_rag.kg_construction.utils.csv_processing.csv_add_numeric_id import add_csv_columns
|
||
|
|
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
|
||
|
|
from atlas_rag.vectorstore.create_neo4j_index import create_faiss_index
|
||
|
|
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import TRIPLE_INSTRUCTIONS
|
||
|
|
from atlas_rag.kg_construction.triple_config import ProcessingConfig
|
||
|
|
# Constants
|
||
|
|
TOKEN_LIMIT = 1024
|
||
|
|
INSTRUCTION_TOKEN_ESTIMATE = 200
|
||
|
|
CHAR_TO_TOKEN_RATIO = 3.5
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
class TextChunker:
|
||
|
|
"""Handles text chunking based on token limits."""
|
||
|
|
|
||
|
|
def __init__(self, max_tokens: int = TOKEN_LIMIT, instruction_tokens: int = INSTRUCTION_TOKEN_ESTIMATE):
|
||
|
|
self.max_tokens = max_tokens
|
||
|
|
self.instruction_tokens = instruction_tokens
|
||
|
|
self.char_ratio = CHAR_TO_TOKEN_RATIO
|
||
|
|
|
||
|
|
def calculate_max_chars(self) -> int:
|
||
|
|
"""Calculate maximum characters per chunk."""
|
||
|
|
available_tokens = self.max_tokens - self.instruction_tokens
|
||
|
|
return int(available_tokens * self.char_ratio)
|
||
|
|
|
||
|
|
def split_text(self, text: str) -> List[str]:
|
||
|
|
"""Split text into chunks that fit within token limits."""
|
||
|
|
max_chars = self.calculate_max_chars()
|
||
|
|
chunks = []
|
||
|
|
|
||
|
|
while len(text) > max_chars:
|
||
|
|
chunks.append(text[:max_chars])
|
||
|
|
text = text[max_chars:]
|
||
|
|
|
||
|
|
if text: # Add remaining text
|
||
|
|
chunks.append(text)
|
||
|
|
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
|
||
|
|
class DatasetProcessor:
|
||
|
|
"""Processes and prepares dataset for knowledge graph extraction."""
|
||
|
|
|
||
|
|
def __init__(self, config: ProcessingConfig):
|
||
|
|
self.config = config
|
||
|
|
self.chunker = TextChunker()
|
||
|
|
|
||
|
|
def filter_language_content(self, sample: Dict[str, Any]) -> bool:
|
||
|
|
"""Check if content is in English."""
|
||
|
|
metadata = sample.get("metadata", {})
|
||
|
|
language = metadata.get("lang", "en") # Default to English if not specified
|
||
|
|
supported_languages = list(TRIPLE_INSTRUCTIONS.keys())
|
||
|
|
return language in supported_languages
|
||
|
|
|
||
|
|
|
||
|
|
def create_sample_chunks(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
|
|
"""Create chunks from a single sample."""
|
||
|
|
original_text = sample.get("text", "")
|
||
|
|
if self.config.remove_doc_spaces:
|
||
|
|
original_text = re.sub(r'\s+', ' ',original_text).strip()
|
||
|
|
text_chunks = self.chunker.split_text(original_text)
|
||
|
|
chunks = []
|
||
|
|
|
||
|
|
for chunk_idx, chunk_text in enumerate(text_chunks):
|
||
|
|
chunk_data = {
|
||
|
|
"id": sample["id"],
|
||
|
|
"text": chunk_text,
|
||
|
|
"chunk_id": chunk_idx,
|
||
|
|
"metadata": sample["metadata"]
|
||
|
|
}
|
||
|
|
chunks.append(chunk_data)
|
||
|
|
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
def prepare_dataset(self, raw_dataset) -> List[Dict[str, Any]]:
|
||
|
|
"""Process raw dataset into chunks suitable for processing with generalized slicing."""
|
||
|
|
|
||
|
|
processed_samples = []
|
||
|
|
total_texts = len(raw_dataset)
|
||
|
|
|
||
|
|
# Handle edge cases
|
||
|
|
if total_texts == 0:
|
||
|
|
print(f"No texts found for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
|
||
|
|
return processed_samples
|
||
|
|
|
||
|
|
# Calculate base and remainder for fair distribution
|
||
|
|
base_texts_per_shard = total_texts // self.config.total_shards_triple
|
||
|
|
remainder = total_texts % self.config.total_shards_triple
|
||
|
|
|
||
|
|
# Calculate start index
|
||
|
|
if self.config.current_shard_triple < remainder:
|
||
|
|
start_idx = self.config.current_shard_triple * (base_texts_per_shard + 1)
|
||
|
|
else:
|
||
|
|
start_idx = remainder * (base_texts_per_shard + 1) + (self.config.current_shard_triple - remainder) * base_texts_per_shard
|
||
|
|
|
||
|
|
# Calculate end index
|
||
|
|
if self.config.current_shard_triple < remainder:
|
||
|
|
end_idx = start_idx + (base_texts_per_shard + 1)
|
||
|
|
else:
|
||
|
|
end_idx = start_idx + base_texts_per_shard
|
||
|
|
|
||
|
|
# Ensure indices are within bounds
|
||
|
|
start_idx = min(start_idx, total_texts)
|
||
|
|
end_idx = min(end_idx, total_texts)
|
||
|
|
|
||
|
|
print(f"Processing shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple} "
|
||
|
|
f"(texts {start_idx}-{end_idx-1} of {total_texts}, {end_idx - start_idx} documents)")
|
||
|
|
|
||
|
|
# Process documents in assigned shard
|
||
|
|
for idx in range(start_idx, end_idx):
|
||
|
|
sample = raw_dataset[idx]
|
||
|
|
|
||
|
|
# Filter by language
|
||
|
|
if not self.filter_language_content(sample):
|
||
|
|
print(f"Unsupported language in sample {idx}, skipping.")
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Create chunks
|
||
|
|
chunks = self.create_sample_chunks(sample)
|
||
|
|
processed_samples.extend(chunks)
|
||
|
|
|
||
|
|
# Debug mode early termination
|
||
|
|
if self.config.debug_mode and len(processed_samples) >= 20:
|
||
|
|
print("Debug mode: Stopping at 20 chunks")
|
||
|
|
break
|
||
|
|
|
||
|
|
print(f"Generated {len(processed_samples)} chunks for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
|
||
|
|
return processed_samples
|
||
|
|
|
||
|
|
|
||
|
|
class CustomDataLoader:
|
||
|
|
"""Custom data loader for knowledge graph extraction."""
|
||
|
|
|
||
|
|
def __init__(self, dataset, processor: DatasetProcessor):
|
||
|
|
self.raw_dataset = dataset
|
||
|
|
self.processor = processor
|
||
|
|
self.processed_data = processor.prepare_dataset(dataset)
|
||
|
|
self.stage_to_prompt_dict = {
|
||
|
|
"stage_1": "entity_relation",
|
||
|
|
"stage_2": "event_entity",
|
||
|
|
"stage_3": "event_relation"
|
||
|
|
}
|
||
|
|
|
||
|
|
def __len__(self) -> int:
|
||
|
|
return len(self.processed_data)
|
||
|
|
|
||
|
|
def create_batch_instructions(self, batch_data: List[Dict[str, Any]]) -> List[str]:
|
||
|
|
messages_dict = {
|
||
|
|
'stage_1': [],
|
||
|
|
'stage_2': [],
|
||
|
|
'stage_3': []
|
||
|
|
}
|
||
|
|
for item in batch_data:
|
||
|
|
# get language
|
||
|
|
language = item.get("metadata",{}).get("lang", "en")
|
||
|
|
system_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['system']
|
||
|
|
stage_1_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['entity_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n' + item["text"]
|
||
|
|
stage_2_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_entity'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
|
||
|
|
stage_3_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
|
||
|
|
stage_one_message = [
|
||
|
|
{"role": "system", "content": system_msg},
|
||
|
|
{"role": "user", "content": stage_1_msg}
|
||
|
|
]
|
||
|
|
stage_two_message = [
|
||
|
|
{"role": "system", "content": system_msg},
|
||
|
|
{"role": "user", "content": stage_2_msg}
|
||
|
|
]
|
||
|
|
stage_three_message = [
|
||
|
|
{"role": "system", "content": system_msg},
|
||
|
|
{"role": "user", "content": stage_3_msg}
|
||
|
|
]
|
||
|
|
messages_dict['stage_1'].append(stage_one_message)
|
||
|
|
messages_dict['stage_2'].append(stage_two_message)
|
||
|
|
messages_dict['stage_3'].append(stage_three_message)
|
||
|
|
|
||
|
|
return messages_dict
|
||
|
|
|
||
|
|
def __iter__(self):
|
||
|
|
"""Iterate through batches."""
|
||
|
|
batch_size = self.processor.config.batch_size_triple
|
||
|
|
start_idx = self.processor.config.resume_from * batch_size
|
||
|
|
|
||
|
|
for i in tqdm(range(start_idx, len(self.processed_data), batch_size)):
|
||
|
|
batch_data = self.processed_data[i:i + batch_size]
|
||
|
|
|
||
|
|
# Prepare instructions
|
||
|
|
instructions = self.create_batch_instructions(batch_data)
|
||
|
|
|
||
|
|
# Extract batch information
|
||
|
|
batch_ids = [item["id"] for item in batch_data]
|
||
|
|
batch_metadata = [item["metadata"] for item in batch_data]
|
||
|
|
batch_texts = [item["text"] for item in batch_data]
|
||
|
|
|
||
|
|
yield instructions, batch_ids, batch_texts, batch_metadata
|
||
|
|
|
||
|
|
|
||
|
|
class OutputParser:
|
||
|
|
"""Parses model outputs and extracts structured data."""
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def extract_structured_data(self, outputs: List[str]) -> List[List[Dict[str, Any]]]:
|
||
|
|
"""Extract structured data from model outputs."""
|
||
|
|
results = []
|
||
|
|
for output in outputs:
|
||
|
|
parsed_data = json_repair.loads(output)
|
||
|
|
results.append(parsed_data)
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
|
||
|
|
class KnowledgeGraphExtractor:
|
||
|
|
"""Main class for knowledge graph extraction pipeline."""
|
||
|
|
|
||
|
|
def __init__(self, model:LLMGenerator, config: ProcessingConfig):
|
||
|
|
self.config = config
|
||
|
|
self.model = None
|
||
|
|
self.parser = None
|
||
|
|
self.model = model
|
||
|
|
self.model_name = model.model_name
|
||
|
|
self.parser = OutputParser()
|
||
|
|
|
||
|
|
def load_dataset(self) -> Any:
|
||
|
|
"""Load and prepare dataset."""
|
||
|
|
data_path = Path(self.config.data_directory)
|
||
|
|
all_files = os.listdir(data_path)
|
||
|
|
|
||
|
|
valid_files = [
|
||
|
|
filename for filename in all_files
|
||
|
|
if filename.startswith(self.config.filename_pattern) and
|
||
|
|
(filename.endswith(".json.gz") or filename.endswith(".json") or filename.endswith(".jsonl") or filename.endswith(".jsonl.gz"))
|
||
|
|
]
|
||
|
|
|
||
|
|
print(f"Found data files: {valid_files}")
|
||
|
|
data_files = valid_files
|
||
|
|
dataset_config = {"train": data_files}
|
||
|
|
return load_dataset(self.config.data_directory, data_files=dataset_config["train"])
|
||
|
|
|
||
|
|
def process_stage(self, instructions: Dict[str, str], stage = 1) -> Tuple[List[str], List[List[Dict[str, Any]]]]:
|
||
|
|
"""Process first stage: entity-relation extraction."""
|
||
|
|
outputs = self.model.triple_extraction(messages=instructions, max_tokens=self.config.max_new_tokens, stage=stage, record=self.config.record)
|
||
|
|
if self.config.record:
|
||
|
|
text_outputs = [output[0] for output in outputs]
|
||
|
|
else:
|
||
|
|
text_outputs = outputs
|
||
|
|
structured_data = self.parser.extract_structured_data(text_outputs)
|
||
|
|
return outputs, structured_data
|
||
|
|
|
||
|
|
def create_output_filename(self) -> str:
|
||
|
|
"""Create output filename with timestamp and shard info."""
|
||
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
|
|
model_name_safe = self.config.model_path.replace("/", "_")
|
||
|
|
|
||
|
|
filename = (f"{model_name_safe}_{self.config.filename_pattern}_output_"
|
||
|
|
f"{timestamp}_{self.config.current_shard_triple + 1}_in_{self.config.total_shards_triple}.json")
|
||
|
|
|
||
|
|
extraction_dir = os.path.join(self.config.output_directory, "kg_extraction")
|
||
|
|
os.makedirs(extraction_dir, exist_ok=True)
|
||
|
|
|
||
|
|
return os.path.join(extraction_dir, filename)
|
||
|
|
|
||
|
|
def prepare_result_dict(self, batch_data: Tuple, stage_outputs: Tuple, index: int) -> Dict[str, Any]:
|
||
|
|
"""Prepare result dictionary for a single sample."""
|
||
|
|
ids, original_texts, metadata = batch_data
|
||
|
|
(stage1_results, entity_relations), (stage2_results, event_entities), (stage3_results, event_relations) = stage_outputs
|
||
|
|
if self.config.record:
|
||
|
|
stage1_outputs = [output[0] for output in stage1_results]
|
||
|
|
stage1_usage = [output[1] for output in stage1_results]
|
||
|
|
stage2_outputs = [output[0] for output in stage2_results]
|
||
|
|
stage2_usage = [output[1] for output in stage2_results]
|
||
|
|
stage3_outputs = [output[0] for output in stage3_results]
|
||
|
|
stage3_usage = [output[1] for output in stage3_results]
|
||
|
|
else:
|
||
|
|
stage1_outputs = stage1_results
|
||
|
|
stage2_outputs = stage2_results
|
||
|
|
stage3_outputs = stage3_results
|
||
|
|
result = {
|
||
|
|
"id": ids[index],
|
||
|
|
"metadata": metadata[index],
|
||
|
|
"original_text": original_texts[index],
|
||
|
|
"entity_relation_dict": entity_relations[index],
|
||
|
|
"event_entity_relation_dict": event_entities[index],
|
||
|
|
"event_relation_dict": event_relations[index],
|
||
|
|
"output_stage_one": stage1_outputs[index],
|
||
|
|
"output_stage_two": stage2_outputs[index],
|
||
|
|
"output_stage_three": stage3_outputs[index],
|
||
|
|
}
|
||
|
|
if self.config.record:
|
||
|
|
result['usage_stage_one'] = stage1_usage[index]
|
||
|
|
result['usage_stage_two'] = stage2_usage[index]
|
||
|
|
result['usage_stage_three'] = stage3_usage[index]
|
||
|
|
|
||
|
|
# Handle date serialization
|
||
|
|
if 'date_download' in result['metadata']:
|
||
|
|
result['metadata']['date_download'] = str(result['metadata']['date_download'])
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
def debug_print_result(self, result: Dict[str, Any]):
|
||
|
|
"""Print result for debugging."""
|
||
|
|
for key, value in result.items():
|
||
|
|
print(f"{key}: {value}")
|
||
|
|
print("-" * 100)
|
||
|
|
|
||
|
|
def run_extraction(self):
|
||
|
|
"""Run the complete knowledge graph extraction pipeline."""
|
||
|
|
# Setup
|
||
|
|
os.makedirs(self.config.output_directory+'/kg_extraction', exist_ok=True)
|
||
|
|
dataset = self.load_dataset()
|
||
|
|
|
||
|
|
if self.config.debug_mode:
|
||
|
|
print("Debug mode: Processing only 20 samples")
|
||
|
|
|
||
|
|
# Create data processor and loader
|
||
|
|
processor = DatasetProcessor(self.config)
|
||
|
|
data_loader = CustomDataLoader(dataset["train"], processor)
|
||
|
|
|
||
|
|
output_file = self.create_output_filename()
|
||
|
|
print(f"Model: {self.config.model_path}")
|
||
|
|
|
||
|
|
batch_counter = 0
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
with open(output_file, "w") as output_stream:
|
||
|
|
for batch in data_loader:
|
||
|
|
batch_counter += 1
|
||
|
|
messages_dict, batch_ids, batch_texts, batch_metadata = batch
|
||
|
|
|
||
|
|
# Process all three stages
|
||
|
|
stage1_results = self.process_stage(messages_dict['stage_1'],1)
|
||
|
|
stage2_results = self.process_stage(messages_dict['stage_2'],2)
|
||
|
|
stage3_results = self.process_stage(messages_dict['stage_3'],3)
|
||
|
|
|
||
|
|
# Combine results
|
||
|
|
batch_data = (batch_ids, batch_texts, batch_metadata)
|
||
|
|
stage_outputs = (stage1_results, stage2_results, stage3_results)
|
||
|
|
|
||
|
|
# Write results
|
||
|
|
print(f"Processed {batch_counter} batches ({batch_counter * self.config.batch_size_triple} chunks)")
|
||
|
|
for i in range(len(batch_ids)):
|
||
|
|
result = self.prepare_result_dict(batch_data, stage_outputs, i)
|
||
|
|
|
||
|
|
if self.config.debug_mode:
|
||
|
|
self.debug_print_result(result)
|
||
|
|
|
||
|
|
output_stream.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||
|
|
output_stream.flush()
|
||
|
|
|
||
|
|
def convert_json_to_csv(self):
|
||
|
|
json2csv(
|
||
|
|
dataset = self.config.filename_pattern,
|
||
|
|
output_dir=f"{self.config.output_directory}/triples_csv",
|
||
|
|
data_dir=f"{self.config.output_directory}/kg_extraction"
|
||
|
|
)
|
||
|
|
csvs_to_temp_graphml(
|
||
|
|
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||
|
|
triple_edge_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||
|
|
config = self.config
|
||
|
|
)
|
||
|
|
|
||
|
|
def generate_concept_csv_temp(self, batch_size: int = None, **kwargs):
|
||
|
|
generate_concept(
|
||
|
|
model=self.model,
|
||
|
|
input_file=f"{self.config.output_directory}/triples_csv/missing_concepts_{self.config.filename_pattern}_from_json.csv",
|
||
|
|
output_folder=f"{self.config.output_directory}/concepts",
|
||
|
|
output_file="concept.json",
|
||
|
|
logging_file=f"{self.config.output_directory}/concepts/logging.txt",
|
||
|
|
config=self.config,
|
||
|
|
batch_size=batch_size if batch_size else self.config.batch_size_concept,
|
||
|
|
shard=self.config.current_shard_concept,
|
||
|
|
num_shards=self.config.total_shards_concept,
|
||
|
|
record = self.config.record,
|
||
|
|
**kwargs
|
||
|
|
)
|
||
|
|
|
||
|
|
def create_concept_csv(self):
|
||
|
|
merge_csv_files(
|
||
|
|
output_file=f"{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv",
|
||
|
|
input_dir=f"{self.config.output_directory}/concepts",
|
||
|
|
)
|
||
|
|
all_concept_triples_csv_to_csv(
|
||
|
|
node_file=f'{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv',
|
||
|
|
edge_file=f'{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv',
|
||
|
|
concepts_file=f'{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv',
|
||
|
|
output_node_file=f'{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||
|
|
output_edge_file=f'{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||
|
|
output_full_concept_triple_edges=f'{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
|
||
|
|
)
|
||
|
|
|
||
|
|
def convert_to_graphml(self):
|
||
|
|
csvs_to_graphml(
|
||
|
|
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||
|
|
text_node_file=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||
|
|
concept_node_file=f"{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||
|
|
triple_edge_file=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||
|
|
text_edge_file=f"{self.config.output_directory}/triples_csv/text_edges_{self.config.filename_pattern}_from_json.csv",
|
||
|
|
concept_edge_file=f"{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||
|
|
output_file=f"{self.config.output_directory}/kg_graphml/{self.config.filename_pattern}_graph.graphml",
|
||
|
|
)
|
||
|
|
|
||
|
|
def add_numeric_id(self):
|
||
|
|
add_csv_columns(
|
||
|
|
node_csv=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||
|
|
edge_csv=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||
|
|
text_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||
|
|
node_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
|
||
|
|
edge_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
|
||
|
|
text_with_numeric_id=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_numeric_id.csv",
|
||
|
|
)
|
||
|
|
|
||
|
|
def compute_kg_embedding(self, encoder_model:BaseEmbeddingModel, batch_size: int = 2048):
|
||
|
|
encoder_model.compute_kg_embedding(
|
||
|
|
node_csv_without_emb=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
|
||
|
|
node_csv_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
|
||
|
|
edge_csv_without_emb=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
|
||
|
|
edge_csv_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept_with_emb.csv",
|
||
|
|
text_node_csv_without_emb=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
|
||
|
|
text_node_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
|
||
|
|
batch_size = 2048
|
||
|
|
)
|
||
|
|
|
||
|
|
def create_faiss_index(self, index_type="HNSW,Flat"):
|
||
|
|
create_faiss_index(self.config.output_directory, self.config.filename_pattern, index_type)
|
||
|
|
|
||
|
|
def parse_command_line_arguments() -> ProcessingConfig:
|
||
|
|
"""Parse command line arguments and return configuration."""
|
||
|
|
parser = argparse.ArgumentParser(description="Knowledge Graph Extraction Pipeline")
|
||
|
|
|
||
|
|
parser.add_argument("-m", "--model", type=str, required=True,
|
||
|
|
default="meta-llama/Meta-Llama-3-8B-Instruct",
|
||
|
|
help="Model path for knowledge extraction")
|
||
|
|
parser.add_argument("--data_dir", type=str, default="your_data_dir",
|
||
|
|
help="Directory containing input data")
|
||
|
|
parser.add_argument("--file_name", type=str, default="en_simple_wiki_v0",
|
||
|
|
help="Filename pattern to match")
|
||
|
|
parser.add_argument("-b", "--batch_size", type=int, default=16,
|
||
|
|
help="Batch size for processing")
|
||
|
|
parser.add_argument("--output_dir", type=str, default="./generation_result_debug",
|
||
|
|
help="Output directory for results")
|
||
|
|
parser.add_argument("--total_shards_triple", type=int, default=1,
|
||
|
|
help="Total number of data shards")
|
||
|
|
parser.add_argument("--shard", type=int, default=0,
|
||
|
|
help="Current shard index")
|
||
|
|
parser.add_argument("--bit8", action="store_true",
|
||
|
|
help="Use 8-bit quantization")
|
||
|
|
parser.add_argument("--debug", action="store_true",
|
||
|
|
help="Enable debug mode")
|
||
|
|
parser.add_argument("--resume", type=int, default=0,
|
||
|
|
help="Resume from specific batch")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
return ProcessingConfig(
|
||
|
|
model_path=args.model,
|
||
|
|
data_directory=args.data_dir,
|
||
|
|
filename_pattern=args.file_name,
|
||
|
|
batch_size=args.batch_size,
|
||
|
|
output_directory=args.output_dir,
|
||
|
|
total_shards_triple=args.total_shards_triple,
|
||
|
|
current_shard_triple=args.shard,
|
||
|
|
use_8bit=args.bit8,
|
||
|
|
debug_mode=args.debug,
|
||
|
|
resume_from=args.resume
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Main entry point for the knowledge graph extraction pipeline."""
|
||
|
|
config = parse_command_line_arguments()
|
||
|
|
extractor = KnowledgeGraphExtractor(config)
|
||
|
|
extractor.run_extraction()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|