Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import logging | |
import math | |
import torch | |
import torch.nn.functional as F | |
from fairseq import utils | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
class CrossEntropyWithAccCriterion(FairseqCriterion): | |
def __init__(self, task, sentence_avg): | |
super().__init__(task) | |
self.sentence_avg = sentence_avg | |
def compute_loss(self, model, net_output, target, reduction, log_probs): | |
# N, T -> N * T | |
target = target.view(-1) | |
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs) | |
if not hasattr(lprobs, "batch_first"): | |
logging.warning( | |
"ERROR: we need to know whether " | |
"batch first for the net output; " | |
"you need to set batch_first attribute for the return value of " | |
"model.get_normalized_probs. Now, we assume this is true, but " | |
"in the future, we will raise exception instead. " | |
) | |
batch_first = getattr(lprobs, "batch_first", True) | |
if not batch_first: | |
lprobs = lprobs.transpose(0, 1) | |
# N, T, D -> N * T, D | |
lprobs = lprobs.view(-1, lprobs.size(-1)) | |
loss = F.nll_loss( | |
lprobs, target, ignore_index=self.padding_idx, reduction=reduction | |
) | |
return lprobs, loss | |
def get_logging_output(self, sample, target, lprobs, loss): | |
target = target.view(-1) | |
mask = target != self.padding_idx | |
correct = torch.sum( | |
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask) | |
) | |
total = torch.sum(mask) | |
sample_size = ( | |
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] | |
) | |
logging_output = { | |
"loss": utils.item(loss.data), # * sample['ntokens'], | |
"ntokens": sample["ntokens"], | |
"nsentences": sample["target"].size(0), | |
"sample_size": sample_size, | |
"correct": utils.item(correct.data), | |
"total": utils.item(total.data), | |
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(), | |
} | |
return sample_size, logging_output | |
def forward(self, model, sample, reduction="sum", log_probs=True): | |
"""Computes the cross entropy with accuracy metric for the given sample. | |
This is similar to CrossEntropyCriterion in fairseq, but also | |
computes accuracy metrics as part of logging | |
Args: | |
logprobs (Torch.tensor) of shape N, T, D i.e. | |
batchsize, timesteps, dimensions | |
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps | |
Returns: | |
tuple: With three elements: | |
1) the loss | |
2) the sample size, which is used as the denominator for the gradient | |
3) logging outputs to display while training | |
TODO: | |
* Currently this Criterion will only work with LSTMEncoderModels or | |
FairseqModels which have decoder, or Models which return TorchTensor | |
as net_output. | |
We need to make a change to support all FairseqEncoder models. | |
""" | |
net_output = model(**sample["net_input"]) | |
target = model.get_targets(sample, net_output) | |
lprobs, loss = self.compute_loss( | |
model, net_output, target, reduction, log_probs | |
) | |
sample_size, logging_output = self.get_logging_output( | |
sample, target, lprobs, loss | |
) | |
return loss, sample_size, logging_output | |
def aggregate_logging_outputs(logging_outputs): | |
"""Aggregate logging outputs from data parallel training.""" | |
correct_sum = sum(log.get("correct", 0) for log in logging_outputs) | |
total_sum = sum(log.get("total", 0) for log in logging_outputs) | |
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | |
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) | |
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) | |
nframes = sum(log.get("nframes", 0) for log in logging_outputs) | |
agg_output = { | |
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0, | |
# if args.sentence_avg, then sample_size is nsentences, then loss | |
# is per-sentence loss; else sample_size is ntokens, the loss | |
# becomes per-output token loss | |
"ntokens": ntokens, | |
"nsentences": nsentences, | |
"nframes": nframes, | |
"sample_size": sample_size, | |
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0, | |
"correct": correct_sum, | |
"total": total_sum, | |
# total is the number of validate tokens | |
} | |
if sample_size != ntokens: | |
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2) | |
# loss: per output token loss | |
# nll_loss: per sentence loss | |
return agg_output | |