File size: 11,947 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

from typing import Any, Dict, List, Optional
from torch import Tensor

import torch
import torch.nn as nn

from fairseq.models import (
    FairseqEncoderDecoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.models.transformer import (
    base_architecture,
    Embedding,
    TransformerModel,
    TransformerEncoder,
    TransformerDecoder,
)
from fairseq.modules import (
    TransformerDecoderLayer,
)

logger = logging.getLogger(__name__)


@register_model("laser_transformer")
class LaserTransformerModel(FairseqEncoderDecoderModel):
    """Train Transformer for LASER task

    Requires --task laser
    """

    def __init__(self, encoder, decoder):
        super().__init__(encoder, decoder)

    def forward(
        self,
        src_tokens,
        src_lengths,
        prev_output_tokens=None,
        tgt_tokens=None,
        tgt_lengths=None,
        target_language_id=-1,
        dataset_name="",
    ):
        laser_encoder_out = self.encoder(src_tokens, src_lengths)
        return self.decoder(
            prev_output_tokens, laser_encoder_out, lang_id=target_language_id
        )

    @staticmethod
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        TransformerModel.add_args(parser)
        parser.add_argument(
            "--decoder-lang-embed-dim",
            type=int,
            metavar="N",
            help="decoder language embedding dimension",
        )

    @classmethod
    def build_model(cls, args, task):
        base_laser_transformer_architecture(args)

        num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0

        def load_embed_tokens(dictionary, embed_dim):
            num_embeddings = len(dictionary)
            padding_idx = dictionary.pad()

            return Embedding(num_embeddings, embed_dim, padding_idx)

        encoder_embed_tokens = load_embed_tokens(
            task.source_dictionary, args.encoder_embed_dim
        )
        decoder_embed_tokens = load_embed_tokens(
            task.target_dictionary, args.decoder_embed_dim
        )
        num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0

        encoder = LaserTransformerEncoder(
            args, task.source_dictionary, encoder_embed_tokens
        )

        decoder = LaserTransformerDecoder(
            args,
            task.target_dictionary,
            decoder_embed_tokens,
            num_langs=num_langs,
            lang_embed_dim=args.decoder_lang_embed_dim,
        )

        return cls(encoder, decoder)


class LaserTransformerEncoder(TransformerEncoder):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, src_tokens, *args, **kwargs):
        encoder_out = super().forward(src_tokens, *args, **kwargs)

        x = encoder_out["encoder_out"][0]  # T x B x C
        padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)

        if padding_mask.any():
            x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)

        # Build the sentence embedding by max-pooling over the encoder outputs
        sentemb = x.max(dim=0)[0]

        # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
        # `foward` so we use a dictionary instead.
        # TorchScript does not support mixed values so the values are all lists.
        # The empty list is equivalent to None.
        return {"sentemb": [sentemb]}  # B x C

    @torch.jit.export
    def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
        """
        Same as the one in transformer.py, with new_sentemb
        """
        if len(encoder_out["sentemb"]) == 0:
            new_sentemb = []
        else:
            new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)]

        return {
            "sentemb": new_sentemb,  # B x C
        }


