File size: 5,318 Bytes
62b1ca5
 
 
 
 
 
 
 
 
d3c75c1
62b1ca5
 
 
 
 
 
 
 
 
 
 
ee0f30d
 
 
62b1ca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e4adc1
 
62b1ca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from dataset import ParallelTextReader
from torch.utils.data import DataLoader
from accelerate.memory_utils import find_executable_batch_size
from datasets import load_metric
from tqdm import tqdm
import torch
import json
import argparse
import numpy as np
import os


def get_dataloader(pred_path: str, gold_path: str, batch_size: int):
    """
    Returns a dataloader for the given files.
    """

    def collate_fn(batch):
        return list(map(list, zip(*batch)))

    reader = ParallelTextReader(pred_path=pred_path, gold_path=gold_path)
    dataloader = DataLoader(
        reader, batch_size=batch_size, collate_fn=collate_fn, num_workers=0
    )
    return dataloader


def eval_files(
    pred_path: str,
    gold_path: str,
    bert_score_model: str,
    starting_batch_size: int = 128,
    output_path: str = None,
):
    """
    Evaluates the given files.
    """
    if torch.cuda.is_available():
        device = "cuda:0"
        print("We will use a GPU to calculate BertScore.")
    else:
        device = "cpu"
        print(
            f"We will use the CPU to calculate BertScore, this can be slow for large datasets."
        )

    dataloader = get_dataloader(pred_path, gold_path, starting_batch_size)
    print("Loading sacrebleu...")
    sacrebleu = load_metric("sacrebleu")
    print("Loading rouge...")
    rouge = load_metric("rouge")
    print("Loading bleu...")
    bleu = load_metric("bleu")
    print("Loading meteor...")
    meteor = load_metric("meteor")
    print("Loading ter...")
    ter = load_metric("ter")
    print("Loading BertScore...")
    bert_score = load_metric("bertscore")

    with tqdm(total=len(dataloader.dataset), desc="Loading data...") as pbar:
        for predictions, references in dataloader:
            sacrebleu.add_batch(predictions=predictions, references=references)
            rouge.add_batch(predictions=predictions, references=references)
            bleu.add_batch(
                predictions=[p.split() for p in predictions],
                references=[[r[0].split()] for r in references],
            )
            meteor.add_batch(predictions=predictions, references=references)
            ter.add_batch(predictions=predictions, references=references)
            bert_score.add_batch(predictions=predictions, references=references)
            pbar.update(len(predictions))

    result_dictionary = {}
    print(f"Computing sacrebleu")
    result_dictionary["sacrebleu"] = sacrebleu.compute()
    print(f"Computing rouge score")
    result_dictionary["rouge"] = rouge.compute()
    print(f"Computing bleu score")
    result_dictionary["bleu"] = bleu.compute()
    print(f"Computing meteor score")
    result_dictionary["meteor"] = meteor.compute()
    print(f"Computing ter score")
    result_dictionary["ter"] = ter.compute()

    @find_executable_batch_size(starting_batch_size=starting_batch_size)
    def inference(batch_size):
        nonlocal bert_score, bert_score_model
        print(f"Computing bert score with batch size {batch_size} on {device}")
        results = bert_score.compute(
            model_type=bert_score_model,
            batch_size=batch_size,
            device=device,
            use_fast_tokenizer=True,
        )

        results["precision"] = np.average(results["precision"])
        results["recall"] = np.average(results["recall"])
        results["f1"] = np.average(results["f1"])

        return results

    result_dictionary["bert_score"] = inference()

    if output_path is not None:
        if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
            os.makedirs(os.path.abspath(os.path.dirname(output_path)))
        with open(output_path, "w") as f:
            json.dump(result_dictionary, f, indent=4)

    print(f"Results: {json.dumps(result_dictionary,indent=4)}")

    return result_dictionary


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the translation evaluation experiments"
    )
    parser.add_argument(
        "--pred_path",
        type=str,
        required=True,
        help="Path to a txt file containing the predicted sentences.",
    )

    parser.add_argument(
        "--gold_path",
        type=str,
        required=True,
        help="Path to a txt file containing the gold sentences.",
    )

    parser.add_argument(
        "--starting_batch_size",
        type=int,
        default=64,
        help="Starting batch size for BertScore, we will automatically reduce it if we find an OOM error.",
    )

    parser.add_argument(
        "--output_path",
        type=str,
        default=None,
        help="Path to a json file to save the results. If not given, the results will be printed to the console.",
    )

    parser.add_argument(
        "--bert_score_model",
        type=str,
        default="microsoft/deberta-xlarge-mnli",
        help="Model to use for BertScore. See: https://github.com/huggingface/datasets/tree/master/metrics/bertscore"
        "and https://github.com/Tiiiger/bert_score for more details.",
    )

    args = parser.parse_args()

    eval_files(
        pred_path=args.pred_path,
        gold_path=args.gold_path,
        starting_batch_size=args.starting_batch_size,
        output_path=args.output_path,
        bert_score_model=args.bert_score_model,
    )