381 lines
22 KiB
Python
381 lines
22 KiB
Python
import json
|
||
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
|
||
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
|
||
from copy import deepcopy
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import time
|
||
|
||
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
|
||
from atlas_rag.llm_generator.format.validate_json_output import validate_filter_output, messages as filter_messages
|
||
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, validate_keyword_output, keyword_filtering_prompt
|
||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||
from atlas_rag.llm_generator.format.validate_json_output import fix_and_validate_response
|
||
|
||
from transformers.pipelines import Pipeline
|
||
import jsonschema
|
||
from typing import Union
|
||
from logging import Logger
|
||
|
||
stage_to_prompt_type = {
|
||
1: "entity_relation",
|
||
2: "event_entity",
|
||
3: "event_relation",
|
||
}
|
||
retry_decorator = retry(
|
||
stop=(stop_after_delay(120) | stop_after_attempt(5)),
|
||
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
|
||
)
|
||
|
||
class LLMGenerator:
|
||
def __init__(self, client, model_name):
|
||
self.model_name = model_name
|
||
self.client: OpenAI | Pipeline = client
|
||
if isinstance(client, (OpenAI, AzureOpenAI)):
|
||
self.inference_type = "openai"
|
||
elif isinstance(client, Pipeline):
|
||
self.inference_type = "pipeline"
|
||
else:
|
||
raise ValueError("Unsupported client type6Please provide either an OpenAI client or a Huggingface Pipeline Object.")
|
||
|
||
@retry_decorator
|
||
def _generate_response(self, messages, do_sample=True, max_new_tokens=8192, temperature=0.7,
|
||
frequency_penalty=None, response_format={"type": "text"}, return_text_only=True,
|
||
return_thinking=False, reasoning_effort=None):
|
||
if temperature == 0.0:
|
||
do_sample = False
|
||
if self.inference_type == "openai":
|
||
start_time = time.time()
|
||
response = self.client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=messages,
|
||
max_tokens=max_new_tokens,
|
||
temperature=temperature,
|
||
frequency_penalty=NOT_GIVEN if frequency_penalty is None else frequency_penalty,
|
||
response_format=response_format if response_format is not None else {"type": "text"},
|
||
timeout=120,
|
||
reasoning_effort=NOT_GIVEN if reasoning_effort is None else reasoning_effort,
|
||
)
|
||
time_cost = time.time() - start_time
|
||
content = response.choices[0].message.content
|
||
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
|
||
content = response.choices[0].message.reasoning_content
|
||
else:
|
||
content = response.choices[0].message.content
|
||
if '</think>' in content and not return_thinking:
|
||
content = content.split('</think>')[-1].strip()
|
||
else:
|
||
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None:
|
||
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
|
||
if return_text_only:
|
||
return content
|
||
else:
|
||
completion_usage_dict = response.usage.model_dump()
|
||
completion_usage_dict['time'] = time_cost
|
||
return content, completion_usage_dict
|
||
elif self.inference_type == "pipeline":
|
||
start_time = time.time()
|
||
if hasattr(self.client, 'tokenizer'):
|
||
input_text = self.client.tokenizer.apply_chat_template(messages, tokenize=False)
|
||
else:
|
||
input_text = "\n".join([msg["content"] for msg in messages])
|
||
response = self.client(
|
||
input_text,
|
||
max_new_tokens=max_new_tokens,
|
||
temperature=temperature,
|
||
do_sample=do_sample,
|
||
)
|
||
time_cost = time.time() - start_time
|
||
content = response[0]['generated_text'].strip()
|
||
if '</think>' in content and not return_thinking:
|
||
content = content.split('</think>')[-1].strip()
|
||
if return_text_only:
|
||
return content
|
||
else:
|
||
token_count = len(content.split())
|
||
completion_usage_dict = {
|
||
'completion_tokens': token_count,
|
||
'time': time_cost
|
||
}
|
||
return content, completion_usage_dict
|
||
|
||
def _generate_batch_responses(self, batch_messages, do_sample=True, max_new_tokens=8192,
|
||
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
|
||
return_text_only=True, return_thinking=False, reasoning_effort=None):
|
||
if self.inference_type == "openai":
|
||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||
futures = [
|
||
executor.submit(
|
||
self._generate_response, messages, do_sample, max_new_tokens, temperature,
|
||
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort
|
||
) for messages in batch_messages
|
||
]
|
||
results = [future.result() for future in futures]
|
||
return results
|
||
elif self.inference_type == "pipeline":
|
||
if not hasattr(self.client, 'tokenizer'):
|
||
raise ValueError("Pipeline must have a tokenizer for batch processing.")
|
||
batch_inputs = [self.client.tokenizer.apply_chat_template(messages, tokenize=False) for messages in batch_messages]
|
||
start_time = time.time()
|
||
responses = self.client(
|
||
batch_inputs,
|
||
max_new_tokens=max_new_tokens,
|
||
temperature=temperature,
|
||
do_sample=do_sample,
|
||
)
|
||
time_cost = time.time() - start_time
|
||
contents = [resp['generated_text'].strip() for resp in responses]
|
||
if not return_thinking:
|
||
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
|
||
if return_text_only:
|
||
return contents
|
||
else:
|
||
usage_dicts = [{
|
||
'completion_tokens': len(content.split()),
|
||
'time': time_cost / len(batch_messages)
|
||
} for content in contents]
|
||
return list(zip(contents, usage_dicts))
|
||
|
||
def generate_cot(self, questions, max_new_tokens=1024):
|
||
if isinstance(questions, str):
|
||
messages = [{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||
{"role": "user", "content": questions}]
|
||
return self._generate_response(messages, max_new_tokens=max_new_tokens)
|
||
elif isinstance(questions, list):
|
||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||
{"role": "user", "content": q}] for q in questions]
|
||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
|
||
|
||
def generate_with_context(self, question, context, max_new_tokens=1024, temperature=0.7):
|
||
if isinstance(question, str):
|
||
messages = [{"role": "system", "content": "".join(cot_system_instruction)},
|
||
{"role": "user", "content": f"{context}\n\n{question}\nThought:"}]
|
||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||
elif isinstance(question, list):
|
||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction)},
|
||
{"role": "user", "content": f"{context}\n\n{q}\nThought:"}] for q in question]
|
||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||
|
||
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature=0.7):
|
||
if isinstance(question, str):
|
||
messages = deepcopy(prompt_template)
|
||
messages.append({"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"})
|
||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||
elif isinstance(question, list):
|
||
batch_messages = [deepcopy(prompt_template) + [{"role": "user", "content": f"{context}\n\nQuestions:{q}\nThought:"}]
|
||
for q in question]
|
||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||
|
||
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature=0.7):
|
||
if isinstance(question, str):
|
||
messages = [{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||
{"role": "user", "content": f"{context}\n\n{question}"}]
|
||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||
elif isinstance(question, list):
|
||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||
{"role": "user", "content": f"{context}\n\n{q}"}] for q in question]
|
||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||
|
||
@retry_decorator
|
||
def filter_triples_with_entity(self, question, nodes, max_new_tokens=1024):
|
||
if isinstance(question, str):
|
||
messages = [{"role": "system", "content": """
|
||
Your task is to filter text candidates based on their relevance to a given query...
|
||
"""}, {"role": "user", "content": f"{question} \n Output Before Filter: {nodes} \n Output After Filter:"}]
|
||
try:
|
||
response = json.loads(self._generate_response(messages, max_new_tokens=max_new_tokens))
|
||
return response
|
||
except Exception:
|
||
return json.loads(nodes)
|
||
elif isinstance(question, list):
|
||
batch_messages = [[{"role": "system", "content": """
|
||
Your task is to filter text candidates based on their relevance to a given query...
|
||
"""}, {"role": "user", "content": f"{q} \n Output Before Filter: {nodes} \n Output After Filter:"}]
|
||
for q in question]
|
||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
|
||
return [json.loads(resp) if json.loads(resp) else json.loads(nodes) for resp in responses]
|
||
|
||
@retry_decorator
|
||
def filter_triples_with_entity_event(self, question, triples):
|
||
if isinstance(question, str):
|
||
messages = deepcopy(filter_messages)
|
||
messages.append({"role": "user", "content": f"[ ## question ## ]]\n{question}\n[[ ## fact_before_filter ## ]]\n{triples}"})
|
||
try:
|
||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
|
||
cleaned_data = validate_filter_output(response)
|
||
return cleaned_data['fact']
|
||
except Exception:
|
||
return []
|
||
elif isinstance(question, list):
|
||
batch_messages = [deepcopy(filter_messages) + [{"role": "user", "content": f"[ ## question ## ]]\n{q}\n[[ ## fact_before_filter ## ]]\n{triples}"}]
|
||
for q in question]
|
||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
|
||
return [validate_filter_output(resp)['fact'] if validate_filter_output(resp) else [] for resp in responses]
|
||
|
||
def generate_with_custom_messages(self, custom_messages, do_sample=True, max_new_tokens=1024, temperature=0.8, frequency_penalty=None):
|
||
if isinstance(custom_messages[0], dict):
|
||
return self._generate_response(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
|
||
elif isinstance(custom_messages[0], list):
|
||
return self._generate_batch_responses(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
|
||
|
||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||
def large_kg_filter_keywords_with_entity(self, question, keywords):
|
||
if isinstance(question, str):
|
||
messages = deepcopy(keyword_filtering_prompt)
|
||
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{question}\n[[ ## keywords_before_filter ## ]]\n{keywords}"})
|
||
try:
|
||
response = self._generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||
cleaned_data = validate_keyword_output(response)
|
||
return cleaned_data['keywords']
|
||
except Exception:
|
||
return keywords
|
||
elif isinstance(question, list):
|
||
batch_messages = [deepcopy(keyword_filtering_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{q}\n[[ ## keywords_before_filter ## ]]\n{k}"}]
|
||
for q, k in zip(question, keywords)]
|
||
responses = self._generate_batch_responses(batch_messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else keywords for resp in responses]
|
||
|
||
def ner(self, text):
|
||
if isinstance(text, str):
|
||
messages = [{"role": "system", "content": "Please extract the entities..."},
|
||
{"role": "user", "content": f"Extract the named entities from: {text}"}]
|
||
return self._generate_response(messages)
|
||
elif isinstance(text, list):
|
||
batch_messages = [[{"role": "system", "content": "Please extract the entities..."},
|
||
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
|
||
return self._generate_batch_responses(batch_messages)
|
||
|
||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||
def large_kg_ner(self, text):
|
||
if isinstance(text, str):
|
||
messages = deepcopy(ner_prompt)
|
||
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{text}"})
|
||
try:
|
||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||
cleaned_data = validate_keyword_output(response)
|
||
return cleaned_data['keywords']
|
||
except Exception:
|
||
return []
|
||
elif isinstance(text, list):
|
||
batch_messages = [deepcopy(ner_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{t}"}] for t in text]
|
||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
|
||
|
||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||
def large_kg_tog_ner(self, text):
|
||
if isinstance(text, str):
|
||
messages = [{"role": "system", "content": "You are an advanced AI assistant..."},
|
||
{"role": "user", "content": f"Extract the named entities from: {text}"}]
|
||
try:
|
||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||
cleaned_data = validate_keyword_output(response)
|
||
return cleaned_data['keywords']
|
||
except Exception:
|
||
return []
|
||
elif isinstance(text, list):
|
||
batch_messages = [[{"role": "system", "content": "You are an advanced AI assistant..."},
|
||
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
|
||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
|
||
|
||
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
|
||
# Implementation remains single-input focused as it’s iterative; batching not applicable here
|
||
react_system_instruction = (
|
||
'You are an advanced AI assistant that uses the ReAct framework...'
|
||
)
|
||
full_context = []
|
||
if search_history:
|
||
for i, (thought, action, observation) in enumerate(search_history):
|
||
full_context.append(f"\nPrevious search attempt {i}:\n{action}\n Result: {observation}\n")
|
||
if context:
|
||
full_context.append(f"Current Retrieved Context:\n{context}\n")
|
||
messages = [{"role": "system", "content": react_system_instruction},
|
||
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
|
||
if full_context else f"Question: {question}"}]
|
||
return self._generate_response(messages, max_new_tokens=max_new_tokens)
|
||
|
||
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'],
|
||
max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
|
||
# Single-input iterative process; batching not applicable
|
||
search_history = []
|
||
if isinstance(retriever, BaseEdgeRetriever):
|
||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||
current_context = ". ".join(initial_context)
|
||
elif isinstance(retriever, BasePassageRetriever):
|
||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||
current_context = "\n".join(initial_context)
|
||
for iteration in range(max_iterations):
|
||
analysis_response = self.generate_with_react(
|
||
question=question, context=current_context, max_new_tokens=max_new_tokens, search_history=search_history, logger=logger
|
||
)
|
||
try:
|
||
thought = analysis_response.split("Thought:")[1].split("\n")[0]
|
||
action = analysis_response.split("Action:")[1].split("\n")[0]
|
||
answer = analysis_response.split("Answer:")[1].strip()
|
||
if answer.lower() != "need more information":
|
||
search_history.append((thought, action, "Using current context"))
|
||
return answer, search_history
|
||
if "search" in action.lower():
|
||
search_query = action.split("search for")[-1].strip()
|
||
if isinstance(retriever, BaseEdgeRetriever):
|
||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||
current_contexts = current_context.split(". ")
|
||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||
new_context = ". ".join(new_context)
|
||
elif isinstance(retriever, BasePassageRetriever):
|
||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||
current_contexts = current_context.split("\n")
|
||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||
new_context = "\n".join(new_context)
|
||
observation = f"Found information: {new_context}" if new_context else "No new information found..."
|
||
search_history.append((thought, action, observation))
|
||
if new_context:
|
||
current_context = f"{current_context}\n{new_context}"
|
||
else:
|
||
search_history.append((thought, action, "No action taken but answer not found"))
|
||
return "Unable to find answer", search_history
|
||
except Exception as e:
|
||
return analysis_response, search_history
|
||
return answer, search_history
|
||
|
||
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False):
|
||
if isinstance(messages[0], dict):
|
||
messages = [messages]
|
||
responses = self._generate_batch_responses(
|
||
batch_messages=messages,
|
||
max_new_tokens=max_tokens,
|
||
temperature=0.0,
|
||
do_sample=False,
|
||
frequency_penalty=0.5,
|
||
reasoning_effort="none",
|
||
return_text_only=not record
|
||
)
|
||
processed_responses = []
|
||
for response in responses:
|
||
if record:
|
||
content, usage_dict = response
|
||
else:
|
||
content = response
|
||
usage_dict = None
|
||
try:
|
||
prompt_type = stage_to_prompt_type.get(stage, None)
|
||
if prompt_type:
|
||
corrected, error = fix_and_validate_response(content, prompt_type)
|
||
if error:
|
||
raise ValueError(f"Validation failed for prompt_type '{prompt_type}'")
|
||
else:
|
||
corrected = content
|
||
if corrected and corrected.strip():
|
||
if record:
|
||
processed_responses.append((corrected, usage_dict))
|
||
else:
|
||
processed_responses.append(corrected)
|
||
else:
|
||
raise ValueError("Invalid response")
|
||
except Exception as e:
|
||
print(f"Failed to process response: {str(e)}")
|
||
if record:
|
||
usage_dict = {'completion_tokens': 0, 'total_tokens': 0, 'time': 0}
|
||
processed_responses.append(("[]", usage_dict))
|
||
else:
|
||
processed_responses.append("[]")
|
||
return processed_responses |