first commit
This commit is contained in:
1
AIEC-RAG/atlas_rag/evaluation/__init__.py
Normal file
1
AIEC-RAG/atlas_rag/evaluation/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .benchmark import BenchMarkConfig, RAGBenchmark
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
236
AIEC-RAG/atlas_rag/evaluation/benchmark.py
Normal file
236
AIEC-RAG/atlas_rag/evaluation/benchmark.py
Normal file
@ -0,0 +1,236 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from logging import Logger
|
||||
from atlas_rag.retriever.base import BaseRetriever, BaseEdgeRetriever, BasePassageRetriever
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from transformers import AutoModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from atlas_rag.vectorstore.embedding_model import NvEmbed, SentenceEmbedding
|
||||
from atlas_rag.llm_generator.llm_generator import LLMGenerator
|
||||
from atlas_rag.evaluation.evaluation import QAJudger
|
||||
from dataclasses import dataclass
|
||||
from atlas_rag.llm_generator.prompt.react import ReAct
|
||||
|
||||
|
||||
def normalize_embeddings(embeddings):
|
||||
"""Normalize the embeddings to unit length (L2 norm)."""
|
||||
if isinstance(embeddings, torch.Tensor):
|
||||
# Handle PyTorch tensors
|
||||
norm_emb = F.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
elif isinstance(embeddings, np.ndarray):
|
||||
# Handle numpy arrays
|
||||
norm_emb = F.normalize(torch.tensor(embeddings), p=2, dim=1).detach().cpu().numpy()
|
||||
else:
|
||||
raise TypeError(f"Unsupported input type: {type(embeddings)}. Must be torch.Tensor or np.ndarray")
|
||||
|
||||
return norm_emb
|
||||
|
||||
@dataclass
|
||||
class BenchMarkConfig:
|
||||
"""
|
||||
Configuration class for benchmarking.
|
||||
|
||||
Attributes:
|
||||
dataset_name (str): Name of the dataset. Default is "hotpotqa".
|
||||
question_file (str): Path to the question file. Default is "hotpotqa".
|
||||
graph_file (str): Path to the graph file. Default is "hotpotqa_concept.graphml".
|
||||
include_events (bool): Whether to include events. Default is False.
|
||||
include_concept (bool): Whether to include concepts. Default is False.
|
||||
reader_model_name (str): Name of the reader model. Default is "meta-llama/Llama-2-7b-chat-hf".
|
||||
encoder_model_name (str): Name of the encoder model. Default is "nvidia/NV-Embed-v2".
|
||||
number_of_samples (int): Number of samples to use from the dataset. Default is -1 (use all samples).
|
||||
"""
|
||||
dataset_name: str = "hotpotqa"
|
||||
question_file: str = "hotpotqa"
|
||||
include_events: bool = False
|
||||
include_concept: bool = False
|
||||
reader_model_name: str = "meta-llama/Llama-2-7b-chat-hf"
|
||||
encoder_model_name: str = "nvidia/NV-Embed-v2"
|
||||
number_of_samples: int = -1 # Default to -1 to use all samples
|
||||
react_max_iterations: int = 5
|
||||
|
||||
|
||||
class RAGBenchmark:
|
||||
def __init__(self, config:BenchMarkConfig, logger:Logger = None):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.logging : bool = self.logger is not None
|
||||
|
||||
def load_encoder_model(self, encoder_model_name, **kwargs):
|
||||
if encoder_model_name == "nvidia/NV-Embed-v2":
|
||||
sentence_encoder = AutoModel.from_pretrained("nvidia/NV-Embed-v2", **kwargs)
|
||||
return NvEmbed(sentence_encoder)
|
||||
else:
|
||||
sentence_encoder = SentenceTransformer(encoder_model_name, **kwargs)
|
||||
return SentenceEmbedding(sentence_encoder)
|
||||
|
||||
def run(self, retrievers:List[BaseRetriever],
|
||||
llm_generator:LLMGenerator,
|
||||
use_react: bool = False):
|
||||
qa_judge = QAJudger()
|
||||
if use_react:
|
||||
react_agent = ReAct(llm_generator=llm_generator)
|
||||
result_list = []
|
||||
with open(self.config.question_file, "r") as f:
|
||||
data = json.load(f)
|
||||
print(f"Data loaded from {self.config.question_file}")
|
||||
if self.config.number_of_samples > 0:
|
||||
data = data[:self.config.number_of_samples]
|
||||
print(f"Using only the first {self.config.number_of_samples} samples from the dataset")
|
||||
for sample in tqdm(data):
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
|
||||
gold_file_ids = []
|
||||
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||||
for fact in sample["supporting_facts"]:
|
||||
gold_file_ids.append(fact[0])
|
||||
elif self.config.dataset_name == "musique":
|
||||
for paragraph in sample["paragraphs"]:
|
||||
if paragraph["is_supporting"]:
|
||||
gold_file_ids.append(paragraph["paragraph_text"])
|
||||
else:
|
||||
print("Dataset not supported")
|
||||
continue
|
||||
|
||||
result = {
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"gold_file_ids": gold_file_ids,
|
||||
}
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"Question: {question}")
|
||||
for retriever in retrievers:
|
||||
if use_react:
|
||||
# Use RAG with ReAct
|
||||
llm_generated_answer, search_history = react_agent.generate_with_rag_react(
|
||||
question=question,
|
||||
retriever=retriever,
|
||||
max_iterations=self.config.react_max_iterations,
|
||||
max_new_tokens=2048,
|
||||
logger=self.logger
|
||||
)
|
||||
self.logger.info(f"Search history: {search_history}")
|
||||
self.logger.info(f"Final answer: {llm_generated_answer}")
|
||||
# Store search history in results
|
||||
result[f"{retriever.__class__.__name__}_search_history"] = search_history
|
||||
|
||||
# Extract all retrieved contexts from search history
|
||||
all_contexts = []
|
||||
for _, action, observation in search_history:
|
||||
if "search" in action.lower() or "look up" in action.lower():
|
||||
all_contexts.append(observation)
|
||||
|
||||
sorted_context = "\n".join(all_contexts)
|
||||
sorted_context_ids = [] # We don't track IDs in ReAct mode
|
||||
else:
|
||||
# Original RAG implementation
|
||||
sorted_context, sorted_context_ids = retriever.retrieve(question, topN=5)
|
||||
|
||||
if isinstance(retriever, BaseEdgeRetriever):
|
||||
retrieved_context = ". ".join(sorted_context)
|
||||
llm_generated_answer = llm_generator.generate_with_context_kg(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||||
elif isinstance(retriever, BasePassageRetriever):
|
||||
retrieved_context = "\n".join(sorted_context)
|
||||
llm_generated_answer = llm_generator.generate_with_context(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
|
||||
|
||||
if self.logging:
|
||||
self.logger.info(f"{retriever.__class__.__name__} retrieved passages: {sorted_context}")
|
||||
self.logger.info(f"{retriever.__class__.__name__} generated answer: {llm_generated_answer}")
|
||||
|
||||
short_answer = qa_judge.split_answer(llm_generated_answer)
|
||||
em, f1 = qa_judge.judge(short_answer, answer)
|
||||
|
||||
result[f"{retriever.__class__.__name__ }_em"] = em
|
||||
result[f"{retriever.__class__.__name__ }_f1"] = f1
|
||||
result[f"{retriever.__class__.__name__ }_passages"] = sorted_context
|
||||
if not use_react:
|
||||
result[f"{retriever.__class__.__name__ }_id"] = sorted_context_ids
|
||||
result[f"{retriever.__class__.__name__ }_generated_answer"] = llm_generated_answer
|
||||
result[f"{retriever.__class__.__name__ }short_answer"] = short_answer
|
||||
|
||||
# Calculate recall
|
||||
if not use_react: # Only calculate recall for non-ReAct mode
|
||||
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
|
||||
recall_2, recall_5 = qa_judge.recall(sorted_context_ids, gold_file_ids)
|
||||
elif self.config.dataset_name == "musique":
|
||||
recall_2, recall_5 = qa_judge.recall(sorted_context, gold_file_ids)
|
||||
|
||||
result[f"{retriever.__class__.__name__ }_recall@2"] = recall_2
|
||||
result[f"{retriever.__class__.__name__ }_recall@5"] = recall_5
|
||||
|
||||
result_list.append(result)
|
||||
|
||||
self.save_results(result_list, [retriever.__class__.__name__ for retriever in retrievers])
|
||||
|
||||
|
||||
def save_results(self, result_list, retriever_names:List[str]):
|
||||
current_time = datetime.now()
|
||||
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
|
||||
|
||||
dataset_name = self.config.dataset_name
|
||||
include_events = self.config.include_events
|
||||
include_concept = self.config.include_concept
|
||||
encoder_model_name = self.config.encoder_model_name
|
||||
reader_model_name = self.config.reader_model_name
|
||||
|
||||
# use last part of model name as identifier
|
||||
if "/" in encoder_model_name:
|
||||
encoder_model_name = encoder_model_name.split("/")[-1]
|
||||
if "/" in reader_model_name:
|
||||
reader_model_name = reader_model_name.split("/")[-1]
|
||||
|
||||
summary_file = f"./result/{dataset_name}/summary_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||||
if not os.path.exists(os.path.dirname(summary_file)):
|
||||
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
|
||||
|
||||
result_dir = f"./result/{dataset_name}/result_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
|
||||
if not os.path.exists(os.path.dirname(result_dir)):
|
||||
os.makedirs(os.path.dirname(result_dir), exist_ok=True)
|
||||
|
||||
summary_dict = self.calculate_summary(result_list, retriever_names)
|
||||
|
||||
with open(summary_file, "w") as f_summary:
|
||||
json.dump(summary_dict, f_summary)
|
||||
f_summary.write("\n")
|
||||
|
||||
with open(result_dir, "w") as f:
|
||||
for result in result_list:
|
||||
json.dump(result, f)
|
||||
f.write("\n")
|
||||
|
||||
def calculate_summary(self, result_list, method):
|
||||
summary_dict = {}
|
||||
for retriever_name in method:
|
||||
if not all(f"{retriever_name}_em" in result for result in result_list):
|
||||
raise ValueError(f"Missing {retriever_name}_em in results")
|
||||
if not all(f"{retriever_name}_f1" in result for result in result_list):
|
||||
raise ValueError(f"Missing {retriever_name}_f1 in results")
|
||||
|
||||
average_em = sum([result[f"{retriever_name}_em"] for result in result_list]) / len(result_list)
|
||||
average_f1 = sum([result[f"{retriever_name}_f1"] for result in result_list]) / len(result_list)
|
||||
|
||||
# Only calculate recall metrics if they exist in the results
|
||||
if all(f"{retriever_name}_recall@2" in result for result in result_list):
|
||||
average_recall_2 = sum([result[f"{retriever_name}_recall@2"] for result in result_list]) / len(result_list)
|
||||
average_recall_5 = sum([result[f"{retriever_name}_recall@5"] for result in result_list]) / len(result_list)
|
||||
summary_dict.update({
|
||||
f"{retriever_name}_average_f1": average_f1,
|
||||
f"{retriever_name}_average_em": average_em,
|
||||
f"{retriever_name}_average_recall@2": average_recall_2,
|
||||
f"{retriever_name}_average_recall@5": average_recall_5,
|
||||
})
|
||||
else:
|
||||
# For ReAct mode where recall metrics don't exist
|
||||
summary_dict.update({
|
||||
f"{retriever_name}_average_f1": average_f1,
|
||||
f"{retriever_name}_average_em": average_em,
|
||||
})
|
||||
|
||||
return summary_dict
|
||||
158
AIEC-RAG/atlas_rag/evaluation/evaluation.py
Normal file
158
AIEC-RAG/atlas_rag/evaluation/evaluation.py
Normal file
@ -0,0 +1,158 @@
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Tuple
|
||||
import argparse
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
class QAJudger:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def split_answer(self, generated_text):
|
||||
if "Answer:" in generated_text:
|
||||
generated_text = generated_text.split("Answer:")[-1]
|
||||
elif "answer:" in generated_text:
|
||||
generated_text = generated_text.split("answer:")[-1]
|
||||
# if answer is none
|
||||
if not generated_text:
|
||||
return "none"
|
||||
return generated_text
|
||||
|
||||
def normalize_answer(self, answer: str) -> str:
|
||||
"""Direct copy of the normalization from QAExactMatch/QAF1Score"""
|
||||
# Lowercase and normalize whitespace
|
||||
answer = answer.lower()
|
||||
# Replace hyphens with spaces
|
||||
answer = answer.replace('-', ' ')
|
||||
# Remove all other punctuation
|
||||
answer = re.sub(r'[^\w\s]', '', answer)
|
||||
# Standardize whitespace
|
||||
return ' '.join(answer.split())
|
||||
|
||||
def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]:
|
||||
"""Direct port of the original scoring logic"""
|
||||
# Extract answer from generated text
|
||||
pred_answer = self.split_answer(generated_text)
|
||||
|
||||
# Normalize both answers
|
||||
pred_norm = self.normalize_answer(pred_answer)
|
||||
ref_norm = self.normalize_answer(reference_text)
|
||||
|
||||
# Exact match calculation
|
||||
em = 1 if pred_norm == ref_norm else 0
|
||||
|
||||
# F1 calculation (direct port from QAF1Score)
|
||||
pred_tokens = pred_norm.split()
|
||||
ref_tokens = ref_norm.split()
|
||||
|
||||
common = Counter(pred_tokens) & Counter(ref_tokens)
|
||||
num_same = sum(common.values())
|
||||
|
||||
if num_same == 0:
|
||||
return em, 0.0
|
||||
|
||||
precision = num_same / len(pred_tokens) if pred_tokens else 0.0
|
||||
recall = num_same / len(ref_tokens) if ref_tokens else 0.0
|
||||
|
||||
if (precision + recall) == 0:
|
||||
f1 = 0.0
|
||||
else:
|
||||
f1 = 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
return em, f1
|
||||
|
||||
def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float:
|
||||
"""Calculates recall at k based on the top k retrieved texts."""
|
||||
successful_retrievals = 0
|
||||
|
||||
# Limit the retrieved texts to the top k entries
|
||||
limited_retrieved_text = retrieved_text[:k]
|
||||
|
||||
for ref_text in reference_text:
|
||||
for ret_text in limited_retrieved_text:
|
||||
if ref_text in ret_text:
|
||||
successful_retrievals += 1
|
||||
break
|
||||
|
||||
recall = successful_retrievals / len(reference_text) if reference_text else 0
|
||||
return recall
|
||||
|
||||
# recall for 1 answer
|
||||
def recall(self, retrieved_text: list, reference_text: list) -> dict:
|
||||
"""Calculates recall values at different k levels."""
|
||||
recall_values = {
|
||||
'recall@2': self.recall_at_k(retrieved_text, reference_text, 2),
|
||||
'recall@5': self.recall_at_k(retrieved_text, reference_text, 5),
|
||||
}
|
||||
return recall_values['recall@2'], recall_values['recall@5']
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argument_parser = argparse.ArgumentParser()
|
||||
argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.")
|
||||
args = argument_parser.parse_args()
|
||||
|
||||
# Initialize the QAJudger
|
||||
llm_judge = QAJudger()
|
||||
|
||||
# Load results from the JSON file
|
||||
result_list = []
|
||||
with open(args.file_path, 'r') as file:
|
||||
for line in file:
|
||||
if line.strip(): # Make sure the line is not empty
|
||||
try:
|
||||
result = json.loads(line.strip())
|
||||
result_list.append(result)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON: {e}")
|
||||
|
||||
# Debugging output to inspect the loaded data structure
|
||||
# print("Loaded data structure:", result_list)
|
||||
|
||||
# Evaluate each entry in result_list
|
||||
for result in tqdm(result_list):
|
||||
if isinstance(result, dict): # Ensure each result is a dictionary
|
||||
question = result["question"]
|
||||
answer = result["answer"]
|
||||
|
||||
# Evaluate generated answers with Hippo and Hippo2
|
||||
hippo_generated_answer = result["hippo_generated_answer"]
|
||||
hippo2_generated_answer = result["hippo2_generated_answer"]
|
||||
|
||||
# Split and judge the answers
|
||||
hippo_short_answer = llm_judge.split_answer(hippo_generated_answer)
|
||||
hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer)
|
||||
|
||||
hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer)
|
||||
hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer)
|
||||
|
||||
# Store the scores back in the result dictionary
|
||||
result["hippo_em"] = hippo_em
|
||||
result["hippo_f1"] = hippo_f1
|
||||
result["hippo2_em"] = hippo2_em
|
||||
result["hippo2_f1"] = hippo2_f1
|
||||
|
||||
result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids'])
|
||||
result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids'])
|
||||
|
||||
# Calculate averages
|
||||
average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list)
|
||||
average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list)
|
||||
|
||||
average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list)
|
||||
average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list)
|
||||
|
||||
average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list)
|
||||
average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list)
|
||||
average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list)
|
||||
average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list)
|
||||
# Output the averages
|
||||
print(f"Average EM with Hippo: {average_em_with_hippo:.4f}")
|
||||
print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}")
|
||||
print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}")
|
||||
print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}")
|
||||
|
||||
print(f"Average Recall@2: {average_recall2:.4f}")
|
||||
print(f"Average Recall@5: {average_recall5:.4f}")
|
||||
print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}")
|
||||
print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")
|
||||
Reference in New Issue
Block a user