first commit
This commit is contained in:
1
AIEC-RAG/atlas_rag/evaluation/__init__.py
Normal file
1
AIEC-RAG/atlas_rag/evaluation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .benchmark import BenchMarkConfig, RAGBenchmark
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
236
AIEC-RAG/atlas_rag/evaluation/benchmark.py
Normal file
236
AIEC-RAG/atlas_rag/evaluation/benchmark.py
Normal file
@ -0,0 +1,236 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from logging import Logger
|
||||
from atlas_rag.retriever.base import BaseRetriever, BaseEdgeRetriever, BasePassageRetriever
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from transformers import AutoModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from atlas_rag.vectorstore.embedding_model import NvEmbed, SentenceEmbedding
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.evaluation.evaluation import QAJudger
|
||||
from dataclasses import dataclass
|
||||
from atlas_rag.llm_generator.prompt.react import ReAct
|
||||
|
||||
|
||||
def normalize_embeddings(embeddings):
|
||||
"""Normalize the embeddings to unit length (L2 norm)."""
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
# Handle PyTorch tensors
|
||||
norm_emb = F.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
elif isinstance(embeddings, np.ndarray):
|
||||
# Handle numpy arrays
|
||||
norm_emb = F.normalize(torch.tensor(embeddings), p=2, dim=1).detach().cpu().numpy()
|
||||
else:
|
||||
raise TypeError(f"Unsupported input type: {type(embeddings)}. Must be torch.Tensor or np.ndarray")
|
||||
|
||||
return norm_emb
|
||||
|
||||
@dataclass
|
||||
class BenchMarkConfig:
|
||||
"""
|
||||
Configuration class for benchmarking.
|
||||
|
||||
Attributes:
|
||||
dataset_name (str): Name of the dataset. Default is "hotpotqa".
|
||||
question_file (str): Path to the question file. Default is "hotpotqa".
|
||||
graph_file (str): Path to the graph file. Default is "hotpotqa_concept.graphml".
|
||||
include_events (bool): Whether to include events. Default is False.
|
||||
include_concept (bool): Whether to include concepts. Default is False.
|
||||
reader_model_name (str): Name of the reader model. Default is "meta-llama/Llama-2-7b-chat-hf".
|
||||
encoder_model_name (str): Name of the encoder model. Default is "nvidia/NV-Embed-v2".
|
||||
number_of_samples (int): Number of samples to use from the dataset. Default is -1 (use all samples).
|
||||
"""
|
||||
dataset_name: str = "hotpotqa"
|
||||
question_file: str = "hotpotqa"
|
||||
include_events: bool = False
|
||||
include_concept: bool = False
|
||||
reader_model_name: str = "meta-llama/Llama-2-7b-chat-hf"
|
||||
encoder_model_name: str = "nvidia/NV-Embed-v2"
|
||||
number_of_samples: int = -1 # Default to -1 to use all samples
|
||||
react_max_iterations: int = 5
|
||||
|
||||
|
||||
class RAGBenchmark:
|
||||
def __init__(self, config:BenchMarkConfig, logger:Logger = None):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.logging : bool = self.logger is not None
|
||||
|
||||
def load_encoder_model(self, encoder_model_name, **kwargs):
|
||||
if encoder_model_name == "nvidia/NV-Embed-v2":
|
||||
sentence_encoder = AutoModel.from_pretrained("nvidia/NV-Embed-v2", **kwargs)
|
||||
return NvEmbed(sentence_encoder)
|
||||
else:
|
||||
sentence_encoder = SentenceTransformer(encoder_model_name, **kwargs)
|
||||
return SentenceEmbedding(sentence_encoder)
|
||||
|
||||
def run(self, retrievers:List[BaseRetriever],
|
||||
llm_generator:LLMGenerator,
|
||||
use_react: bool = False):
|
||||
qa_judge = QAJudger()
|
||||
if use_react:
|
||||
react_agent = ReAct(llm_generator=llm_generator)
|
||||
result_list = []
|
||||
with open(self.config.question_file, "r") as f:
|
||||
data = json.load(f)
|
||||
print(f"Data loaded from {self.config.question_file}")
|
||||
if self.config.number_of_samples > 0:
|
||||
data = data[:self.config.number_of_samples]
|
||||
print(f"Using only the first {self.config.number_of_samples} samples from the dataset")
|
||||
for sample in tqdm(data):
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
gold_file_ids = []
|
||||
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||||
for fact in sample["supporting_facts"]:
|
||||
gold_file_ids.append(fact[0])
|
||||
elif self.config.dataset_name == "musique":
|
||||
for paragraph in sample["paragraphs"]:
|
||||
if paragraph["is_supporting"]:
|
||||
gold_file_ids.append(paragraph["paragraph_text"])
|
||||
else:
|
||||
print("Dataset not supported")
|
||||
continue
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"gold_file_ids": gold_file_ids,
|
||||
}
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"Question: {question}")
|
||||
for retriever in retrievers:
|
||||
if use_react:
|
||||
# Use RAG with ReAct
|
||||
llm_generated_answer, search_history = react_agent.generate_with_rag_react(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
max_iterations=self.config.react_max_iterations,
|
||||
max_new_tokens=2048,
|
||||
logger=self.logger
|
||||
)
|
||||
self.logger.info(f"Search history: {search_history}")
|
||||
self.logger.info(f"Final answer: {llm_generated_answer}")
|
||||
# Store search history in results
|
||||
result[f"{retriever.__class__.__name__}_search_history"] = search_history
|
||||
|
||||
# Extract all retrieved contexts from search history
|
||||
all_contexts = []
|
||||
for _, action, observation in search_history:
|
||||
if "search" in action.lower() or "look up" in action.lower():
|
||||
all_contexts.append(observation)
|
||||
|
||||
sorted_context = "\n".join(all_contexts)
|
||||
sorted_context_ids = [] # We don't track IDs in ReAct mode
|
||||
else:
|
||||
# Original RAG implementation
|
||||
sorted_context, sorted_context_ids = retriever.retrieve(question, topN=5)
|
||||
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
retrieved_context = ". ".join(sorted_context)
|
||||
llm_generated_answer = llm_generator.generate_with_context_kg(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
retrieved_context = "\n".join(sorted_context)
|
||||
llm_generated_answer = llm_generator.generate_with_context(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"{retriever.__class__.__name__} retrieved passages: {sorted_context}")
|
||||
self.logger.info(f"{retriever.__class__.__name__} generated answer: {llm_generated_answer}")
|
||||
|
||||
short_answer = qa_judge.split_answer(llm_generated_answer)
|
||||
em, f1 = qa_judge.judge(short_answer, answer)
|
||||
|
||||
result[f"{retriever.__class__.__name__ }_em"] = em
|
||||
result[f"{retriever.__class__.__name__ }_f1"] = f1
|
||||
result[f"{retriever.__class__.__name__ }_passages"] = sorted_context
|
||||
if not use_react:
|
||||
result[f"{retriever.__class__.__name__ }_id"] = sorted_context_ids
|
||||
result[f"{retriever.__class__.__name__ }_generated_answer"] = llm_generated_answer
|
||||
result[f"{retriever.__class__.__name__ }short_answer"] = short_answer
|
||||
|
||||
# Calculate recall
|
||||
if not use_react: # Only calculate recall for non-ReAct mode
|
||||
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||||
recall_2, recall_5 = qa_judge.recall(sorted_context_ids, gold_file_ids)
|
||||
elif self.config.dataset_name == "musique":
|
||||
recall_2, recall_5 = qa_judge.recall(sorted_context, gold_file_ids)
|
||||
|
||||
result[f"{retriever.__class__.__name__ }_recall@2"] = recall_2
|
||||
result[f"{retriever.__class__.__name__ }_recall@5"] = recall_5
|
||||
|
||||
result_list.append(result)
|
||||
|
||||
self.save_results(result_list, [retriever.__class__.__name__ for retriever in retrievers])
|
||||
|
||||
|
||||
def save_results(self, result_list, retriever_names:List[str]):
|
||||
current_time = datetime.now()
|
||||
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
|
||||
|
||||
dataset_name = self.config.dataset_name
|
||||
include_events = self.config.include_events
|
||||
include_concept = self.config.include_concept
|
||||
encoder_model_name = self.config.encoder_model_name
|
||||
reader_model_name = self.config.reader_model_name
|
||||
|
||||
# use last part of model name as identifier
|
||||
if "/" in encoder_model_name:
|
||||
encoder_model_name = encoder_model_name.split("/")[-1]
|
||||
if "/" in reader_model_name:
|
||||
reader_model_name = reader_model_name.split("/")[-1]
|
||||
|
||||
summary_file = f"./result/{dataset_name}/summary_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||||
if not os.path.exists(os.path.dirname(summary_file)):
|
||||
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
|
||||
|
||||
result_dir = f"./result/{dataset_name}/result_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||||
if not os.path.exists(os.path.dirname(result_dir)):
|
||||
os.makedirs(os.path.dirname(result_dir), exist_ok=True)
|
||||
|
||||
summary_dict = self.calculate_summary(result_list, retriever_names)
|
||||
|
||||
with open(summary_file, "w") as f_summary:
|
||||
json.dump(summary_dict, f_summary)
|
||||
f_summary.write("\n")
|
||||
|
||||
with open(result_dir, "w") as f:
|
||||
for result in result_list:
|
||||
json.dump(result, f)
|
||||
f.write("\n")
|
||||
|
||||
def calculate_summary(self, result_list, method):
|
||||
summary_dict = {}
|
||||
for retriever_name in method:
|
||||
if not all(f"{retriever_name}_em" in result for result in result_list):
|
||||
raise ValueError(f"Missing {retriever_name}_em in results")
|
||||
if not all(f"{retriever_name}_f1" in result for result in result_list):
|
||||
raise ValueError(f"Missing {retriever_name}_f1 in results")
|
||||
|
||||
average_em = sum([result[f"{retriever_name}_em"] for result in result_list]) / len(result_list)
|
||||
average_f1 = sum([result[f"{retriever_name}_f1"] for result in result_list]) / len(result_list)
|
||||
|
||||
# Only calculate recall metrics if they exist in the results
|
||||
if all(f"{retriever_name}_recall@2" in result for result in result_list):
|
||||
average_recall_2 = sum([result[f"{retriever_name}_recall@2"] for result in result_list]) / len(result_list)
|
||||
average_recall_5 = sum([result[f"{retriever_name}_recall@5"] for result in result_list]) / len(result_list)
|
||||
summary_dict.update({
|
||||
f"{retriever_name}_average_f1": average_f1,
|
||||
f"{retriever_name}_average_em": average_em,
|
||||
f"{retriever_name}_average_recall@2": average_recall_2,
|
||||
f"{retriever_name}_average_recall@5": average_recall_5,
|
||||
})
|
||||
else:
|
||||
# For ReAct mode where recall metrics don't exist
|
||||
summary_dict.update({
|
||||
f"{retriever_name}_average_f1": average_f1,
|
||||
f"{retriever_name}_average_em": average_em,
|
||||
})
|
||||
|
||||
return summary_dict
|
||||
158
AIEC-RAG/atlas_rag/evaluation/evaluation.py
Normal file
158
AIEC-RAG/atlas_rag/evaluation/evaluation.py
Normal file
@ -0,0 +1,158 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Tuple
|
||||
import argparse
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
class QAJudger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def split_answer(self, generated_text):
|
||||
if "Answer:" in generated_text:
|
||||
generated_text = generated_text.split("Answer:")[-1]
|
||||
elif "answer:" in generated_text:
|
||||
generated_text = generated_text.split("answer:")[-1]
|
||||
# if answer is none
|
||||
if not generated_text:
|
||||
return "none"
|
||||
return generated_text
|
||||
|
||||
def normalize_answer(self, answer: str) -> str:
|
||||
"""Direct copy of the normalization from QAExactMatch/QAF1Score"""
|
||||
# Lowercase and normalize whitespace
|
||||
answer = answer.lower()
|
||||
# Replace hyphens with spaces
|
||||
answer = answer.replace('-', ' ')
|
||||
# Remove all other punctuation
|
||||
answer = re.sub(r'[^\w\s]', '', answer)
|
||||
# Standardize whitespace
|
||||
return ' '.join(answer.split())
|
||||
|
||||
def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]:
|
||||
"""Direct port of the original scoring logic"""
|
||||
# Extract answer from generated text
|
||||
pred_answer = self.split_answer(generated_text)
|
||||
|
||||
# Normalize both answers
|
||||
pred_norm = self.normalize_answer(pred_answer)
|
||||
ref_norm = self.normalize_answer(reference_text)
|
||||
|
||||
# Exact match calculation
|
||||
em = 1 if pred_norm == ref_norm else 0
|
||||
|
||||
# F1 calculation (direct port from QAF1Score)
|
||||
pred_tokens = pred_norm.split()
|
||||
ref_tokens = ref_norm.split()
|
||||
|
||||
common = Counter(pred_tokens) & Counter(ref_tokens)
|
||||
num_same = sum(common.values())
|
||||
|
||||
if num_same == 0:
|
||||
return em, 0.0
|
||||
|
||||
precision = num_same / len(pred_tokens) if pred_tokens else 0.0
|
||||
recall = num_same / len(ref_tokens) if ref_tokens else 0.0
|
||||
|
||||
if (precision + recall) == 0:
|
||||
f1 = 0.0
|
||||
else:
|
||||
f1 = 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
return em, f1
|
||||
|
||||
def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float:
|
||||
"""Calculates recall at k based on the top k retrieved texts."""
|
||||
successful_retrievals = 0
|
||||
|
||||
# Limit the retrieved texts to the top k entries
|
||||
limited_retrieved_text = retrieved_text[:k]
|
||||
|
||||
for ref_text in reference_text:
|
||||
for ret_text in limited_retrieved_text:
|
||||
if ref_text in ret_text:
|
||||
successful_retrievals += 1
|
||||
break
|
||||
|
||||
recall = successful_retrievals / len(reference_text) if reference_text else 0
|
||||
return recall
|
||||
|
||||
# recall for 1 answer
|
||||
def recall(self, retrieved_text: list, reference_text: list) -> dict:
|
||||
"""Calculates recall values at different k levels."""
|
||||
recall_values = {
|
||||
'recall@2': self.recall_at_k(retrieved_text, reference_text, 2),
|
||||
'recall@5': self.recall_at_k(retrieved_text, reference_text, 5),
|
||||
}
|
||||
return recall_values['recall@2'], recall_values['recall@5']
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argument_parser = argparse.ArgumentParser()
|
||||
argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.")
|
||||
args = argument_parser.parse_args()
|
||||
|
||||
# Initialize the QAJudger
|
||||
llm_judge = QAJudger()
|
||||
|
||||
# Load results from the JSON file
|
||||
result_list = []
|
||||
with open(args.file_path, 'r') as file:
|
||||
for line in file:
|
||||
if line.strip(): # Make sure the line is not empty
|
||||
try:
|
||||
result = json.loads(line.strip())
|
||||
result_list.append(result)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
|
||||
# Debugging output to inspect the loaded data structure
|
||||
# print("Loaded data structure:", result_list)
|
||||
|
||||
# Evaluate each entry in result_list
|
||||
for result in tqdm(result_list):
|
||||
if isinstance(result, dict): # Ensure each result is a dictionary
|
||||
question = result["question"]
|
||||
answer = result["answer"]
|
||||
|
||||
# Evaluate generated answers with Hippo and Hippo2
|
||||
hippo_generated_answer = result["hippo_generated_answer"]
|
||||
hippo2_generated_answer = result["hippo2_generated_answer"]
|
||||
|
||||
# Split and judge the answers
|
||||
hippo_short_answer = llm_judge.split_answer(hippo_generated_answer)
|
||||
hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer)
|
||||
|
||||
hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer)
|
||||
hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer)
|
||||
|
||||
# Store the scores back in the result dictionary
|
||||
result["hippo_em"] = hippo_em
|
||||
result["hippo_f1"] = hippo_f1
|
||||
result["hippo2_em"] = hippo2_em
|
||||
result["hippo2_f1"] = hippo2_f1
|
||||
|
||||
result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids'])
|
||||
result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids'])
|
||||
|
||||
# Calculate averages
|
||||
average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list)
|
||||
average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list)
|
||||
|
||||
average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list)
|
||||
average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list)
|
||||
|
||||
average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list)
|
||||
average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list)
|
||||
average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list)
|
||||
average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list)
|
||||
# Output the averages
|
||||
print(f"Average EM with Hippo: {average_em_with_hippo:.4f}")
|
||||
print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}")
|
||||
print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}")
|
||||
print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}")
|
||||
|
||||
print(f"Average Recall@2: {average_recall2:.4f}")
|
||||
print(f"Average Recall@5: {average_recall5:.4f}")
|
||||
print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}")
|
||||
print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")
|
||||
Reference in New Issue
Block a user