Files
AIEC-RAG---/AIEC-RAG/atlas_rag/llm_generator/prompt/react.py
2025-09-25 10:33:37 +08:00

108 lines
5.4 KiB
Python

from atlas_rag.llm_generator import LLMGenerator
from atlas_rag.retriever.base import BaseEdgeRetriever, BasePassageRetriever
from typing import Union
from logging import Logger
class ReAct():
def __init__(self, llm:LLMGenerator):
self.llm = llm
def generate_with_rag_react(self, question: str, retriever: Union['BaseEdgeRetriever', 'BasePassageRetriever'], max_iterations: int = 5, max_new_tokens: int = 1024, logger: Logger = None):
"""
Generate a response using RAG with ReAct framework, starting with an initial search using the original query.
Args:
question (str): The question to answer
retriever: The retriever instance to use for searching
max_iterations (int): Maximum number of ReAct iterations
max_new_tokens (int): Maximum number of tokens to generate per iteration
Returns:
tuple: (final_answer, search_history)
- final_answer: The final answer generated
- search_history: List of (thought, action, observation) tuples
"""
search_history = []
# Perform initial search with the original query
if isinstance(retriever, BaseEdgeRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = ". ".join(initial_context)
elif isinstance(retriever, BasePassageRetriever):
initial_context, _ = retriever.retrieve(question, topN=5)
current_context = "\n".join(initial_context)
# Start ReAct process with the initial context
for iteration in range(max_iterations):
# First, analyze if we can answer with current context
analysis_response = self.llm.generate_with_react(
question=question,
context=current_context,
max_new_tokens=max_new_tokens,
search_history=search_history,
logger = logger
)
if logger:
logger.info(f"Analysis response: {analysis_response}")
try:
# Parse the analysis response
thought = analysis_response.split("Thought:")[1].split("\n")[0]
if logger:
logger.info(f"Thought: {thought}")
action = analysis_response.split("Action:")[1].split("\n")[0]
answer = analysis_response.split("Answer:")[1].strip()
# If the answer indicates we can answer with current context
if answer.lower() != "need more information":
search_history.append((thought, action, "Using current context"))
return answer, search_history
# If we need more information, perform the search
if "search" in action.lower():
# Extract search query from the action
search_query = action.split("search for")[-1].strip()
# Perform the search
if isinstance(retriever, BaseEdgeRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
# Filter out contexts that are already in current_context
current_contexts = current_context.split(". ")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = ". ".join(new_context)
elif isinstance(retriever, BasePassageRetriever):
new_context, _ = retriever.retrieve(search_query, topN=3)
# Filter out contexts that are already in current_context
current_contexts = current_context.split("\n")
new_context = [ctx for ctx in new_context if ctx not in current_contexts]
new_context = "\n".join(new_context)
# Store the search results as observation
if new_context:
observation = f"Found information: {new_context}"
else:
observation = "No new information found. Consider searching for related entities or events."
search_history.append((thought, action, observation))
# Update context with new search results
if new_context:
current_context = f"{current_context}\n{new_context}"
if logger:
logger.info(f"New search results: {new_context}")
else:
if logger:
logger.info("No new information found, suggesting to try related entities")
else:
# If no search is needed but we can't answer, something went wrong
search_history.append((thought, action, "No action taken but answer not found"))
return "Unable to find answer", search_history
except Exception as e:
if logger:
logger.error(f"Error parsing ReAct response: {e}")
return analysis_response, search_history
# If we've reached max iterations, return the last answer
return answer, search_history