Files
AIEC-new/AIEC-RAG/atlas_rag/evaluation/evaluation.py

158 lines
6.5 KiB
Python
Raw Normal View History

2025-10-17 09:31:28 +08:00
import re
from collections import Counter
from typing import Tuple
import argparse
import json
from tqdm import tqdm
class QAJudger:
def __init__(self):
pass
def split_answer(self, generated_text):
if "Answer:" in generated_text:
generated_text = generated_text.split("Answer:")[-1]
elif "answer:" in generated_text:
generated_text = generated_text.split("answer:")[-1]
# if answer is none
if not generated_text:
return "none"
return generated_text
def normalize_answer(self, answer: str) -> str:
"""Direct copy of the normalization from QAExactMatch/QAF1Score"""
# Lowercase and normalize whitespace
answer = answer.lower()
# Replace hyphens with spaces
answer = answer.replace('-', ' ')
# Remove all other punctuation
answer = re.sub(r'[^\w\s]', '', answer)
# Standardize whitespace
return ' '.join(answer.split())
def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]:
"""Direct port of the original scoring logic"""
# Extract answer from generated text
pred_answer = self.split_answer(generated_text)
# Normalize both answers
pred_norm = self.normalize_answer(pred_answer)
ref_norm = self.normalize_answer(reference_text)
# Exact match calculation
em = 1 if pred_norm == ref_norm else 0
# F1 calculation (direct port from QAF1Score)
pred_tokens = pred_norm.split()
ref_tokens = ref_norm.split()
common = Counter(pred_tokens) & Counter(ref_tokens)
num_same = sum(common.values())
if num_same == 0:
return em, 0.0
precision = num_same / len(pred_tokens) if pred_tokens else 0.0
recall = num_same / len(ref_tokens) if ref_tokens else 0.0
if (precision + recall) == 0:
f1 = 0.0
else:
f1 = 2 * (precision * recall) / (precision + recall)
return em, f1
def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float:
"""Calculates recall at k based on the top k retrieved texts."""
successful_retrievals = 0
# Limit the retrieved texts to the top k entries
limited_retrieved_text = retrieved_text[:k]
for ref_text in reference_text:
for ret_text in limited_retrieved_text:
if ref_text in ret_text:
successful_retrievals += 1
break
recall = successful_retrievals / len(reference_text) if reference_text else 0
return recall
# recall for 1 answer
def recall(self, retrieved_text: list, reference_text: list) -> dict:
"""Calculates recall values at different k levels."""
recall_values = {
'recall@2': self.recall_at_k(retrieved_text, reference_text, 2),
'recall@5': self.recall_at_k(retrieved_text, reference_text, 5),
}
return recall_values['recall@2'], recall_values['recall@5']
if __name__ == "__main__":
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.")
args = argument_parser.parse_args()
# Initialize the QAJudger
llm_judge = QAJudger()
# Load results from the JSON file
result_list = []
with open(args.file_path, 'r') as file:
for line in file:
if line.strip(): # Make sure the line is not empty
try:
result = json.loads(line.strip())
result_list.append(result)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
# Debugging output to inspect the loaded data structure
# print("Loaded data structure:", result_list)
# Evaluate each entry in result_list
for result in tqdm(result_list):
if isinstance(result, dict): # Ensure each result is a dictionary
question = result["question"]
answer = result["answer"]
# Evaluate generated answers with Hippo and Hippo2
hippo_generated_answer = result["hippo_generated_answer"]
hippo2_generated_answer = result["hippo2_generated_answer"]
# Split and judge the answers
hippo_short_answer = llm_judge.split_answer(hippo_generated_answer)
hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer)
hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer)
hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer)
# Store the scores back in the result dictionary
result["hippo_em"] = hippo_em
result["hippo_f1"] = hippo_f1
result["hippo2_em"] = hippo2_em
result["hippo2_f1"] = hippo2_f1
result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids'])
result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids'])
# Calculate averages
average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list)
average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list)
average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list)
average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list)
average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list)
average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list)
average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list)
average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list)
# Output the averages
print(f"Average EM with Hippo: {average_em_with_hippo:.4f}")
print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}")
print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}")
print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}")
print(f"Average Recall@2: {average_recall2:.4f}")
print(f"Average Recall@5: {average_recall5:.4f}")
print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}")
print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")