first commit
This commit is contained in:
108
AIEC-RAG/atlas_rag/llm_generator/prompt/react.py
Normal file
108
AIEC-RAG/atlas_rag/llm_generator/prompt/react.py
Normal file
@ -0,0 +1,108 @@
|
||||
from atlas_rag.llm_generator import LLMGenerator
|
||||
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
|
||||
from typing import Union
|
||||
from logging import Logger
|
||||
|
||||
class ReAct():
|
||||
def __init__(self, llm:LLMGenerator):
|
||||
self.llm = llm
|
||||
|
||||
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'], max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
|
||||
"""
|
||||
Generate a response using RAG with ReAct framework, starting with an initial search using the original query.
|
||||
|
||||
Args:
|
||||
question (str): The question to answer
|
||||
retriever: The retriever instance to use for searching
|
||||
max_iterations (int): Maximum number of ReAct iterations
|
||||
max_new_tokens (int): Maximum number of tokens to generate per iteration
|
||||
|
||||
Returns:
|
||||
tuple: (final_answer, search_history)
|
||||
- final_answer: The final answer generated
|
||||
- search_history: List of (thought, action, observation) tuples
|
||||
"""
|
||||
search_history = []
|
||||
|
||||
# Perform initial search with the original query
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = ". ".join(initial_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
initial_context, _ = retriever.retrieve(question, topN=5)
|
||||
current_context = "\n".join(initial_context)
|
||||
|
||||
# Start ReAct process with the initial context
|
||||
for iteration in range(max_iterations):
|
||||
# First, analyze if we can answer with current context
|
||||
analysis_response = self.llm.generate_with_react(
|
||||
question=question,
|
||||
context=current_context,
|
||||
max_new_tokens=max_new_tokens,
|
||||
search_history=search_history,
|
||||
logger = logger
|
||||
)
|
||||
|
||||
if logger:
|
||||
logger.info(f"Analysis response: {analysis_response}")
|
||||
|
||||
try:
|
||||
# Parse the analysis response
|
||||
thought = analysis_response.split("Thought:")[1].split("\n")[0]
|
||||
if logger:
|
||||
logger.info(f"Thought: {thought}")
|
||||
action = analysis_response.split("Action:")[1].split("\n")[0]
|
||||
answer = analysis_response.split("Answer:")[1].strip()
|
||||
|
||||
# If the answer indicates we can answer with current context
|
||||
if answer.lower() != "need more information":
|
||||
search_history.append((thought, action, "Using current context"))
|
||||
return answer, search_history
|
||||
|
||||
# If we need more information, perform the search
|
||||
if "search" in action.lower():
|
||||
# Extract search query from the action
|
||||
search_query = action.split("search for")[-1].strip()
|
||||
|
||||
# Perform the search
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
# Filter out contexts that are already in current_context
|
||||
current_contexts = current_context.split(". ")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = ". ".join(new_context)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
new_context, _ = retriever.retrieve(search_query, topN=3)
|
||||
# Filter out contexts that are already in current_context
|
||||
current_contexts = current_context.split("\n")
|
||||
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
|
||||
new_context = "\n".join(new_context)
|
||||
|
||||
# Store the search results as observation
|
||||
if new_context:
|
||||
observation = f"Found information: {new_context}"
|
||||
else:
|
||||
observation = "No new information found. Consider searching for related entities or events."
|
||||
search_history.append((thought, action, observation))
|
||||
|
||||
# Update context with new search results
|
||||
if new_context:
|
||||
current_context = f"{current_context}\n{new_context}"
|
||||
if logger:
|
||||
logger.info(f"New search results: {new_context}")
|
||||
else:
|
||||
if logger:
|
||||
logger.info("No new information found, suggesting to try related entities")
|
||||
|
||||
else:
|
||||
# If no search is needed but we can't answer, something went wrong
|
||||
search_history.append((thought, action, "No action taken but answer not found"))
|
||||
return "Unable to find answer", search_history
|
||||
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.error(f"Error parsing ReAct response: {e}")
|
||||
return analysis_response, search_history
|
||||
|
||||
# If we've reached max iterations, return the last answer
|
||||
return answer, search_history
|
||||
Reference in New Issue
Block a user