tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
17.2 kB
"""RNN encoder implementation for transducer-based models.
These classes are based on the ones in espnet.nets.pytorch_backend.rnn.encoders,
and modified to output intermediate layers representation based on a list of
layers given as input. These additional outputs are intended to be used with
auxiliary tasks.
It should be noted that, here, RNN class rely on a stack of 1-layer LSTM instead
of a multi-layer LSTM for that purpose.
"""
import argparse
import logging
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from espnet.nets.e2e_asr_common import get_vgg2l_odim
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.nets_utils import to_device
class RNNP(torch.nn.Module):
"""RNN with projection layer module.
Args:
idim: Dimension of inputs
elayers: Dimension of encoder layers
cdim: Number of units (results in cdim * 2 if bidirectional)
hdim: Number of projection units
subsample: List of subsampling number
dropout: Dropout rate
typ: RNN type
aux_task_layer_list: List of layer ids for intermediate output
"""
def __init__(
self,
idim: int,
elayers: int,
cdim: int,
hdim: int,
subsample: np.ndarray,
dropout: float,
typ: str = "blstm",
aux_task_layer_list: List = [],
):
"""Initialize RNNP module."""
super(RNNP, self).__init__()
bidir = typ[0] == "b"
for i in range(elayers):
if i == 0:
inputdim = idim
else:
inputdim = hdim
RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU
rnn = RNN(
inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True
)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
if bidir:
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
else:
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
self.elayers = elayers
self.cdim = cdim
self.subsample = subsample
self.typ = typ
self.bidir = bidir
self.dropout = dropout
self.aux_task_layer_list = aux_task_layer_list
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_state: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, List], torch.Tensor]:
"""RNNP forward.
Args:
xs_pad: Batch of padded input sequences (B, Tmax, idim)
ilens: Batch of lengths of input sequences (B)
prev_state: Batch of previous RNN states
Returns:
: Batch of padded output sequences (B, Tmax, hdim)
or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)])
: Batch of lengths of output sequences (B)
: Batch of hidden state sequences (B, Tmax, hdim)
"""
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
aux_xs_list = []
elayer_states = []
for layer in range(self.elayers):
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(
xs_pack, hx=None if prev_state is None else prev_state[layer]
)
elayer_states.append(states)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
sub = self.subsample[layer + 1]
if sub > 1:
ys_pad = ys_pad[:, ::sub]
ilens = torch.tensor([int(i + 1) // sub for i in ilens])
projection_layer = getattr(self, "bt%d" % layer)
projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2)))
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
if layer in self.aux_task_layer_list:
aux_xs_list.append(xs_pad)
if layer < self.elayers - 1:
xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout))
if aux_xs_list:
return (xs_pad, aux_xs_list), ilens, elayer_states
else:
return xs_pad, ilens, elayer_states
class RNN(torch.nn.Module):
"""RNN module.
Args:
idim: Dimension of inputs
elayers: Number of encoder layers
cdim: Number of rnn units (resulted in cdim * 2 if bidirectional)
hdim: Number of final projection units
dropout: Dropout rate
typ: The RNN type
"""
def __init__(
self,
idim: int,
elayers: int,
cdim: int,
hdim: int,
dropout: float,
typ: str = "blstm",
aux_task_layer_list: List = [],
):
"""Initialize RNN module."""
super(RNN, self).__init__()
bidir = typ[0] == "b"
for i in range(elayers):
if i == 0:
inputdim = idim
else:
inputdim = cdim
layer_type = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU
rnn = layer_type(
inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True
)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
self.dropout = torch.nn.Dropout(p=dropout)
self.elayers = elayers
self.cdim = cdim
self.hdim = hdim
self.typ = typ
self.bidir = bidir
self.l_last = torch.nn.Linear(cdim, hdim)
self.aux_task_layer_list = aux_task_layer_list
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_state: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, List], torch.Tensor]:
"""RNN forward.
Args:
xs_pad: Batch of padded input sequences (B, Tmax, idim)
ilens: Batch of lengths of input sequences (B)
prev_state: Batch of previous RNN states
Returns:
: Batch of padded output sequences (B, Tmax, hdim)
or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)])
: Batch of lengths of output sequences (B)
: Batch of hidden state sequences (B, Tmax, hdim)
"""
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
aux_xs_list = []
elayer_states = []
for layer in range(self.elayers):
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
xs, states = rnn(
xs_pack, hx=None if prev_state is None else prev_state[layer]
)
elayer_states.append(states)
xs_pad, ilens = pad_packed_sequence(xs, batch_first=True)
if self.bidir:
xs_pad = xs_pad[:, :, : self.cdim] + xs_pad[:, :, self.cdim :]
if layer in self.aux_task_layer_list:
aux_projected = torch.tanh(
self.l_last(xs_pad.contiguous().view(-1, xs_pad.size(2)))
)
aux_xs_pad = aux_projected.view(xs_pad.size(0), xs_pad.size(1), -1)
aux_xs_list.append(aux_xs_pad)
if layer < self.elayers - 1:
xs_pad = self.dropout(xs_pad)
projected = torch.tanh(
self.l_last(xs_pad.contiguous().view(-1, xs_pad.size(2)))
)
xs_pad = projected.view(xs_pad.size(0), xs_pad.size(1), -1)
if aux_xs_list:
return (xs_pad, aux_xs_list), ilens, elayer_states
else:
return xs_pad, ilens, elayer_states
def reset_backward_rnn_state(
states: Union[torch.Tensor, Tuple, List]
) -> Union[torch.Tensor, Tuple, List]:
"""Set backward BRNN states to zeroes.
Args:
states: RNN states
Returns:
states: RNN states with backward set to zeroes
"""
if isinstance(states, (list, tuple)):
for state in states:
state[1::2] = 0.0
else:
states[1::2] = 0.0
return states
class VGG2L(torch.nn.Module):
"""VGG-like module.
Args:
in_channel: number of input channels
"""
def __init__(self, in_channel: int = 1):
"""Initialize VGG-like module."""
super(VGG2L, self).__init__()
# CNN layer (VGG motivated)
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.in_channel = in_channel
def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor, **kwargs):
"""VGG2L forward.
Args:
xs_pad: Batch of padded input sequences (B, Tmax, D)
ilens: Batch of lengths of input sequences (B)
Returns:
: Batch of padded output sequences (B, Tmax // 4, 128 * D // 4)
: Batch of lengths of output sequences (B)
"""
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
xs_pad = xs_pad.view(
xs_pad.size(0),
xs_pad.size(1),
self.in_channel,
xs_pad.size(2) // self.in_channel,
).transpose(1, 2)
xs_pad = F.relu(self.conv1_1(xs_pad))
xs_pad = F.relu(self.conv1_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
xs_pad = F.relu(self.conv2_1(xs_pad))
xs_pad = F.relu(self.conv2_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
if torch.is_tensor(ilens):
ilens = ilens.cpu().numpy()
else:
ilens = np.array(ilens, dtype=np.float32)
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
ilens = np.array(
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64
).tolist()
xs_pad = xs_pad.transpose(1, 2)
xs_pad = xs_pad.contiguous().view(
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)
)
return xs_pad, ilens, None
class Encoder(torch.nn.Module):
"""Encoder module.
Args:
etype: Type of encoder network
idim: Number of dimensions of encoder network
elayers: Number of layers of encoder network
eunits: Number of RNN units of encoder network
eprojs: Number of projection units of encoder network
subsample: List of subsampling numbers
dropout: Dropout rate
in_channel: Number of input channels
"""
def __init__(
self,
etype: str,
idim: int,
elayers: int,
eunits: int,
eprojs: int,
subsample: np.ndarray,
dropout: float,
in_channel: int = 1,
aux_task_layer_list: List = [],
):
"""Initialize Encoder module."""
super(Encoder, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ["lstm", "gru", "blstm", "bgru"]:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[
VGG2L(in_channel),
RNNP(
get_vgg2l_odim(idim, in_channel=in_channel),
elayers,
eunits,
eprojs,
subsample,
dropout,
typ=typ,
aux_task_layer_list=aux_task_layer_list,
),
]
)
logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder")
else:
self.enc = torch.nn.ModuleList(
[
VGG2L(in_channel),
RNN(
get_vgg2l_odim(idim, in_channel=in_channel),
elayers,
eunits,
eprojs,
dropout,
typ=typ,
aux_task_layer_list=aux_task_layer_list,
),
]
)
logging.info("Use CNN-VGG + " + typ.upper() + " for encoder")
self.conv_subsampling_factor = 4
else:
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[
RNNP(
idim,
elayers,
eunits,
eprojs,
subsample,
dropout,
typ=typ,
aux_task_layer_list=aux_task_layer_list,
)
]
)
logging.info(typ.upper() + " with every-layer projection for encoder")
else:
self.enc = torch.nn.ModuleList(
[
RNN(
idim,
elayers,
eunits,
eprojs,
dropout,
typ=typ,
aux_task_layer_list=aux_task_layer_list,
)
]
)
logging.info(typ.upper() + " without projection for encoder")
self.conv_subsampling_factor = 1
def forward(self, xs_pad, ilens, prev_states=None):
"""Forward encoder.
Args:
xs_pad: Batch of padded input sequences (B, Tmax, idim)
ilens: Batch of lengths of input sequences (B)
prev_state: Batch of previous encoder hidden states (B, ??)
Returns:
: Batch of padded output sequences (B, Tmax, hdim)
or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)])
: Batch of lengths of output sequences (B)
: Batch of hidden state sequences (B, Tmax, hdim)
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(
xs_pad,
ilens,
prev_state=prev_state,
)
current_states.append(states)
if isinstance(xs_pad, tuple):
final_xs_pad, aux_xs_list = xs_pad[0], xs_pad[1]
mask = to_device(final_xs_pad, make_pad_mask(ilens).unsqueeze(-1))
aux_xs_list = [layer.masked_fill(mask, 0.0) for layer in aux_xs_list]
return (
(
final_xs_pad.masked_fill(mask, 0.0),
aux_xs_list,
),
ilens,
current_states,
)
else:
mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1))
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
def encoder_for(
args: argparse.Namespace,
idim: Union[int, List],
subsample: np.ndarray,
aux_task_layer_list: List = [],
) -> Union[torch.nn.Module, List[torch.nn.Module]]:
"""Instantiate an encoder module given the program arguments.
Args:
args: The model arguments
idim: Dimension of inputs or list of dimensions of inputs for each encoder
subsample: subsample factors or list of subsample factors for each encoder
Returns:
: The encoder module or list of encoder modules
"""
return Encoder(
args.etype,
idim,
args.elayers,
args.eunits,
args.eprojs,
subsample,
args.dropout_rate,
aux_task_layer_list=aux_task_layer_list,
)