first commit

This commit is contained in:
闫旭隆
2025-09-24 09:29:12 +08:00
parent 6339cdebb9
commit 2308536f66
360 changed files with 136381 additions and 0 deletions

View File

@ -0,0 +1 @@
from .benchmark import BenchMarkConfig, RAGBenchmark

View File

@ -0,0 +1,236 @@
import os
import json
import numpy as np
from logging import Logger
from atlas_rag.retriever.base import BaseRetriever, BaseEdgeRetriever, BasePassageRetriever
from typing import List
from datetime import datetime
from transformers import AutoModel
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import torch
import torch.nn.functional as F
from atlas_rag.vectorstore.embedding_model import NvEmbed, SentenceEmbedding
from atlas_rag.llm_generator.llm_generator import LLMGenerator
from atlas_rag.evaluation.evaluation import QAJudger
from dataclasses import dataclass
from atlas_rag.llm_generator.prompt.react import ReAct
def normalize_embeddings(embeddings):
"""Normalize the embeddings to unit length (L2 norm)."""
if isinstance(embeddings, torch.Tensor):
# Handle PyTorch tensors
norm_emb = F.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
elif isinstance(embeddings, np.ndarray):
# Handle numpy arrays
norm_emb = F.normalize(torch.tensor(embeddings), p=2, dim=1).detach().cpu().numpy()
else:
raise TypeError(f"Unsupported input type: {type(embeddings)}. Must be torch.Tensor or np.ndarray")
return norm_emb
@dataclass
class BenchMarkConfig:
"""
Configuration class for benchmarking.
Attributes:
dataset_name (str): Name of the dataset. Default is "hotpotqa".
question_file (str): Path to the question file. Default is "hotpotqa".
graph_file (str): Path to the graph file. Default is "hotpotqa_concept.graphml".
include_events (bool): Whether to include events. Default is False.
include_concept (bool): Whether to include concepts. Default is False.
reader_model_name (str): Name of the reader model. Default is "meta-llama/Llama-2-7b-chat-hf".
encoder_model_name (str): Name of the encoder model. Default is "nvidia/NV-Embed-v2".
number_of_samples (int): Number of samples to use from the dataset. Default is -1 (use all samples).
"""
dataset_name: str = "hotpotqa"
question_file: str = "hotpotqa"
include_events: bool = False
include_concept: bool = False
reader_model_name: str = "meta-llama/Llama-2-7b-chat-hf"
encoder_model_name: str = "nvidia/NV-Embed-v2"
number_of_samples: int = -1 # Default to -1 to use all samples
react_max_iterations: int = 5
class RAGBenchmark:
def __init__(self, config:BenchMarkConfig, logger:Logger = None):
self.config = config
self.logger = logger
self.logging : bool = self.logger is not None
def load_encoder_model(self, encoder_model_name, **kwargs):
if encoder_model_name == "nvidia/NV-Embed-v2":
sentence_encoder = AutoModel.from_pretrained("nvidia/NV-Embed-v2", **kwargs)
return NvEmbed(sentence_encoder)
else:
sentence_encoder = SentenceTransformer(encoder_model_name, **kwargs)
return SentenceEmbedding(sentence_encoder)
def run(self, retrievers:List[BaseRetriever],
llm_generator:LLMGenerator,
use_react: bool = False):
qa_judge = QAJudger()
if use_react:
react_agent = ReAct(llm_generator=llm_generator)
result_list = []
with open(self.config.question_file, "r") as f:
data = json.load(f)
print(f"Data loaded from {self.config.question_file}")
if self.config.number_of_samples > 0:
data = data[:self.config.number_of_samples]
print(f"Using only the first {self.config.number_of_samples} samples from the dataset")
for sample in tqdm(data):
question = sample["question"]
answer = sample["answer"]
gold_file_ids = []
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
for fact in sample["supporting_facts"]:
gold_file_ids.append(fact[0])
elif self.config.dataset_name == "musique":
for paragraph in sample["paragraphs"]:
if paragraph["is_supporting"]:
gold_file_ids.append(paragraph["paragraph_text"])
else:
print("Dataset not supported")
continue
result = {
"question": question,
"answer": answer,
"gold_file_ids": gold_file_ids,
}
if self.logging:
self.logger.info(f"Question: {question}")
for retriever in retrievers:
if use_react:
# Use RAG with ReAct
llm_generated_answer, search_history = react_agent.generate_with_rag_react(
question=question,
retriever=retriever,
max_iterations=self.config.react_max_iterations,
max_new_tokens=2048,
logger=self.logger
)
self.logger.info(f"Search history: {search_history}")
self.logger.info(f"Final answer: {llm_generated_answer}")
# Store search history in results
result[f"{retriever.__class__.__name__}_search_history"] = search_history
# Extract all retrieved contexts from search history
all_contexts = []
for _, action, observation in search_history:
if "search" in action.lower() or "look up" in action.lower():
all_contexts.append(observation)
sorted_context = "\n".join(all_contexts)
sorted_context_ids = [] # We don't track IDs in ReAct mode
else:
# Original RAG implementation
sorted_context, sorted_context_ids = retriever.retrieve(question, topN=5)
if isinstance(retriever, BaseEdgeRetriever):
retrieved_context = ". ".join(sorted_context)
llm_generated_answer = llm_generator.generate_with_context_kg(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
elif isinstance(retriever, BasePassageRetriever):
retrieved_context = "\n".join(sorted_context)
llm_generated_answer = llm_generator.generate_with_context(question, retrieved_context, max_new_tokens=2048, temperature=0.5)
if self.logging:
self.logger.info(f"{retriever.__class__.__name__} retrieved passages: {sorted_context}")
self.logger.info(f"{retriever.__class__.__name__} generated answer: {llm_generated_answer}")
short_answer = qa_judge.split_answer(llm_generated_answer)
em, f1 = qa_judge.judge(short_answer, answer)
result[f"{retriever.__class__.__name__ }_em"] = em
result[f"{retriever.__class__.__name__ }_f1"] = f1
result[f"{retriever.__class__.__name__ }_passages"] = sorted_context
if not use_react:
result[f"{retriever.__class__.__name__ }_id"] = sorted_context_ids
result[f"{retriever.__class__.__name__ }_generated_answer"] = llm_generated_answer
result[f"{retriever.__class__.__name__ }short_answer"] = short_answer
# Calculate recall
if not use_react: # Only calculate recall for non-ReAct mode
if self.config.dataset_name in ("hotpotqa", "2wikimultihopqa"):
recall_2, recall_5 = qa_judge.recall(sorted_context_ids, gold_file_ids)
elif self.config.dataset_name == "musique":
recall_2, recall_5 = qa_judge.recall(sorted_context, gold_file_ids)
result[f"{retriever.__class__.__name__ }_recall@2"] = recall_2
result[f"{retriever.__class__.__name__ }_recall@5"] = recall_5
result_list.append(result)
self.save_results(result_list, [retriever.__class__.__name__ for retriever in retrievers])
def save_results(self, result_list, retriever_names:List[str]):
current_time = datetime.now()
formatted_time = current_time.strftime("%Y%m%d%H%M%S")
dataset_name = self.config.dataset_name
include_events = self.config.include_events
include_concept = self.config.include_concept
encoder_model_name = self.config.encoder_model_name
reader_model_name = self.config.reader_model_name
# use last part of model name as identifier
if "/" in encoder_model_name:
encoder_model_name = encoder_model_name.split("/")[-1]
if "/" in reader_model_name:
reader_model_name = reader_model_name.split("/")[-1]
summary_file = f"./result/{dataset_name}/summary_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
if not os.path.exists(os.path.dirname(summary_file)):
os.makedirs(os.path.dirname(summary_file), exist_ok=True)
result_dir = f"./result/{dataset_name}/result_{formatted_time}_event{include_events}_concept{include_concept}_{encoder_model_name}_{reader_model_name}.json"
if not os.path.exists(os.path.dirname(result_dir)):
os.makedirs(os.path.dirname(result_dir), exist_ok=True)
summary_dict = self.calculate_summary(result_list, retriever_names)
with open(summary_file, "w") as f_summary:
json.dump(summary_dict, f_summary)
f_summary.write("\n")
with open(result_dir, "w") as f:
for result in result_list:
json.dump(result, f)
f.write("\n")
def calculate_summary(self, result_list, method):
summary_dict = {}
for retriever_name in method:
if not all(f"{retriever_name}_em" in result for result in result_list):
raise ValueError(f"Missing {retriever_name}_em in results")
if not all(f"{retriever_name}_f1" in result for result in result_list):
raise ValueError(f"Missing {retriever_name}_f1 in results")
average_em = sum([result[f"{retriever_name}_em"] for result in result_list]) / len(result_list)
average_f1 = sum([result[f"{retriever_name}_f1"] for result in result_list]) / len(result_list)
# Only calculate recall metrics if they exist in the results
if all(f"{retriever_name}_recall@2" in result for result in result_list):
average_recall_2 = sum([result[f"{retriever_name}_recall@2"] for result in result_list]) / len(result_list)
average_recall_5 = sum([result[f"{retriever_name}_recall@5"] for result in result_list]) / len(result_list)
summary_dict.update({
f"{retriever_name}_average_f1": average_f1,
f"{retriever_name}_average_em": average_em,
f"{retriever_name}_average_recall@2": average_recall_2,
f"{retriever_name}_average_recall@5": average_recall_5,
})
else:
# For ReAct mode where recall metrics don't exist
summary_dict.update({
f"{retriever_name}_average_f1": average_f1,
f"{retriever_name}_average_em": average_em,
})
return summary_dict

View File

@ -0,0 +1,158 @@
import re
from collections import Counter
from typing import Tuple
import argparse
import json
from tqdm import tqdm
class QAJudger:
def __init__(self):
pass
def split_answer(self, generated_text):
if "Answer:" in generated_text:
generated_text = generated_text.split("Answer:")[-1]
elif "answer:" in generated_text:
generated_text = generated_text.split("answer:")[-1]
# if answer is none
if not generated_text:
return "none"
return generated_text
def normalize_answer(self, answer: str) -> str:
"""Direct copy of the normalization from QAExactMatch/QAF1Score"""
# Lowercase and normalize whitespace
answer = answer.lower()
# Replace hyphens with spaces
answer = answer.replace('-', ' ')
# Remove all other punctuation
answer = re.sub(r'[^\w\s]', '', answer)
# Standardize whitespace
return ' '.join(answer.split())
def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]:
"""Direct port of the original scoring logic"""
# Extract answer from generated text
pred_answer = self.split_answer(generated_text)
# Normalize both answers
pred_norm = self.normalize_answer(pred_answer)
ref_norm = self.normalize_answer(reference_text)
# Exact match calculation
em = 1 if pred_norm == ref_norm else 0
# F1 calculation (direct port from QAF1Score)
pred_tokens = pred_norm.split()
ref_tokens = ref_norm.split()
common = Counter(pred_tokens) & Counter(ref_tokens)
num_same = sum(common.values())
if num_same == 0:
return em, 0.0
precision = num_same / len(pred_tokens) if pred_tokens else 0.0
recall = num_same / len(ref_tokens) if ref_tokens else 0.0
if (precision + recall) == 0:
f1 = 0.0
else:
f1 = 2 * (precision * recall) / (precision + recall)
return em, f1
def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float:
"""Calculates recall at k based on the top k retrieved texts."""
successful_retrievals = 0
# Limit the retrieved texts to the top k entries
limited_retrieved_text = retrieved_text[:k]
for ref_text in reference_text:
for ret_text in limited_retrieved_text:
if ref_text in ret_text:
successful_retrievals += 1
break
recall = successful_retrievals / len(reference_text) if reference_text else 0
return recall
# recall for 1 answer
def recall(self, retrieved_text: list, reference_text: list) -> dict:
"""Calculates recall values at different k levels."""
recall_values = {
'recall@2': self.recall_at_k(retrieved_text, reference_text, 2),
'recall@5': self.recall_at_k(retrieved_text, reference_text, 5),
}
return recall_values['recall@2'], recall_values['recall@5']
if __name__ == "__main__":
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.")
args = argument_parser.parse_args()
# Initialize the QAJudger
llm_judge = QAJudger()
# Load results from the JSON file
result_list = []
with open(args.file_path, 'r') as file:
for line in file:
if line.strip(): # Make sure the line is not empty
try:
result = json.loads(line.strip())
result_list.append(result)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
# Debugging output to inspect the loaded data structure
# print("Loaded data structure:", result_list)
# Evaluate each entry in result_list
for result in tqdm(result_list):
if isinstance(result, dict): # Ensure each result is a dictionary
question = result["question"]
answer = result["answer"]
# Evaluate generated answers with Hippo and Hippo2
hippo_generated_answer = result["hippo_generated_answer"]
hippo2_generated_answer = result["hippo2_generated_answer"]
# Split and judge the answers
hippo_short_answer = llm_judge.split_answer(hippo_generated_answer)
hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer)
hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer)
hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer)
# Store the scores back in the result dictionary
result["hippo_em"] = hippo_em
result["hippo_f1"] = hippo_f1
result["hippo2_em"] = hippo2_em
result["hippo2_f1"] = hippo2_f1
result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids'])
result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids'])
# Calculate averages
average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list)
average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list)
average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list)
average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list)
average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list)
average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list)
average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list)
average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list)
# Output the averages
print(f"Average EM with Hippo: {average_em_with_hippo:.4f}")
print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}")
print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}")
print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}")
print(f"Average Recall@2: {average_recall2:.4f}")
print(f"Average Recall@5: {average_recall5:.4f}")
print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}")
print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")