first commit
This commit is contained in:
1
AIEC-RAG/atlas_rag/llm_generator/__init__.py
Normal file
1
AIEC-RAG/atlas_rag/llm_generator/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .llm_generator import LLMGenerator
|
||||
Binary file not shown.
Binary file not shown.
0
AIEC-RAG/atlas_rag/llm_generator/format/__init__.py
Normal file
0
AIEC-RAG/atlas_rag/llm_generator/format/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
144
AIEC-RAG/atlas_rag/llm_generator/format/validate_json_output.py
Normal file
144
AIEC-RAG/atlas_rag/llm_generator/format/validate_json_output.py
Normal file
@ -0,0 +1,144 @@
|
||||
import json
|
||||
from typing import List, Any
|
||||
import json_repair
|
||||
import jsonschema
|
||||
|
||||
def normalize_key(key):
|
||||
return key.strip().lower()
|
||||
|
||||
# recover function can be fix_triple_extraction_response, fix_filter_triplets
|
||||
def validate_output(output_str, **kwargs):
|
||||
schema = kwargs.get("schema")
|
||||
fix_function = kwargs.get("fix_function", None)
|
||||
allow_empty = kwargs.get("allow_empty", True)
|
||||
if fix_function:
|
||||
parsed_data = fix_function(output_str, **kwargs)
|
||||
jsonschema.validate(instance=parsed_data, schema=schema)
|
||||
if not allow_empty and (not parsed_data or len(parsed_data) == 0):
|
||||
raise ValueError("Parsed data is empty after validation.")
|
||||
return json.dumps(parsed_data, ensure_ascii=False)
|
||||
|
||||
def fix_filter_triplets(data: str, **kwargs) -> dict:
|
||||
data = json_repair.loads(data)
|
||||
processed_facts = []
|
||||
def find_triplet(element: Any) -> List[str] | None:
|
||||
# Base case: a valid triplet
|
||||
if isinstance(element, list) and len(element) == 3 and all(isinstance(item, str) for item in element):
|
||||
return element
|
||||
# Recursive case: dig deeper into nested lists
|
||||
elif isinstance(element, list):
|
||||
for sub_element in element:
|
||||
result = find_triplet(sub_element)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
for item in data.get("fact", []):
|
||||
triplet = find_triplet(item)
|
||||
if triplet:
|
||||
processed_facts.append(triplet)
|
||||
|
||||
return {"fact": processed_facts}
|
||||
|
||||
def fix_triple_extraction_response(response: str, **kwargs) -> str:
|
||||
"""Attempt to fix and validate JSON response based on the prompt type."""
|
||||
# Extract the JSON list from the response
|
||||
# raise error if prompt_type is not provided
|
||||
if "prompt_type" not in kwargs:
|
||||
raise ValueError("The 'prompt_type' argument is required.")
|
||||
prompt_type = kwargs.get("prompt_type")
|
||||
|
||||
json_start_token = response.find("[")
|
||||
if json_start_token == -1:
|
||||
# add [ at the start
|
||||
response = "[" + response.strip() + "]"
|
||||
parsed_objects = json_repair.loads(response)
|
||||
if len(parsed_objects) == 0:
|
||||
return []
|
||||
# Define required keys for each prompt type
|
||||
required_keys = {
|
||||
"entity_relation": {"Head", "Relation", "Tail"},
|
||||
"event_entity": {"Event", "Entity"},
|
||||
"event_relation": {"Head", "Relation", "Tail"}
|
||||
}
|
||||
|
||||
corrected_data = []
|
||||
seen_triples = set()
|
||||
for idx, item in enumerate(parsed_objects):
|
||||
if not isinstance(item, dict):
|
||||
print(f"Item {idx} must be a JSON object. Problematic item: {item}")
|
||||
continue
|
||||
|
||||
# Correct the keys
|
||||
corrected_item = {}
|
||||
for key, value in item.items():
|
||||
norm_key = normalize_key(key)
|
||||
matching_expected_keys = [exp_key for exp_key in required_keys[prompt_type] if normalize_key(exp_key) in norm_key]
|
||||
if len(matching_expected_keys) == 1:
|
||||
corrected_key = matching_expected_keys[0]
|
||||
corrected_item[corrected_key] = value
|
||||
else:
|
||||
corrected_item[key] = value
|
||||
|
||||
# Check for missing keys in corrected_item
|
||||
missing = required_keys[prompt_type] - corrected_item.keys()
|
||||
if missing:
|
||||
print(f"Item {idx} missing required keys: {missing}. Problematic item: {item}")
|
||||
continue
|
||||
|
||||
# Validate and correct the values in corrected_item
|
||||
if prompt_type == "entity_relation":
|
||||
for key in ["Head", "Relation", "Tail"]:
|
||||
if not isinstance(corrected_item[key], str) or not corrected_item[key].strip():
|
||||
print(f"Item {idx} {key} must be a non-empty string. Problematic item: {corrected_item}")
|
||||
continue
|
||||
|
||||
elif prompt_type == "event_entity":
|
||||
if not isinstance(corrected_item["Event"], str) or not corrected_item["Event"].strip():
|
||||
print(f"Item {idx} Event must be a non-empty string. Problematic item: {corrected_item}")
|
||||
continue
|
||||
if not isinstance(corrected_item["Entity"], list) or not corrected_item["Entity"]:
|
||||
print(f"Item {idx} Entity must be a non-empty array. Problematic item: {corrected_item}")
|
||||
continue
|
||||
else:
|
||||
corrected_item["Entity"] = [ent.strip() for ent in corrected_item["Entity"] if isinstance(ent, str)]
|
||||
|
||||
elif prompt_type == "event_relation":
|
||||
for key in ["Head", "Tail", "Relation"]:
|
||||
if not isinstance(corrected_item[key], str) or not corrected_item[key].strip():
|
||||
print(f"Item {idx} {key} must be a non-empty sentence. Problematic item: {corrected_item}")
|
||||
continue
|
||||
|
||||
triple_tuple = tuple((k, str(v)) for k, v in corrected_item.items())
|
||||
if triple_tuple in seen_triples:
|
||||
print(f"Item {idx} is a duplicate triple: {corrected_item}")
|
||||
continue
|
||||
else:
|
||||
seen_triples.add(triple_tuple)
|
||||
corrected_data.append(corrected_item)
|
||||
|
||||
if not corrected_data:
|
||||
return []
|
||||
|
||||
return corrected_data
|
||||
|
||||
def fix_lkg_keywords(data: str, **kwargs) -> dict:
|
||||
"""
|
||||
Extract and flatten keywords into a list of strings, filtering invalid types.
|
||||
"""
|
||||
data = json_repair.loads(data)
|
||||
processed_keywords = []
|
||||
|
||||
def collect_strings(element: Any) -> None:
|
||||
if isinstance(element, str):
|
||||
if len(element) <= 200: # Filter out keywords longer than 100 characters
|
||||
processed_keywords.append(element)
|
||||
elif isinstance(element, list):
|
||||
for item in element:
|
||||
collect_strings(item)
|
||||
|
||||
# Start processing from the root "keywords" field
|
||||
collect_strings(data.get("keywords", []))
|
||||
|
||||
return {"keywords": processed_keywords}
|
||||
|
||||
@ -0,0 +1,93 @@
|
||||
filter_fact_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"fact": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string" # All items in the inner array must be strings
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"additionalItems": False # Block extra items
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["fact"]
|
||||
}
|
||||
|
||||
lkg_keyword_json_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 1,
|
||||
}
|
||||
},
|
||||
"required": ["keywords"]
|
||||
}
|
||||
|
||||
triple_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Head": {
|
||||
"type": "string"
|
||||
},
|
||||
"Relation": {
|
||||
"type": "string"
|
||||
},
|
||||
"Tail": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["Head", "Relation", "Tail"]
|
||||
},
|
||||
}
|
||||
event_relation_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Head": {
|
||||
"type": "string"
|
||||
},
|
||||
"Relation": {
|
||||
"type": "string",
|
||||
},
|
||||
"Tail": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["Head", "Relation", "Tail"]
|
||||
},
|
||||
}
|
||||
event_entity_json_schema = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Event": {
|
||||
"type": "string"
|
||||
},
|
||||
"Entity": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 1
|
||||
}
|
||||
},
|
||||
"required": ["Event", "Entity"]
|
||||
},
|
||||
}
|
||||
stage_to_schema = {
|
||||
1: triple_json_schema,
|
||||
2: event_entity_json_schema,
|
||||
3: event_relation_json_schema
|
||||
}
|
||||
364
AIEC-RAG/atlas_rag/llm_generator/llm_generator.py
Normal file
364
AIEC-RAG/atlas_rag/llm_generator/llm_generator.py
Normal file
@ -0,0 +1,364 @@
|
||||
import json
|
||||
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
|
||||
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
|
||||
from copy import deepcopy
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
|
||||
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, keyword_filtering_prompt
|
||||
from atlas_rag.llm_generator.prompt.rag_prompt import filter_triple_messages
|
||||
|
||||
from atlas_rag.llm_generator.format.validate_json_output import *
|
||||
from atlas_rag.llm_generator.format.validate_json_schema import filter_fact_json_schema, lkg_keyword_json_schema, stage_to_schema
|
||||
|
||||
from transformers.pipelines import Pipeline
|
||||
import jsonschema
|
||||
|
||||
|
||||
import time
|
||||
stage_to_prompt_type = {
|
||||
1: "entity_relation",
|
||||
2: "event_entity",
|
||||
3: "event_relation",
|
||||
}
|
||||
retry_decorator = retry(
|
||||
stop=(stop_after_delay(120) | stop_after_attempt(5)), # Max 2 minutes or 5 attempts
|
||||
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
|
||||
)
|
||||
class LLMGenerator():
|
||||
def __init__(self, client, model_name):
|
||||
self.model_name = model_name
|
||||
self.client : OpenAI|Pipeline = client
|
||||
if isinstance(client, OpenAI|AzureOpenAI):
|
||||
self.inference_type = "openai"
|
||||
elif isinstance(client, Pipeline):
|
||||
self.inference_type = "pipeline"
|
||||
else:
|
||||
raise ValueError("Unsupported client type. Please provide either an OpenAI client or a Huggingface Pipeline Object.")
|
||||
|
||||
@retry_decorator
|
||||
def _api_inference(self, message, max_new_tokens=8192,
|
||||
temperature = 0.7,
|
||||
frequency_penalty = None,
|
||||
response_format = {"type": "text"},
|
||||
return_text_only=True,
|
||||
return_thinking=False,
|
||||
reasoning_effort=None,
|
||||
**kwargs):
|
||||
start_time = time.time()
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=message,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
frequency_penalty= NOT_GIVEN if frequency_penalty is None else frequency_penalty,
|
||||
response_format = response_format if response_format is not None else {"type": "text"},
|
||||
timeout = 120,
|
||||
reasoning_effort= NOT_GIVEN if reasoning_effort is None else reasoning_effort,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
content = response.choices[0].message.content
|
||||
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
|
||||
content = response.choices[0].message.reasoning_content
|
||||
validate_function = kwargs.get('validate_function', None)
|
||||
content = validate_function(content, **kwargs) if validate_function else content
|
||||
|
||||
if '</think>' in content and not return_thinking:
|
||||
content = content.split('</think>')[-1].strip()
|
||||
else:
|
||||
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None and return_thinking:
|
||||
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
|
||||
|
||||
if return_text_only:
|
||||
return content
|
||||
else:
|
||||
completion_usage_dict = response.usage.model_dump()
|
||||
completion_usage_dict['time'] = time_cost
|
||||
return content, completion_usage_dict
|
||||
|
||||
def generate_response(self, batch_messages, do_sample=True, max_new_tokens=8192,
|
||||
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
|
||||
return_text_only=True, return_thinking=False, reasoning_effort=None, **kwargs):
|
||||
if temperature == 0.0:
|
||||
do_sample = False
|
||||
# single = list of dict, batch = list of list of dict
|
||||
is_batch = isinstance(batch_messages[0], list)
|
||||
if not is_batch:
|
||||
batch_messages = [batch_messages]
|
||||
results = [None] * len(batch_messages)
|
||||
to_process = list(range(len(batch_messages)))
|
||||
if self.inference_type == "openai":
|
||||
max_workers = kwargs.get('max_workers', 3) # Default to 4 workers if not specified
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
def process_message(i):
|
||||
try:
|
||||
return self._api_inference(
|
||||
batch_messages[i], max_new_tokens, temperature,
|
||||
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error processing message {i}: {e.last_attempt.result()}")
|
||||
return ""
|
||||
futures = [executor.submit(process_message, i) for i in to_process]
|
||||
for i, future in enumerate(futures):
|
||||
results[i] = future.result()
|
||||
|
||||
elif self.inference_type == "pipeline":
|
||||
max_retries = kwargs.get('max_retries', 3) # Default to 3 retries if not specified
|
||||
start_time = time.time()
|
||||
# Initial processing of all messages
|
||||
responses = self.client(
|
||||
batch_messages,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
return_full_text=False
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
|
||||
# Extract contents
|
||||
contents = [resp[0]['generated_text'].strip() for resp in responses]
|
||||
|
||||
# Validate and collect failed indices
|
||||
validate_function = kwargs.get('validate_function', None)
|
||||
failed_indices = []
|
||||
for i, content in enumerate(contents):
|
||||
if validate_function:
|
||||
try:
|
||||
contents[i] = validate_function(content, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"Validation failed for index {i}: {e}")
|
||||
failed_indices.append(i)
|
||||
|
||||
# Retry failed messages in batches
|
||||
for attempt in range(max_retries):
|
||||
if not failed_indices:
|
||||
break # No more failures to retry
|
||||
print(f"Retry attempt {attempt + 1}/{max_retries} for {len(failed_indices)} failed messages")
|
||||
# Prepare batch of failed messages
|
||||
failed_messages = [batch_messages[i] for i in failed_indices]
|
||||
try:
|
||||
# Process failed messages as a batch
|
||||
retry_responses = self.client(
|
||||
failed_messages,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
return_full_text=False
|
||||
)
|
||||
retry_contents = [resp[0]['generated_text'].strip() for resp in retry_responses]
|
||||
|
||||
# Validate retry results and update contents
|
||||
new_failed_indices = []
|
||||
for j, i in enumerate(failed_indices):
|
||||
try:
|
||||
if validate_function:
|
||||
retry_contents[j] = validate_function(retry_contents[j], **kwargs)
|
||||
contents[i] = retry_contents[j]
|
||||
except Exception as e:
|
||||
print(f"Validation failed for index {i} on retry {attempt + 1}: {e}")
|
||||
new_failed_indices.append(i)
|
||||
failed_indices = new_failed_indices # Update failed indices for next retry
|
||||
except Exception as e:
|
||||
print(f"Batch retry {attempt + 1} failed: {e}")
|
||||
# If batch processing fails, keep all indices in failed_indices
|
||||
if attempt == max_retries - 1:
|
||||
for i in failed_indices:
|
||||
contents[i] = "" # Set to "" if all retries fail
|
||||
|
||||
# Set remaining failed messages to "" after all retries
|
||||
for i in failed_indices:
|
||||
contents[i] = ""
|
||||
|
||||
# Process thinking tags
|
||||
if not return_thinking:
|
||||
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
|
||||
|
||||
if return_text_only:
|
||||
results = contents
|
||||
else:
|
||||
usage_dicts = [{
|
||||
'completion_tokens': len(content.split()),
|
||||
'time': time_cost / len(batch_messages)
|
||||
} for content in contents]
|
||||
results = list(zip(contents, usage_dicts))
|
||||
return results[0] if not is_batch else results
|
||||
|
||||
def generate_cot(self, question, max_new_tokens=1024):
|
||||
messages = [
|
||||
{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||||
{"role": "user", "content": question},
|
||||
]
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
def generate_with_context(self, question, context, max_new_tokens=1024, temperature = 0.7):
|
||||
messages = [
|
||||
{"role": "system", "content": "".join(cot_system_instruction)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}\nThought:"},
|
||||
]
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
|
||||
|
||||
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature = 0.7):
|
||||
messages = deepcopy(prompt_template)
|
||||
messages.append(
|
||||
{"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"},
|
||||
)
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
|
||||
|
||||
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature = 0.7):
|
||||
messages = [
|
||||
{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}"},
|
||||
]
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens, temperature = temperature)
|
||||
|
||||
@retry_decorator
|
||||
def filter_triples_with_entity_event(self,question, triples):
|
||||
messages = deepcopy(filter_triple_messages)
|
||||
messages.append(
|
||||
{"role": "user", "content": f"""[ ## question ## ]]
|
||||
{question}
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{triples}"""})
|
||||
try:
|
||||
validate_args = {
|
||||
"schema": filter_fact_json_schema,
|
||||
"fix_function": fix_filter_triplets,
|
||||
}
|
||||
response = self.generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"},
|
||||
validate_function=validate_output, **validate_args)
|
||||
return response
|
||||
except Exception as e:
|
||||
# If all retries fail, return the original triples
|
||||
return triples
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_filter_keywords_with_entity(self, question, keywords):
|
||||
messages = deepcopy(keyword_filtering_prompt)
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"""[[ ## question ## ]]
|
||||
{question}
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
{keywords}"""
|
||||
})
|
||||
|
||||
try:
|
||||
response = self.generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||||
|
||||
# Validate and clean the response
|
||||
cleaned_data = validate_output(response, lkg_keyword_json_schema, fix_lkg_keywords)
|
||||
|
||||
return cleaned_data['keywords']
|
||||
except Exception as e:
|
||||
return keywords
|
||||
|
||||
def ner(self, text):
|
||||
messages = [
|
||||
{"role": "system", "content": "Please extract the entities from the following question and output them separated by comma, in the following format: entity1, entity2, ..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"},
|
||||
]
|
||||
return self.generate_response(messages)
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_ner(self, text):
|
||||
messages = deepcopy(ner_prompt)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"[[ ## question ## ]]\n{text}"
|
||||
}
|
||||
)
|
||||
validation_args = {
|
||||
"schema": lkg_keyword_json_schema,
|
||||
"fix_function": fix_lkg_keywords
|
||||
}
|
||||
# Generate raw response from LLM
|
||||
raw_response = self.generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"}, validate_output=validate_output, **validation_args)
|
||||
|
||||
try:
|
||||
# Validate and clean the response
|
||||
cleaned_data = json_repair.loads(raw_response)
|
||||
return cleaned_data['keywords']
|
||||
|
||||
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
|
||||
return [] # Fallback to empty list or raise custom exception
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_tog_ner(self, text):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an advanced AI assistant that extracts named entities from given text. "},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"}
|
||||
]
|
||||
|
||||
# Generate raw response from LLM
|
||||
validation_args = {
|
||||
"schema": lkg_keyword_json_schema,
|
||||
"fix_function": fix_lkg_keywords
|
||||
}
|
||||
raw_response = self.generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"}, validate_output=validate_output, **validation_args)
|
||||
|
||||
try:
|
||||
# Validate and clean the response
|
||||
cleaned_data = json_repair.loads(raw_response)
|
||||
return cleaned_data['keywords']
|
||||
|
||||
except (json.JSONDecodeError, jsonschema.ValidationError) as e:
|
||||
return [] # Fallback to empty list or raise custom exception
|
||||
|
||||
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
|
||||
react_system_instruction = (
|
||||
'You are an advanced AI assistant that uses the ReAct framework to solve problems through iterative search. '
|
||||
'Follow these steps in your response:\n'
|
||||
'1. Thought: Think step by step and analyze if the current context is sufficient to answer the question. If not, review the current context and think critically about what can be searched to help answer the question.\n'
|
||||
' - Break down the question into *1-hop* sub-questions if necessary (e.g., identify key entities like people or places before addressing specific events).\n'
|
||||
' - Use the available context to make inferences about key entities and their relationships.\n'
|
||||
' - If a previous search query (prefix with "Previous search attempt") was not useful, reflect on why and adjust your strategy—avoid repeating similar queries and consider searching for general information about key entities or related concepts.\n'
|
||||
'2. Action: Choose one of:\n'
|
||||
' - Search for [Query]: If you need more information, specify a new query. The [Query] must differ from previous searches in wording and direction to explore new angles.\n'
|
||||
' - No Action: If the current context is sufficient.\n'
|
||||
'3. Answer: Provide one of:\n'
|
||||
' - A concise, definitive response as a noun phrase if you can answer.\n'
|
||||
' - "Need more information" if you need to search.\n\n'
|
||||
'Format your response exactly as:\n'
|
||||
'Thought: [your reasoning]\n'
|
||||
'Action: [Search for [Query] or No Action]\n'
|
||||
'Answer: [concise noun phrase if you can answer, or "Need more information" if you need to search]\n\n'
|
||||
)
|
||||
|
||||
# Build context with search history if available
|
||||
full_context = []
|
||||
if search_history:
|
||||
for i, (thought, action, observation) in enumerate(search_history):
|
||||
search_history_text = f"\nPrevious search attempt {i}:\n"
|
||||
search_history_text += f"{action}\n Result: {observation}\n"
|
||||
full_context.append(search_history_text)
|
||||
if context:
|
||||
full_context_text = f"Current Retrieved Context:\n{context}\n"
|
||||
full_context.append(full_context_text)
|
||||
if logger:
|
||||
logger.info(f"Full context for ReAct generation: {full_context}")
|
||||
|
||||
# Combine few-shot examples with system instruction and user query
|
||||
messages = [
|
||||
{"role": "system", "content": react_system_instruction},
|
||||
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
|
||||
if full_context else f"Question: {question}"}
|
||||
]
|
||||
if logger:
|
||||
logger.info(f"Messages for ReAct generation: {search_history}Question: {question}")
|
||||
return self.generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
|
||||
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False, allow_empty=False):
|
||||
if isinstance(messages[0], dict):
|
||||
messages = [messages]
|
||||
validate_kwargs = {
|
||||
'schema': stage_to_schema.get(stage, None),
|
||||
'fix_function': fix_triple_extraction_response,
|
||||
'prompt_type': stage_to_prompt_type.get(stage, None),
|
||||
'allow_empty': allow_empty
|
||||
}
|
||||
result = self.generate_response(messages, max_new_tokens=max_tokens, validate_function=validate_output, return_text_only = not record, **validate_kwargs)
|
||||
return result
|
||||
381
AIEC-RAG/atlas_rag/llm_generator/llm_generator_legacy.py
Normal file
381
AIEC-RAG/atlas_rag/llm_generator/llm_generator_legacy.py
Normal file
@ -0,0 +1,381 @@
|
||||
import json
|
||||
from openai import OpenAI, AzureOpenAI, NOT_GIVEN
|
||||
from tenacity import retry, stop_after_attempt, stop_after_delay, wait_fixed, wait_exponential, wait_random
|
||||
from copy import deepcopy
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import time
|
||||
|
||||
from atlas_rag.llm_generator.prompt.rag_prompt import cot_system_instruction, cot_system_instruction_kg, cot_system_instruction_no_doc, prompt_template
|
||||
from atlas_rag.llm_generator.format.validate_json_output import validate_filter_output, messages as filter_messages
|
||||
from atlas_rag.llm_generator.prompt.lkg_prompt import ner_prompt, validate_keyword_output, keyword_filtering_prompt
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
from atlas_rag.llm_generator.format.validate_json_output import fix_and_validate_response
|
||||
|
||||
from transformers.pipelines import Pipeline
|
||||
import jsonschema
|
||||
from typing import Union
|
||||
from logging import Logger
|
||||
|
||||
stage_to_prompt_type = {
|
||||
1: "entity_relation",
|
||||
2: "event_entity",
|
||||
3: "event_relation",
|
||||
}
|
||||
retry_decorator = retry(
|
||||
stop=(stop_after_delay(120) | stop_after_attempt(5)),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=30) + wait_random(min=0, max=2),
|
||||
)
|
||||
|
||||
class LLMGenerator:
|
||||
def __init__(self, client, model_name):
|
||||
self.model_name = model_name
|
||||
self.client: OpenAI | Pipeline = client
|
||||
if isinstance(client, (OpenAI, AzureOpenAI)):
|
||||
self.inference_type = "openai"
|
||||
elif isinstance(client, Pipeline):
|
||||
self.inference_type = "pipeline"
|
||||
else:
|
||||
raise ValueError("Unsupported client type6Please provide either an OpenAI client or a Huggingface Pipeline Object.")
|
||||
|
||||
@retry_decorator
|
||||
def _generate_response(self, messages, do_sample=True, max_new_tokens=8192, temperature=0.7,
|
||||
frequency_penalty=None, response_format={"type": "text"}, return_text_only=True,
|
||||
return_thinking=False, reasoning_effort=None):
|
||||
if temperature == 0.0:
|
||||
do_sample = False
|
||||
if self.inference_type == "openai":
|
||||
start_time = time.time()
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
frequency_penalty=NOT_GIVEN if frequency_penalty is None else frequency_penalty,
|
||||
response_format=response_format if response_format is not None else {"type": "text"},
|
||||
timeout=120,
|
||||
reasoning_effort=NOT_GIVEN if reasoning_effort is None else reasoning_effort,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
content = response.choices[0].message.content
|
||||
if content is None and hasattr(response.choices[0].message, 'reasoning_content'):
|
||||
content = response.choices[0].message.reasoning_content
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if '</think>' in content and not return_thinking:
|
||||
content = content.split('</think>')[-1].strip()
|
||||
else:
|
||||
if hasattr(response.choices[0].message, 'reasoning_content') and response.choices[0].message.reasoning_content is not None:
|
||||
content = '<think>' + response.choices[0].message.reasoning_content + '</think>' + content
|
||||
if return_text_only:
|
||||
return content
|
||||
else:
|
||||
completion_usage_dict = response.usage.model_dump()
|
||||
completion_usage_dict['time'] = time_cost
|
||||
return content, completion_usage_dict
|
||||
elif self.inference_type == "pipeline":
|
||||
start_time = time.time()
|
||||
if hasattr(self.client, 'tokenizer'):
|
||||
input_text = self.client.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
else:
|
||||
input_text = "\n".join([msg["content"] for msg in messages])
|
||||
response = self.client(
|
||||
input_text,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
content = response[0]['generated_text'].strip()
|
||||
if '</think>' in content and not return_thinking:
|
||||
content = content.split('</think>')[-1].strip()
|
||||
if return_text_only:
|
||||
return content
|
||||
else:
|
||||
token_count = len(content.split())
|
||||
completion_usage_dict = {
|
||||
'completion_tokens': token_count,
|
||||
'time': time_cost
|
||||
}
|
||||
return content, completion_usage_dict
|
||||
|
||||
def _generate_batch_responses(self, batch_messages, do_sample=True, max_new_tokens=8192,
|
||||
temperature=0.7, frequency_penalty=None, response_format={"type": "text"},
|
||||
return_text_only=True, return_thinking=False, reasoning_effort=None):
|
||||
if self.inference_type == "openai":
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self._generate_response, messages, do_sample, max_new_tokens, temperature,
|
||||
frequency_penalty, response_format, return_text_only, return_thinking, reasoning_effort
|
||||
) for messages in batch_messages
|
||||
]
|
||||
results = [future.result() for future in futures]
|
||||
return results
|
||||
elif self.inference_type == "pipeline":
|
||||
if not hasattr(self.client, 'tokenizer'):
|
||||
raise ValueError("Pipeline must have a tokenizer for batch processing.")
|
||||
batch_inputs = [self.client.tokenizer.apply_chat_template(messages, tokenize=False) for messages in batch_messages]
|
||||
start_time = time.time()
|
||||
responses = self.client(
|
||||
batch_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
do_sample=do_sample,
|
||||
)
|
||||
time_cost = time.time() - start_time
|
||||
contents = [resp['generated_text'].strip() for resp in responses]
|
||||
if not return_thinking:
|
||||
contents = [content.split('</think>')[-1].strip() if '</think>' in content else content for content in contents]
|
||||
if return_text_only:
|
||||
return contents
|
||||
else:
|
||||
usage_dicts = [{
|
||||
'completion_tokens': len(content.split()),
|
||||
'time': time_cost / len(batch_messages)
|
||||
} for content in contents]
|
||||
return list(zip(contents, usage_dicts))
|
||||
|
||||
def generate_cot(self, questions, max_new_tokens=1024):
|
||||
if isinstance(questions, str):
|
||||
messages = [{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||||
{"role": "user", "content": questions}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
elif isinstance(questions, list):
|
||||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_no_doc)},
|
||||
{"role": "user", "content": q}] for q in questions]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
def generate_with_context(self, question, context, max_new_tokens=1024, temperature=0.7):
|
||||
if isinstance(question, str):
|
||||
messages = [{"role": "system", "content": "".join(cot_system_instruction)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}\nThought:"}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction)},
|
||||
{"role": "user", "content": f"{context}\n\n{q}\nThought:"}] for q in question]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
def generate_with_context_one_shot(self, question, context, max_new_tokens=4096, temperature=0.7):
|
||||
if isinstance(question, str):
|
||||
messages = deepcopy(prompt_template)
|
||||
messages.append({"role": "user", "content": f"{context}\n\nQuestions:{question}\nThought:"})
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [deepcopy(prompt_template) + [{"role": "user", "content": f"{context}\n\nQuestions:{q}\nThought:"}]
|
||||
for q in question]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
def generate_with_context_kg(self, question, context, max_new_tokens=1024, temperature=0.7):
|
||||
if isinstance(question, str):
|
||||
messages = [{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||||
{"role": "user", "content": f"{context}\n\n{question}"}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [[{"role": "system", "content": "".join(cot_system_instruction_kg)},
|
||||
{"role": "user", "content": f"{context}\n\n{q}"}] for q in question]
|
||||
return self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens, temperature=temperature)
|
||||
|
||||
@retry_decorator
|
||||
def filter_triples_with_entity(self, question, nodes, max_new_tokens=1024):
|
||||
if isinstance(question, str):
|
||||
messages = [{"role": "system", "content": """
|
||||
Your task is to filter text candidates based on their relevance to a given query...
|
||||
"""}, {"role": "user", "content": f"{question} \n Output Before Filter: {nodes} \n Output After Filter:"}]
|
||||
try:
|
||||
response = json.loads(self._generate_response(messages, max_new_tokens=max_new_tokens))
|
||||
return response
|
||||
except Exception:
|
||||
return json.loads(nodes)
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [[{"role": "system", "content": """
|
||||
Your task is to filter text candidates based on their relevance to a given query...
|
||||
"""}, {"role": "user", "content": f"{q} \n Output Before Filter: {nodes} \n Output After Filter:"}]
|
||||
for q in question]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=max_new_tokens)
|
||||
return [json.loads(resp) if json.loads(resp) else json.loads(nodes) for resp in responses]
|
||||
|
||||
@retry_decorator
|
||||
def filter_triples_with_entity_event(self, question, triples):
|
||||
if isinstance(question, str):
|
||||
messages = deepcopy(filter_messages)
|
||||
messages.append({"role": "user", "content": f"[ ## question ## ]]\n{question}\n[[ ## fact_before_filter ## ]]\n{triples}"})
|
||||
try:
|
||||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
|
||||
cleaned_data = validate_filter_output(response)
|
||||
return cleaned_data['fact']
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [deepcopy(filter_messages) + [{"role": "user", "content": f"[ ## question ## ]]\n{q}\n[[ ## fact_before_filter ## ]]\n{triples}"}]
|
||||
for q in question]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.0, response_format={"type": "json_object"})
|
||||
return [validate_filter_output(resp)['fact'] if validate_filter_output(resp) else [] for resp in responses]
|
||||
|
||||
def generate_with_custom_messages(self, custom_messages, do_sample=True, max_new_tokens=1024, temperature=0.8, frequency_penalty=None):
|
||||
if isinstance(custom_messages[0], dict):
|
||||
return self._generate_response(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
|
||||
elif isinstance(custom_messages[0], list):
|
||||
return self._generate_batch_responses(custom_messages, do_sample, max_new_tokens, temperature, frequency_penalty)
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_filter_keywords_with_entity(self, question, keywords):
|
||||
if isinstance(question, str):
|
||||
messages = deepcopy(keyword_filtering_prompt)
|
||||
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{question}\n[[ ## keywords_before_filter ## ]]\n{keywords}"})
|
||||
try:
|
||||
response = self._generate_response(messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||||
cleaned_data = validate_keyword_output(response)
|
||||
return cleaned_data['keywords']
|
||||
except Exception:
|
||||
return keywords
|
||||
elif isinstance(question, list):
|
||||
batch_messages = [deepcopy(keyword_filtering_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{q}\n[[ ## keywords_before_filter ## ]]\n{k}"}]
|
||||
for q, k in zip(question, keywords)]
|
||||
responses = self._generate_batch_responses(batch_messages, response_format={"type": "json_object"}, temperature=0.0, max_new_tokens=2048)
|
||||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else keywords for resp in responses]
|
||||
|
||||
def ner(self, text):
|
||||
if isinstance(text, str):
|
||||
messages = [{"role": "system", "content": "Please extract the entities..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"}]
|
||||
return self._generate_response(messages)
|
||||
elif isinstance(text, list):
|
||||
batch_messages = [[{"role": "system", "content": "Please extract the entities..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
|
||||
return self._generate_batch_responses(batch_messages)
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_ner(self, text):
|
||||
if isinstance(text, str):
|
||||
messages = deepcopy(ner_prompt)
|
||||
messages.append({"role": "user", "content": f"[[ ## question ## ]]\n{text}"})
|
||||
try:
|
||||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
cleaned_data = validate_keyword_output(response)
|
||||
return cleaned_data['keywords']
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(text, list):
|
||||
batch_messages = [deepcopy(ner_prompt) + [{"role": "user", "content": f"[[ ## question ## ]]\n{t}"}] for t in text]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
|
||||
|
||||
@retry(stop=(stop_after_delay(60) | stop_after_attempt(6)), wait=wait_fixed(2))
|
||||
def large_kg_tog_ner(self, text):
|
||||
if isinstance(text, str):
|
||||
messages = [{"role": "system", "content": "You are an advanced AI assistant..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {text}"}]
|
||||
try:
|
||||
response = self._generate_response(messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
cleaned_data = validate_keyword_output(response)
|
||||
return cleaned_data['keywords']
|
||||
except Exception:
|
||||
return []
|
||||
elif isinstance(text, list):
|
||||
batch_messages = [[{"role": "system", "content": "You are an advanced AI assistant..."},
|
||||
{"role": "user", "content": f"Extract the named entities from: {t}"}] for t in text]
|
||||
responses = self._generate_batch_responses(batch_messages, max_new_tokens=4096, temperature=0.7, frequency_penalty=1.1, response_format={"type": "json_object"})
|
||||
return [validate_keyword_output(resp)['keywords'] if validate_keyword_output(resp) else [] for resp in responses]
|
||||
|
||||
def generate_with_react(self, question, context=None, max_new_tokens=1024, search_history=None, logger=None):
|
||||
# Implementation remains single-input focused as it’s iterative; batching not applicable here
|
||||
react_system_instruction = (
|
||||
'You are an advanced AI assistant that uses the ReAct framework...'
|
||||
)
|
||||
full_context = []
|
||||
if search_history:
|
||||
for i, (thought, action, observation) in enumerate(search_history):
|
||||
full_context.append(f"\nPrevious search attempt {i}:\n{action}\n Result: {observation}\n")
|
||||
if context:
|
||||
full_context.append(f"Current Retrieved Context:\n{context}\n")
|
||||
messages = [{"role": "system", "content": react_system_instruction},
|
||||
{"role": "user", "content": f"Search History:\n\n{''.join(full_context)}\n\nQuestion: {question}"
|
||||
if full_context else f"Question: {question}"}]
|
||||
return self._generate_response(messages, max_new_tokens=max_new_tokens)
|
||||
|
||||
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'],
|
||||
max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
|
||||
# Single-input iterative process; batching not applicable
|
||||
search_history = []
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = ". ".join(initial_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = "\n".join(initial_context)
|
||||
for iteration in range(max_iterations):
|
||||
analysis_response = self.generate_with_react(
|
||||
question=question, context=current_context, max_new_tokens=max_new_tokens, search_history=search_history, logger=logger
|
||||
)
|
||||
try:
|
||||
thought = analysis_response.split("Thought:")[1].split("\n")[0]
|
||||
action = analysis_response.split("Action:")[1].split("\n")[0]
|
||||
answer = analysis_response.split("Answer:")[1].strip()
|
||||
if answer.lower() != "need more information":
|
||||
search_history.append((thought, action, "Using current context"))
|
||||
return answer, search_history
|
||||
if "search" in action.lower():
|
||||
search_query = action.split("search for")[-1].strip()
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
current_contexts = current_context.split(". ")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = ". ".join(new_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
current_contexts = current_context.split("\n")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = "\n".join(new_context)
|
||||
observation = f"Found information: {new_context}" if new_context else "No new information found..."
|
||||
search_history.append((thought, action, observation))
|
||||
if new_context:
|
||||
current_context = f"{current_context}\n{new_context}"
|
||||
else:
|
||||
search_history.append((thought, action, "No action taken but answer not found"))
|
||||
return "Unable to find answer", search_history
|
||||
except Exception as e:
|
||||
return analysis_response, search_history
|
||||
return answer, search_history
|
||||
|
||||
def triple_extraction(self, messages, max_tokens=4096, stage=None, record=False):
|
||||
if isinstance(messages[0], dict):
|
||||
messages = [messages]
|
||||
responses = self._generate_batch_responses(
|
||||
batch_messages=messages,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
do_sample=False,
|
||||
frequency_penalty=0.5,
|
||||
reasoning_effort="none",
|
||||
return_text_only=not record
|
||||
)
|
||||
processed_responses = []
|
||||
for response in responses:
|
||||
if record:
|
||||
content, usage_dict = response
|
||||
else:
|
||||
content = response
|
||||
usage_dict = None
|
||||
try:
|
||||
prompt_type = stage_to_prompt_type.get(stage, None)
|
||||
if prompt_type:
|
||||
corrected, error = fix_and_validate_response(content, prompt_type)
|
||||
if error:
|
||||
raise ValueError(f"Validation failed for prompt_type '{prompt_type}'")
|
||||
else:
|
||||
corrected = content
|
||||
if corrected and corrected.strip():
|
||||
if record:
|
||||
processed_responses.append((corrected, usage_dict))
|
||||
else:
|
||||
processed_responses.append(corrected)
|
||||
else:
|
||||
raise ValueError("Invalid response")
|
||||
except Exception as e:
|
||||
print(f"Failed to process response: {str(e)}")
|
||||
if record:
|
||||
usage_dict = {'completion_tokens': 0, 'total_tokens': 0, 'time': 0}
|
||||
processed_responses.append(("[]", usage_dict))
|
||||
else:
|
||||
processed_responses.append("[]")
|
||||
return processed_responses
|
||||
0
AIEC-RAG/atlas_rag/llm_generator/prompt/__init__.py
Normal file
0
AIEC-RAG/atlas_rag/llm_generator/prompt/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
263
AIEC-RAG/atlas_rag/llm_generator/prompt/lkg_prompt.py
Normal file
263
AIEC-RAG/atlas_rag/llm_generator/prompt/lkg_prompt.py
Normal file
@ -0,0 +1,263 @@
|
||||
import json
|
||||
import jsonschema
|
||||
from typing import Any
|
||||
|
||||
ner_prompt =[
|
||||
{"role": "system",
|
||||
"content": """
|
||||
You are a domain analysis engine. You must provide keywords for searching relevant documents.
|
||||
When given any academic question, follow these steps:
|
||||
1. **Identify Tested Skills:** Determine the *abstract knowledge/skills* required to solve the problem (e.g., "translating universal statements into predicate logic"), not the concrete entities in the question (e.g., "children," "school").
|
||||
2. **Extract Domain Specific Term:** Extract domain-specific technical terms (e.g., "school" in *educational policy*), exclude common nouns/verbs describing the question's *scenario* (e.g., "child," "school," "goes to").
|
||||
3. **Prioritize Formal Structures:** For logic/math problems, focus on notation rules (e.g., quantifier order, implication vs. conjunction), not scenario labels.
|
||||
4. **Capture Rare Technical Terms:** Include uncommon domain-specific terms critical to the question (e.g., "epigenetics" in biology, "monad" in computer science), even if they appear infrequently in general language.
|
||||
Your input fields are:
|
||||
1. question (str): Query for keyword extraction
|
||||
Your output fields are:
|
||||
1. keywords (array): Extracted keywords in JSON format
|
||||
All interactions will be structured as:
|
||||
[[ ## question ## ]]
|
||||
{question}
|
||||
[[ ## keywords ## ]]
|
||||
{keywords}
|
||||
The output must be parseable according to JSON schema:
|
||||
{"type": "object", "properties": {"keywords": {"type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}
|
||||
"""},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nSolve \(x^2 - 5x + 6 = 0\)."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"quadratic equation\", \"factoring\", \"roots\", \"algebraic manipulation\"]}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nExplain the socio-economic causes of the French Revolution."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"historical causation\", \"class struggle\", \"economic inequality\", \"Enlightenment philosophy\"]}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nProve that the square root of 2 is irrational."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"proof by contradiction\", \"irrational numbers\", \"number theory\", \"rational/irrational distinction\"]}"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[[ ## question ## ]]\nExplain the implications of Heisenberg's uncertainty principle on quantum measurements."
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\"keywords\": [\"Heisenberg uncertainty principle\", \"quantum mechanics\", \"measurement theory\", \"quantum state collapse\", \"observable quantities\"]}"
|
||||
}
|
||||
]
|
||||
|
||||
keyword_filtering_prompt = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a precision-focused component of a knowledge retrieval system used by researchers and educators.
|
||||
Your task is to filter keywords based on their relevance to the **core domain knowledge** required to answer a given question.
|
||||
You must critically evaluate whether each item represents:
|
||||
1. Academic concepts/methodologies (entities)
|
||||
2. Critical processes or relationships (events)
|
||||
The output should be in JSON format, e.g., {"keywords": ["concept1", "concept2"]}. If no keywords are relevant, return {"keywords": []}.
|
||||
Accuracy is crucial, as these keywords drive document retrieval for critical research. Do not generate any new keywords or explanations.
|
||||
**Do not change the content of each object in the list. You must only use text from the candidate list and cannot generate new text.**
|
||||
**Include all characters of the selected keywords.**
|
||||
**Do not include any duplicate keywords.**
|
||||
**The keywords can be a sentence describing a event as long as it is helpful for searching.**
|
||||
---
|
||||
|
||||
**Input Fields:**
|
||||
1. question (str): Query requiring knowledge analysis
|
||||
2. keywords_before_filter (list): Candidate keywords to evaluate
|
||||
|
||||
**Output Field:**
|
||||
1. keywords_after_filter (list): Filtered keywords in JSON format
|
||||
|
||||
---
|
||||
|
||||
**Interaction Structure:**
|
||||
[[ ## question ## ]]
|
||||
{question}
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
{keywords_before_filter}
|
||||
[[ ## keywords_after_filter ## ]]
|
||||
{keywords_after_filter}
|
||||
|
||||
**JSON Schema Validation:**
|
||||
{"type": "object", "properties": {"keywords": {"type": "array", "items": {"type": "string"}}}, "required": ["keywords"]}
|
||||
"""},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Explain the causes of the French Revolution.
|
||||
[[ ## candidates ## ]]
|
||||
["French Revolution", "social inequality", "Enlightenment ideas spreading", "economic crisis", "Bastille storming event"]"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"keywords": ["social inequality", "Enlightenment ideas spreading", "economic crisis"]}"""
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Translate 'All children go to some school' into predicate logic.
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
["predicate logic", "children", "school", "universal quantifier (∀)", "attendance"]"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"keywords": ["predicate logic", "universal quantifier (∀)"]}"""
|
||||
},
|
||||
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Solve \(x^2 - 5x + 6 = 0\).
|
||||
[[ ## keywords_before_filter ## ]]
|
||||
["quadratic equation", "polynomial", "algebra", "roots", "classroom teaching"]"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"keywords": ["quadratic equation", "polynomial", "roots"]}"""
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
test_messages = [
|
||||
[
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: What is the phenotype of a congenital disorder impairing the secretion of leptin?
|
||||
A. Normal energy intake, normal body weight and hyperthyroidism
|
||||
B. Obesity, excess energy intake, normal growth and hypoinsulinaemia
|
||||
C. Obesity, abnormal growth, hypothyroidism, hyperinsulinaemia
|
||||
D. Underweight, abnormal growth, hypothyroidism, hyperinsulinaemia"""
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and five candidate answers (A, B, C and D), choose the answer.
|
||||
Question: Which of the following is notan advantage of stratified random sampling over simple random sampling?
|
||||
A. When done correctly, a stratified random sample is less biased than a simple random sample.
|
||||
B. When done correctly, a stratified random sampling process has less variability from sample to sample than a simple random sample.
|
||||
C. When done correctly, a stratified random sample can provide, with a smaller sample size, an estimate that is just as reliable as that of a simple random sample with a larger sample size.
|
||||
D. A stratified random sample provides information about each stratum in the population as well as an estimate for the population as a whole, and a simple random sample does not."""
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
|
||||
A. 0
|
||||
B. 4
|
||||
C. 2
|
||||
D. 6"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: A lesion causing compression of the facial nerve at the stylomastoid foramen will cause ipsilateral
|
||||
A. paralysis of the facial muscles.
|
||||
B. paralysis of the facial muscles and loss of taste.
|
||||
C. paralysis of the facial muscles, loss of taste and lacrimation.
|
||||
D. paralysis of the facial muscles, loss of taste, lacrimation and decreased salivation."""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: What is true for a type-Ia ("type one-a") supernova?
|
||||
A. This type occurs in binary systems.
|
||||
B. This type occurs in young galaxies.
|
||||
C. This type produces gamma-ray bursts.
|
||||
D. This type produces high amounts of X-rays."""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: _______ such as bitcoin are becoming increasingly mainstream and have a whole host of associated ethical implications, for example, they are______ and more ______. However, they have also been used to engage in _______.
|
||||
A. Cryptocurrencies, Expensive, Secure, Financial Crime
|
||||
B. Traditional currency, Cheap, Unsecure, Charitable giving
|
||||
C. Cryptocurrencies, Cheap, Secure, Financial crime
|
||||
D. Traditional currency, Expensive, Unsecure, Charitable giving"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: The access matrix approach to protection has the difficulty that
|
||||
A. the matrix, if stored directly, is large and can be clumsy to manage
|
||||
B. it is not capable of expressing complex protection requirements
|
||||
C. deciding whether a process has access to a resource is undecidable
|
||||
D. there is no way to express who has rights to change the access matrix itself"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: In the NoNicks operating system, the time required by a single file-read operation has four nonoverlapping components: disk seek time-25 msec disk latency time-8 msec disk transfer time- 1 msec per 1,000 bytes operating system overhead-1 msec per 1,000 bytes + 10 msec In version 1 of the system, the file read retrieved blocks of 1,000 bytes. In version 2, the file read (along with the underlying layout on disk) was modified to retrieve blocks of 4,000 bytes. The ratio of-the time required to read a large file under version 2 to the time required to read the same large file under version 1 is approximately
|
||||
A. 1:4
|
||||
B. 1:3.5
|
||||
C. 1:1
|
||||
D. 1.1:1"""
|
||||
}
|
||||
], [
|
||||
{
|
||||
"role":"system",
|
||||
"content":"You are a helpful asssistant. You will answer the question based on the context provided."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content":"""Query: Given the following question and four candidate answers (A, B, C and D), choose the answer.
|
||||
Question: Which of the following propositions is an immediate (one-step) consequence in PL of the given premises? ~E ⊃ ~F G ⊃ F H ∨ ~E H ⊃ I ~I
|
||||
A. E ⊃ F
|
||||
B. F ⊃ G
|
||||
C. H ⊃ ~E
|
||||
D. ~H"""
|
||||
}
|
||||
],
|
||||
]
|
||||
115
AIEC-RAG/atlas_rag/llm_generator/prompt/rag_prompt.py
Normal file
115
AIEC-RAG/atlas_rag/llm_generator/prompt/rag_prompt.py
Normal file
@ -0,0 +1,115 @@
|
||||
one_shot_rag_qa_docs = (
|
||||
"""Wikipedia Title: The Last Horse\nThe Last Horse (Spanish:El último caballo) is a 1950 Spanish comedy film directed by Edgar Neville starring Fernando Fernán Gómez.\n"""
|
||||
"""Wikipedia Title: Southampton\nThe University of Southampton, which was founded in 1862 and received its Royal Charter as a university in 1952, has over 22,000 students. The university is ranked in the top 100 research universities in the world in the Academic Ranking of World Universities 2010. In 2010, the THES - QS World University Rankings positioned the University of Southampton in the top 80 universities in the world. The university considers itself one of the top 5 research universities in the UK. The university has a global reputation for research into engineering sciences, oceanography, chemistry, cancer sciences, sound and vibration research, computer science and electronics, optoelectronics and textile conservation at the Textile Conservation Centre (which is due to close in October 2009.) It is also home to the National Oceanography Centre, Southampton (NOCS), the focus of Natural Environment Research Council-funded marine research.\n"""
|
||||
"""Wikipedia Title: Stanton Township, Champaign County, Illinois\nStanton Township is a township in Champaign County, Illinois, USA. As of the 2010 census, its population was 505 and it contained 202 housing units.\n"""
|
||||
"""Wikipedia Title: Neville A. Stanton\nNeville A. Stanton is a British Professor of Human Factors and Ergonomics at the University of Southampton. Prof Stanton is a Chartered Engineer (C.Eng), Chartered Psychologist (C.Psychol) and Chartered Ergonomist (C.ErgHF). He has written and edited over a forty books and over three hundered peer-reviewed journal papers on applications of the subject. Stanton is a Fellow of the British Psychological Society, a Fellow of The Institute of Ergonomics and Human Factors and a member of the Institution of Engineering and Technology. He has been published in academic journals including "Nature". He has also helped organisations design new human-machine interfaces, such as the Adaptive Cruise Control system for Jaguar Cars.\n"""
|
||||
"""Wikipedia Title: Finding Nemo\nFinding Nemo Theatrical release poster Directed by Andrew Stanton Produced by Graham Walters Screenplay by Andrew Stanton Bob Peterson David Reynolds Story by Andrew Stanton Starring Albert Brooks Ellen DeGeneres Alexander Gould Willem Dafoe Music by Thomas Newman Cinematography Sharon Calahan Jeremy Lasky Edited by David Ian Salter Production company Walt Disney Pictures Pixar Animation Studios Distributed by Buena Vista Pictures Distribution Release date May 30, 2003 (2003 - 05 - 30) Running time 100 minutes Country United States Language English Budget $$94 million Box office $$940.3 million"""
|
||||
)
|
||||
|
||||
one_shot_ircot_demo = (
|
||||
f'{one_shot_rag_qa_docs}'
|
||||
'\n\nQuestion: '
|
||||
f"When was Neville A. Stanton's employer founded?"
|
||||
'\nThought: '
|
||||
f"The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. So the answer is: 1862."
|
||||
'\n\n'
|
||||
)
|
||||
|
||||
rag_qa_system = (
|
||||
'As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. '
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response, devoid of additional elaborations.'
|
||||
)
|
||||
|
||||
one_shot_rag_qa_input = (
|
||||
f"{one_shot_rag_qa_docs}"
|
||||
"\n\nQuestion: "
|
||||
"When was Neville A. Stanton's employer founded?"
|
||||
'\nThought: '
|
||||
)
|
||||
|
||||
one_shot_rag_qa_output = (
|
||||
"The employer of Neville A. Stanton is University of Southampton. The University of Southampton was founded in 1862. "
|
||||
"\nAnswer: 1862."
|
||||
)
|
||||
|
||||
|
||||
prompt_template = [
|
||||
{"role": "system", "content": rag_qa_system},
|
||||
{"role": "user", "content": one_shot_rag_qa_input},
|
||||
{"role": "assistant", "content": one_shot_rag_qa_output},
|
||||
]
|
||||
|
||||
# from https://github.com/OSU-NLP-Group/HippoRAG/blob/main/src/qa/qa_reader.py
|
||||
|
||||
cot_system_instruction = ('As an advanced reading comprehension assistant, your task is to analyze text passages and corresponding questions meticulously. If the information is not enough, you can use your own knowledge to answer the question.'
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
|
||||
|
||||
cot_system_instruction_no_doc = ('As an advanced reading comprehension assistant, your task is to analyze the questions and then answer them. '
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
|
||||
|
||||
cot_system_instruction_kg = ('As an advanced reading comprehension assistant, your task is to analyze extracted information and corresponding questions meticulously. If the knowledge graph information is not enough, you can use your own knowledge to answer the question. '
|
||||
'Your response start after "Thought: ", where you will methodically break down the reasoning process, illustrating how you arrive at conclusions. '
|
||||
'Conclude with "Answer: " to present a concise, definitive response as a noun phrase, no elaborations.')
|
||||
|
||||
|
||||
filter_triple_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are a critical component of a high-stakes question-answering system used by top researchers and decision-makers worldwide.
|
||||
Your task is to filter facts based on their relevance to a given query.
|
||||
The query requires careful analysis and possibly multi-hop reasoning to connect different pieces of information.
|
||||
You must select all relevant facts from the provided candidate list, aiding in reasoning and providing an accurate answer.
|
||||
The output should be in JSON format, e.g., {"fact": [["s1", "p1", "o1"], ["s2", "p2", "o2"]]}, and if no facts are relevant, return an empty list, {"fact": []}.
|
||||
The accuracy of your response is paramount, as it will directly impact the decisions made by these high-level stakeholders. You must only use facts from the candidate list and not generate new facts.
|
||||
The future of critical decision-making relies on your ability to accurately filter and present relevant information.
|
||||
|
||||
Your input fields are:
|
||||
1. question (str): Query for retrieval
|
||||
2. fact_before_filter (str): Candidate facts to be filtered
|
||||
|
||||
Your output fields are:
|
||||
1. fact_after_filter (Fact): Filtered facts in JSON format
|
||||
|
||||
All interactions will be structured as:
|
||||
[[ ## question ## ]]
|
||||
{question}
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{fact_before_filter}
|
||||
|
||||
[[ ## fact_after_filter ## ]]
|
||||
{fact_after_filter}
|
||||
|
||||
The output must be parseable according to JSON schema: {"type": "object", "properties": {"fact": {"type": "array", "items": {"type": "array", "items": {"type": "string"}}}}, "required": ["fact"]}"""
|
||||
},
|
||||
# Example 1
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
Are Imperial River (Florida) and Amaradia (Dolj) both located in the same country?
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{"fact": [["imperial river", "is located in", "florida"], ["imperial river", "is a river in", "united states"], ["imperial river", "may refer to", "south america"], ["amaradia", "flows through", "ro ia de amaradia"], ["imperial river", "may refer to", "united states"]]}"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"fact":[["imperial river","is located in","florida"],["imperial river","is a river in","united states"],["amaradia","flows through","ro ia de amaradia"]]}"""
|
||||
},
|
||||
|
||||
# Example 2
|
||||
{
|
||||
"role": "user",
|
||||
"content": """[[ ## question ## ]]
|
||||
When is the director of film The Ancestor 's birthday?
|
||||
|
||||
[[ ## fact_before_filter ## ]]
|
||||
{"fact": [["jean jacques annaud", "born on", "1 october 1943"], ["tsui hark", "born on", "15 february 1950"], ["pablo trapero", "born on", "4 october 1971"], ["the ancestor", "directed by", "guido brignone"], ["benh zeitlin", "born on", "october 14 1982"]]}"""
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": """{"fact":[["the ancestor","directed by","guido brignone"]]}"""
|
||||
},
|
||||
]
|
||||
108
AIEC-RAG/atlas_rag/llm_generator/prompt/react.py
Normal file
108
AIEC-RAG/atlas_rag/llm_generator/prompt/react.py
Normal file
@ -0,0 +1,108 @@
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
from typing import Union
|
||||
from logging import Logger
|
||||
|
||||
class ReAct():
|
||||
def __init__(self, llm:LLMGenerator):
|
||||
self.llm = llm
|
||||
|
||||
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'], max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
|
||||
"""
|
||||
Generate a response using RAG with ReAct framework, starting with an initial search using the original query.
|
||||
|
||||
Args:
|
||||
question (str): The question to answer
|
||||
retriever: The retriever instance to use for searching
|
||||
max_iterations (int): Maximum number of ReAct iterations
|
||||
max_new_tokens (int): Maximum number of tokens to generate per iteration
|
||||
|
||||
Returns:
|
||||
tuple: (final_answer, search_history)
|
||||
- final_answer: The final answer generated
|
||||
- search_history: List of (thought, action, observation) tuples
|
||||
"""
|
||||
search_history = []
|
||||
|
||||
# Perform initial search with the original query
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = ". ".join(initial_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = "\n".join(initial_context)
|
||||
|
||||
# Start ReAct process with the initial context
|
||||
for iteration in range(max_iterations):
|
||||
# First, analyze if we can answer with current context
|
||||
analysis_response = self.llm.generate_with_react(
|
||||
question=question,
|
||||
context=current_context,
|
||||
max_new_tokens=max_new_tokens,
|
||||
search_history=search_history,
|
||||
logger = logger
|
||||
)
|
||||
|
||||
if logger:
|
||||
logger.info(f"Analysis response: {analysis_response}")
|
||||
|
||||
try:
|
||||
# Parse the analysis response
|
||||
thought = analysis_response.split("Thought:")[1].split("\n")[0]
|
||||
if logger:
|
||||
logger.info(f"Thought: {thought}")
|
||||
action = analysis_response.split("Action:")[1].split("\n")[0]
|
||||
answer = analysis_response.split("Answer:")[1].strip()
|
||||
|
||||
# If the answer indicates we can answer with current context
|
||||
if answer.lower() != "need more information":
|
||||
search_history.append((thought, action, "Using current context"))
|
||||
return answer, search_history
|
||||
|
||||
# If we need more information, perform the search
|
||||
if "search" in action.lower():
|
||||
# Extract search query from the action
|
||||
search_query = action.split("search for")[-1].strip()
|
||||
|
||||
# Perform the search
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
# Filter out contexts that are already in current_context
|
||||
current_contexts = current_context.split(". ")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = ". ".join(new_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
# Filter out contexts that are already in current_context
|
||||
current_contexts = current_context.split("\n")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = "\n".join(new_context)
|
||||
|
||||
# Store the search results as observation
|
||||
if new_context:
|
||||
observation = f"Found information: {new_context}"
|
||||
else:
|
||||
observation = "No new information found. Consider searching for related entities or events."
|
||||
search_history.append((thought, action, observation))
|
||||
|
||||
# Update context with new search results
|
||||
if new_context:
|
||||
current_context = f"{current_context}\n{new_context}"
|
||||
if logger:
|
||||
logger.info(f"New search results: {new_context}")
|
||||
else:
|
||||
if logger:
|
||||
logger.info("No new information found, suggesting to try related entities")
|
||||
|
||||
else:
|
||||
# If no search is needed but we can't answer, something went wrong
|
||||
search_history.append((thought, action, "No action taken but answer not found"))
|
||||
return "Unable to find answer", search_history
|
||||
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.error(f"Error parsing ReAct response: {e}")
|
||||
return analysis_response, search_history
|
||||
|
||||
# If we've reached max iterations, return the last answer
|
||||
return answer, search_history
|
||||
@ -0,0 +1,313 @@
|
||||
TRIPLE_INSTRUCTIONS = {
|
||||
"en":{
|
||||
"system": "You are a helpful assistant who always response in a valid array of JSON objects without any explanation",
|
||||
"entity_relation": """Given a passage, summarize all the important entities and the relations between them in a concise manner. Relations should briefly capture the connections between entities, without repeating information from the head and tail entities. The entities should be as specific as possible. Exclude pronouns from being considered as entities.
|
||||
You must **strictly output in the following JSON format**:\n
|
||||
[
|
||||
{
|
||||
"Head": "{a noun}",
|
||||
"Relation": "{a verb}",
|
||||
"Tail": "{a noun}",
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_entity": """Please analyze and summarize the participation relations between the events and entities in the given paragraph. Each event is a single independent sentence. Additionally, identify all the entities that participated in the events. Do not use ellipses.
|
||||
You must **strictly output in the following JSON format**:\n
|
||||
[
|
||||
{
|
||||
"Event": "{a simple sentence describing an event}",
|
||||
"Entity": ["entity 1", "entity 2", "..."]
|
||||
}...
|
||||
] """,
|
||||
|
||||
"event_relation": """Please analyze and summarize the relationships between the events in the paragraph. Each event is a single independent sentence. Identify temporal and causal relationships between the events using the following types: before, after, at the same time, because, and as a result. Each extracted triple should be specific, meaningful, and able to stand alone. Do not use ellipses.
|
||||
You must **strictly output in the following JSON format**:\n
|
||||
[
|
||||
{
|
||||
"Head": "{a simple sentence describing the event 1}",
|
||||
"Relation": "{temporal or causality relation between the events}",
|
||||
"Tail": "{a simple sentence describing the event 2}"
|
||||
}...
|
||||
]""",
|
||||
"passage_start" : """Here is the passage."""
|
||||
},
|
||||
"zh-CN": {
|
||||
"system": """"你是一个始终以有效JSON数组格式回应的助手""",
|
||||
"entity_relation": """给定一段文字,提取所有重要实体及其关系,并以简洁的方式总结。关系描述应清晰表达实体间的联系,且不重复头尾实体的信息。实体需具体明确,排除代词。
|
||||
|
||||
**重要格式要求:**
|
||||
1. Head字段:必须是一个字符串,不能为空
|
||||
2. Relation字段:必须是一个字符串且不能为空,如果不确定请用"相关"
|
||||
3. Tail字段:必须是一个字符串,不能为空。如果有多个项目,请用顿号(、)连接
|
||||
|
||||
返回格式必须为以下JSON结构,内容需用简体中文表述:
|
||||
[
|
||||
{
|
||||
"Head": "{名词}",
|
||||
"Relation": "{动词或关系描述,不能为空}",
|
||||
"Tail": "{名词,多个用顿号连接}"
|
||||
}...
|
||||
]
|
||||
|
||||
示例:
|
||||
输入:"企业内部审计数字化产品技术要求包括审计作业、审计管理、数据建设及应用"
|
||||
输出:
|
||||
[
|
||||
{"Head": "企业内部审计数字化产品", "Relation": "技术要求包括", "Tail": "审计作业、审计管理、数据建设及应用"}
|
||||
]""",
|
||||
|
||||
"event_entity": """分析段落中的事件及其参与实体。每个事件应为独立单句,列出所有相关实体(需具体,不含代词)。
|
||||
返回格式必须为以下JSON结构,内容需用简体中文表述:
|
||||
[
|
||||
{
|
||||
"Event": "{描述事件的简单句子}",
|
||||
"Entity": ["实体1", "实体2", "..."]
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_relation": """分析事件间的时序或因果关系,关系类型包括:之前,之后,同时,因为,结果.每个事件应为独立单句。
|
||||
返回格式必须为以下JSON结构.内容需用简体中文表述.
|
||||
[
|
||||
{
|
||||
"Head": "{事件1描述}",
|
||||
"Relation": "{时序/因果关系}",
|
||||
"Tail": "{事件2描述}"
|
||||
}...
|
||||
]""",
|
||||
|
||||
"passage_start": "给定以下段落:"
|
||||
},
|
||||
"zh-HK": {
|
||||
"system": "你是一個始終以有效JSON數組格式回覆的助手",
|
||||
"entity_relation": """給定一段文字,提取所有重要實體及其關係,並以簡潔的方式總結。關係描述應清晰表達實體間的聯繫,且不重複頭尾實體的信息。實體需具體明確,排除代詞。
|
||||
返回格式必須為以下JSON結構,內容需用繁體中文表述:
|
||||
[
|
||||
{
|
||||
"Head": "{名詞}",
|
||||
"Relation": "{動詞或關係描述}",
|
||||
"Tail": "{名詞}"
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_entity": """分析段落中的事件及其參與實體。每個事件應為獨立單句,列出所有相關實體(需具體,不含代詞)。
|
||||
返回格式必須為以下JSON結構,內容需用繁體中文表述:
|
||||
[
|
||||
{
|
||||
"Event": "{描述事件的簡單句子}",
|
||||
"Entity": ["實體1", "實體2", "..."]
|
||||
}...
|
||||
]""",
|
||||
|
||||
"event_relation": """分析事件間的時序或因果關係,關係類型包括:之前,之後,同時,因為,結果.每個事件應為獨立單句。
|
||||
返回格式必須為以下JSON結構.內容需用繁體中文表述.
|
||||
[
|
||||
{
|
||||
"Head": "{事件1描述}",
|
||||
"Relation": "{時序/因果關係}",
|
||||
"Tail": "{事件2描述}"
|
||||
}...
|
||||
]""",
|
||||
|
||||
"passage_start": "給定以下段落:"
|
||||
}
|
||||
}
|
||||
|
||||
CONCEPT_INSTRUCTIONS = {
|
||||
"en": {
|
||||
"event": """I will give you an EVENT. You need to give several phrases containing 1-2 words for the ABSTRACT EVENT of this EVENT.
|
||||
You must return your answer in the following format: phrases1, phrases2, phrases3,...
|
||||
You can't return anything other than answers.
|
||||
These abstract event words should fulfill the following requirements.
|
||||
1. The ABSTRACT EVENT phrases can well represent the EVENT, and it could be the type of the EVENT or the related concepts of the EVENT.
|
||||
2. Strictly follow the provided format, do not add extra characters or words.
|
||||
3. Write at least 3 or more phrases at different abstract level if possible.
|
||||
4. Do not repeat the same word and the input in the answer.
|
||||
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
|
||||
|
||||
EVENT: A man retreats to mountains and forests.
|
||||
Your answer: retreat, relaxation, escape, nature, solitude
|
||||
EVENT: A cat chased a prey into its shelter
|
||||
Your answer: hunting, escape, predation, hidding, stalking
|
||||
EVENT: Sam playing with his dog
|
||||
Your answer: relaxing event, petting, playing, bonding, friendship
|
||||
EVENT: [EVENT]
|
||||
Your answer:""",
|
||||
"entity":"""I will give you an ENTITY. You need to give several phrases containing 1-2 words for the ABSTRACT ENTITY of this ENTITY.
|
||||
You must return your answer in the following format: phrases1, phrases2, phrases3,...
|
||||
You can't return anything other than answers.
|
||||
These abstract intention words should fulfill the following requirements.
|
||||
1. The ABSTRACT ENTITY phrases can well represent the ENTITY, and it could be the type of the ENTITY or the related concepts of the ENTITY.
|
||||
2. Strictly follow the provided format, do not add extra characters or words.
|
||||
3. Write at least 3 or more phrases at different abstract level if possible.
|
||||
4. Do not repeat the same word and the input in the answer.
|
||||
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
|
||||
|
||||
ENTITY: Soul
|
||||
CONTEXT: premiered BFI London Film Festival, became highest-grossing Pixar release
|
||||
Your answer: movie, film
|
||||
|
||||
ENTITY: Thinkpad X60
|
||||
CONTEXT: Richard Stallman announced he is using Trisquel on a Thinkpad X60
|
||||
Your answer: Thinkpad, laptop, machine, device, hardware, computer, brand
|
||||
|
||||
ENTITY: Harry Callahan
|
||||
CONTEXT: bluffs another robber, tortures Scorpio
|
||||
Your answer: person, Amarican, character, police officer, detective
|
||||
|
||||
ENTITY: Black Mountain College
|
||||
CONTEXT: was started by John Andrew Rice, attracted faculty
|
||||
Your answer: college, university, school, liberal arts college
|
||||
|
||||
EVENT: 1st April
|
||||
CONTEXT: Utkal Dibas celebrates
|
||||
Your answer: date, day, time, festival
|
||||
|
||||
ENTITY: [ENTITY]
|
||||
CONTEXT: [CONTEXT]
|
||||
Your answer:""",
|
||||
"relation":"""I will give you an RELATION. You need to give several phrases containing 1-2 words for the ABSTRACT RELATION of this RELATION.
|
||||
You must return your answer in the following format: phrases1, phrases2, phrases3,...
|
||||
You can't return anything other than answers.
|
||||
These abstract intention words should fulfill the following requirements.
|
||||
1. The ABSTRACT RELATION phrases can well represent the RELATION, and it could be the type of the RELATION or the simplest concepts of the RELATION.
|
||||
2. Strictly follow the provided format, do not add extra characters or words.
|
||||
3. Write at least 3 or more phrases at different abstract level if possible.
|
||||
4. Do not repeat the same word and the input in the answer.
|
||||
5. Stop immediately if you can't think of any more phrases, and no explanation is needed.
|
||||
|
||||
RELATION: participated in
|
||||
Your answer: become part of, attend, take part in, engage in, involve in
|
||||
RELATION: be included in
|
||||
Your answer: join, be a part of, be a member of, be a component of
|
||||
RELATION: [RELATION]
|
||||
Your answer:"""
|
||||
},
|
||||
"zh-CN": {
|
||||
"event": """我将给你一个事件。你需要为这个事件的抽象概念提供几个1-2个词的短语。
|
||||
你必须按照以下格式返回答案:短语1, 短语2, 短语3,...
|
||||
除了答案外不要返回任何其他内容,请以简体中文输出。
|
||||
这些抽象事件短语应满足以下要求:
|
||||
1. 能很好地代表该事件的类型或相关概念
|
||||
2. 严格遵循给定格式,不要添加额外字符或词语
|
||||
3. 尽可能提供3个或以上不同抽象层次的短语
|
||||
4. 不要重复相同词语或输入内容
|
||||
5. 如果无法想出更多短语立即停止,不需要解释
|
||||
|
||||
事件:一个人退隐到山林中
|
||||
你的回答:退隐, 放松, 逃避, 自然, 独处
|
||||
事件:一只猫将猎物追进巢穴
|
||||
你的回答:捕猎, 逃跑, 捕食, 躲藏, 潜行
|
||||
事件:山姆和他的狗玩耍
|
||||
你的回答:休闲活动, 抚摸, 玩耍, bonding, 友谊
|
||||
事件:[EVENT]
|
||||
请以简体中文输出你的回答:""",
|
||||
"entity":"""我将给你一个实体。你需要为这个实体的抽象概念提供几个1-2个词的短语。
|
||||
你必须按照以下格式返回答案:短语1, 短语2, 短语3,...
|
||||
除了答案外不要返回任何其他内容,请以简体中文输出。
|
||||
这些抽象实体短语应满足以下要求:
|
||||
1. 能很好地代表该实体的类型或相关概念
|
||||
2. 严格遵循给定格式,不要添加额外字符或词语
|
||||
3. 尽可能提供3个或以上不同抽象层次的短语
|
||||
4. 不要重复相同词语或输入内容
|
||||
5. 如果无法想出更多短语立即停止,不需要解释
|
||||
|
||||
实体:心灵奇旅
|
||||
上下文:在BFI伦敦电影节首映,成为皮克斯最卖座影片
|
||||
你的回答:电影, 影片
|
||||
实体:Thinkpad X60
|
||||
上下文:Richard Stallman宣布他在Thinkpad X60上使用Trisquel系统
|
||||
你的回答:Thinkpad, 笔记本电脑, 机器, 设备, 硬件, 电脑, 品牌
|
||||
实体:哈利·卡拉汉
|
||||
上下文:吓退另一个劫匪,折磨天蝎座
|
||||
你的回答:人物, 美国人, 角色, 警察, 侦探
|
||||
实体:黑山学院
|
||||
上下文:由John Andrew Rice创办,吸引了众多教员
|
||||
你的回答:学院, 大学, 学校, 文理学院
|
||||
事件:4月1日
|
||||
上下文:庆祝Utkal Dibas
|
||||
你的回答:日期, 日子, 时间, 节日
|
||||
实体:[ENTITY]
|
||||
上下文:[CONTEXT]
|
||||
请以简体中文输出你的回答:""",
|
||||
"relation":"""我将给你一个关系。你需要为这个关系的抽象概念提供几个1-2个词的短语。
|
||||
你必须按照以下格式返回答案:短语1, 短语2, 短语3,...
|
||||
除了答案外不要返回任何其他内容,请以简体中文输出。
|
||||
这些抽象关系短语应满足以下要求:
|
||||
1. 能很好地代表该关系的类型或最基本概念
|
||||
2. 严格遵循给定格式,不要添加额外字符或词语
|
||||
3. 尽可能提供3个或以上不同抽象层次的短语
|
||||
4. 不要重复相同词语或输入内容
|
||||
5. 如果无法想出更多短语立即停止,不需要解释
|
||||
|
||||
关系:参与
|
||||
你的回答:成为一部分, 参加, 参与其中, 涉及, 卷入
|
||||
关系:被包含在
|
||||
你的回答:加入, 成为一部分, 成为成员, 成为组成部分
|
||||
关系:[RELATION]
|
||||
请以简体中文输出你的回答:"""
|
||||
},
|
||||
"zh-HK": {
|
||||
"event": """我將給你一個事件。你需要為這個事件的抽象概念提供幾個1-2個詞的短語。
|
||||
你必須按照以下格式返回答案:短語1, 短語2, 短語3,...
|
||||
除了答案外不要返回任何其他內容,請以繁體中文輸出。
|
||||
這些抽象事件短語應滿足以下要求:
|
||||
1. 能很好地代表該事件的類型或相關概念
|
||||
2. 嚴格遵循給定格式,不要添加額外字符或詞語
|
||||
3. 盡可能提供3個或以上不同抽象層次的短語
|
||||
4. 不要重複相同詞語或輸入內容
|
||||
5. 如果無法想出更多短語立即停止,不需要解釋
|
||||
|
||||
事件:一個人退隱到山林中
|
||||
你的回答:退隱, 放鬆, 逃避, 自然, 獨處
|
||||
事件:一隻貓將獵物追進巢穴
|
||||
你的回答:捕獵, 逃跑, 捕食, 躲藏, 潛行
|
||||
事件:山姆和他的狗玩耍
|
||||
你的回答:休閒活動, 撫摸, 玩耍, bonding, 友誼
|
||||
事件:[EVENT]
|
||||
請以繁體中文輸出你的回答:""",
|
||||
"entity":"""我將給你一個實體。你需要為這個實體的抽象概念提供幾個1-2個詞的短語。
|
||||
你必須按照以下格式返回答案:短語1, 短語2, 短語3,...
|
||||
除了答案外不要返回任何其他內容,請以繁體中文輸出。
|
||||
這些抽象實體短語應滿足以下要求:
|
||||
1. 能很好地代表該實體的類型或相關概念
|
||||
2. 嚴格遵循給定格式,不要添加額外字符或詞語
|
||||
3. 盡可能提供3個或以上不同抽象層次的短語
|
||||
4. 不要重複相同詞語或輸入內容
|
||||
5. 如果無法想出更多短語立即停止,不需要解釋
|
||||
|
||||
實體:心靈奇旅
|
||||
上下文:在BFI倫敦電影節首映,成為皮克斯最賣座影片
|
||||
你的回答:電影, 影片
|
||||
實體:Thinkpad X60
|
||||
上下文:Richard Stallman宣布他在Thinkpad X60上使用Trisquel系統
|
||||
你的回答:Thinkpad, 筆記本電腦, 機器, 設備, 硬件, 電腦, 品牌
|
||||
實體:哈利·卡拉漢
|
||||
上下文:嚇退另一個劫匪,折磨天蠍座
|
||||
你的回答:人物, 美國人, 角色, 警察, 偵探
|
||||
實體:黑山學院
|
||||
上下文:由John Andrew Rice創辦,吸引了眾多教員
|
||||
你的回答:學院, 大學, 學校, 文理學院
|
||||
事件:4月1日
|
||||
上下文:慶祝Utkal Dibas
|
||||
你的回答:日期, 日子, 時間, 節日
|
||||
實體:[ENTITY]
|
||||
上下文:[CONTEXT]
|
||||
請以繁體中文輸出你的回答:""",
|
||||
"relation":"""我將給你一個關係。你需要為這個關係的抽象概念提供幾個1-2個詞的短語。
|
||||
你必須按照以下格式返回答案:短語1, 短語2, 短語3,...
|
||||
除了答案外不要返回任何其他內容,請以繁體中文輸出。
|
||||
這些抽象關係短語應滿足以下要求:
|
||||
1. 能很好地代表該關係的類型或最基本概念
|
||||
2. 嚴格遵循給定格式,不要添加額外字符或詞語
|
||||
3. 盡可能提供3個或以上不同抽象層次的短語
|
||||
4. 不要重複相同詞語或輸入內容
|
||||
5. 如果無法想出更多短語立即停止,不需要解釋
|
||||
|
||||
關係:參與
|
||||
你的回答:成為一部分, 參加, 參與其中, 涉及, 捲入
|
||||
關係:被包含在
|
||||
你的回答:加入, 成為一部分, 成為成員, 成為組成部分
|
||||
關係:[RELATION]
|
||||
請以繁體中文輸出你的回答:"""
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user