class LaserTransformerDecoder(TransformerDecoder):
    def __init__(self, args, dictionary, *kargs, **kwargs):
        self.num_langs = kwargs.get("num_langs", 1)
        self.lang_embed_dim = kwargs.get("lang_embed_dim", 0)
        kwargs.pop("num_langs", None)
        kwargs.pop("lang_embed_dim", None)

        super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True)

        if self.lang_embed_dim == 0:
            self.embed_lang = None
        else:
            self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim)
            nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)

        if self.output_projection is not None:
            laser_output_embed_dim = (
                self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
            )
            self.output_projection = nn.Linear(
                laser_output_embed_dim, len(dictionary), bias=False
            )
            nn.init.normal_(
                self.output_projection.weight,
                mean=0,
                std=laser_output_embed_dim ** -0.5,
            )

    def build_decoder_layer(self, args, no_encoder_attn=False):
        decoder_embed_dim = args.decoder_embed_dim
        args.decoder_embed_dim = (
            decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
        )
        res = TransformerDecoderLayer(args, no_encoder_attn=True)
        args.decoder_embed_dim = decoder_embed_dim

        return res

    def extract_features(
        self,
        prev_output_tokens,
        encoder_out: Optional[Dict[str, List[Tensor]]],
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        lang_id: Optional[int] = None,
    ):
        """
        Similar to *forward* but only return features.

        Includes several features from "Jointly Learning to Align and
        Translate with Transformer Models" (Garg et al., EMNLP 2019).

        Args:
            full_context_alignment (bool, optional): don't apply
                auto-regressive mask to self-attention (default: False).
            alignment_layer (int, optional): return mean alignment over
                heads at this layer (default: last layer).
            alignment_heads (int, optional): only average alignment over
                this many heads (default: all heads).

        Returns:
            tuple:
                - the decoder's features of shape `(batch, tgt_len, embed_dim)`
                - a dictionary with any model-specific outputs
        """
        if alignment_layer is None:
            alignment_layer = self.num_layers - 1

        # embed positions
        positions = (
            self.embed_positions(
                prev_output_tokens, incremental_state=incremental_state
            )
            if self.embed_positions is not None
            else None
        )

        if incremental_state is not None:
            prev_output_tokens = prev_output_tokens[:, -1:]
            if positions is not None:
                positions = positions[:, -1:]

        bsz, seqlen = prev_output_tokens.size()

        # embed tokens and positions
        x = self.embed_scale * self.embed_tokens(prev_output_tokens)

        if self.quant_noise is not None:
            x = self.quant_noise(x)

        if self.project_in_dim is not None:
            x = self.project_in_dim(x)

        if positions is not None:
            x += positions

        if self.layernorm_embedding is not None:
            x = self.layernorm_embedding(x)

        x = self.dropout_module(x)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        if self.embed_lang is not None:
            lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
            langemb = self.embed_lang(lang_ids)
            langemb = langemb.unsqueeze(0)
            repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * (
                len(langemb.shape) - 1
            )
            x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1)

        sentemb = encoder_out["sentemb"][0]
        sentemb = sentemb.unsqueeze(0)

        repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1)
        x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1)

        self_attn_padding_mask: Optional[Tensor] = None
        if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
            self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)

        # decoder layers
        attn: Optional[Tensor] = None
        inner_states: List[Optional[Tensor]] = [x]
        for idx, layer in enumerate(self.layers):
            if incremental_state is None and not full_context_alignment:
                self_attn_mask = self.buffered_future_mask(x)
            else:
                self_attn_mask = None

            x, layer_attn, _ = layer(
                x,
                None,
                None,
                incremental_state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask,
                need_attn=bool((idx == alignment_layer)),
                need_head_weights=bool((idx == alignment_layer)),
            )
            inner_states.append(x)
            if layer_attn is not None and idx == alignment_layer:
                attn = layer_attn.float().to(x)

        if attn is not None:
            if alignment_heads is not None:
                attn = attn[:alignment_heads]

            # average probabilities over heads
            attn = attn.mean(dim=0)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        if self.project_out_dim is not None:
            x = self.project_out_dim(x)

        return x, {"attn": [attn], "inner_states": inner_states}

    def forward(
        self,
        prev_output_tokens,
        encoder_out: Optional[Dict[str, List[Tensor]]] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        features_only: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        src_lengths: Optional[Any] = None,
        return_all_hiddens: bool = False,
        lang_id: Optional[int] = None,
    ):
        """
        Args:
            prev_output_tokens (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_out (optional): output from the encoder, used for
                encoder-side attention
            incremental_state (dict): dictionary used for storing state during
                :ref:`Incremental decoding`
            features_only (bool, optional): only return features without
                applying output layer (default: False).

        Returns:
            tuple:
                - the decoder's output of shape `(batch, tgt_len, vocab)`
                - a dictionary with any model-specific outputs
        """

        assert lang_id is not None

        x, extra = self.extract_features(
            prev_output_tokens,
            encoder_out=encoder_out,
            incremental_state=incremental_state,
            alignment_layer=alignment_layer,
            alignment_heads=alignment_heads,
            lang_id=lang_id,
        )
        if not features_only:
            x = self.output_layer(x)
        return x, extra


@register_model_architecture("laser_transformer", "laser_transformer")
def base_laser_transformer_architecture(args):
    base_architecture(args)
    args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)