Spaces:
Runtime error
Runtime error
OFA-OCR-dedao-demo001
/
fairseq
/examples
/speech_text_joint_to_text
/models
/s2t_dualinputtransformer.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from collections import namedtuple | |
import torch | |
import torch.nn as nn | |
from fairseq import checkpoint_utils | |
from fairseq import utils | |
from fairseq.models import ( | |
FairseqEncoder, | |
FairseqDecoder, | |
FairseqEncoderDecoderModel, | |
register_model, | |
register_model_architecture, | |
) | |
from fairseq.models.fairseq_encoder import EncoderOut | |
from fairseq.models.speech_to_text import ( | |
TransformerDecoder, | |
S2TTransformerEncoder, | |
) | |
from fairseq.models.transformer import TransformerEncoder | |
from fairseq.modules import ( | |
TransformerEncoderLayer, | |
GradMultiply, | |
LayerNorm, | |
) | |
logger = logging.getLogger(__name__) | |
class SpeechEoSEncoder(FairseqEncoder): | |
def __init__(self, encoder, eos_num, feat_dim, adapter_type="None", adapter_dim=0): | |
super().__init__(None) | |
self.encoder = encoder | |
self.eos_num = eos_num # downsampling rate for speech input feature | |
self.eos_emb = ( | |
nn.Parameter(torch.zeros(1, feat_dim), requires_grad=True) | |
if eos_num > 0 | |
else None | |
) | |
self.adapter = self.add_adapter(adapter_type, adapter_dim) | |
def add_adapter(self, adapter_type, adapter_dim): | |
def _make_identity(linear, eps=1e-5): | |
assert isinstance(linear, nn.Linear) | |
linear.weight.data.mul_(eps) | |
linear.weight.data.fill_diagonal_(1.0) | |
if linear.bias is not None: | |
linear.bias.data.mul_(eps) | |
adapter = None | |
if adapter_type == "Linear": | |
assert adapter_dim > 0 | |
adapter = nn.Sequential( | |
nn.Linear(adapter_dim, adapter_dim), LayerNorm(adapter_dim) | |
) | |
# initialize the adapter as identity matrix first | |
_make_identity(adapter[0]) | |
elif adapter_type == "MLP": | |
assert adapter_dim > 0 | |
# assume the model is pre-norm model | |
adapter = nn.Sequential( | |
nn.Linear(adapter_dim, 2 * adapter_dim), | |
nn.ReLU(), | |
nn.Linear(2 * adapter_dim, adapter_dim), | |
LayerNorm(adapter_dim), | |
) | |
_make_identity(adapter[0]) | |
_make_identity(adapter[2]) | |
return adapter | |
def add_eos(self, src_tokens, src_lengths): | |
bsz, max_seq_len, fdim = src_tokens.size() | |
if self.eos_num > 0: | |
src_token_eos = torch.zeros( | |
[bsz, max_seq_len + self.eos_num, fdim], | |
dtype=src_tokens.dtype, | |
device=src_tokens.device, | |
) | |
src_token_eos[:, :max_seq_len] = src_tokens | |
for bi in range(bsz): | |
src_token_eos[bi][ | |
src_lengths[bi] : src_lengths[bi] + self.eos_num | |
] = self.eos_emb.expand(self.eos_num, fdim) | |
src_lengths = src_lengths + self.eos_num | |
src_tokens = src_token_eos | |
return src_tokens, src_lengths | |
def apply_adapter(self, enc_out): | |
if self.adapter is None: | |
return enc_out | |
rst = self.adapter(enc_out.encoder_out) | |
if enc_out.encoder_padding_mask is not None: | |
rst.masked_fill_( | |
enc_out.encoder_padding_mask.transpose(0, 1).unsqueeze(-1), 0 | |
) | |
return EncoderOut( | |
encoder_out=rst, | |
encoder_padding_mask=enc_out.encoder_padding_mask, | |
encoder_embedding=enc_out.encoder_embedding, | |
encoder_states=enc_out.encoder_states, | |
src_tokens=enc_out.src_tokens, | |
src_lengths=enc_out.src_lengths, | |
) | |
def forward(self, src_tokens, src_lengths=None, return_all_hiddens=False, **kwargs): | |
""" | |
src_tokens: padded tensor (B, T, C * feat) | |
src_lengths: tensor of original lengths of input utterances (B,) | |
""" | |
src_tokens, src_lengths = self.add_eos(src_tokens, src_lengths) | |
enc_out = self.encoder(src_tokens, src_lengths, return_all_hiddens) | |
enc_out = self.apply_adapter(enc_out) | |
return enc_out | |
def reorder_encoder_out(self, encoder_out, new_order): | |
return self.encoder.reorder_encoder_out(encoder_out, new_order) | |
class DualInputEncoder(FairseqEncoder): | |
def __init__( | |
self, | |
args, | |
spch_encoder, | |
text_encoder, | |
dictionary, | |
cross_attentive_loss_before_last_layer=-1, | |
): | |
super().__init__(dictionary) | |
self.spch_encoder = spch_encoder | |
self.text_encoder = text_encoder | |
self.enc_grad_mult = args.enc_grad_mult | |
self.cross_attentive_loss_before_last_layer = ( | |
cross_attentive_loss_before_last_layer | |
) | |
self.use_cross_attentive_loss = ( | |
False if cross_attentive_loss_before_last_layer <= -1 else True | |
) | |
self.enc2_along_grad_mult = args.enc2_along_grad_mult | |
def set_shared_layer(cls, share_level, src_layer, tgt_layer): | |
""" | |
share parameters from tgt_layer to src_layer | |
share_level: | |
0: share everything | |
1: share everything but different model | |
2: share weight but not bias, layernorm | |
""" | |
if share_level == 0: | |
return tgt_layer | |
if isinstance(src_layer, nn.Linear): | |
return tgt_layer | |
if isinstance(src_layer, TransformerEncoderLayer): | |
assert src_layer.embed_dim == tgt_layer.embed_dim | |
assert src_layer.normalize_before == tgt_layer.normalize_before | |
if share_level == 1: | |
src_layer.fc1 = tgt_layer.fc1 | |
src_layer.fc2 = tgt_layer.fc2 | |
src_layer.self_attn = tgt_layer.self_attn | |
src_layer.final_layer_norm = tgt_layer.final_layer_norm | |
src_layer.self_attn_layer_norm = tgt_layer.self_attn_layer_norm | |
src_layer.layernorm_embedding = tgt_layer.layernorm_embedding | |
else: | |
src_layer.fc1.weight = tgt_layer.fc1.weight | |
src_layer.fc2.weight = tgt_layer.fc2.weight | |
src_layer.self_attn.k_proj.weight = tgt_layer.self_attn.k_proj.weight | |
src_layer.self_attn.v_proj.weight = tgt_layer.self_attn.v_proj.weight | |
src_layer.self_attn.q_proj.weight = tgt_layer.self_attn.q_proj.weight | |
src_layer.self_attn.out_proj.weight = ( | |
tgt_layer.self_attn.out_proj.weight | |
) | |
else: | |
if share_level == 1: | |
return tgt_layer | |
return src_layer | |
def build_spch_encoder(cls, args): | |
cfg = { | |
"input_feat_per_channel": args.input_feat_per_channel, | |
"input_channels": args.input_channels, | |
"conv_kernel_sizes": args.conv_kernel_sizes, | |
"conv_channels": args.conv_channels, | |
"encoder_embed_dim": args.encoder_embed_dim, | |
"encoder_ffn_embed_dim": args.encoder_ffn_embed_dim, | |
"encoder_layers": args.speech_encoder_layers, | |
"encoder_layerdrop": args.encoder_layerdrop, | |
"encoder_attention_heads": args.encoder_attention_heads, | |
"max_source_positions": args.max_source_positions, | |
"dropout": args.dropout, | |
"encoder_normalize_before": args.encoder_normalize_before, | |
"activation_dropout": args.activation_dropout, | |
"attention_dropout": args.attention_dropout, | |
"activation_fn": args.activation_fn, | |
"layernorm_embedding": args.layernorm_embedding, | |
"no_token_positional_embeddings": args.no_token_positional_embeddings, | |
"no_scale_embedding": args.no_scale_embedding, | |
"quant_noise_pq": args.quant_noise_pq, | |
"encoder_freezing_updates": 0, | |
} | |
model_args = namedtuple("args", cfg.keys())(*cfg.values()) | |
spch_encoder = S2TTransformerEncoder(model_args) | |
if args.add_speech_eos: | |
spch_encoder = SpeechEoSEncoder( | |
spch_encoder, | |
2 * len(args.conv_kernel_sizes.split(",")), | |
args.input_feat_per_channel, | |
adapter_type=getattr(args, "speech_encoder_adapter_type", "None"), | |
adapter_dim=args.encoder_embed_dim, | |
) | |
return spch_encoder | |
def build_text_encoder(cls, args, src_dictionary, spch_encoder): | |
if args.encoder_shared_layers > 0: | |
mx_shared_layers = ( | |
args.speech_encoder_layers | |
if args.speech_encoder_layers < args.text_encoder_layers | |
else args.text_encoder_layers | |
) | |
args.encoder_shared_layers = ( | |
args.encoder_shared_layers | |
if args.encoder_shared_layers <= mx_shared_layers | |
else mx_shared_layers | |
) | |
cfg = { | |
"encoder_embed_dim": args.encoder_text_embed_dim, | |
"encoder_ffn_embed_dim": args.encoder_ffn_embed_dim, | |
"encoder_layers": args.text_encoder_layers, | |
"encoder_layerdrop": args.encoder_layerdrop, | |
"encoder_attention_heads": args.encoder_attention_heads, | |
"encoder_learned_pos": args.encoder_learned_pos, | |
"max_source_positions": args.max_source_positions, | |
"dropout": args.dropout, | |
"encoder_normalize_before": args.encoder_normalize_before, | |
"activation_dropout": args.activation_dropout, | |
"attention_dropout": args.attention_dropout, | |
"activation_fn": args.activation_fn, | |
"adaptive_input": args.adaptive_input, | |
"no_token_positional_embeddings": args.no_token_positional_embeddings, | |
"no_scale_embedding": args.no_scale_embedding, | |
"quant_noise_pq": args.quant_noise_pq, | |
} | |
model_args = namedtuple("args", cfg.keys())(*cfg.values()) | |
enc_emb = nn.Embedding( | |
len(src_dictionary), model_args.encoder_embed_dim, src_dictionary.pad() | |
) | |
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) | |
if args.add_speech_eos: | |
spch_encoder = spch_encoder.encoder | |
if args.encoder_shared_layers > 0: | |
text_encoder.layer_norm = cls.set_shared_layer( | |
args.encoder_shared_layer_level, | |
text_encoder.layer_norm, | |
spch_encoder.layer_norm, | |
) | |
for i, ly in enumerate( | |
spch_encoder.transformer_layers[-args.encoder_shared_layers :] | |
): | |
ly_id = i + args.text_encoder_layers - args.encoder_shared_layers | |
assert isinstance(text_encoder.layers[ly_id], type(ly)) | |
text_encoder.layers[ly_id] = cls.set_shared_layer( | |
args.encoder_shared_layer_level, | |
text_encoder.layers[ly_id], | |
ly, | |
) | |
return text_encoder | |
def mult_rst_grad(self, rst, ratio): | |
assert isinstance(rst, dict) # instead of EncoderOut | |
assert len(rst["encoder_out"]) == 1 | |
rst["encoder_out"][0] = GradMultiply.apply(rst["encoder_out"][0], ratio) | |
return rst | |
def process_attentive_loss_states(self, rst, interstates): | |
assert isinstance(rst, dict) # instead of EncoderOut | |
rst["encoder_states"] = interstates | |
return rst | |
def forward( | |
self, | |
src_tokens, | |
src_lengths=None, | |
src_txt_tokens=None, | |
src_txt_lengths=None, | |
**kwargs | |
): | |
""" | |
Args: | |
src_tokens: padded tensor (B, T, C * feat) | |
src_lengths: tensor of original lengths of input utterances (speech) (B,) | |
src_txt_tokens: padded tensor (B, T) | |
src_txt_lengths: tensor of original lengths of input utterances (text) (B,) | |
""" | |
# src_tokens only: inference | |
# src_tokens, src_lengths: speech only training | |
# src_txt_tokens, src_txt_lengths: text only training | |
# all valid: speech + text training | |
if src_tokens is None and src_txt_tokens is None: | |
raise ValueError( | |
"src_tokens and src_txt_tokens cannot be None at the same time" | |
) | |
ret1 = None | |
ret2 = None | |
return_all_hiddens = False | |
if src_tokens is not None: | |
if ( | |
self.use_cross_attentive_loss and src_txt_tokens is not None | |
): # remove self.training so we can get attn score during validation step | |
return_all_hiddens = True | |
ret1 = self.spch_encoder( | |
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens | |
) | |
if self.use_cross_attentive_loss and src_txt_tokens is not None: | |
assert self.cross_attentive_loss_before_last_layer < len( | |
ret1["encoder_states"] | |
) | |
ret1 = self.process_attentive_loss_states( | |
ret1, | |
ret1["encoder_states"][ | |
-self.cross_attentive_loss_before_last_layer - 1 | |
], | |
) | |
if src_txt_tokens is not None: | |
ret2 = self.text_encoder( | |
src_txt_tokens, src_txt_lengths, return_all_hiddens=return_all_hiddens | |
) | |
if return_all_hiddens: | |
if self.cross_attentive_loss_before_last_layer == len( | |
self.text_encoder.layers | |
): | |
text_embedding, _ = self.text_encoder.forward_embedding( | |
src_txt_tokens | |
) | |
text_embedding = text_embedding.transpose(0, 1) | |
ret2 = self.process_attentive_loss_states(ret2, text_embedding) | |
else: | |
assert self.cross_attentive_loss_before_last_layer < len( | |
self.text_encoder.layers | |
) | |
ret2 = self.process_attentive_loss_states( | |
ret2, | |
ret2["encoder_states"][ | |
-self.cross_attentive_loss_before_last_layer - 1 | |
], | |
) | |
def merge_output(rst1, rst2): | |
if rst1 is None: | |
if not (self.enc2_along_grad_mult == 1.0 or self.training): | |
rst2 = self.mult_rst_grad(rst2, self.enc2_along_grad_mult) | |
return rst2 | |
if rst2 is None: | |
return rst1 | |
if self.enc_grad_mult != 1.0 and self.training: | |
rst1 = self.mult_rst_grad(rst1, self.enc_grad_mult) | |
rst2 = self.mult_rst_grad(rst2, self.enc_grad_mult) | |
rst = (rst1, rst2) | |
return rst | |
return merge_output(ret1, ret2) | |
def reorder_encoder_out(self, encoder_out, new_order): | |
assert self.training is False # used for inference only | |
return self.spch_encoder.reorder_encoder_out(encoder_out, new_order) | |
# TransformerMultiInputDecoder: take one or two encoder inputs | |
class TransformerMultiInputDecoder(FairseqDecoder): | |
def __init__( | |
self, | |
dictionary, | |
spch_decoder, | |
text_decoder, | |
compute_cross_attentive_loss=False, | |
cross_attentive_loss_with_norm=True, | |
cross_attentive_loss_reverse=False, | |
): | |
super().__init__(dictionary) | |
self.spch_decoder = spch_decoder | |
self.text_decoder = text_decoder | |
self.compute_cross_attentive_loss = compute_cross_attentive_loss | |
self.cross_attentive_loss_with_norm = cross_attentive_loss_with_norm | |
self.cross_attentive_loss_reverse = cross_attentive_loss_reverse | |
def share_spchdecoder(cls, task_args, text_decoder, spch_decoder): | |
if task_args.decoder_shared_layer_level == 0: | |
return text_decoder | |
assert text_decoder.embed_tokens == spch_decoder.embed_tokens | |
spch_decoder.project_in_dim = text_decoder.project_in_dim | |
spch_decoder.embed_positions = text_decoder.embed_positions | |
spch_decoder.layernorm_embedding = text_decoder.layernorm_embedding | |
spch_decoder.project_out_dim = text_decoder.project_out_dim | |
spch_decoder.adaptive_softmax = text_decoder.adaptive_softmax | |
if task_args.decoder_shared_layer_level == 1: | |
spch_decoder.output_projection = text_decoder.output_projection | |
spch_decoder.layer_norm = text_decoder.layer_norm | |
else: # 2 | |
spch_decoder.output_projection.weight = ( | |
text_decoder.output_projection.weight | |
) | |
for i, ly in enumerate(text_decoder.layers): | |
sly = spch_decoder.layers[i] | |
sly.self_attn = ly.self_attn | |
sly.self_attn_layer_norm = ly.self_attn_layer_norm | |
# sly.encoder_attn = ly.encoder_attn | |
if ( | |
task_args.decoder_shared_layer_level == 1 | |
): # share everything, but under different models | |
sly.encoder_attn = ly.encoder_attn | |
sly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm | |
sly.fc1 = ly.fc1 | |
sly.fc2 = ly.fc2 | |
sly.final_layer_norm = ly.final_layer_norm | |
else: # task_args.decoder_shared_layer_level == 2: #separated encoder_attn_layer_norm and bias | |
sly.encoder_attn.k_proj.weight = ly.encoder_attn.k_proj.weight | |
sly.encoder_attn.v_proj.weight = ly.encoder_attn.v_proj.weight | |
sly.encoder_attn.q_proj.weight = ly.encoder_attn.q_proj.weight | |
sly.encoder_attn.out_proj.weight = ly.encoder_attn.out_proj.weight | |
sly.fc1.weight = ly.fc1.weight | |
sly.fc2.weight = ly.fc2.weight | |
return spch_decoder | |
def cross_attentive_loss( | |
self, teacher_states, student_states, teacher_masking, student_masking, eps=1e-6 | |
): | |
x = teacher_states.transpose(0, 1) # from T X B X D to B X T X D | |
y = student_states.transpose(0, 1) | |
if self.cross_attentive_loss_with_norm: | |
x = x / (x.norm(dim=2, keepdim=True) + eps) | |
y = y / (y.norm(dim=2, keepdim=True) + eps) | |
dim = x.size(-1) | |
# lengths: batch X seqLen | |
sim_scores_xy = torch.bmm(x, y.transpose(1, 2)) # batch X lenx X leny ] | |
if y.dtype == torch.float16: | |
sim_scores_xy = sim_scores_xy.float() | |
y = y.float() | |
x = x.float() | |
if teacher_masking != []: | |
assert len(teacher_masking) == 1 | |
sim_scores_xy = sim_scores_xy.masked_fill( | |
teacher_masking[0].unsqueeze(-1), float("-inf") | |
) | |
if student_masking != []: | |
sim_scores_xy = sim_scores_xy.masked_fill( | |
student_masking[0].unsqueeze(1), float("-inf") | |
) | |
# do masking | |
y_weights = utils.softmax(sim_scores_xy, dim=-1) | |
if teacher_masking != []: | |
y_weights = y_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0) | |
x_reconstruct_from_y = torch.bmm(y_weights, y) | |
sim_scores_xx = torch.bmm(x, x.transpose(1, 2)) # batch X lenx X lenx ] | |
x_weights = utils.softmax(sim_scores_xx, dim=-1) | |
if teacher_masking != []: | |
x_weights = x_weights.masked_fill(teacher_masking[0].unsqueeze(-1), 0) | |
# no gradient for teacher state | |
x_reconstruct_from_x = torch.bmm(x_weights, x).detach() | |
cost = (x_reconstruct_from_x - x_reconstruct_from_y).norm(dim=2) | |
if teacher_masking != []: | |
cost = cost.masked_fill(teacher_masking[0], 0) | |
if not self.cross_attentive_loss_with_norm: | |
cost = cost / dim | |
return cost | |
def forward( | |
self, | |
prev_output_tokens, | |
encoder_out, | |
incremental_state=None, | |
has_txt_input=False, | |
**kwargs | |
): | |
""" | |
Args: | |
prev_output_tokens (LongTensor): previous decoder outputs of shape | |
`(batch, tgt_len)`, for input feeding/teacher forcing. If there are | |
two or more input during training, they will share the same prev_output_tokens | |
encoder_out (tuple[Tensor]): output from the encoder, used for | |
encoder-side attention. It will be tuple if there are more inputs, but a tensor | |
if only one input | |
incremental_state ([dict]): dictionary used for storing state during | |
:ref:`Incremental decoding`. It is only valid for inference, only from single | |
input | |
Returns: | |
tuple: | |
- the last decoder layer's output of shape `(batch, tgt_len, | |
vocab)`. If there are N inputs, batch will be N bigger than a single input | |
- the last decoder layer's attention weights of shape `(batch, | |
tgt_len, src_len)` | |
""" | |
assert not isinstance(encoder_out, EncoderOut) | |
if isinstance(encoder_out, tuple): # training with mulitple input | |
rst = [] | |
assert len(encoder_out) == 2 | |
for i, eo in enumerate(encoder_out): | |
assert incremental_state is None | |
if i == 0: | |
rst.append( | |
self.spch_decoder(prev_output_tokens, eo, incremental_state) | |
) | |
else: | |
rst.append( | |
self.text_decoder(prev_output_tokens, eo, incremental_state) | |
) | |
dec_out = torch.cat([r[0] for r in rst], dim=0) | |
attn_cost = None | |
if self.compute_cross_attentive_loss: | |
assert isinstance(encoder_out[0], dict) | |
if self.cross_attentive_loss_reverse: | |
attn_cost = self.cross_attentive_loss( | |
teacher_states=encoder_out[1]["encoder_states"], # text_states | |
student_states=encoder_out[0]["encoder_states"], # spch_states | |
teacher_masking=encoder_out[1]["encoder_padding_mask"], | |
student_masking=encoder_out[0]["encoder_padding_mask"], | |
) | |
else: | |
attn_cost = self.cross_attentive_loss( | |
teacher_states=encoder_out[0]["encoder_states"], # spch_states | |
student_states=encoder_out[1]["encoder_states"], # text_states | |
teacher_masking=encoder_out[0]["encoder_padding_mask"], | |
student_masking=encoder_out[1]["encoder_padding_mask"], | |
) | |
return (dec_out, {"attn_cost": attn_cost}) | |
else: # inference or training with one input | |
if has_txt_input: | |
return self.text_decoder( | |
prev_output_tokens, encoder_out, incremental_state | |
) | |
return self.spch_decoder(prev_output_tokens, encoder_out, incremental_state) | |
# Note: | |
# dual input transformer: | |
# encoder: S2TTransformerEncoder for speech + TransformerEncoder for text | |
# decoder: TransformerDecoder for text | |
class DualInputS2TTransformerModel(FairseqEncoderDecoderModel): | |
def __init__(self, encoder, decoder): | |
super().__init__(encoder, decoder) | |
self.num_updates = 0 | |
def max_positions(self): | |
return None # it is provided in task | |
def add_args(parser): | |
"""Add model-specific arguments to the parser.""" | |
# encoder 1: S2TTransformerEncoder for speech | |
parser.add_argument( | |
"--conv-kernel-sizes", | |
type=str, | |
metavar="N", | |
help="kernel sizes of Conv1d subsampling layers", | |
) | |
parser.add_argument( | |
"--conv-channels", | |
type=int, | |
metavar="N", | |
help="# of channels in Conv1d subsampling layers", | |
) | |
parser.add_argument( | |
"--enc-output-dim", | |
type=int, | |
metavar="N", | |
help=""" | |
encoder output dimension, can be None. If specified, projecting the | |
transformer output to the specified dimension""", | |
) | |
# standard Transformer | |
parser.add_argument( | |
"--activation-fn", | |
type=str, | |
default="relu", | |
choices=utils.get_available_activation_fns(), | |
help="activation function to use", | |
) | |
parser.add_argument( | |
"--dropout", type=float, metavar="D", help="dropout probability" | |
) | |
parser.add_argument( | |
"--attention-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability for attention weights", | |
) | |
parser.add_argument( | |
"--activation-dropout", | |
"--relu-dropout", | |
type=float, | |
metavar="D", | |
help="dropout probability after activation in FFN.", | |
) | |
parser.add_argument( | |
"--encoder-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder embedding dimension", | |
) | |
parser.add_argument( | |
"--encoder-text-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder text embedding dimension", | |
) | |
parser.add_argument( | |
"--encoder-ffn-embed-dim", | |
type=int, | |
metavar="N", | |
help="encoder embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--encoder-attention-heads", | |
type=int, | |
metavar="N", | |
help="num encoder attention heads", | |
) | |
parser.add_argument( | |
"--decoder-embed-dim", | |
type=int, | |
metavar="N", | |
help="decoder embedding dimension", | |
) | |
parser.add_argument( | |
"--decoder-ffn-embed-dim", | |
type=int, | |
metavar="N", | |
help="decoder embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--decoder-layers", type=int, metavar="N", help="num decoder layers" | |
) | |
parser.add_argument( | |
"--decoder-attention-heads", | |
type=int, | |
metavar="N", | |
help="num decoder attention heads", | |
) | |
parser.add_argument( | |
"--layernorm-embedding", | |
action="store_true", | |
help="add layernorm to embedding", | |
) | |
parser.add_argument( | |
"--no-scale-embedding", | |
action="store_true", | |
help="if True, dont scale embeddings", | |
) | |
# non-standard transformer parameters | |
parser.add_argument( | |
"--speech-encoder-layers", | |
type=int, | |
metavar="N", | |
help="num speech encoder layers", | |
) | |
parser.add_argument( | |
"--text-encoder-layers", | |
type=int, | |
metavar="N", | |
help="num text encoder layers", | |
) | |
parser.add_argument( | |
"--encoder-shared-layers", | |
type=int, | |
metavar="N", | |
help="num shared encoder layers", | |
) | |
parser.add_argument( | |
"--encoder-shared-layer-level", | |
type=int, | |
metavar="N", | |
default=0, | |
choices=[0, 1, 2], | |
help="share layer level 0: all share 1: all share with separate model 2: share weight but not bias and layernorm", | |
) | |
parser.add_argument( | |
"--decoder-shared-layer-level", | |
default=0, | |
choices=[0, 1, 2], | |
type=int, | |
metavar="N", | |
help="0: share everything; 1: share everything with different model 2: no share layer_norm and bias", | |
) | |
### | |
parser.add_argument( | |
"--text-input-cost-ratio", | |
type=float, | |
default=1.0, | |
metavar="V", | |
help="text input cost ratio relative to speech input cost", | |
) | |
parser.add_argument( | |
"--init-scale", | |
type=float, | |
default=1.0, | |
metavar="V", | |
help="scale the initial weight by given factor", | |
) | |
parser.add_argument( | |
"--enc-grad-mult", | |
type=float, | |
metavar="V", | |
default=1.0, | |
help="multiply enc1 and enc2 gradient by V", | |
) | |
parser.add_argument( | |
"--enc2-along-grad-mult", | |
type=float, | |
metavar="V", | |
default=1.0, | |
help="multiply enc2 gradient by V if only enc2 is used", | |
) | |
parser.add_argument( | |
"--load-pretrain-encoder", | |
type=str, | |
default="", | |
metavar="EXPR", | |
help=""" path to the pretrained encoder """, | |
) | |
parser.add_argument( | |
"--load-pretrain-speech-encoder", | |
type=str, | |
default="", | |
metavar="EXPR", | |
help=""" path to the pretrained speech encoder """, | |
) | |
parser.add_argument( | |
"--load-pretrain-text-encoder", | |
type=str, | |
default="", | |
metavar="EXPR", | |
help=""" path to the pretrained text encoder """, | |
) | |
parser.add_argument( | |
"--load-pretrain-text-encoder-last", | |
type=str, | |
default="", | |
metavar="EXPR", | |
help=""" path to the pretrained text encoder """, | |
) | |
parser.add_argument( | |
"--load-pretrain-decoder", | |
type=str, | |
metavar="EXPR", | |
default="", | |
help=""" path to the pretrained encoder """, | |
) | |
parser.add_argument( | |
"--add-speech-eos", | |
action="store_true", | |
help="add eos token at the end of input feature", | |
) | |
parser.add_argument( | |
"--speech-encoder-adapter-type", | |
type=str, | |
metavar="EXPR", | |
default="None", | |
choices=["None", "Linear", "MLP"], | |
help="add speech encoder adapter", | |
) | |
def build_encoder(cls, args, task): | |
spch_encoder = DualInputEncoder.build_spch_encoder(args) | |
text_encoder = DualInputEncoder.build_text_encoder( | |
args, task.src_dict, spch_encoder | |
) | |
cross_attentive_loss_before_last_layer = ( | |
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 | |
) | |
encoder = DualInputEncoder( | |
args, | |
spch_encoder, | |
text_encoder, | |
task.src_dict, | |
cross_attentive_loss_before_last_layer, | |
) | |
if args.init_scale != 1.0: | |
with torch.no_grad(): | |
for param in encoder.parameters(): | |
param.data.mul_(args.init_scale) | |
if args.load_pretrain_text_encoder != "": | |
checkpoint_utils.load_pretrained_component_from_model( | |
text_encoder, args.load_pretrain_text_encoder | |
) | |
if args.load_pretrain_speech_encoder != "": | |
if hasattr(spch_encoder, "encoder"): | |
checkpoint_utils.load_pretrained_component_from_model( | |
spch_encoder.encoder, args.load_pretrain_speech_encoder | |
) | |
else: | |
checkpoint_utils.load_pretrained_component_from_model( | |
spch_encoder, args.load_pretrain_speech_encoder | |
) | |
if ( | |
args.load_pretrain_text_encoder_last != "" | |
): # if share encoder, speech encoder parameters will be used. | |
# It provides a chance to use pre-trained mt encoder instead | |
checkpoint_utils.load_pretrained_component_from_model( | |
text_encoder, args.load_pretrain_text_encoder_last | |
) | |
if args.load_pretrain_encoder != "": | |
checkpoint_utils.load_pretrained_component_from_model( | |
encoder, args.load_pretrain_encoder | |
) | |
return encoder | |
def build_decoder(cls, args, task): | |
dec_cfg = { | |
"decoder_layerdrop": args.decoder_layerdrop, | |
"share_decoder_input_output_embed": args.share_decoder_input_output_embed, | |
"decoder_embed_dim": args.decoder_embed_dim, | |
"max_target_positions": args.max_target_positions, | |
"dropout": args.dropout, | |
"encoder_learned_pos": args.encoder_learned_pos, | |
"decoder_learned_pos": args.decoder_learned_pos, | |
"layernorm_embedding": args.layernorm_embedding, | |
"decoder_normalize_before": args.decoder_normalize_before, | |
"activation_dropout": args.activation_dropout, | |
"attention_dropout": args.attention_dropout, | |
"decoder_ffn_embed_dim": args.decoder_ffn_embed_dim, | |
"decoder_layers": args.decoder_layers, | |
"decoder_attention_heads": args.decoder_attention_heads, | |
"decoder_output_dim": args.decoder_embed_dim, | |
"no_scale_embedding": args.no_scale_embedding, | |
"adaptive_input": args.adaptive_input, | |
"quant_noise_pq": args.quant_noise_pq, | |
"adaptive_softmax_cutoff": args.adaptive_softmax_cutoff, | |
"tie_adaptive_weights": args.tie_adaptive_weights, | |
"no_token_positional_embeddings": args.no_token_positional_embeddings, | |
} | |
dec_cfg = namedtuple("args", dec_cfg.keys())(*dec_cfg.values()) | |
dec_emb = nn.Embedding( | |
len(task.target_dictionary), | |
args.decoder_embed_dim, | |
task.target_dictionary.pad(), | |
) | |
compute_cross_attentive_loss = ( | |
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False | |
) | |
cross_attentive_loss_without_norm = getattr( | |
args, "attentive_cost_without_normalize", False | |
) | |
cross_attentive_loss_reverse = ( | |
False # getattr(args, "attentive_cost_reverse", False) | |
) | |
text_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb) | |
spch_decoder = TransformerDecoder(dec_cfg, task.target_dictionary, dec_emb) | |
spch_decoder = TransformerMultiInputDecoder.share_spchdecoder( | |
args, text_decoder, spch_decoder | |
) | |
decoder = TransformerMultiInputDecoder( | |
dictionary=task.target_dictionary, | |
spch_decoder=spch_decoder, | |
text_decoder=text_decoder, | |
compute_cross_attentive_loss=compute_cross_attentive_loss, | |
cross_attentive_loss_with_norm=True | |
if not cross_attentive_loss_without_norm | |
else False, | |
cross_attentive_loss_reverse=cross_attentive_loss_reverse, | |
) | |
if args.init_scale != 1.0: | |
with torch.no_grad(): | |
for param in decoder.parameters(): | |
param.data.mul_(args.init_scale) | |
if args.load_pretrain_decoder != "": | |
try: | |
checkpoint_utils.load_pretrained_component_from_model( | |
decoder, args.load_pretrain_decoder | |
) | |
except RuntimeError: | |
checkpoint_utils.load_pretrained_component_from_model( | |
decoder.text_decoder, args.load_pretrain_decoder | |
) | |
if args.decoder_shared_layer_level > 0: | |
checkpoint_utils.load_pretrained_component_from_model( | |
decoder.spch_decoder, args.load_pretrain_decoder | |
) | |
return decoder | |
def build_model(cls, args, task): | |
"""Build a new model instance.""" | |
# make sure that all args are properly defaulted | |
# (in case there are any new ones) | |
dualinputs2ttransformer_base(args) | |
encoder = cls.build_encoder(args, task) | |
decoder = cls.build_decoder(args, task) | |
return cls(encoder, decoder) | |
def get_normalized_probs(self, net_output, log_probs, sample=None): | |
# net_output['encoder_out'] is a (B, T, D) tensor | |
lprobs = super().get_normalized_probs(net_output, log_probs, sample) | |
lprobs.batch_first = True | |
return lprobs | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
super().set_num_updates(num_updates) | |
self.num_updates = num_updates | |
def forward( | |
self, | |
src_tokens, | |
src_lengths, | |
prev_output_tokens, | |
use_encoder_outputs=False, | |
src_txt_tokens=None, | |
src_txt_lengths=None, | |
mode="sup_speech", | |
**kwargs | |
): | |
""" | |
Run the forward pass for an encoder-decoder model. | |
First feed a batch of source tokens through the encoder. Then, feed the | |
encoder output and previous decoder outputs (i.e., teacher forcing) to | |
the decoder to produce the next outputs:: | |
encoder_out = self.encoder(src_tokens, src_lengths) | |
return self.decoder(prev_output_tokens, encoder_out) | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (LongTensor): source sentence lengths of shape `(batch)` | |
prev_output_tokens (LongTensor): previous decoder outputs of shape | |
`(batch, tgt_len)`, for teacher forcing | |
mode = 'sup_speech' or 'text' | |
Returns: | |
tuple: | |
- the decoder's output of shape `(batch, tgt_len, vocab)` | |
- a dictionary with any model-specific outputs | |
""" | |
if mode == "text": | |
assert src_txt_tokens is None | |
src_txt_tokens = src_tokens | |
src_txt_lengths = src_lengths | |
src_tokens = None | |
src_lengths = None | |
encoder_out = self.encoder( | |
src_tokens, | |
src_lengths=src_lengths, | |
src_txt_tokens=src_txt_tokens, | |
src_txt_lengths=src_txt_lengths, | |
**kwargs | |
) | |
has_txt_input = True if src_txt_tokens is not None else False | |
decoder_out = self.decoder( | |
prev_output_tokens, | |
encoder_out=encoder_out, | |
has_txt_input=has_txt_input, | |
**kwargs | |
) | |
if use_encoder_outputs: | |
return decoder_out, encoder_out | |
return decoder_out | |
def dualinputs2ttransformer_base(args): | |
args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) | |
# Convolutional subsampler | |
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) | |
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") | |
args.conv_channels = getattr(args, "conv_channels", 1024) | |
# Transformer | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
args.encoder_text_embed_dim = getattr( | |
args, "encoder_text_embed_dim", args.encoder_embed_dim | |
) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) | |
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) | |
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) | |
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) | |
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) | |
args.decoder_ffn_embed_dim = getattr( | |
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim | |
) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | |
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) | |
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.attention_dropout = getattr(args, "attention_dropout", args.dropout) | |
args.activation_dropout = getattr(args, "activation_dropout", args.dropout) | |
args.activation_fn = getattr(args, "activation_fn", "relu") | |
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) | |
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) | |
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) | |
args.share_decoder_input_output_embed = getattr( | |
args, "share_decoder_input_output_embed", False | |
) | |
args.no_token_positional_embeddings = getattr( | |
args, "no_token_positional_embeddings", False | |
) | |
args.adaptive_input = getattr(args, "adaptive_input", False) | |
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) | |
args.decoder_output_dim = getattr( | |
args, "decoder_output_dim", args.decoder_embed_dim | |
) | |
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) | |
args.no_scale_embedding = getattr(args, "no_scale_embedding", False) | |
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) | |
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10) | |
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) | |
args.encoder_shared_layers = getattr(args, "encoder_shared_layers", 0) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
args.add_speech_eos = getattr(args, "add_speech_eos", False) | |
def dualinputs2ttransformer_s(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) | |
args.dropout = getattr(args, "dropout", 0.1) | |
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 7) | |
args.text_encoder_layers = getattr(args, "text_encoder_layers", 7) | |
args.decoder_layers = getattr(args, "decoder_layers", 7) | |
dualinputs2ttransformer_base(args) | |
def dualinputs2ttransformer_m(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) | |
args.dropout = getattr(args, "dropout", 0.15) | |
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 10) | |
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
dualinputs2ttransformer_base(args) | |
def dualinputs2ttransformer_b(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) | |
args.dropout = getattr(args, "dropout", 0.15) | |
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12) | |
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
dualinputs2ttransformer_base(args) | |
def dualinputs2ttransformer_l(args): | |
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) | |
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4) | |
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) | |
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) | |
args.dropout = getattr(args, "dropout", 0.2) | |
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12) | |
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) | |
args.decoder_layers = getattr(args, "decoder_layers", 6) | |
dualinputs2ttransformer_base(args) | |