237 lines
12 KiB
Python
237 lines
12 KiB
Python
|
|
import os
|
||
|
|
import json
|
||
|
|
import numpy as np
|
||
|
|
from logging import Logger
|
||
|
|
from atlas_rag.retriever.base import BaseRetriever, BaseEdgeRetriever, BasePassageRetriever
|
||
|
|
from typing import List
|
||
|
|
from datetime import datetime
|
||
|
|
from transformers import AutoModel
|
||
|
|
from sentence_transformers import SentenceTransformer
|
||
|
|
from tqdm import tqdm
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from atlas_rag.vectorstore.embedding_model import NvEmbed, SentenceEmbedding
|
||
|
|
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||
|
|
from atlas_rag.evaluation.evaluation import QAJudger
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from atlas_rag.llm_generator.prompt.react import ReAct
|
||
|
|
|
||
|
|
|
||
|
|
def normalize_embeddings(embeddings):
|
||
|
|
"""Normalize the embeddings to unit length (L2 norm)."""
|
||
|
|
if isinstance(embeddings, torch.Tensor):
|
||
|
|
# Handle PyTorch tensors
|
||
|
|
norm_emb = F.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||
|
|
elif isinstance(embeddings, np.ndarray):
|
||
|
|
# Handle numpy arrays
|
||
|
|
norm_emb = F.normalize(torch.tensor(embeddings), p=2, dim=1).detach().cpu().numpy()
|
||
|
|
else:
|
||
|
|
raise TypeError(f"Unsupported input type: {type(embeddings)}. Must be torch.Tensor or np.ndarray")
|
||
|
|
|
||
|
|
return norm_emb
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class BenchMarkConfig:
|
||
|
|
"""
|
||
|
|
Configuration class for benchmarking.
|
||
|
|
|
||
|
|
Attributes:
|
||
|
|
dataset_name (str): Name of the dataset. Default is "hotpotqa".
|
||
|
|
question_file (str): Path to the question file. Default is "hotpotqa".
|
||
|
|
graph_file (str): Path to the graph file. Default is "hotpotqa_concept.graphml".
|
||
|
|
include_events (bool): Whether to include events. Default is False.
|
||
|
|
include_concept (bool): Whether to include concepts. Default is False.
|
||
|
|
reader_model_name (str): Name of the reader model. Default is "meta-llama/Llama-2-7b-chat-hf".
|
||
|
|
encoder_model_name (str): Name of the encoder model. Default is "nvidia/NV-Embed-v2".
|
||
|
|
number_of_samples (int): Number of samples to use from the dataset. Default is -1 (use all samples).
|
||
|
|
"""
|
||
|
|
dataset_name: str = "hotpotqa"
|
||
|
|
question_file: str = "hotpotqa"
|
||
|
|
include_events: bool = False
|
||
|
|
include_concept: bool = False
|
||
|
|
reader_model_name: str = "meta-llama/Llama-2-7b-chat-hf"
|
||
|
|
encoder_model_name: str = "nvidia/NV-Embed-v2"
|
||
|
|
number_of_samples: int = -1 # Default to -1 to use all samples
|
||
|
|
react_max_iterations: int = 5
|
||
|
|
|
||
|
|
|
||
|
|
class RAGBenchmark:
|
||
|
|
def __init__(self, config:BenchMarkConfig, logger:Logger = None):
|
||
|
|
self.config = config
|
||
|
|
self.logger = logger
|
||
|
|
self.logging : bool = self.logger is not None
|
||
|
|
|
||
|
|
def load_encoder_model(self, encoder_model_name, **kwargs):
|
||
|
|
if encoder_model_name == "nvidia/NV-Embed-v2":
|
||
|
|
sentence_encoder = AutoModel.from_pretrained("nvidia/NV-Embed-v2", **kwargs)
|
||
|
|
return NvEmbed(sentence_encoder)
|
||
|
|
else:
|
||
|
|
sentence_encoder = SentenceTransformer(encoder_model_name, **kwargs)
|
||
|
|
return SentenceEmbedding(sentence_encoder)
|
||
|
|
|
||
|
|
def run(self, retrievers:List[BaseRetriever],
|
||
|
|
llm_generator:LLMGenerator,
|
||
|
|
use_react: bool = False):
|
||
|
|
qa_judge = QAJudger()
|
||
|
|
if use_react:
|
||
|
|
react_agent = ReAct(llm_generator=llm_generator)
|
||
|
|
result_list = []
|
||
|
|
with open(self.config.question_file, "r") as f:
|
||
|
|
data = json.load(f)
|
||
|
|
print(f"Data loaded from {self.config.question_file}")
|
||
|
|
if self.config.number_of_samples > 0:
|
||
|
|
data = data[:self.config.number_of_samples]
|
||
|
|
print(f"Using only the first {self.config.number_of_samples} samples from the dataset")
|
||
|
|
for sample in tqdm(data):
|
||
|
|
question = sample["question"]
|
||
|
|
answer = sample["answer"]
|
||
|
|
|
||
|
|
gold_file_ids = []
|
||
|
|
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||
|
|
for fact in sample["supporting_facts"]:
|
||
|
|
gold_file_ids.append(fact[0])
|
||
|
|
elif self.config.dataset_name == "musique":
|
||
|
|
for paragraph in sample["paragraphs"]:
|
||
|
|
if paragraph["is_supporting"]:
|
||
|
|
gold_file_ids.append(paragraph["paragraph_text"])
|
||
|
|
else:
|
||
|
|
print("Dataset not supported")
|
||
|
|
continue
|
||
|
|
|
||
|
|
result = {
|
||
|
|
"question": question,
|
||
|
|
"answer": answer,
|
||
|
|
"gold_file_ids": gold_file_ids,
|
||
|
|
}
|
||
|
|
|
||
|
|
if self.logging:
|
||
|
|
self.logger.info(f"Question: {question}")
|
||
|
|
for retriever in retrievers:
|
||
|
|
if use_react:
|
||
|
|
# Use RAG with ReAct
|
||
|
|
llm_generated_answer, search_history = react_agent.generate_with_rag_react(
|
||
|
|
question=question,
|
||
|
|
retriever=retriever,
|
||
|
|
max_iterations=self.config.react_max_iterations,
|
||
|
|
max_new_tokens=2048,
|
||
|
|
logger=self.logger
|
||
|
|
)
|
||
|
|
self.logger.info(f"Search history: {search_history}")
|
||
|
|
self.logger.info(f"Final answer: {llm_generated_answer}")
|
||
|
|
# Store search history in results
|
||
|
|
result[f"{retriever.__class__.__name__}_search_history"] = search_history
|
||
|
|
|
||
|
|
# Extract all retrieved contexts from search history
|
||
|
|
all_contexts = []
|
||
|
|
for _, action, observation in search_history:
|
||
|
|
if "search" in action.lower() or "look up" in action.lower():
|
||
|
|
all_contexts.append(observation)
|
||
|
|
|
||
|
|
sorted_context = "\n".join(all_contexts)
|
||
|
|
sorted_context_ids = [] # We don't track IDs in ReAct mode
|
||
|
|
else:
|
||
|
|
# Original RAG implementation
|
||
|
|
sorted_context, sorted_context_ids = retriever.retrieve(question, topN=5)
|
||
|
|
|
||
|
|
if isinstance(retriever, BaseEdgeRetriever):
|
||
|
|
retrieved_context = ". ".join(sorted_context)
|
||
|
|
llm_generated_answer = llm_generator.generate_with_context_kg(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||
|
|
elif isinstance(retriever, BasePassageRetriever):
|
||
|
|
retrieved_context = "\n".join(sorted_context)
|
||
|
|
llm_generated_answer = llm_generator.generate_with_context(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||
|
|
|
||
|
|
if self.logging:
|
||
|
|
self.logger.info(f"{retriever.__class__.__name__} retrieved passages: {sorted_context}")
|
||
|
|
self.logger.info(f"{retriever.__class__.__name__} generated answer: {llm_generated_answer}")
|
||
|
|
|
||
|
|
short_answer = qa_judge.split_answer(llm_generated_answer)
|
||
|
|
em, f1 = qa_judge.judge(short_answer, answer)
|
||
|
|
|
||
|
|
result[f"{retriever.__class__.__name__ }_em"] = em
|
||
|
|
result[f"{retriever.__class__.__name__ }_f1"] = f1
|
||
|
|
result[f"{retriever.__class__.__name__ }_passages"] = sorted_context
|
||
|
|
if not use_react:
|
||
|
|
result[f"{retriever.__class__.__name__ }_id"] = sorted_context_ids
|
||
|
|
result[f"{retriever.__class__.__name__ }_generated_answer"] = llm_generated_answer
|
||
|
|
result[f"{retriever.__class__.__name__ }short_answer"] = short_answer
|
||
|
|
|
||
|
|
# Calculate recall
|
||
|
|
if not use_react: # Only calculate recall for non-ReAct mode
|
||
|
|
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||
|
|
recall_2, recall_5 = qa_judge.recall(sorted_context_ids, gold_file_ids)
|
||
|
|
elif self.config.dataset_name == "musique":
|
||
|
|
recall_2, recall_5 = qa_judge.recall(sorted_context, gold_file_ids)
|
||
|
|
|
||
|
|
result[f"{retriever.__class__.__name__ }_recall@2"] = recall_2
|
||
|
|
result[f"{retriever.__class__.__name__ }_recall@5"] = recall_5
|
||
|
|
|
||
|
|
result_list.append(result)
|
||
|
|
|
||
|
|
self.save_results(result_list, [retriever.__class__.__name__ for retriever in retrievers])
|
||
|
|
|
||
|
|
|
||
|
|
def save_results(self, result_list, retriever_names:List[str]):
|
||
|
|
current_time = datetime.now()
|
||
|
|
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
|
||
|
|
|
||
|
|
dataset_name = self.config.dataset_name
|
||
|
|
include_events = self.config.include_events
|
||
|
|
include_concept = self.config.include_concept
|
||
|
|
encoder_model_name = self.config.encoder_model_name
|
||
|
|
reader_model_name = self.config.reader_model_name
|
||
|
|
|
||
|
|
# use last part of model name as identifier
|
||
|
|
if "/" in encoder_model_name:
|
||
|
|
encoder_model_name = encoder_model_name.split("/")[-1]
|
||
|
|
if "/" in reader_model_name:
|
||
|
|
reader_model_name = reader_model_name.split("/")[-1]
|
||
|
|
|
||
|
|
summary_file = f"./result/{dataset_name}/summary_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||
|
|
if not os.path.exists(os.path.dirname(summary_file)):
|
||
|
|
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
|
||
|
|
|
||
|
|
result_dir = f"./result/{dataset_name}/result_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||
|
|
if not os.path.exists(os.path.dirname(result_dir)):
|
||
|
|
os.makedirs(os.path.dirname(result_dir), exist_ok=True)
|
||
|
|
|
||
|
|
summary_dict = self.calculate_summary(result_list, retriever_names)
|
||
|
|
|
||
|
|
with open(summary_file, "w") as f_summary:
|
||
|
|
json.dump(summary_dict, f_summary)
|
||
|
|
f_summary.write("\n")
|
||
|
|
|
||
|
|
with open(result_dir, "w") as f:
|
||
|
|
for result in result_list:
|
||
|
|
json.dump(result, f)
|
||
|
|
f.write("\n")
|
||
|
|
|
||
|
|
def calculate_summary(self, result_list, method):
|
||
|
|
summary_dict = {}
|
||
|
|
for retriever_name in method:
|
||
|
|
if not all(f"{retriever_name}_em" in result for result in result_list):
|
||
|
|
raise ValueError(f"Missing {retriever_name}_em in results")
|
||
|
|
if not all(f"{retriever_name}_f1" in result for result in result_list):
|
||
|
|
raise ValueError(f"Missing {retriever_name}_f1 in results")
|
||
|
|
|
||
|
|
average_em = sum([result[f"{retriever_name}_em"] for result in result_list]) / len(result_list)
|
||
|
|
average_f1 = sum([result[f"{retriever_name}_f1"] for result in result_list]) / len(result_list)
|
||
|
|
|
||
|
|
# Only calculate recall metrics if they exist in the results
|
||
|
|
if all(f"{retriever_name}_recall@2" in result for result in result_list):
|
||
|
|
average_recall_2 = sum([result[f"{retriever_name}_recall@2"] for result in result_list]) / len(result_list)
|
||
|
|
average_recall_5 = sum([result[f"{retriever_name}_recall@5"] for result in result_list]) / len(result_list)
|
||
|
|
summary_dict.update({
|
||
|
|
f"{retriever_name}_average_f1": average_f1,
|
||
|
|
f"{retriever_name}_average_em": average_em,
|
||
|
|
f"{retriever_name}_average_recall@2": average_recall_2,
|
||
|
|
f"{retriever_name}_average_recall@5": average_recall_5,
|
||
|
|
})
|
||
|
|
else:
|
||
|
|
# For ReAct mode where recall metrics don't exist
|
||
|
|
summary_dict.update({
|
||
|
|
f"{retriever_name}_average_f1": average_f1,
|
||
|
|
f"{retriever_name}_average_em": average_em,
|
||
|
|
})
|
||
|
|
|
||
|
|
return summary_dict
|