easy-translate / eval.py
Iker's picture
Implement evaluation
62b1ca5
raw
history blame
5.13 kB
from dataset import ParallelTextReader
from torch.utils.data import DataLoader
from accelerate.memory_utils import find_executable_batch_size
from datasets import load_metric
from tqdm import tqdm
import torch
import json
import argparse
import numpy as np
def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
"""
Returns a dataloader for the given files.
"""
def collate_fn(batch):
return list(map(list, zip(*batch)))
reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
dataloader = DataLoader(reader, batch_size=batch_size, collate_fn=collate_fn)
return dataloader
def eval_files(
pred_path: str,
gold_path: str,
bert_score_model: str,
starting_batch_size: int = 128,
output_path: str = None,
):
"""
Evaluates the given files.
"""
if torch.cuda.is_available():
device = "cuda:0"
print("We will use a GPU to calculate BertScore.")
else:
device = "cpu"
print(
f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
)
dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
print("Loading sacrebleu...")
sacrebleu = load_metric("sacrebleu")
print("Loading rouge...")
rouge = load_metric("rouge")
print("Loading bleu...")
bleu = load_metric("bleu")
print("Loading meteor...")
meteor = load_metric("meteor")
print("Loading ter...")
ter = load_metric("ter")
print("Loading BertScore...")
bert_score = load_metric("bertscore")
with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
for predictions, references in dataloader:
sacrebleu.add_batch(predictions=predictions, references=references)
rouge.add_batch(predictions=predictions, references=references)
bleu.add_batch(
predictions=[p.split() for p in predictions],
references=[[r[0].split()] for r in references],
)
meteor.add_batch(predictions=predictions, references=references)
ter.add_batch(predictions=predictions, references=references)
bert_score.add_batch(predictions=predictions, references=references)
pbar.update(len(predictions))
result_dictionary = {}
print(f"Computing sacrebleu")
result_dictionary["sacrebleu"] = sacrebleu.compute()
print(f"Computing rouge score")
result_dictionary["rouge"] = rouge.compute()
print(f"Computing bleu score")
result_dictionary["bleu"] = bleu.compute()
print(f"Computing meteor score")
result_dictionary["meteor"] = meteor.compute()
print(f"Computing ter score")
result_dictionary["ter"] = ter.compute()
@find_executable_batch_size(starting_batch_size=starting_batch_size)
def inference(batch_size):
nonlocal bert_score, bert_score_model
print(f"Computing bert score with batch size {batch_size} on {device}")
results = bert_score.compute(
model_type=bert_score_model,
batch_size=batch_size,
device=device,
use_fast_tokenizer=True,
)
results["precision"] = np.average(results["precision"])
results["recall"] = np.average(results["recall"])
results["f1"] = np.average(results["f1"])
return results
result_dictionary["bert_score"] = inference()
if output_path is not None:
with open(output_path, "w") as f:
json.dump(result_dictionary, f, indent=4)
print(f"Results: {json.dumps(result_dictionary,indent=4)}")
return result_dictionary
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the translation evaluation experiments"
)
parser.add_argument(
"--pred_path",
type=str,
required=True,
help="Path to a txt file containing the predicted sentences.",
)
parser.add_argument(
"--gold_path",
type=str,
required=True,
help="Path to a txt file containing the gold sentences.",
)
parser.add_argument(
"--starting_batch_size",
type=int,
default=64,
help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
)
parser.add_argument(
"--output_path",
type=str,
default=None,
help="Path to a json file to save the results. If not given, the results will be printed to the console.",
)
parser.add_argument(
"--bert_score_model",
type=str,
default="microsoft/deberta-xlarge-mnli",
help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
"and https://github.com/Tiiiger/bert_score for more details.",
)
args = parser.parse_args()
eval_files(
pred_path=args.pred_path,
gold_path=args.gold_path,
starting_batch_size=args.starting_batch_size,
output_path=args.output_path,
bert_score_model=args.bert_score_model,
)