from tqdm import tqdm import random import logging import csv import os import hashlib import re from atlas_rag.llm_generator import LLMGenerator from atlas_rag.kg_construction.triple_config import ProcessingConfig from atlas_rag.kg_construction.utils.csv_processing.csv_to_graphml import get_node_id from atlas_rag.llm_generator.prompt.triple_extraction_prompt import CONCEPT_INSTRUCTIONS import pickle # Increase the field size limit csv.field_size_limit(10 * 1024 * 1024) # 10 MB limit def build_batch_data(sessions, batch_size): batched_sessions = [] for i in range(0, len(sessions), batch_size): batched_sessions.append(sessions[i:i+batch_size]) return batched_sessions # Function to compute a hash ID from text def compute_hash_id(text): # Use SHA-256 to generate a hash hash_object = hashlib.sha256(text.encode('utf-8')) return hash_object.hexdigest() # Return hash as a hex string def convert_attribute(value): """ Convert attributes to GDS-compatible types. """ if isinstance(value, list): return [str(v) for v in value] elif isinstance(value, (int, float)): return value else: return str(value) def clean_text(text): # remove NUL as well new_text = text.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\v", " ").replace("\f", " ").replace("\b", " ").replace("\a", " ").replace("\e", " ").replace(";", ",") new_text = new_text.replace("\x00", "") new_text = re.sub(r'\s+', ' ', new_text).strip() return new_text def remove_NUL(text): return text.replace("\x00", "") def build_batched_events(all_node_list, batch_size): "The types are in Entity Event Relation" event_nodes = [node[0] for node in all_node_list if node[1].lower() == "event"] batched_events = [] for i in range(0, len(event_nodes), batch_size): batched_events.append(event_nodes[i:i+batch_size]) return batched_events def build_batched_entities(all_node_list, batch_size): entity_nodes = [node[0] for node in all_node_list if node[1].lower() == "entity"] batched_entities = [] for i in range(0, len(entity_nodes), batch_size): batched_entities.append(entity_nodes[i:i+batch_size]) return batched_entities def build_batched_relations(all_node_list, batch_size): relations = [node[0] for node in all_node_list if node[1].lower() == "relation"] # relations = list(set(relations)) batched_relations = [] for i in range(0, len(relations), batch_size): batched_relations.append(relations[i:i+batch_size]) return batched_relations def batched_inference(model:LLMGenerator, inputs, record=False, **kwargs): responses = model.generate_response(inputs, return_text_only = not record, **kwargs) answers = [] if record: text_responses = [response[0] for response in responses] usages = [response[1] for response in responses] else: text_responses = responses for i in range(len(text_responses)): answer = text_responses[i] answers.append([x.strip().lower() for x in answer.split(",")]) if record: return answers, usages else: return answers def load_data_with_shard(input_file, shard_idx, num_shards): with open(input_file, "r") as f: csv_reader = list(csv.reader(f)) # data = csv_reader data = csv_reader[1:] # Random shuffle the data before splitting into shards random.shuffle(data) total_lines = len(data) lines_per_shard = (total_lines + num_shards - 1) // num_shards start_idx = shard_idx * lines_per_shard end_idx = min((shard_idx + 1) * lines_per_shard, total_lines) return data[start_idx:end_idx] def generate_concept(model: LLMGenerator, input_file = 'processed_data/triples_csv', output_folder = 'processed_data/triples_conceptualized', output_file = 'output.json', logging_file = 'processed_data/logging.txt', config:ProcessingConfig=None, batch_size=32, shard=0, num_shards=1, **kwargs): log_dir = os.path.dirname(logging_file) if log_dir and not os.path.exists(log_dir): os.makedirs(log_dir) # Create the log file if it doesn't exist if not os.path.exists(logging_file): open(logging_file, 'w').close() language = kwargs.get('language', 'en') record = kwargs.get('record', False) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler(logging_file) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logging.getLogger().addHandler(file_handler) with open(f"{config.output_directory}/kg_graphml/{config.filename_pattern}_without_concept.pkl", "rb") as f: temp_kg = pickle.load(f) # read data if not os.path.exists(output_folder): os.makedirs(output_folder) all_missing_nodes = load_data_with_shard( input_file, shard_idx=shard, num_shards=num_shards ) batched_events = build_batched_events(all_missing_nodes, batch_size) batched_entities = build_batched_entities(all_missing_nodes, batch_size) batched_relations = build_batched_relations(all_missing_nodes, batch_size) all_batches = [] all_batches.extend(('event', batch) for batch in batched_events) all_batches.extend(('entity', batch) for batch in batched_entities) all_batches.extend(('relation', batch) for batch in batched_relations) print("all_batches", len(all_batches)) output_file = output_folder + f"/{output_file.rsplit('.', 1)[0]}_shard_{shard}.csv" with open(output_file, "w", newline='') as file: csv_writer = csv.writer(file) csv_writer.writerow(["node", "conceptualized_node", "node_type"]) # for batch_type, batch in tqdm(all_batches, total=total_batches, desc="Generating concepts"): # don't use tqdm for now for batch_type, batch in tqdm(all_batches, desc="Shard_{}".format(shard)): # print("batch_type", batch_type) # print("batch", batch) replace_context_token = None if batch_type == 'event': template = CONCEPT_INSTRUCTIONS[language]['event'] node_type = 'event' replace_token = '[EVENT]' elif batch_type == 'entity': template = CONCEPT_INSTRUCTIONS[language]['entity'] node_type = 'entity' replace_token = '[ENTITY]' replace_context_token = '[CONTEXT]' elif batch_type == 'relation': template = CONCEPT_INSTRUCTIONS[language]['relation'] node_type = 'relation' replace_token = '[RELATION]' inputs = [] for node in batch: # sample node from given node and replace context token. if replace_context_token: node_id = get_node_id(node) entity_predecessors = list(temp_kg.predecessors(node_id)) entity_successors = list(temp_kg.successors(node_id)) context = "" if len(entity_predecessors) > 0: random_two_neighbors = random.sample(entity_predecessors, min(1, len(entity_predecessors))) context += ", ".join([f"{temp_kg.nodes[neighbor]['id']} {temp_kg[neighbor][node_id]['relation']}" for neighbor in random_two_neighbors]) if len(entity_successors) > 0: random_two_neighbors = random.sample(entity_successors, min(1, len(entity_successors))) context += ", ".join([f"{temp_kg[node_id][neighbor]['relation']} {temp_kg.nodes[neighbor]['id']}" for neighbor in random_two_neighbors]) prompt = template.replace(replace_token, node).replace(replace_context_token, context) else: prompt = template.replace(replace_token, node) constructed_input = [ {"role": "system", "content": "You are a helpful AI assistant."}, {"role": "user", "content": f"{prompt}"}, ] inputs.append(constructed_input) try: # print("inputs", inputs) if record: # If recording, we will get both answers and responses answers, usages = batched_inference(model, inputs, record=record, max_workers = config.max_workers) else: answers = batched_inference(model, inputs, record=record, max_workers = config.max_workers) usages = None # print("answers", answers) except Exception as e: logging.error(f"Error processing {batch_type} batch: {e}") raise e # try: # answers = batched_inference(llm, sampling_params, inputs) # except Exception as e: # logging.error(f"Error processing {batch_type} batch: {e}") # continue for i,(node, answer) in enumerate(zip(batch, answers)): # print(node, answer, node_type) if usages is not None: logging.info(f"Usage log: Node {node}, completion_usage: {usages[i]}") csv_writer.writerow([node, ", ".join(answer), node_type]) file.flush() # count unique conceptualized nodes conceptualized_nodes = [] conceptualized_events = [] conceptualized_entities = [] conceptualized_relations = [] with open(output_file, "r") as file: reader = csv.reader(file) next(reader) for row in reader: conceptualized_nodes.extend(row[1].split(",")) if row[2] == "event": conceptualized_events.extend(row[1].split(",")) elif row[2] == "entity": conceptualized_entities.extend(row[1].split(",")) elif row[2] == "relation": conceptualized_relations.extend(row[1].split(",")) conceptualized_nodes = [x.strip() for x in conceptualized_nodes] conceptualized_events = [x.strip() for x in conceptualized_events] conceptualized_entities = [x.strip() for x in conceptualized_entities] conceptualized_relations = [x.strip() for x in conceptualized_relations] unique_conceptualized_nodes = list(set(conceptualized_nodes)) unique_conceptualized_events = list(set(conceptualized_events)) unique_conceptualized_entities = list(set(conceptualized_entities)) unique_conceptualized_relations = list(set(conceptualized_relations)) print(f"Number of unique conceptualized nodes: {len(unique_conceptualized_nodes)}") print(f"Number of unique conceptualized events: {len(unique_conceptualized_events)}") print(f"Number of unique conceptualized entities: {len(unique_conceptualized_entities)}") print(f"Number of unique conceptualized relations: {len(unique_conceptualized_relations)}") return