tobiasc's picture
Initial commit
ad16788
"""Cutom encoder definition for transducer models."""
import torch
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
class CustomEncoder(torch.nn.Module):
"""Custom encoder module for transducer models.
Args:
idim (int): input dim
enc_arch (list): list of encoder blocks (type and parameters)
input_layer (str): input layer type
repeat_block (int): repeat provided block N times if N > 1
self_attn_type (str): type of self-attention
positional_encoding_type (str): positional encoding type
positionwise_layer_type (str): linear
positionwise_activation_type (str): positionwise activation type
conv_mod_activation_type (str): convolutional module activation type
normalize_before (bool): whether to use layer_norm before the first block
aux_task_layer_list (list): list of layer ids for intermediate output
padding_idx (int): padding_idx for embedding input layer (if specified)
"""
def __init__(
self,
idim,
enc_arch,
input_layer="linear",
repeat_block=0,
self_attn_type="selfattn",
positional_encoding_type="abs_pos",
positionwise_layer_type="linear",
positionwise_activation_type="relu",
conv_mod_activation_type="relu",
normalize_before=True,
aux_task_layer_list=[],
padding_idx=-1,
):
"""Construct an CustomEncoder object."""
super().__init__()
(
self.embed,
self.encoders,
self.enc_out,
self.conv_subsampling_factor,
) = build_blocks(
"encoder",
idim,
input_layer,
enc_arch,
repeat_block=repeat_block,
self_attn_type=self_attn_type,
positional_encoding_type=positional_encoding_type,
positionwise_layer_type=positionwise_layer_type,
positionwise_activation_type=positionwise_activation_type,
conv_mod_activation_type=conv_mod_activation_type,
padding_idx=padding_idx,
)
self.normalize_before = normalize_before
if self.normalize_before:
self.after_norm = LayerNorm(self.enc_out)
self.n_blocks = len(enc_arch) * repeat_block
self.aux_task_layer_list = aux_task_layer_list
def forward(self, xs, masks):
"""Encode input sequence.
Args:
xs (torch.Tensor): input tensor
masks (torch.Tensor): input mask
Returns:
xs (torch.Tensor or tuple):
position embedded output or
(position embedded output, auxiliary outputs)
mask (torch.Tensor): position embedded mask
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if self.aux_task_layer_list:
aux_xs_list = []
for b in range(self.n_blocks):
xs, masks = self.encoders[b](xs, masks)
if b in self.aux_task_layer_list:
if isinstance(xs, tuple):
aux_xs = xs[0]
else:
aux_xs = xs
if self.normalize_before:
aux_xs_list.append(self.after_norm(aux_xs))
else:
aux_xs_list.append(aux_xs)
else:
xs, masks = self.encoders(xs, masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
if self.aux_task_layer_list:
return (xs, aux_xs_list), masks
return xs, masks