Spaces:
Runtime error
Runtime error
OFA-OCR
/
fairseq
/examples
/latent_depth
/latent_depth_src
/multilingual_translation_latent_depth.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from fairseq.tasks import register_task | |
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask | |
from fairseq.utils import safe_hasattr | |
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss | |
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask): | |
"""A task for multiple translation with latent depth. | |
See `"Deep Transformer with Latent Depth" | |
(Li et al., 2020) <https://arxiv.org/pdf/2009.13102.pdf>`_. | |
""" | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
# fmt: off | |
MultilingualTranslationTask.add_args(parser) | |
parser.add_argument('--encoder-latent-layer', action='store_true', help='latent layer selection in encoder') | |
parser.add_argument('--decoder-latent-layer', action='store_true', help='latent layer selection in decoder') | |
parser.add_argument('--target-layers', default=-1, type=int, | |
help='number of effective layers to learn; -1 means no constraint') | |
parser.add_argument('--sparsity-weight', default=0.0, type=float, | |
help='weight for sparsity loss') | |
parser.add_argument('--share-weight', default=0.0, type=float, | |
help='weight for sharing loss') | |
parser.add_argument('--soft-update', default=1, type=int, | |
help='number of updates with soft sampling') | |
parser.add_argument('--anneal-updates', default=1, type=int, | |
help='number of updates to anneal the KL loss weight') | |
parser.add_argument('--prior', default="uniform", type=str, | |
help='prior used for computing KL loss') | |
# fmt: on | |
def __init__(self, args, dicts, training): | |
super().__init__(args, dicts, training) | |
self.src_langs, self.tgt_langs = zip( | |
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs] | |
) | |
if self.training and self.encoder_latent_layer: | |
assert self.args.share_encoders | |
if self.training and self.decoder_latent_layer: | |
assert self.args.share_decoders | |
if training or self.encoder_latent_layer or self.decoder_latent_layer: | |
self.lang_pairs = args.lang_pairs | |
else: | |
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] | |
self.eval_lang_pairs = self.lang_pairs | |
self.model_lang_pairs = self.lang_pairs | |
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer): | |
self.kl_loss = LatentLayersKLLoss(self.args) | |
self.sparsity_loss = LatentLayersSparsityLoss(self.args) | |
def _per_lang_pair_train_loss( | |
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad | |
): | |
src, tgt = lang_pair.split("-") | |
if self.encoder_latent_layer: | |
src_lang_idx = self.src_lang_idx_dict[src] | |
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx) | |
model.models[lang_pair].encoder.layer_select.hard_select = ( | |
update_num > self.args.soft_update | |
) | |
if self.decoder_latent_layer: | |
tgt_lang_idx = self.tgt_lang_idx_dict[tgt] | |
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) | |
model.models[lang_pair].decoder.layer_select.hard_select = ( | |
update_num > self.args.soft_update | |
) | |
loss, sample_size, logging_output = criterion( | |
model.models[lang_pair], sample[lang_pair] | |
) | |
if self.encoder_latent_layer: | |
none_samples = sum( | |
1 if x is None else 0 | |
for x in model.models[lang_pair].encoder.layer_select.layer_samples | |
) | |
if none_samples == 0 or self.args.prior != "agged_posterior": | |
loss += self.kl_loss( | |
model.models[lang_pair].encoder.layer_select.layer_samples, | |
src_lang_idx, | |
update_num, | |
sample_size, | |
) | |
if self.decoder_latent_layer: | |
none_samples = sum( | |
1 if x is None else 0 | |
for x in model.models[lang_pair].decoder.layer_select.layer_samples | |
) | |
if none_samples == 0 or self.args.prior != "agged_posterior": | |
loss += self.kl_loss( | |
model.models[lang_pair].decoder.layer_select.layer_samples, | |
tgt_lang_idx, | |
update_num, | |
sample_size, | |
) | |
if ignore_grad: | |
loss *= 0 | |
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num): | |
# need to retain the graph if sparsity loss needs to be added | |
loss.backward(retain_graph=True) | |
else: | |
optimizer.backward(loss) | |
return loss, sample_size, logging_output | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
agg_loss, agg_sample_size, agg_logging_output = super().train_step( | |
sample, model, criterion, optimizer, update_num, ignore_grad | |
) | |
# compute auxiliary loss from layere sparsity, based on all samples from all languages | |
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num): | |
sparsity_loss = 0 | |
if self.encoder_latent_layer: | |
sparsity_loss += self.sparsity_loss( | |
next( | |
iter(model.models.values()) | |
).encoder.layer_select.layer_samples, | |
update_num, | |
agg_sample_size, | |
) | |
if self.decoder_latent_layer: | |
sparsity_loss += self.sparsity_loss( | |
next( | |
iter(model.models.values()) | |
).decoder.layer_select.layer_samples, | |
update_num, | |
agg_sample_size, | |
) | |
if sparsity_loss > 0: | |
optimizer.backward(sparsity_loss) | |
return agg_loss, agg_sample_size, agg_logging_output | |
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): | |
src, tgt = lang_pair.split("-") | |
if self.encoder_latent_layer: | |
src_lang_idx = self.src_lang_idx_dict[src] | |
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx) | |
if self.decoder_latent_layer: | |
tgt_lang_idx = self.tgt_lang_idx_dict[tgt] | |
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx) | |
loss, sample_size, logging_output = criterion( | |
model.models[lang_pair], sample[lang_pair] | |
) | |
return loss, sample_size, logging_output | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
if self.encoder_latent_layer or self.decoder_latent_layer: | |
for model in models: | |
if self.encoder_latent_layer: | |
assert model.encoder.layer_select is not None | |
src_lang_idx = self.src_lang_idx_dict[self.args.source_lang] | |
model.encoder.set_lang_idx(src_lang_idx) | |
if self.decoder_latent_layer: | |
assert model.decoder.layer_select is not None | |
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang] | |
model.decoder.set_lang_idx(tgt_lang_idx) | |
return super().inference_step( | |
generator, models, sample, prefix_tokens, constraints | |
) | |
def encoder_latent_layer(self): | |
return ( | |
safe_hasattr(self.args, "encoder_latent_layer") | |
and self.args.encoder_latent_layer | |
) | |
def decoder_latent_layer(self): | |
return ( | |
safe_hasattr(self.args, "decoder_latent_layer") | |
and self.args.decoder_latent_layer | |
) | |
def src_lang_idx_dict(self): | |
return {lang: lang_idx for lang_idx, lang in enumerate(self.src_langs)} | |
def tgt_lang_idx_dict(self): | |
return {lang: lang_idx for lang_idx, lang in enumerate(self.tgt_langs)} | |