Files
AIEC-RAG/atlas_rag/llm_generator/llm_generator_legacy.py
2025-09-24 09:29:12 +08:00

381 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor
import time
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
from atlas_rag.llm_generator.format.validate_json_output import validate_filter_output, messages as filter_messages
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, validate_keyword_output, keyword_filtering_prompt
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
from atlas_rag.llm_generator.format.validate_json_output import fix_and_validate_response
from transformers.pipelines import Pipeline
import jsonschema
from typing import Union
from logging import Logger
stage_to_prompt_type = {
1: "entity_relation",
2: "event_entity",
3: "event_relation",
}
retry_decorator = retry(
stop=(stop_after_delay(120) | stop_after_attempt(5)),
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
)
class LLMGenerator:
def __init__(self, client, model_name):
self.model_name = model_name
self.client: OpenAI | Pipeline = client
if isinstance(client, (OpenAI, AzureOpenAI)):
self.inference_type = "openai"
elif isinstance(client, Pipeline):
self.inference_type = "pipeline"
else:
raise ValueError("Unsupported client type6Please provide either an OpenAI client or a Huggingface Pipeline Object.")
@retry_decorator
def _generate_response(self, messages, do_sample=True, max_new_tokens=8192, temperature=0.7,
frequency_penalty=None, response_format={"type": "text"}, return_text_only=True,
return_thinking=False, reasoning_effort=None):
if temperature == 0.0:
do_sample = False
if self.inference_type == "openai":
start_time = time.time()
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_new_tokens,
temperature=temperature,
frequency_penalty=NOT_GIVEN if frequency_penalty is None else frequency_penalty,
response_format=response_format if response_format is not None else {"type": "text"},
timeout=120,
reasoning_effort=NOT_GIVEN if reasoning_effort is None else reasoning_effort,
)
time_cost = time.time() - start_time
content = response.choices[0].message.content
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
content = response.choices[0].message.reasoning_content
else:
content = response.choices[0].message.content
if '</think>' in content and not return_thinking:
content = content.split('</think>')[-1].strip()
else:
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None:
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
if return_text_only:
return content
else:
completion_usage_dict = response.usage.model_dump()
completion_usage_dict['time'] = time_cost
return content, completion_usage_dict
elif self.inference_type == "pipeline":
start_time = time.time()
if hasattr(self.client, 'tokenizer'):
input_text = self.client.tokenizer.apply_chat_template(messages, tokenize=False)
else:
input_text = "\n".join([msg["content"] for msg in messages])
response = self.client(
input_text,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
)
time_cost = time.time() - start_time
content = response[0]['generated_text'].strip()
if '</think>' in content and not return_thinking:
content = content.split('</think>')[-1].strip()
if return_text_only:
return content
else:
token_count = len(content.split())
completion_usage_dict = {
'completion_tokens': token_count,
'time': time_cost
}
return content, completion_usage_dict
def _generate_batch_responses(self, batch_messages, do_sample=True, max_new_tokens=8192,
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
return_text_only=True, return_thinking=False, reasoning_effort=None):
if self.inference_type == "openai":
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(
self._generate_response, messages, do_sample, max_new_tokens, temperature,
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort
) for messages in batch_messages
]
results = [future.result() for future in futures]
return results
elif self.inference_type == "pipeline":
if not hasattr(self.client, 'tokenizer'):
raise ValueError("Pipeline must have a tokenizer for batch processing.")
batch_inputs = [self.client.tokenizer.apply_chat_template(messages, tokenize=False) for messages in batch_messages]
start_time = time.time()
responses = self.client(
batch_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
)
time_cost = time.time() - start_time
contents = [resp['generated_text'].strip() for resp in responses]
if not return_thinking:
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
if return_text_only:
return contents
else:
usage_dicts = [{
'completion_tokens': len(content.split()),
'time': time_cost / len(batch_messages)
} for content in contents]
return list(zip(contents, usage_dicts))
def generate_cot(self, questions, max_new_tokens=1024):
if isinstance(questions, str):
messages = [{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
{"role": "user", "content": questions}]
return self._generate_response(messages, max_new_tokens=max_new_tokens)
elif isinstance(questions, list):
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
{"role": "user", "content": q}] for q in questions]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
def generate_with_context(self, question, context, max_new_tokens=1024, temperature=0.7):
if isinstance(question, str):
messages = [{"role": "system", "content": "".join(cot_system_instruction)},
{"role": "user", "content": f"{context}\n\n{question}\nThought:"}]
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
elif isinstance(question, list):
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction)},
{"role": "user", "content": f"{context}\n\n{q}\nThought:"}] for q in question]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature=0.7):
if isinstance(question, str):
messages = deepcopy(prompt_template)
messages.append({"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"})
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
elif isinstance(question, list):
batch_messages = [deepcopy(prompt_template) + [{"role": "user", "content": f"{context}\n\nQuestions:{q}\nThought:"}]
for q in question]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature=0.7):
if isinstance(question, str):
messages = [{"role": "system", "content": "".join(cot_system_instruction_kg)},
{"role": "user", "content": f"{context}\n\n{question}"}]
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
elif isinstance(question, list):
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_kg)},
{"role": "user", "content": f"{context}\n\n{q}"}] for q in question]
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
@retry_decorator
def filter_triples_with_entity(self, question, nodes, max_new_tokens=1024):
if isinstance(question, str):
messages = [{"role": "system", "content": """
Your task is to filter text candidates based on their relevance to a given query...
"""}, {"role": "user", "content": f"{question} \n Output Before Filter: {nodes} \n Output After Filter:"}]
try:
response = json.loads(self._generate_response(messages, max_new_tokens=max_new_tokens))
return response
except Exception:
return json.loads(nodes)
elif isinstance(question, list):
batch_messages = [[{"role": "system", "content": """
Your task is to filter text candidates based on their relevance to a given query...
"""}, {"role": "user", "content": f"{q} \n Output Before Filter: {nodes} \n Output After Filter:"}]
for q in question]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
return [json.loads(resp) if json.loads(resp) else json.loads(nodes) for resp in responses]
@retry_decorator
def filter_triples_with_entity_event(self, question, triples):
if isinstance(question, str):
messages = deepcopy(filter_messages)
messages.append({"role": "user", "content": f"[ ## question ## ]]\n{question}\n[[ ## fact_before_filter ## ]]\n{triples}"})
try:
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
cleaned_data = validate_filter_output(response)
return cleaned_data['fact']
except Exception:
return []
elif isinstance(question, list):
batch_messages = [deepcopy(filter_messages) + [{"role": "user", "content": f"[ ## question ## ]]\n{q}\n[[ ## fact_before_filter ## ]]\n{triples}"}]
for q in question]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
return [validate_filter_output(resp)['fact'] if validate_filter_output(resp) else [] for resp in responses]
def generate_with_custom_messages(self, custom_messages, do_sample=True, max_new_tokens=1024, temperature=0.8, frequency_penalty=None):
if isinstance(custom_messages[0], dict):
return self._generate_response(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
elif isinstance(custom_messages[0], list):
return self._generate_batch_responses(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_filter_keywords_with_entity(self, question, keywords):
if isinstance(question, str):
messages = deepcopy(keyword_filtering_prompt)
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{question}\n[[ ## keywords_before_filter ## ]]\n{keywords}"})
try:
response = self._generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
cleaned_data = validate_keyword_output(response)
return cleaned_data['keywords']
except Exception:
return keywords
elif isinstance(question, list):
batch_messages = [deepcopy(keyword_filtering_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{q}\n[[ ## keywords_before_filter ## ]]\n{k}"}]
for q, k in zip(question, keywords)]
responses = self._generate_batch_responses(batch_messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else keywords for resp in responses]
def ner(self, text):
if isinstance(text, str):
messages = [{"role": "system", "content": "Please extract the entities..."},
{"role": "user", "content": f"Extract the named entities from: {text}"}]
return self._generate_response(messages)
elif isinstance(text, list):
batch_messages = [[{"role": "system", "content": "Please extract the entities..."},
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
return self._generate_batch_responses(batch_messages)
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_ner(self, text):
if isinstance(text, str):
messages = deepcopy(ner_prompt)
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{text}"})
try:
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
cleaned_data = validate_keyword_output(response)
return cleaned_data['keywords']
except Exception:
return []
elif isinstance(text, list):
batch_messages = [deepcopy(ner_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{t}"}] for t in text]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
def large_kg_tog_ner(self, text):
if isinstance(text, str):
messages = [{"role": "system", "content": "You are an advanced AI assistant..."},
{"role": "user", "content": f"Extract the named entities from: {text}"}]
try:
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
cleaned_data = validate_keyword_output(response)
return cleaned_data['keywords']
except Exception:
return []
elif isinstance(text, list):
batch_messages = [[{"role": "system", "content": "You are an advanced AI assistant..."},
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
# Implementation remains single-input focused as its iterative; batching not applicable here
react_system_instruction = (
'You are an advanced AI assistant that uses the ReAct framework...'
)
full_context = []
if search_history:
for i, (thought, action, observation) in enumerate(search_history):
full_context.append(f"\nPrevious search attempt {i}:\n{action}\n Result: {observation}\n")
if context:
full_context.append(f"Current Retrieved Context:\n{context}\n")
messages = [{"role": "system", "content": react_system_instruction},
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
if full_context else f"Question: {question}"}]
return self._generate_response(messages, max_new_tokens=max_new_tokens)
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'],
max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
# Single-input iterative process; batching not applicable
search_history = []
if isinstance(retriever, BaseEdgeRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = ". ".join(initial_context)
elif isinstance(retriever, BasePassageRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = "\n".join(initial_context)
for iteration in range(max_iterations):
analysis_response = self.generate_with_react(
question=question, context=current_context, max_new_tokens=max_new_tokens, search_history=search_history, logger=logger
)
try:
thought = analysis_response.split("Thought:")[1].split("\n")[0]
action = analysis_response.split("Action:")[1].split("\n")[0]
answer = analysis_response.split("Answer:")[1].strip()
if answer.lower() != "need more information":
search_history.append((thought, action, "Using current context"))
return answer, search_history
if "search" in action.lower():
search_query = action.split("search for")[-1].strip()
if isinstance(retriever, BaseEdgeRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
current_contexts = current_context.split(". ")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = ". ".join(new_context)
elif isinstance(retriever, BasePassageRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
current_contexts = current_context.split("\n")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = "\n".join(new_context)
observation = f"Found information: {new_context}" if new_context else "No new information found..."
search_history.append((thought, action, observation))
if new_context:
current_context = f"{current_context}\n{new_context}"
else:
search_history.append((thought, action, "No action taken but answer not found"))
return "Unable to find answer", search_history
except Exception as e:
return analysis_response, search_history
return answer, search_history
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False):
if isinstance(messages[0], dict):
messages = [messages]
responses = self._generate_batch_responses(
batch_messages=messages,
max_new_tokens=max_tokens,
temperature=0.0,
do_sample=False,
frequency_penalty=0.5,
reasoning_effort="none",
return_text_only=not record
)
processed_responses = []
for response in responses:
if record:
content, usage_dict = response
else:
content = response
usage_dict = None
try:
prompt_type = stage_to_prompt_type.get(stage, None)
if prompt_type:
corrected, error = fix_and_validate_response(content, prompt_type)
if error:
raise ValueError(f"Validation failed for prompt_type '{prompt_type}'")
else:
corrected = content
if corrected and corrected.strip():
if record:
processed_responses.append((corrected, usage_dict))
else:
processed_responses.append(corrected)
else:
raise ValueError("Invalid response")
except Exception as e:
print(f"Failed to process response: {str(e)}")
if record:
usage_dict = {'completion_tokens': 0, 'total_tokens': 0, 'time': 0}
processed_responses.append(("[]", usage_dict))
else:
processed_responses.append("[]")
return processed_responses