Files
AIEC-RAG/atlas_rag/kg_construction/triple_extraction.py

497 lines
24 KiB
Python
Raw Normal View History

2025-09-24 09:29:12 +08:00
#!/usr/bin/env python3
"""
Knowledge Graph Extraction Pipeline
Extracts entities, relations, and events from text data using transformer models.
"""
import re
import json
import os
import argparse
from datetime import datetime
from typing import List, Dict, Any, Tuple
from pathlib import Path
import torch
from datasets import load_dataset
from tqdm import tqdm
import json_repair
from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.kg_construction.utils.json_processing.json_to_csv import json2csv
from atlas_rag.kg_construction.concept_generation import generate_concept
from atlas_rag.kg_construction.utils.csv_processing.merge_csv import merge_csv_files
from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import csvs_to_graphml, csvs_to_temp_graphml
from atlas_rag.kg_construction.concept_to_csv import all_concept_triples_csv_to_csv
from atlas_rag.kg_construction.utils.csv_processing.csv_add_numeric_id import add_csv_columns
from atlas_rag.vectorstore.embedding_model import BaseEmbeddingModel
from atlas_rag.vectorstore.create_neo4j_index import create_faiss_index
from atlas_rag.llm_generator.prompt.triple_extraction_prompt import TRIPLE_INSTRUCTIONS
from atlas_rag.kg_construction.triple_config import ProcessingConfig
# Constants
TOKEN_LIMIT = 1024
INSTRUCTION_TOKEN_ESTIMATE = 200
CHAR_TO_TOKEN_RATIO = 3.5
class TextChunker:
"""Handles text chunking based on token limits."""
def __init__(self, max_tokens: int = TOKEN_LIMIT, instruction_tokens: int = INSTRUCTION_TOKEN_ESTIMATE):
self.max_tokens = max_tokens
self.instruction_tokens = instruction_tokens
self.char_ratio = CHAR_TO_TOKEN_RATIO
def calculate_max_chars(self) -> int:
"""Calculate maximum characters per chunk."""
available_tokens = self.max_tokens - self.instruction_tokens
return int(available_tokens * self.char_ratio)
def split_text(self, text: str) -> List[str]:
"""Split text into chunks that fit within token limits."""
max_chars = self.calculate_max_chars()
chunks = []
while len(text) > max_chars:
chunks.append(text[:max_chars])
text = text[max_chars:]
if text: # Add remaining text
chunks.append(text)
return chunks
class DatasetProcessor:
"""Processes and prepares dataset for knowledge graph extraction."""
def __init__(self, config: ProcessingConfig):
self.config = config
self.chunker = TextChunker()
def filter_language_content(self, sample: Dict[str, Any]) -> bool:
"""Check if content is in English."""
metadata = sample.get("metadata", {})
language = metadata.get("lang", "en") # Default to English if not specified
supported_languages = list(TRIPLE_INSTRUCTIONS.keys())
return language in supported_languages
def create_sample_chunks(self, sample: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Create chunks from a single sample."""
original_text = sample.get("text", "")
if self.config.remove_doc_spaces:
original_text = re.sub(r'\s+', ' ',original_text).strip()
text_chunks = self.chunker.split_text(original_text)
chunks = []
for chunk_idx, chunk_text in enumerate(text_chunks):
chunk_data = {
"id": sample["id"],
"text": chunk_text,
"chunk_id": chunk_idx,
"metadata": sample["metadata"]
}
chunks.append(chunk_data)
return chunks
def prepare_dataset(self, raw_dataset) -> List[Dict[str, Any]]:
"""Process raw dataset into chunks suitable for processing with generalized slicing."""
processed_samples = []
total_texts = len(raw_dataset)
# Handle edge cases
if total_texts == 0:
print(f"No texts found for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
return processed_samples
# Calculate base and remainder for fair distribution
base_texts_per_shard = total_texts // self.config.total_shards_triple
remainder = total_texts % self.config.total_shards_triple
# Calculate start index
if self.config.current_shard_triple < remainder:
start_idx = self.config.current_shard_triple * (base_texts_per_shard + 1)
else:
start_idx = remainder * (base_texts_per_shard + 1) + (self.config.current_shard_triple - remainder) * base_texts_per_shard
# Calculate end index
if self.config.current_shard_triple < remainder:
end_idx = start_idx + (base_texts_per_shard + 1)
else:
end_idx = start_idx + base_texts_per_shard
# Ensure indices are within bounds
start_idx = min(start_idx, total_texts)
end_idx = min(end_idx, total_texts)
print(f"Processing shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple} "
f"(texts {start_idx}-{end_idx-1} of {total_texts}, {end_idx - start_idx} documents)")
# Process documents in assigned shard
for idx in range(start_idx, end_idx):
sample = raw_dataset[idx]
# Filter by language
if not self.filter_language_content(sample):
print(f"Unsupported language in sample {idx}, skipping.")
continue
# Create chunks
chunks = self.create_sample_chunks(sample)
processed_samples.extend(chunks)
# Debug mode early termination
if self.config.debug_mode and len(processed_samples) >= 20:
print("Debug mode: Stopping at 20 chunks")
break
print(f"Generated {len(processed_samples)} chunks for shard {self.config.current_shard_triple+1}/{self.config.total_shards_triple}")
return processed_samples
class CustomDataLoader:
"""Custom data loader for knowledge graph extraction."""
def __init__(self, dataset, processor: DatasetProcessor):
self.raw_dataset = dataset
self.processor = processor
self.processed_data = processor.prepare_dataset(dataset)
self.stage_to_prompt_dict = {
"stage_1": "entity_relation",
"stage_2": "event_entity",
"stage_3": "event_relation"
}
def __len__(self) -> int:
return len(self.processed_data)
def create_batch_instructions(self, batch_data: List[Dict[str, Any]]) -> List[str]:
messages_dict = {
'stage_1': [],
'stage_2': [],
'stage_3': []
}
for item in batch_data:
# get language
language = item.get("metadata",{}).get("lang", "en")
system_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['system']
stage_1_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['entity_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n' + item["text"]
stage_2_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_entity'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
stage_3_msg = TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['event_relation'] + TRIPLE_INSTRUCTIONS.get(language, TRIPLE_INSTRUCTIONS["en"])['passage_start'] + '\n'+ item["text"]
stage_one_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_1_msg}
]
stage_two_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_2_msg}
]
stage_three_message = [
{"role": "system", "content": system_msg},
{"role": "user", "content": stage_3_msg}
]
messages_dict['stage_1'].append(stage_one_message)
messages_dict['stage_2'].append(stage_two_message)
messages_dict['stage_3'].append(stage_three_message)
return messages_dict
def __iter__(self):
"""Iterate through batches."""
batch_size = self.processor.config.batch_size_triple
start_idx = self.processor.config.resume_from * batch_size
for i in tqdm(range(start_idx, len(self.processed_data), batch_size)):
batch_data = self.processed_data[i:i + batch_size]
# Prepare instructions
instructions = self.create_batch_instructions(batch_data)
# Extract batch information
batch_ids = [item["id"] for item in batch_data]
batch_metadata = [item["metadata"] for item in batch_data]
batch_texts = [item["text"] for item in batch_data]
yield instructions, batch_ids, batch_texts, batch_metadata
class OutputParser:
"""Parses model outputs and extracts structured data."""
def __init__(self):
pass
def extract_structured_data(self, outputs: List[str]) -> List[List[Dict[str, Any]]]:
"""Extract structured data from model outputs."""
results = []
for output in outputs:
parsed_data = json_repair.loads(output)
results.append(parsed_data)
return results
class KnowledgeGraphExtractor:
"""Main class for knowledge graph extraction pipeline."""
def __init__(self, model:LLMGenerator, config: ProcessingConfig):
self.config = config
self.model = None
self.parser = None
self.model = model
self.model_name = model.model_name
self.parser = OutputParser()
def load_dataset(self) -> Any:
"""Load and prepare dataset."""
data_path = Path(self.config.data_directory)
all_files = os.listdir(data_path)
valid_files = [
filename for filename in all_files
if filename.startswith(self.config.filename_pattern) and
(filename.endswith(".json.gz") or filename.endswith(".json") or filename.endswith(".jsonl") or filename.endswith(".jsonl.gz"))
]
print(f"Found data files: {valid_files}")
data_files = valid_files
dataset_config = {"train": data_files}
return load_dataset(self.config.data_directory, data_files=dataset_config["train"])
def process_stage(self, instructions: Dict[str, str], stage = 1) -> Tuple[List[str], List[List[Dict[str, Any]]]]:
"""Process first stage: entity-relation extraction."""
outputs = self.model.triple_extraction(messages=instructions, max_tokens=self.config.max_new_tokens, stage=stage, record=self.config.record)
if self.config.record:
text_outputs = [output[0] for output in outputs]
else:
text_outputs = outputs
structured_data = self.parser.extract_structured_data(text_outputs)
return outputs, structured_data
def create_output_filename(self) -> str:
"""Create output filename with timestamp and shard info."""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
model_name_safe = self.config.model_path.replace("/", "_")
filename = (f"{model_name_safe}_{self.config.filename_pattern}_output_"
f"{timestamp}_{self.config.current_shard_triple + 1}_in_{self.config.total_shards_triple}.json")
extraction_dir = os.path.join(self.config.output_directory, "kg_extraction")
os.makedirs(extraction_dir, exist_ok=True)
return os.path.join(extraction_dir, filename)
def prepare_result_dict(self, batch_data: Tuple, stage_outputs: Tuple, index: int) -> Dict[str, Any]:
"""Prepare result dictionary for a single sample."""
ids, original_texts, metadata = batch_data
(stage1_results, entity_relations), (stage2_results, event_entities), (stage3_results, event_relations) = stage_outputs
if self.config.record:
stage1_outputs = [output[0] for output in stage1_results]
stage1_usage = [output[1] for output in stage1_results]
stage2_outputs = [output[0] for output in stage2_results]
stage2_usage = [output[1] for output in stage2_results]
stage3_outputs = [output[0] for output in stage3_results]
stage3_usage = [output[1] for output in stage3_results]
else:
stage1_outputs = stage1_results
stage2_outputs = stage2_results
stage3_outputs = stage3_results
result = {
"id": ids[index],
"metadata": metadata[index],
"original_text": original_texts[index],
"entity_relation_dict": entity_relations[index],
"event_entity_relation_dict": event_entities[index],
"event_relation_dict": event_relations[index],
"output_stage_one": stage1_outputs[index],
"output_stage_two": stage2_outputs[index],
"output_stage_three": stage3_outputs[index],
}
if self.config.record:
result['usage_stage_one'] = stage1_usage[index]
result['usage_stage_two'] = stage2_usage[index]
result['usage_stage_three'] = stage3_usage[index]
# Handle date serialization
if 'date_download' in result['metadata']:
result['metadata']['date_download'] = str(result['metadata']['date_download'])
return result
def debug_print_result(self, result: Dict[str, Any]):
"""Print result for debugging."""
for key, value in result.items():
print(f"{key}: {value}")
print("-" * 100)
def run_extraction(self):
"""Run the complete knowledge graph extraction pipeline."""
# Setup
os.makedirs(self.config.output_directory+'/kg_extraction', exist_ok=True)
dataset = self.load_dataset()
if self.config.debug_mode:
print("Debug mode: Processing only 20 samples")
# Create data processor and loader
processor = DatasetProcessor(self.config)
data_loader = CustomDataLoader(dataset["train"], processor)
output_file = self.create_output_filename()
print(f"Model: {self.config.model_path}")
batch_counter = 0
with torch.no_grad():
with open(output_file, "w") as output_stream:
for batch in data_loader:
batch_counter += 1
messages_dict, batch_ids, batch_texts, batch_metadata = batch
# Process all three stages
stage1_results = self.process_stage(messages_dict['stage_1'],1)
stage2_results = self.process_stage(messages_dict['stage_2'],2)
stage3_results = self.process_stage(messages_dict['stage_3'],3)
# Combine results
batch_data = (batch_ids, batch_texts, batch_metadata)
stage_outputs = (stage1_results, stage2_results, stage3_results)
# Write results
print(f"Processed {batch_counter} batches ({batch_counter * self.config.batch_size_triple} chunks)")
for i in range(len(batch_ids)):
result = self.prepare_result_dict(batch_data, stage_outputs, i)
if self.config.debug_mode:
self.debug_print_result(result)
output_stream.write(json.dumps(result, ensure_ascii=False) + "\n")
output_stream.flush()
def convert_json_to_csv(self):
json2csv(
dataset = self.config.filename_pattern,
output_dir=f"{self.config.output_directory}/triples_csv",
data_dir=f"{self.config.output_directory}/kg_extraction"
)
csvs_to_temp_graphml(
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
triple_edge_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv",
config = self.config
)
def generate_concept_csv_temp(self, batch_size: int = None, **kwargs):
generate_concept(
model=self.model,
input_file=f"{self.config.output_directory}/triples_csv/missing_concepts_{self.config.filename_pattern}_from_json.csv",
output_folder=f"{self.config.output_directory}/concepts",
output_file="concept.json",
logging_file=f"{self.config.output_directory}/concepts/logging.txt",
config=self.config,
batch_size=batch_size if batch_size else self.config.batch_size_concept,
shard=self.config.current_shard_concept,
num_shards=self.config.total_shards_concept,
record = self.config.record,
**kwargs
)
def create_concept_csv(self):
merge_csv_files(
output_file=f"{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv",
input_dir=f"{self.config.output_directory}/concepts",
)
all_concept_triples_csv_to_csv(
node_file=f'{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv',
edge_file=f'{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb.csv',
concepts_file=f'{self.config.output_directory}/triples_csv/{self.config.filename_pattern}_from_json_with_concept.csv',
output_node_file=f'{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv',
output_edge_file=f'{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
output_full_concept_triple_edges=f'{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv',
)
def convert_to_graphml(self):
csvs_to_graphml(
triple_node_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
text_node_file=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
concept_node_file=f"{self.config.output_directory}/concept_csv/concept_nodes_{self.config.filename_pattern}_from_json_with_concept.csv",
triple_edge_file=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
text_edge_file=f"{self.config.output_directory}/triples_csv/text_edges_{self.config.filename_pattern}_from_json.csv",
concept_edge_file=f"{self.config.output_directory}/concept_csv/concept_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
output_file=f"{self.config.output_directory}/kg_graphml/{self.config.filename_pattern}_graph.graphml",
)
def add_numeric_id(self):
add_csv_columns(
node_csv=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
edge_csv=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
text_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
node_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
edge_with_numeric_id=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_without_emb_with_numeric_id.csv",
text_with_numeric_id=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_numeric_id.csv",
)
def compute_kg_embedding(self, encoder_model:BaseEmbeddingModel, batch_size: int = 2048):
encoder_model.compute_kg_embedding(
node_csv_without_emb=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_without_emb.csv",
node_csv_file=f"{self.config.output_directory}/triples_csv/triple_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
edge_csv_without_emb=f"{self.config.output_directory}/concept_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept.csv",
edge_csv_file=f"{self.config.output_directory}/triples_csv/triple_edges_{self.config.filename_pattern}_from_json_with_concept_with_emb.csv",
text_node_csv_without_emb=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json.csv",
text_node_csv=f"{self.config.output_directory}/triples_csv/text_nodes_{self.config.filename_pattern}_from_json_with_emb.csv",
batch_size = 2048
)
def create_faiss_index(self, index_type="HNSW,Flat"):
create_faiss_index(self.config.output_directory, self.config.filename_pattern, index_type)
def parse_command_line_arguments() -> ProcessingConfig:
"""Parse command line arguments and return configuration."""
parser = argparse.ArgumentParser(description="Knowledge Graph Extraction Pipeline")
parser.add_argument("-m", "--model", type=str, required=True,
default="meta-llama/Meta-Llama-3-8B-Instruct",
help="Model path for knowledge extraction")
parser.add_argument("--data_dir", type=str, default="your_data_dir",
help="Directory containing input data")
parser.add_argument("--file_name", type=str, default="en_simple_wiki_v0",
help="Filename pattern to match")
parser.add_argument("-b", "--batch_size", type=int, default=16,
help="Batch size for processing")
parser.add_argument("--output_dir", type=str, default="./generation_result_debug",
help="Output directory for results")
parser.add_argument("--total_shards_triple", type=int, default=1,
help="Total number of data shards")
parser.add_argument("--shard", type=int, default=0,
help="Current shard index")
parser.add_argument("--bit8", action="store_true",
help="Use 8-bit quantization")
parser.add_argument("--debug", action="store_true",
help="Enable debug mode")
parser.add_argument("--resume", type=int, default=0,
help="Resume from specific batch")
args = parser.parse_args()
return ProcessingConfig(
model_path=args.model,
data_directory=args.data_dir,
filename_pattern=args.file_name,
batch_size=args.batch_size,
output_directory=args.output_dir,
total_shards_triple=args.total_shards_triple,
current_shard_triple=args.shard,
use_8bit=args.bit8,
debug_mode=args.debug,
resume_from=args.resume
)
def main():
"""Main entry point for the knowledge graph extraction pipeline."""
config = parse_command_line_arguments()
extractor = KnowledgeGraphExtractor(config)
extractor.run_extraction()
if __name__ == "__main__":
main()