Spaces:
Running
Running
from dataset import ParallelTextReader | |
from torch.utils.data import DataLoader | |
from accelerate import find_executable_batch_size | |
from evaluate import load | |
from tqdm import tqdm | |
import torch | |
import json | |
import argparse | |
import numpy as np | |
import os | |
def get_dataloader(pred_path: str, gold_path: str, batch_size: int): | |
""" | |
Returns a dataloader for the given files. | |
""" | |
def collate_fn(batch): | |
return list(map(list, zip(*batch))) | |
reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path) | |
dataloader = DataLoader( | |
reader, batch_size=batch_size, collate_fn=collate_fn, num_workers=0 | |
) | |
return dataloader | |
def eval_files( | |
pred_path: str, | |
gold_path: str, | |
bert_score_model: str, | |
starting_batch_size: int = 128, | |
output_path: str = None, | |
): | |
""" | |
Evaluates the given files. | |
""" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
print("We will use a GPU to calculate BertScore.") | |
else: | |
device = "cpu" | |
print( | |
f"We will use the CPU to calculate BertScore, this can be slow for large datasets." | |
) | |
dataloader = get_dataloader(pred_path, gold_path, starting_batch_size) | |
print("Loading sacrebleu...") | |
sacrebleu = load("sacrebleu") | |
print("Loading rouge...") | |
rouge = load("rouge") | |
print("Loading bleu...") | |
bleu = load("bleu") | |
print("Loading meteor...") | |
meteor = load("meteor") | |
print("Loading ter...") | |
ter = load("ter") | |
print("Loading BertScore...") | |
bert_score = load("bertscore") | |
with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar: | |
for predictions, references in dataloader: | |
sacrebleu.add_batch(predictions=predictions, references=references) | |
rouge.add_batch(predictions=predictions, references=references) | |
bleu.add_batch(predictions=predictions, references=references) | |
meteor.add_batch(predictions=predictions, references=references) | |
ter.add_batch(predictions=predictions, references=references) | |
bert_score.add_batch(predictions=predictions, references=references) | |
pbar.update(len(predictions)) | |
result_dictionary = {"path": pred_path} | |
print("Computing sacrebleu") | |
result_dictionary["sacrebleu"] = sacrebleu.compute() | |
print("Computing rouge score") | |
result_dictionary["rouge"] = rouge.compute( | |
use_aggregator=True, rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"] | |
) | |
print("Computing bleu score") | |
result_dictionary["bleu"] = bleu.compute() | |
print("Computing meteor score") | |
result_dictionary["meteor"] = meteor.compute() | |
print("Computing ter score") | |
result_dictionary["ter"] = ter.compute() | |
def inference(batch_size): | |
nonlocal bert_score, bert_score_model | |
print(f"Computing bert score with batch size {batch_size} on {device}") | |
results = bert_score.compute( | |
model_type=bert_score_model, | |
batch_size=batch_size, | |
device=device, | |
use_fast_tokenizer=True, | |
) | |
results["precision"] = np.average(results["precision"]) | |
results["recall"] = np.average(results["recall"]) | |
results["f1"] = np.average(results["f1"]) | |
return results | |
result_dictionary["bert_score"] = inference() | |
if output_path is not None: | |
if not os.path.exists(os.path.abspath(os.path.dirname(output_path))): | |
os.makedirs(os.path.abspath(os.path.dirname(output_path))) | |
with open(output_path, "w") as f: | |
json.dump(result_dictionary, f, indent=4) | |
print(f"Results: {json.dumps(result_dictionary,indent=4)}") | |
return result_dictionary | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Run the translation evaluation experiments" | |
) | |
parser.add_argument( | |
"--pred_path", | |
type=str, | |
required=True, | |
help="Path to a txt file containing the predicted sentences.", | |
) | |
parser.add_argument( | |
"--gold_path", | |
type=str, | |
required=True, | |
help="Path to a txt file containing the gold sentences.", | |
) | |
parser.add_argument( | |
"--starting_batch_size", | |
type=int, | |
default=64, | |
help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.", | |
) | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
default=None, | |
help="Path to a json file to save the results. If not given, the results will be printed to the console.", | |
) | |
parser.add_argument( | |
"--bert_score_model", | |
type=str, | |
default="microsoft/deberta-xlarge-mnli", | |
help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore" | |
"and https://github.com/Tiiiger/bert_score for more details.", | |
) | |
args = parser.parse_args() | |
eval_files( | |
pred_path=args.pred_path, | |
gold_path=args.gold_path, | |
starting_batch_size=args.starting_batch_size, | |
output_path=args.output_path, | |
bert_score_model=args.bert_score_model, | |
) | |