Spaces:
Runtime error
Runtime error
OFA-OCR-dedao-demo001
/
fairseq
/examples
/discriminative_reranking_nmt
/tasks
/discriminative_reranking_task.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass, field | |
import itertools | |
import logging | |
import os | |
import numpy as np | |
import torch | |
from fairseq import metrics | |
from fairseq.data import ( | |
ConcatDataset, | |
ConcatSentencesDataset, | |
data_utils, | |
Dictionary, | |
IdDataset, | |
indexed_dataset, | |
NestedDictionaryDataset, | |
NumSamplesDataset, | |
NumelDataset, | |
PrependTokenDataset, | |
RawLabelDataset, | |
RightPadDataset, | |
SortDataset, | |
TruncateDataset, | |
TokenBlockDataset, | |
) | |
from fairseq.dataclass import ChoiceEnum, FairseqDataclass | |
from fairseq.tasks import FairseqTask, register_task | |
from omegaconf import II, MISSING | |
EVAL_BLEU_ORDER = 4 | |
TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"]) | |
logger = logging.getLogger(__name__) | |
class DiscriminativeRerankingNMTConfig(FairseqDataclass): | |
data: str = field(default=MISSING, metadata={"help": "path to data directory"}) | |
num_data_splits: int = field( | |
default=1, metadata={"help": "total number of data splits"} | |
) | |
no_shuffle: bool = field( | |
default=False, metadata={"help": "do not shuffle training data"} | |
) | |
max_positions: int = field( | |
default=512, metadata={"help": "number of positional embeddings to learn"} | |
) | |
include_src: bool = field( | |
default=False, metadata={"help": "include source sentence"} | |
) | |
mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"}) | |
eval_target_metric: bool = field( | |
default=False, | |
metadata={"help": "evaluation with the target metric during validation"}, | |
) | |
target_metric: TARGET_METRIC_CHOICES = field( | |
default="bleu", metadata={"help": "name of the target metric to optimize for"} | |
) | |
train_subset: str = field( | |
default=II("dataset.train_subset"), | |
metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, | |
) | |
seed: int = field( | |
default=II("common.seed"), | |
metadata={"help": "pseudo random number generator seed"}, | |
) | |
class RerankerScorer(object): | |
"""Scores the target for a given (source (optional), target) input.""" | |
def __init__(self, args, mt_beam): | |
self.mt_beam = mt_beam | |
def generate(self, models, sample, **kwargs): | |
"""Score a batch of translations.""" | |
net_input = sample["net_input"] | |
assert len(models) == 1, "does not support model ensemble" | |
model = models[0] | |
bs = net_input["src_tokens"].shape[0] | |
assert ( | |
model.joint_classification == "none" or bs % self.mt_beam == 0 | |
), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})" | |
model.eval() | |
logits = model(**net_input) | |
batch_out = model.sentence_forward(logits, net_input["src_tokens"]) | |
if model.joint_classification == "sent": | |
batch_out = model.joint_forward( | |
batch_out.view(self.mt_beam, bs // self.mt_beam, -1) | |
) | |
scores = model.classification_forward( | |
batch_out.view(bs, 1, -1) | |
) # input: B x T x C | |
return scores | |
class DiscriminativeRerankingNMTTask(FairseqTask): | |
""" | |
Translation rerank task. | |
The input can be either (src, tgt) sentence pairs or tgt sentence only. | |
""" | |
cfg: DiscriminativeRerankingNMTConfig | |
def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None): | |
super().__init__(cfg) | |
self.dictionary = data_dictionary | |
self._max_positions = cfg.max_positions | |
# args.tokens_per_sample = self._max_positions | |
# self.num_classes = 1 # for model | |
def load_dictionary(cls, cfg, filename): | |
"""Load the dictionary from the filename""" | |
dictionary = Dictionary.load(filename) | |
dictionary.add_symbol("<mask>") # for loading pretrained XLMR model | |
return dictionary | |
def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs): | |
# load data dictionary (assume joint dictionary) | |
data_path = cfg.data | |
data_dict = cls.load_dictionary( | |
cfg, os.path.join(data_path, "input_src/dict.txt") | |
) | |
logger.info("[input] src dictionary: {} types".format(len(data_dict))) | |
return DiscriminativeRerankingNMTTask(cfg, data_dict) | |
def load_dataset(self, split, epoch=0, combine=False, **kwargs): | |
"""Load a given dataset split (e.g., train, valid, test).""" | |
if self.cfg.data.endswith("1"): | |
data_shard = (epoch - 1) % self.cfg.num_data_splits + 1 | |
data_path = self.cfg.data[:-1] + str(data_shard) | |
else: | |
data_path = self.cfg.data | |
def get_path(type, data_split): | |
return os.path.join(data_path, str(type), data_split) | |
def make_dataset(type, dictionary, data_split, combine): | |
split_path = get_path(type, data_split) | |
dataset = data_utils.load_indexed_dataset( | |
split_path, dictionary, combine=combine, | |
) | |
return dataset | |
def load_split(data_split, metric): | |
input_src = None | |
if self.cfg.include_src: | |
input_src = make_dataset( | |
"input_src", self.dictionary, data_split, combine=False | |
) | |
assert input_src is not None, "could not find dataset: {}".format( | |
get_path("input_src", data_split) | |
) | |
input_tgt = make_dataset( | |
"input_tgt", self.dictionary, data_split, combine=False | |
) | |
assert input_tgt is not None, "could not find dataset: {}".format( | |
get_path("input_tgt", data_split) | |
) | |
label_path = f"{get_path(metric, data_split)}.{metric}" | |
assert os.path.exists(label_path), f"could not find dataset: {label_path}" | |
np_labels = np.loadtxt(label_path) | |
if self.cfg.target_metric == "ter": | |
np_labels = -np_labels | |
label = RawLabelDataset(np_labels) | |
return input_src, input_tgt, label | |
src_datasets = [] | |
tgt_datasets = [] | |
label_datasets = [] | |
if split == self.cfg.train_subset: | |
for k in itertools.count(): | |
split_k = "train" + (str(k) if k > 0 else "") | |
prefix = os.path.join(data_path, "input_tgt", split_k) | |
if not indexed_dataset.dataset_exists(prefix, impl=None): | |
if k > 0: | |
break | |
else: | |
raise FileNotFoundError(f"Dataset not found: {prefix}") | |
input_src, input_tgt, label = load_split( | |
split_k, self.cfg.target_metric | |
) | |
src_datasets.append(input_src) | |
tgt_datasets.append(input_tgt) | |
label_datasets.append(label) | |
else: | |
input_src, input_tgt, label = load_split(split, self.cfg.target_metric) | |
src_datasets.append(input_src) | |
tgt_datasets.append(input_tgt) | |
label_datasets.append(label) | |
if len(tgt_datasets) == 1: | |
input_tgt, label = tgt_datasets[0], label_datasets[0] | |
if self.cfg.include_src: | |
input_src = src_datasets[0] | |
else: | |
input_tgt = ConcatDataset(tgt_datasets) | |
label = ConcatDataset(label_datasets) | |
if self.cfg.include_src: | |
input_src = ConcatDataset(src_datasets) | |
input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) | |
if self.cfg.include_src: | |
input_src = PrependTokenDataset(input_src, self.dictionary.bos()) | |
input_src = TruncateDataset(input_src, self.cfg.max_positions) | |
src_lengths = NumelDataset(input_src, reduce=False) | |
src_tokens = ConcatSentencesDataset(input_src, input_tgt) | |
else: | |
src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos()) | |
src_lengths = NumelDataset(src_tokens, reduce=False) | |
dataset = { | |
"id": IdDataset(), | |
"net_input": { | |
"src_tokens": RightPadDataset( | |
src_tokens, pad_idx=self.source_dictionary.pad(), | |
), | |
"src_lengths": src_lengths, | |
}, | |
"nsentences": NumSamplesDataset(), | |
"ntokens": NumelDataset(src_tokens, reduce=True), | |
"target": label, | |
} | |
dataset = NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) | |
assert len(dataset) % self.cfg.mt_beam == 0, ( | |
"dataset size (%d) is not a multiple of beam size (%d)" | |
% (len(dataset), self.cfg.mt_beam) | |
) | |
# no need to shuffle valid/test sets | |
if not self.cfg.no_shuffle and split == self.cfg.train_subset: | |
# need to keep all hypothese together | |
start_idx = np.arange(0, len(dataset), self.cfg.mt_beam) | |
with data_utils.numpy_seed(self.cfg.seed + epoch): | |
np.random.shuffle(start_idx) | |
idx = np.arange(0, self.cfg.mt_beam) | |
shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile( | |
start_idx, (self.cfg.mt_beam, 1) | |
).transpose().reshape(-1) | |
dataset = SortDataset(dataset, sort_order=[shuffle],) | |
logger.info(f"Loaded {split} with #samples: {len(dataset)}") | |
self.datasets[split] = dataset | |
return self.datasets[split] | |
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): | |
assert not self.cfg.include_src or len(src_tokens[0]) == 2 | |
input_src = None | |
if self.cfg.include_src: | |
input_src = TokenBlockDataset( | |
[t[0] for t in src_tokens], | |
[l[0] for l in src_lengths], | |
block_size=None, # ignored for "eos" break mode | |
pad=self.source_dictionary.pad(), | |
eos=self.source_dictionary.eos(), | |
break_mode="eos", | |
) | |
input_src = PrependTokenDataset(input_src, self.dictionary.bos()) | |
input_src = TruncateDataset(input_src, self.cfg.max_positions) | |
input_tgt = TokenBlockDataset( | |
[t[-1] for t in src_tokens], | |
[l[-1] for l in src_lengths], | |
block_size=None, # ignored for "eos" break mode | |
pad=self.source_dictionary.pad(), | |
eos=self.source_dictionary.eos(), | |
break_mode="eos", | |
) | |
input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions) | |
if self.cfg.include_src: | |
src_tokens = ConcatSentencesDataset(input_src, input_tgt) | |
src_lengths = NumelDataset(input_src, reduce=False) | |
else: | |
input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos()) | |
src_tokens = input_tgt | |
src_lengths = NumelDataset(src_tokens, reduce=False) | |
dataset = { | |
"id": IdDataset(), | |
"net_input": { | |
"src_tokens": RightPadDataset( | |
src_tokens, pad_idx=self.source_dictionary.pad(), | |
), | |
"src_lengths": src_lengths, | |
}, | |
"nsentences": NumSamplesDataset(), | |
"ntokens": NumelDataset(src_tokens, reduce=True), | |
} | |
return NestedDictionaryDataset(dataset, sizes=[src_tokens.sizes],) | |
def build_model(self, cfg: FairseqDataclass): | |
return super().build_model(cfg) | |
def build_generator(self, args): | |
return RerankerScorer(args, mt_beam=self.cfg.mt_beam) | |
def max_positions(self): | |
return self._max_positions | |
def source_dictionary(self): | |
return self.dictionary | |
def target_dictionary(self): | |
return self.dictionary | |
def create_dummy_batch(self, device): | |
dummy_target = ( | |
torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device) | |
if not self.cfg.eval_ter | |
else torch.zeros(self.cfg.mt_beam, 3).long().to(device) | |
) | |
return { | |
"id": torch.zeros(self.cfg.mt_beam, 1).long().to(device), | |
"net_input": { | |
"src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device), | |
"src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device), | |
}, | |
"nsentences": 0, | |
"ntokens": 0, | |
"target": dummy_target, | |
} | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
if ignore_grad and sample is None: | |
sample = self.create_dummy_batch(model.device) | |
return super().train_step( | |
sample, model, criterion, optimizer, update_num, ignore_grad | |
) | |
def valid_step(self, sample, model, criterion): | |
if sample is None: | |
sample = self.create_dummy_batch(model.device) | |
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) | |
if not self.cfg.eval_target_metric: | |
return loss, sample_size, logging_output | |
scores = logging_output["scores"] | |
if self.cfg.target_metric == "bleu": | |
assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, ( | |
"target does not contain enough information (" | |
+ str(sample["target"].shape[1]) | |
+ "for evaluating BLEU" | |
) | |
max_id = torch.argmax(scores, dim=1) | |
select_id = max_id + torch.arange( | |
0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam | |
).to(max_id.device) | |
bleu_data = sample["target"][select_id, 1:].sum(0).data | |
logging_output["_bleu_sys_len"] = bleu_data[0] | |
logging_output["_bleu_ref_len"] = bleu_data[1] | |
for i in range(EVAL_BLEU_ORDER): | |
logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i] | |
logging_output["_bleu_totals_" + str(i)] = bleu_data[ | |
2 + EVAL_BLEU_ORDER + i | |
] | |
elif self.cfg.target_metric == "ter": | |
assert sample["target"].shape[1] == 3, ( | |
"target does not contain enough information (" | |
+ str(sample["target"].shape[1]) | |
+ "for evaluating TER" | |
) | |
max_id = torch.argmax(scores, dim=1) | |
select_id = max_id + torch.arange( | |
0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam | |
).to(max_id.device) | |
ter_data = sample["target"][select_id, 1:].sum(0).data | |
logging_output["_ter_num_edits"] = -ter_data[0] | |
logging_output["_ter_ref_len"] = -ter_data[1] | |
return loss, sample_size, logging_output | |
def reduce_metrics(self, logging_outputs, criterion): | |
super().reduce_metrics(logging_outputs, criterion) | |
if not self.cfg.eval_target_metric: | |
return | |
def sum_logs(key): | |
return sum(log.get(key, 0) for log in logging_outputs) | |
if self.cfg.target_metric == "bleu": | |
counts, totals = [], [] | |
for i in range(EVAL_BLEU_ORDER): | |
counts.append(sum_logs("_bleu_counts_" + str(i))) | |
totals.append(sum_logs("_bleu_totals_" + str(i))) | |
if max(totals) > 0: | |
# log counts as numpy arrays -- log_scalar will sum them correctly | |
metrics.log_scalar("_bleu_counts", np.array(counts)) | |
metrics.log_scalar("_bleu_totals", np.array(totals)) | |
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) | |
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) | |
def compute_bleu(meters): | |
import inspect | |
import sacrebleu | |
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] | |
if "smooth_method" in fn_sig: | |
smooth = {"smooth_method": "exp"} | |
else: | |
smooth = {"smooth": "exp"} | |
bleu = sacrebleu.compute_bleu( | |
correct=meters["_bleu_counts"].sum, | |
total=meters["_bleu_totals"].sum, | |
sys_len=meters["_bleu_sys_len"].sum, | |
ref_len=meters["_bleu_ref_len"].sum, | |
**smooth, | |
) | |
return round(bleu.score, 2) | |
metrics.log_derived("bleu", compute_bleu) | |
elif self.cfg.target_metric == "ter": | |
num_edits = sum_logs("_ter_num_edits") | |
ref_len = sum_logs("_ter_ref_len") | |
if ref_len > 0: | |
metrics.log_scalar("_ter_num_edits", num_edits) | |
metrics.log_scalar("_ter_ref_len", ref_len) | |
def compute_ter(meters): | |
score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum | |
return round(score.item(), 2) | |
metrics.log_derived("ter", compute_ter) | |