|
"""Cutom encoder definition for transducer models.""" |
|
|
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks |
|
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L |
|
|
|
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm |
|
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling |
|
|
|
|
|
class CustomEncoder(torch.nn.Module): |
|
"""Custom encoder module for transducer models. |
|
|
|
Args: |
|
idim (int): input dim |
|
enc_arch (list): list of encoder blocks (type and parameters) |
|
input_layer (str): input layer type |
|
repeat_block (int): repeat provided block N times if N > 1 |
|
self_attn_type (str): type of self-attention |
|
positional_encoding_type (str): positional encoding type |
|
positionwise_layer_type (str): linear |
|
positionwise_activation_type (str): positionwise activation type |
|
conv_mod_activation_type (str): convolutional module activation type |
|
normalize_before (bool): whether to use layer_norm before the first block |
|
aux_task_layer_list (list): list of layer ids for intermediate output |
|
padding_idx (int): padding_idx for embedding input layer (if specified) |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
idim, |
|
enc_arch, |
|
input_layer="linear", |
|
repeat_block=0, |
|
self_attn_type="selfattn", |
|
positional_encoding_type="abs_pos", |
|
positionwise_layer_type="linear", |
|
positionwise_activation_type="relu", |
|
conv_mod_activation_type="relu", |
|
normalize_before=True, |
|
aux_task_layer_list=[], |
|
padding_idx=-1, |
|
): |
|
"""Construct an CustomEncoder object.""" |
|
super().__init__() |
|
|
|
( |
|
self.embed, |
|
self.encoders, |
|
self.enc_out, |
|
self.conv_subsampling_factor, |
|
) = build_blocks( |
|
"encoder", |
|
idim, |
|
input_layer, |
|
enc_arch, |
|
repeat_block=repeat_block, |
|
self_attn_type=self_attn_type, |
|
positional_encoding_type=positional_encoding_type, |
|
positionwise_layer_type=positionwise_layer_type, |
|
positionwise_activation_type=positionwise_activation_type, |
|
conv_mod_activation_type=conv_mod_activation_type, |
|
padding_idx=padding_idx, |
|
) |
|
|
|
self.normalize_before = normalize_before |
|
|
|
if self.normalize_before: |
|
self.after_norm = LayerNorm(self.enc_out) |
|
|
|
self.n_blocks = len(enc_arch) * repeat_block |
|
|
|
self.aux_task_layer_list = aux_task_layer_list |
|
|
|
def forward(self, xs, masks): |
|
"""Encode input sequence. |
|
|
|
Args: |
|
xs (torch.Tensor): input tensor |
|
masks (torch.Tensor): input mask |
|
|
|
Returns: |
|
xs (torch.Tensor or tuple): |
|
position embedded output or |
|
(position embedded output, auxiliary outputs) |
|
mask (torch.Tensor): position embedded mask |
|
|
|
""" |
|
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): |
|
xs, masks = self.embed(xs, masks) |
|
else: |
|
xs = self.embed(xs) |
|
|
|
if self.aux_task_layer_list: |
|
aux_xs_list = [] |
|
|
|
for b in range(self.n_blocks): |
|
xs, masks = self.encoders[b](xs, masks) |
|
|
|
if b in self.aux_task_layer_list: |
|
if isinstance(xs, tuple): |
|
aux_xs = xs[0] |
|
else: |
|
aux_xs = xs |
|
|
|
if self.normalize_before: |
|
aux_xs_list.append(self.after_norm(aux_xs)) |
|
else: |
|
aux_xs_list.append(aux_xs) |
|
else: |
|
xs, masks = self.encoders(xs, masks) |
|
|
|
if isinstance(xs, tuple): |
|
xs = xs[0] |
|
|
|
if self.normalize_before: |
|
xs = self.after_norm(xs) |
|
|
|
if self.aux_task_layer_list: |
|
return (xs, aux_xs_list), masks |
|
|
|
return xs, masks |
|
|