Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from fairseq.models import register_model, register_model_architecture | |
from fairseq.models.nat import NATransformerModel | |
def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1): | |
# s: input batch | |
# V: vocabulary size | |
rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device) | |
choices = torch.rand(size=s.size(), device=s.device) | |
choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1) | |
replace = choices < beta / 3 | |
repeat = (choices >= beta / 3) & (choices < beta * 2 / 3) | |
swap = (choices >= beta * 2 / 3) & (choices < beta) | |
safe = choices >= beta | |
for i in range(s.size(1) - 1): | |
rand_word = rand_words[:, i] | |
next_word = s[:, i + 1] | |
self_word = s[:, i] | |
replace_i = replace[:, i] | |
swap_i = swap[:, i] & (next_word != 3) | |
repeat_i = repeat[:, i] & (next_word != 3) | |
safe_i = safe[:, i] | ((next_word == 3) & (~replace_i)) | |
s[:, i] = ( | |
self_word * (safe_i | repeat_i).long() | |
+ next_word * swap_i.long() | |
+ rand_word * replace_i.long() | |
) | |
s[:, i + 1] = ( | |
next_word * (safe_i | replace_i).long() | |
+ self_word * (swap_i | repeat_i).long() | |
) | |
return s | |
def gumbel_noise(input, TINY=1e-8): | |
return ( | |
input.new_zeros(*input.size()) | |
.uniform_() | |
.add_(TINY) | |
.log_() | |
.neg_() | |
.add_(TINY) | |
.log_() | |
.neg_() | |
) | |
class IterNATransformerModel(NATransformerModel): | |
def add_args(parser): | |
NATransformerModel.add_args(parser) | |
parser.add_argument( | |
"--train-step", | |
type=int, | |
help="number of refinement iterations during training", | |
) | |
parser.add_argument( | |
"--dae-ratio", | |
type=float, | |
help="the probability of switching to the denoising auto-encoder loss", | |
) | |
parser.add_argument( | |
"--stochastic-approx", | |
action="store_true", | |
help="sampling from the decoder as the inputs for next iteration", | |
) | |
def build_model(cls, args, task): | |
model = super().build_model(args, task) | |
model.train_step = getattr(args, "train_step", 4) | |
model.dae_ratio = getattr(args, "dae_ratio", 0.5) | |
model.stochastic_approx = getattr(args, "stochastic_approx", False) | |
return model | |
def forward( | |
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs | |
): | |
B, T = prev_output_tokens.size() | |
# encoding | |
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) | |
# length prediction | |
length_out = self.decoder.forward_length( | |
normalize=False, encoder_out=encoder_out | |
) | |
length_tgt = self.decoder.forward_length_prediction( | |
length_out, encoder_out, tgt_tokens | |
) | |
# decoding | |
word_ins_outs, word_ins_tgts, word_ins_masks = [], [], [] | |
for t in range(self.train_step): | |
word_ins_out = self.decoder( | |
normalize=False, | |
prev_output_tokens=prev_output_tokens, | |
encoder_out=encoder_out, | |
step=t, | |
) | |
word_ins_tgt = tgt_tokens | |
word_ins_mask = word_ins_tgt.ne(self.pad) | |
word_ins_outs.append(word_ins_out) | |
word_ins_tgts.append(word_ins_tgt) | |
word_ins_masks.append(word_ins_mask) | |
if t < (self.train_step - 1): | |
# prediction for next iteration | |
if self.stochastic_approx: | |
word_ins_prediction = ( | |
word_ins_out + gumbel_noise(word_ins_out) | |
).max(-1)[1] | |
else: | |
word_ins_prediction = word_ins_out.max(-1)[1] | |
prev_output_tokens = prev_output_tokens.masked_scatter( | |
word_ins_mask, word_ins_prediction[word_ins_mask] | |
) | |
if self.dae_ratio > 0: | |
# we do not perform denoising for the first iteration | |
corrputed = ( | |
torch.rand(size=(B,), device=prev_output_tokens.device) | |
< self.dae_ratio | |
) | |
corrputed_tokens = _sequential_poisoning( | |
tgt_tokens[corrputed], | |
len(self.tgt_dict), | |
0.33, | |
self.bos, | |
self.eos, | |
self.pad, | |
) | |
prev_output_tokens[corrputed] = corrputed_tokens | |
# concat everything | |
word_ins_out = torch.cat(word_ins_outs, 0) | |
word_ins_tgt = torch.cat(word_ins_tgts, 0) | |
word_ins_mask = torch.cat(word_ins_masks, 0) | |
return { | |
"word_ins": { | |
"out": word_ins_out, | |
"tgt": word_ins_tgt, | |
"mask": word_ins_mask, | |
"ls": self.args.label_smoothing, | |
"nll_loss": True, | |
}, | |
"length": { | |
"out": length_out, | |
"tgt": length_tgt, | |
"factor": self.decoder.length_loss_factor, | |
}, | |
} | |
def inat_base_architecture(args): | |
args.encoder_embed_path = getattr(args, "encoder_embed_path", None) | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) | |
args.encoder_layers = getattr(args, "encoder_layers", 6) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) | |
args.decoder_embed_path = getattr(args, "decoder_embed_path", None) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | |
args.attention_dropout = getattr(args, "attention_dropout", 0.0) | |
args.activation_dropout = getattr(args, "activation_dropout", 0.0) | |
args.activation_fn = getattr(args, "activation_fn", "relu") | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", False | |
) | |
args.share_all_embeddings = getattr(args, "share_all_embeddings", False) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.apply_bert_init = getattr(args, "apply_bert_init", False) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) | |
# --- special arguments --- | |
args.sg_length_pred = getattr(args, "sg_length_pred", False) | |
args.pred_length_offset = getattr(args, "pred_length_offset", False) | |
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) | |
args.ngram_predictor = getattr(args, "ngram_predictor", 1) | |
args.src_embedding_copy = getattr(args, "src_embedding_copy", False) | |
args.train_step = getattr(args, "train_step", 4) | |
args.dae_ratio = getattr(args, "dae_ratio", 0.5) | |
args.stochastic_approx = getattr(args, "stochastic_approx", False) | |
def iter_nat_wmt_en_de(args): | |
inat_base_architecture(args) | |