HUBERT / fairseq /models /nat /iterative_nonautoregressive_transformer.py
osanseviero's picture
Add repo
fc67275
raw
history blame
8.65 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nat import NATransformerModel
def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1):
# s: input batch
# V: vocabulary size
rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device)
choices = torch.rand(size=s.size(), device=s.device)
choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1)
replace = choices < beta / 3
repeat = (choices >= beta / 3) & (choices < beta * 2 / 3)
swap = (choices >= beta * 2 / 3) & (choices < beta)
safe = choices >= beta
for i in range(s.size(1) - 1):
rand_word = rand_words[:, i]
next_word = s[:, i + 1]
self_word = s[:, i]
replace_i = replace[:, i]
swap_i = swap[:, i] & (next_word != 3)
repeat_i = repeat[:, i] & (next_word != 3)
safe_i = safe[:, i] | ((next_word == 3) & (~replace_i))
s[:, i] = (
self_word * (safe_i | repeat_i).long()
+ next_word * swap_i.long()
+ rand_word * replace_i.long()
)
s[:, i + 1] = (
next_word * (safe_i | replace_i).long()
+ self_word * (swap_i | repeat_i).long()
)
return s
def gumbel_noise(input, TINY=1e-8):
return (
input.new_zeros(*input.size())
.uniform_()
.add_(TINY)
.log_()
.neg_()
.add_(TINY)
.log_()
.neg_()
)
@register_model("iterative_nonautoregressive_transformer")
class IterNATransformerModel(NATransformerModel):
@staticmethod
def add_args(parser):
NATransformerModel.add_args(parser)
parser.add_argument(
"--train-step",
type=int,
help="number of refinement iterations during training",
)
parser.add_argument(
"--dae-ratio",
type=float,
help="the probability of switching to the denoising auto-encoder loss",
)
parser.add_argument(
"--stochastic-approx",
action="store_true",
help="sampling from the decoder as the inputs for next iteration",
)
@classmethod
def build_model(cls, args, task):
model = super().build_model(args, task)
model.train_step = getattr(args, "train_step", 4)
model.dae_ratio = getattr(args, "dae_ratio", 0.5)
model.stochastic_approx = getattr(args, "stochastic_approx", False)
return model
def forward(
self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs
):
B, T = prev_output_tokens.size()
# encoding
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
# length prediction
length_out = self.decoder.forward_length(
normalize=False, encoder_out=encoder_out
)
length_tgt = self.decoder.forward_length_prediction(
length_out, encoder_out, tgt_tokens
)
# decoding
word_ins_outs, word_ins_tgts, word_ins_masks = [], [], []
for t in range(self.train_step):
word_ins_out = self.decoder(
normalize=False,
prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out,
step=t,
)
word_ins_tgt = tgt_tokens
word_ins_mask = word_ins_tgt.ne(self.pad)
word_ins_outs.append(word_ins_out)
word_ins_tgts.append(word_ins_tgt)
word_ins_masks.append(word_ins_mask)
if t < (self.train_step - 1):
# prediction for next iteration
if self.stochastic_approx:
word_ins_prediction = (
word_ins_out + gumbel_noise(word_ins_out)
).max(-1)[1]
else:
word_ins_prediction = word_ins_out.max(-1)[1]
prev_output_tokens = prev_output_tokens.masked_scatter(
word_ins_mask, word_ins_prediction[word_ins_mask]
)
if self.dae_ratio > 0:
# we do not perform denoising for the first iteration
corrputed = (
torch.rand(size=(B,), device=prev_output_tokens.device)
< self.dae_ratio
)
corrputed_tokens = _sequential_poisoning(
tgt_tokens[corrputed],
len(self.tgt_dict),
0.33,
self.bos,
self.eos,
self.pad,
)
prev_output_tokens[corrputed] = corrputed_tokens
# concat everything
word_ins_out = torch.cat(word_ins_outs, 0)
word_ins_tgt = torch.cat(word_ins_tgts, 0)
word_ins_mask = torch.cat(word_ins_masks, 0)
return {
"word_ins": {
"out": word_ins_out,
"tgt": word_ins_tgt,
"mask": word_ins_mask,
"ls": self.args.label_smoothing,
"nll_loss": True,
},
"length": {
"out": length_out,
"tgt": length_tgt,
"factor": self.decoder.length_loss_factor,
},
}
@register_model_architecture(
"iterative_nonautoregressive_transformer", "iterative_nonautoregressive_transformer"
)
def inat_base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.apply_bert_init = getattr(args, "apply_bert_init", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
# --- special arguments ---
args.sg_length_pred = getattr(args, "sg_length_pred", False)
args.pred_length_offset = getattr(args, "pred_length_offset", False)
args.length_loss_factor = getattr(args, "length_loss_factor", 0.1)
args.ngram_predictor = getattr(args, "ngram_predictor", 1)
args.src_embedding_copy = getattr(args, "src_embedding_copy", False)
args.train_step = getattr(args, "train_step", 4)
args.dae_ratio = getattr(args, "dae_ratio", 0.5)
args.stochastic_approx = getattr(args, "stochastic_approx", False)
@register_model_architecture(
"iterative_nonautoregressive_transformer",
"iterative_nonautoregressive_transformer_wmt_en_de",
)
def iter_nat_wmt_en_de(args):
inat_base_architecture(args)