Spaces:
Runtime error
Runtime error
File size: 7,162 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import checkpoint_utils
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import (
ConvTransformerModel,
convtransformer_espnet,
ConvTransformerEncoder,
)
from fairseq.models.speech_to_text.modules.augmented_memory_attention import (
augmented_memory,
SequenceEncoder,
AugmentedMemoryConvTransformerEncoder,
)
from torch import nn, Tensor
from typing import Dict, List
from fairseq.models.speech_to_text.modules.emformer import NoSegAugmentedMemoryTransformerEncoderLayer
@register_model("convtransformer_simul_trans")
class SimulConvTransformerModel(ConvTransformerModel):
"""
Implementation of the paper:
SimulMT to SimulST: Adapting Simultaneous Text Translation to
End-to-End Simultaneous Speech Translation
https://www.aclweb.org/anthology/2020.aacl-main.58.pdf
"""
@staticmethod
def add_args(parser):
super(SimulConvTransformerModel, SimulConvTransformerModel).add_args(parser)
parser.add_argument(
"--train-monotonic-only",
action="store_true",
default=False,
help="Only train monotonic attention",
)
@classmethod
def build_decoder(cls, args, task, embed_tokens):
tgt_dict = task.tgt_dict
from examples.simultaneous_translation.models.transformer_monotonic_attention import (
TransformerMonotonicDecoder,
)
decoder = TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
return decoder
@register_model_architecture(
"convtransformer_simul_trans", "convtransformer_simul_trans_espnet"
)
def convtransformer_simul_trans_espnet(args):
convtransformer_espnet(args)
@register_model("convtransformer_augmented_memory")
@augmented_memory
class AugmentedMemoryConvTransformerModel(SimulConvTransformerModel):
@classmethod
def build_encoder(cls, args):
encoder = SequenceEncoder(args, AugmentedMemoryConvTransformerEncoder(args))
if getattr(args, "load_pretrained_encoder_from", None) is not None:
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@register_model_architecture(
"convtransformer_augmented_memory", "convtransformer_augmented_memory"
)
def augmented_memory_convtransformer_espnet(args):
convtransformer_espnet(args)
# ============================================================================ #
# Convtransformer
# with monotonic attention decoder
# with emformer encoder
# ============================================================================ #
class ConvTransformerEmformerEncoder(ConvTransformerEncoder):
def __init__(self, args):
super().__init__(args)
stride = self.conv_layer_stride(args)
trf_left_context = args.segment_left_context // stride
trf_right_context = args.segment_right_context // stride
context_config = [trf_left_context, trf_right_context]
self.transformer_layers = nn.ModuleList(
[
NoSegAugmentedMemoryTransformerEncoderLayer(
input_dim=args.encoder_embed_dim,
num_heads=args.encoder_attention_heads,
ffn_dim=args.encoder_ffn_embed_dim,
num_layers=args.encoder_layers,
dropout_in_attn=args.dropout,
dropout_on_attn=args.dropout,
dropout_on_fc1=args.dropout,
dropout_on_fc2=args.dropout,
activation_fn=args.activation_fn,
context_config=context_config,
segment_size=args.segment_length,
max_memory_size=args.max_memory_size,
scaled_init=True, # TODO: use constant for now.
tanh_on_mem=args.amtrf_tanh_on_mem,
)
]
)
self.conv_transformer_encoder = ConvTransformerEncoder(args)
def forward(self, src_tokens, src_lengths):
encoder_out: Dict[str, List[Tensor]] = self.conv_transformer_encoder(src_tokens, src_lengths.to(src_tokens.device))
output = encoder_out["encoder_out"][0]
encoder_padding_masks = encoder_out["encoder_padding_mask"]
return {
"encoder_out": [output],
# This is because that in the original implementation
# the output didn't consider the last segment as right context.
"encoder_padding_mask": [encoder_padding_masks[0][:, : output.size(0)]] if len(encoder_padding_masks) > 0
else [],
"encoder_embedding": [],
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
@staticmethod
def conv_layer_stride(args):
# TODO: make it configurable from the args
return 4
@register_model("convtransformer_emformer")
class ConvtransformerEmformer(SimulConvTransformerModel):
@staticmethod
def add_args(parser):
super(ConvtransformerEmformer, ConvtransformerEmformer).add_args(parser)
parser.add_argument(
"--segment-length",
type=int,
metavar="N",
help="length of each segment (not including left context / right context)",
)
parser.add_argument(
"--segment-left-context",
type=int,
help="length of left context in a segment",
)
parser.add_argument(
"--segment-right-context",
type=int,
help="length of right context in a segment",
)
parser.add_argument(
"--max-memory-size",
type=int,
default=-1,
help="Right context for the segment.",
)
parser.add_argument(
"--amtrf-tanh-on-mem",
default=False,
action="store_true",
help="whether to use tanh on memory vector",
)
@classmethod
def build_encoder(cls, args):
encoder = ConvTransformerEmformerEncoder(args)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@register_model_architecture(
"convtransformer_emformer",
"convtransformer_emformer",
)
def convtransformer_emformer_base(args):
convtransformer_espnet(args)
|