import re from collections import Counter from typing import Tuple import argparse import json from tqdm import tqdm class QAJudger: def __init__(self): pass def split_answer(self, generated_text): if "Answer:" in generated_text: generated_text = generated_text.split("Answer:")[-1] elif "answer:" in generated_text: generated_text = generated_text.split("answer:")[-1] # if answer is none if not generated_text: return "none" return generated_text def normalize_answer(self, answer: str) -> str: """Direct copy of the normalization from QAExactMatch/QAF1Score""" # Lowercase and normalize whitespace answer = answer.lower() # Replace hyphens with spaces answer = answer.replace('-', ' ') # Remove all other punctuation answer = re.sub(r'[^\w\s]', '', answer) # Standardize whitespace return ' '.join(answer.split()) def judge(self, generated_text: str, reference_text: str) -> Tuple[int, float]: """Direct port of the original scoring logic""" # Extract answer from generated text pred_answer = self.split_answer(generated_text) # Normalize both answers pred_norm = self.normalize_answer(pred_answer) ref_norm = self.normalize_answer(reference_text) # Exact match calculation em = 1 if pred_norm == ref_norm else 0 # F1 calculation (direct port from QAF1Score) pred_tokens = pred_norm.split() ref_tokens = ref_norm.split() common = Counter(pred_tokens) & Counter(ref_tokens) num_same = sum(common.values()) if num_same == 0: return em, 0.0 precision = num_same / len(pred_tokens) if pred_tokens else 0.0 recall = num_same / len(ref_tokens) if ref_tokens else 0.0 if (precision + recall) == 0: f1 = 0.0 else: f1 = 2 * (precision * recall) / (precision + recall) return em, f1 def recall_at_k(self, retrieved_text: list, reference_text: list, k: int) -> float: """Calculates recall at k based on the top k retrieved texts.""" successful_retrievals = 0 # Limit the retrieved texts to the top k entries limited_retrieved_text = retrieved_text[:k] for ref_text in reference_text: for ret_text in limited_retrieved_text: if ref_text in ret_text: successful_retrievals += 1 break recall = successful_retrievals / len(reference_text) if reference_text else 0 return recall # recall for 1 answer def recall(self, retrieved_text: list, reference_text: list) -> dict: """Calculates recall values at different k levels.""" recall_values = { 'recall@2': self.recall_at_k(retrieved_text, reference_text, 2), 'recall@5': self.recall_at_k(retrieved_text, reference_text, 5), } return recall_values['recall@2'], recall_values['recall@5'] if __name__ == "__main__": argument_parser = argparse.ArgumentParser() argument_parser.add_argument("--file_path", type=str, required=True, help="Path to the JSON file containing results.") args = argument_parser.parse_args() # Initialize the QAJudger llm_judge = QAJudger() # Load results from the JSON file result_list = [] with open(args.file_path, 'r') as file: for line in file: if line.strip(): # Make sure the line is not empty try: result = json.loads(line.strip()) result_list.append(result) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") # Debugging output to inspect the loaded data structure # print("Loaded data structure:", result_list) # Evaluate each entry in result_list for result in tqdm(result_list): if isinstance(result, dict): # Ensure each result is a dictionary question = result["question"] answer = result["answer"] # Evaluate generated answers with Hippo and Hippo2 hippo_generated_answer = result["hippo_generated_answer"] hippo2_generated_answer = result["hippo2_generated_answer"] # Split and judge the answers hippo_short_answer = llm_judge.split_answer(hippo_generated_answer) hippo_em, hippo_f1 = llm_judge.judge(hippo_short_answer, answer) hippo2_short_answer = llm_judge.split_answer(hippo2_generated_answer) hippo2_em, hippo2_f1 = llm_judge.judge(hippo2_short_answer, answer) # Store the scores back in the result dictionary result["hippo_em"] = hippo_em result["hippo_f1"] = hippo_f1 result["hippo2_em"] = hippo2_em result["hippo2_f1"] = hippo2_f1 result['recall@2'], result['recall@5'] = llm_judge.recall(result['hippo2_id'], result['gold_file_ids']) result['recall@2_hippo'], result['recall@5_hippo'] = llm_judge.recall(result['hippo_id'], result['gold_file_ids']) # Calculate averages average_em_with_hippo = sum(result["hippo_em"] for result in result_list) / len(result_list) average_em_with_hippo2 = sum(result["hippo2_em"] for result in result_list) / len(result_list) average_f1_with_hippo = sum(result["hippo_f1"] for result in result_list) / len(result_list) average_f1_with_hippo2 = sum(result["hippo2_f1"] for result in result_list) / len(result_list) average_recall2_with_hippo = sum(result['recall@2'] for result in result_list) / len(result_list) average_recall5_with_hippo = sum(result['recall@5'] for result in result_list) / len(result_list) average_recall2 = sum(result['recall@2_hippo'] for result in result_list) / len(result_list) average_recall5 = sum(result['recall@5_hippo'] for result in result_list) / len(result_list) # Output the averages print(f"Average EM with Hippo: {average_em_with_hippo:.4f}") print(f"Average EM with Hippo2: {average_em_with_hippo2:.4f}") print(f"Average F1 with Hippo: {average_f1_with_hippo:.4f}") print(f"Average F1 with Hippo2: {average_f1_with_hippo2:.4f}") print(f"Average Recall@2: {average_recall2:.4f}") print(f"Average Recall@5: {average_recall5:.4f}") print(f"Average Recall@2 with Hippo: {average_recall2_with_hippo:.4f}") print(f"Average Recall@5 with Hippo: {average_recall5_with_hippo:.4f}")