|
from dataset import ParallelTextReader |
|
from torch.utils.data import DataLoader |
|
from accelerate.memory_utils import find_executable_batch_size |
|
from datasets import load_metric |
|
from tqdm import tqdm |
|
import torch |
|
import json |
|
import argparse |
|
import numpy as np |
|
|
|
|
|
def get_dataloader(pred_path: str, gold_path: str, batch_size: int): |
|
""" |
|
Returns a dataloader for the given files. |
|
""" |
|
|
|
def collate_fn(batch): |
|
return list(map(list, zip(*batch))) |
|
|
|
reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path) |
|
dataloader = DataLoader(reader, batch_size=batch_size, collate_fn=collate_fn) |
|
return dataloader |
|
|
|
|
|
def eval_files( |
|
pred_path: str, |
|
gold_path: str, |
|
bert_score_model: str, |
|
starting_batch_size: int = 128, |
|
output_path: str = None, |
|
): |
|
""" |
|
Evaluates the given files. |
|
""" |
|
if torch.cuda.is_available(): |
|
device = "cuda:0" |
|
print("We will use a GPU to calculate BertScore.") |
|
else: |
|
device = "cpu" |
|
print( |
|
f"We will use the CPU to calculate BertScore, this can be slow for large datasets." |
|
) |
|
|
|
dataloader = get_dataloader(pred_path, gold_path, starting_batch_size) |
|
print("Loading sacrebleu...") |
|
sacrebleu = load_metric("sacrebleu") |
|
print("Loading rouge...") |
|
rouge = load_metric("rouge") |
|
print("Loading bleu...") |
|
bleu = load_metric("bleu") |
|
print("Loading meteor...") |
|
meteor = load_metric("meteor") |
|
print("Loading ter...") |
|
ter = load_metric("ter") |
|
print("Loading BertScore...") |
|
bert_score = load_metric("bertscore") |
|
|
|
with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar: |
|
for predictions, references in dataloader: |
|
sacrebleu.add_batch(predictions=predictions, references=references) |
|
rouge.add_batch(predictions=predictions, references=references) |
|
bleu.add_batch( |
|
predictions=[p.split() for p in predictions], |
|
references=[[r[0].split()] for r in references], |
|
) |
|
meteor.add_batch(predictions=predictions, references=references) |
|
ter.add_batch(predictions=predictions, references=references) |
|
bert_score.add_batch(predictions=predictions, references=references) |
|
pbar.update(len(predictions)) |
|
|
|
result_dictionary = {} |
|
print(f"Computing sacrebleu") |
|
result_dictionary["sacrebleu"] = sacrebleu.compute() |
|
print(f"Computing rouge score") |
|
result_dictionary["rouge"] = rouge.compute() |
|
print(f"Computing bleu score") |
|
result_dictionary["bleu"] = bleu.compute() |
|
print(f"Computing meteor score") |
|
result_dictionary["meteor"] = meteor.compute() |
|
print(f"Computing ter score") |
|
result_dictionary["ter"] = ter.compute() |
|
|
|
@find_executable_batch_size(starting_batch_size=starting_batch_size) |
|
def inference(batch_size): |
|
nonlocal bert_score, bert_score_model |
|
print(f"Computing bert score with batch size {batch_size} on {device}") |
|
results = bert_score.compute( |
|
model_type=bert_score_model, |
|
batch_size=batch_size, |
|
device=device, |
|
use_fast_tokenizer=True, |
|
) |
|
|
|
results["precision"] = np.average(results["precision"]) |
|
results["recall"] = np.average(results["recall"]) |
|
results["f1"] = np.average(results["f1"]) |
|
|
|
return results |
|
|
|
result_dictionary["bert_score"] = inference() |
|
|
|
if output_path is not None: |
|
with open(output_path, "w") as f: |
|
json.dump(result_dictionary, f, indent=4) |
|
|
|
print(f"Results: {json.dumps(result_dictionary,indent=4)}") |
|
|
|
return result_dictionary |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description="Run the translation evaluation experiments" |
|
) |
|
parser.add_argument( |
|
"--pred_path", |
|
type=str, |
|
required=True, |
|
help="Path to a txt file containing the predicted sentences.", |
|
) |
|
|
|
parser.add_argument( |
|
"--gold_path", |
|
type=str, |
|
required=True, |
|
help="Path to a txt file containing the gold sentences.", |
|
) |
|
|
|
parser.add_argument( |
|
"--starting_batch_size", |
|
type=int, |
|
default=64, |
|
help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.", |
|
) |
|
|
|
parser.add_argument( |
|
"--output_path", |
|
type=str, |
|
default=None, |
|
help="Path to a json file to save the results. If not given, the results will be printed to the console.", |
|
) |
|
|
|
parser.add_argument( |
|
"--bert_score_model", |
|
type=str, |
|
default="microsoft/deberta-xlarge-mnli", |
|
help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore" |
|
"and https://github.com/Tiiiger/bert_score for more details.", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
eval_files( |
|
pred_path=args.pred_path, |
|
gold_path=args.gold_path, |
|
starting_batch_size=args.starting_batch_size, |
|
output_path=args.output_path, |
|
bert_score_model=args.bert_score_model, |
|
) |
|
|