Spaces:
Runtime error
Runtime error
File size: 3,157 Bytes
fc67275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
try:
from fairseq.model_parallel.megatron.mpu.cross_entropy import (
vocab_parallel_cross_entropy,
)
has_megatron_submodule = True
except (ImportError, ModuleNotFoundError):
has_megatron_submodule = False
@register_criterion("vocab_parallel_cross_entropy")
class VocabParallelCrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg):
super().__init__(task)
self.sentence_avg = sentence_avg
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample["net_input"])
target = sample["target"]
loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
loss = (loss * (target != self.padding_idx)).sum()
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data) if reduce else loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
|