Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer | |
from . import build_monotonic_attention | |
from typing import Dict, Optional, List | |
from torch import Tensor | |
import torch | |
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer): | |
def forward(self, x, encoder_padding_mask): | |
seq_len, _, _ = x.size() | |
attn_mask = x.new_ones([seq_len, seq_len]).triu(1) | |
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf")) | |
return super().forward(x, encoder_padding_mask, attn_mask) | |
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer): | |
def __init__(self, args): | |
super().__init__(args) | |
assert args.simul_type is not None, "A --simul-type is needed." | |
self.encoder_attn = build_monotonic_attention(args) | |
def prune_incremental_state( | |
self, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] | |
): | |
input_buffer = self.self_attn._get_input_buffer(incremental_state) | |
for key in ["prev_key", "prev_value"]: | |
input_buffer_key = input_buffer[key] | |
assert input_buffer_key is not None | |
if input_buffer_key.size(2) > 1: | |
input_buffer[key] = input_buffer_key[:, :, :-1, :] | |
else: | |
typed_empty_dict: Dict[str, Optional[Tensor]] = {} | |
input_buffer = typed_empty_dict | |
break | |
assert incremental_state is not None | |
self.self_attn._set_input_buffer(incremental_state, input_buffer) | |
def forward( | |
self, | |
x, | |
encoder_out: Optional[Tensor] = None, | |
encoder_padding_mask: Optional[Tensor] = None, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
prev_self_attn_state: Optional[List[Tensor]] = None, | |
prev_attn_state: Optional[List[Tensor]] = None, | |
self_attn_mask: Optional[Tensor] = None, | |
self_attn_padding_mask: Optional[Tensor] = None, | |
need_attn: bool = False, | |
need_head_weights: bool = False, | |
): | |
""" | |
Args: | |
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
encoder_padding_mask (ByteTensor, optional): binary | |
ByteTensor of shape `(batch, src_len)` where padding | |
elements are indicated by ``1``. | |
need_attn (bool, optional): return attention weights | |
need_head_weights (bool, optional): return attention weights | |
for each head (default: return average over heads). | |
Returns: | |
encoded output of shape `(seq_len, batch, embed_dim)` | |
""" | |
if need_head_weights: | |
need_attn = True | |
residual = x | |
if self.normalize_before: | |
x = self.self_attn_layer_norm(x) | |
if prev_self_attn_state is not None: | |
prev_key, prev_value = prev_self_attn_state[:2] | |
saved_state: Dict[str, Optional[Tensor]] = { | |
"prev_key": prev_key, | |
"prev_value": prev_value, | |
} | |
if len(prev_self_attn_state) >= 3: | |
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] | |
assert incremental_state is not None | |
self.self_attn._set_input_buffer(incremental_state, saved_state) | |
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) | |
if self.cross_self_attention and not ( | |
incremental_state is not None | |
and _self_attn_input_buffer is not None | |
and "prev_key" in _self_attn_input_buffer | |
): | |
if self_attn_mask is not None: | |
assert encoder_out is not None | |
self_attn_mask = torch.cat( | |
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 | |
) | |
if self_attn_padding_mask is not None: | |
if encoder_padding_mask is None: | |
assert encoder_out is not None | |
encoder_padding_mask = self_attn_padding_mask.new_zeros( | |
encoder_out.size(1), encoder_out.size(0) | |
) | |
self_attn_padding_mask = torch.cat( | |
(encoder_padding_mask, self_attn_padding_mask), dim=1 | |
) | |
assert encoder_out is not None | |
y = torch.cat((encoder_out, x), dim=0) | |
else: | |
y = x | |
x, attn = self.self_attn( | |
query=x, | |
key=y, | |
value=y, | |
key_padding_mask=self_attn_padding_mask, | |
incremental_state=incremental_state, | |
need_weights=False, | |
attn_mask=self_attn_mask, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.self_attn_layer_norm(x) | |
assert self.encoder_attn is not None | |
residual = x | |
if self.normalize_before: | |
x = self.encoder_attn_layer_norm(x) | |
if prev_attn_state is not None: | |
prev_key, prev_value = prev_attn_state[:2] | |
saved_state: Dict[str, Optional[Tensor]] = { | |
"prev_key": prev_key, | |
"prev_value": prev_value, | |
} | |
if len(prev_attn_state) >= 3: | |
saved_state["prev_key_padding_mask"] = prev_attn_state[2] | |
assert incremental_state is not None | |
self.encoder_attn._set_input_buffer(incremental_state, saved_state) | |
x, attn = self.encoder_attn( | |
query=x, | |
key=encoder_out, | |
value=encoder_out, | |
key_padding_mask=encoder_padding_mask, | |
incremental_state=incremental_state, | |
static_kv=True, | |
need_weights=need_attn or (not self.training and self.need_attn), | |
need_head_weights=need_head_weights, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.encoder_attn_layer_norm(x) | |
residual = x | |
if self.normalize_before: | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
x = self.activation_dropout_module(x) | |
x = self.fc2(x) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if not self.normalize_before: | |
x = self.final_layer_norm(x) | |
if self.onnx_trace and incremental_state is not None: | |
saved_state = self.self_attn._get_input_buffer(incremental_state) | |
assert saved_state is not None | |
if self_attn_padding_mask is not None: | |
self_attn_state = [ | |
saved_state["prev_key"], | |
saved_state["prev_value"], | |
saved_state["prev_key_padding_mask"], | |
] | |
else: | |
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] | |
return x, attn, self_attn_state | |
return x, attn, None | |