|
"""Transducer speech recognition model (pytorch).""" |
|
|
|
from argparse import Namespace |
|
from collections import Counter |
|
from dataclasses import asdict |
|
import logging |
|
import math |
|
import numpy |
|
|
|
import chainer |
|
import torch |
|
|
|
from espnet.nets.asr_interface import ASRInterface |
|
from espnet.nets.pytorch_backend.ctc import ctc_for |
|
from espnet.nets.pytorch_backend.nets_utils import get_subsample |
|
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask |
|
from espnet.nets.pytorch_backend.transducer.arguments import ( |
|
add_encoder_general_arguments, |
|
add_rnn_encoder_arguments, |
|
add_custom_encoder_arguments, |
|
add_decoder_general_arguments, |
|
add_rnn_decoder_arguments, |
|
add_custom_decoder_arguments, |
|
add_custom_training_arguments, |
|
add_transducer_arguments, |
|
add_auxiliary_task_arguments, |
|
) |
|
from espnet.nets.pytorch_backend.transducer.auxiliary_task import AuxiliaryTask |
|
from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder |
|
from espnet.nets.pytorch_backend.transducer.custom_encoder import CustomEncoder |
|
from espnet.nets.pytorch_backend.transducer.error_calculator import ErrorCalculator |
|
from espnet.nets.pytorch_backend.transducer.initializer import initializer |
|
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork |
|
from espnet.nets.pytorch_backend.transducer.loss import TransLoss |
|
from espnet.nets.pytorch_backend.transducer.rnn_decoder import DecoderRNNT |
|
from espnet.nets.pytorch_backend.transducer.rnn_encoder import encoder_for |
|
from espnet.nets.pytorch_backend.transducer.utils import prepare_loss_inputs |
|
from espnet.nets.pytorch_backend.transducer.utils import valid_aux_task_layer_list |
|
from espnet.nets.pytorch_backend.transformer.attention import ( |
|
MultiHeadedAttention, |
|
RelPositionMultiHeadedAttention, |
|
) |
|
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( |
|
LabelSmoothingLoss, |
|
) |
|
from espnet.nets.pytorch_backend.transformer.mask import target_mask |
|
from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport |
|
from espnet.utils.fill_missing_args import fill_missing_args |
|
|
|
|
|
class Reporter(chainer.Chain): |
|
"""A chainer reporter wrapper for transducer models.""" |
|
|
|
def report( |
|
self, |
|
loss, |
|
loss_trans, |
|
loss_ctc, |
|
loss_lm, |
|
loss_aux_trans, |
|
loss_aux_symm_kl, |
|
cer, |
|
wer, |
|
): |
|
"""Instantiate reporter attributes.""" |
|
chainer.reporter.report({"loss": loss}, self) |
|
chainer.reporter.report({"loss_trans": loss_trans}, self) |
|
chainer.reporter.report({"loss_ctc": loss_ctc}, self) |
|
chainer.reporter.report({"loss_lm": loss_lm}, self) |
|
chainer.reporter.report({"loss_aux_trans": loss_aux_trans}, self) |
|
chainer.reporter.report({"loss_aux_symm_kl": loss_aux_symm_kl}, self) |
|
chainer.reporter.report({"cer": cer}, self) |
|
chainer.reporter.report({"wer": wer}, self) |
|
|
|
logging.info("loss:" + str(loss)) |
|
|
|
|
|
class E2E(ASRInterface, torch.nn.Module): |
|
"""E2E module for transducer models. |
|
|
|
Args: |
|
idim (int): dimension of inputs |
|
odim (int): dimension of outputs |
|
args (Namespace): argument Namespace containing options |
|
ignore_id (int): padding symbol id |
|
blank_id (int): blank symbol id |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments for transducer model.""" |
|
E2E.encoder_add_general_arguments(parser) |
|
E2E.encoder_add_rnn_arguments(parser) |
|
E2E.encoder_add_custom_arguments(parser) |
|
|
|
E2E.decoder_add_general_arguments(parser) |
|
E2E.decoder_add_rnn_arguments(parser) |
|
E2E.decoder_add_custom_arguments(parser) |
|
|
|
E2E.training_add_custom_arguments(parser) |
|
E2E.transducer_add_arguments(parser) |
|
E2E.auxiliary_task_add_arguments(parser) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def encoder_add_general_arguments(parser): |
|
"""Add general arguments for encoder.""" |
|
group = parser.add_argument_group("Encoder general arguments") |
|
group = add_encoder_general_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def encoder_add_rnn_arguments(parser): |
|
"""Add arguments for RNN encoder.""" |
|
group = parser.add_argument_group("RNN encoder arguments") |
|
group = add_rnn_encoder_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def encoder_add_custom_arguments(parser): |
|
"""Add arguments for Custom encoder.""" |
|
group = parser.add_argument_group("Custom encoder arguments") |
|
group = add_custom_encoder_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def decoder_add_general_arguments(parser): |
|
"""Add general arguments for decoder.""" |
|
group = parser.add_argument_group("Decoder general arguments") |
|
group = add_decoder_general_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def decoder_add_rnn_arguments(parser): |
|
"""Add arguments for RNN decoder.""" |
|
group = parser.add_argument_group("RNN decoder arguments") |
|
group = add_rnn_decoder_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def decoder_add_custom_arguments(parser): |
|
"""Add arguments for Custom decoder.""" |
|
group = parser.add_argument_group("Custom decoder arguments") |
|
group = add_custom_decoder_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def training_add_custom_arguments(parser): |
|
"""Add arguments for Custom architecture training.""" |
|
group = parser.add_argument_group("Training arguments for custom archictecture") |
|
group = add_custom_training_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def transducer_add_arguments(parser): |
|
"""Add arguments for transducer model.""" |
|
group = parser.add_argument_group("Transducer model arguments") |
|
group = add_transducer_arguments(group) |
|
|
|
return parser |
|
|
|
@staticmethod |
|
def auxiliary_task_add_arguments(parser): |
|
"""Add arguments for auxiliary task.""" |
|
group = parser.add_argument_group("Auxiliary task arguments") |
|
group = add_auxiliary_task_arguments(group) |
|
|
|
return parser |
|
|
|
@property |
|
def attention_plot_class(self): |
|
"""Get attention plot class.""" |
|
return PlotAttentionReport |
|
|
|
def get_total_subsampling_factor(self): |
|
"""Get total subsampling factor.""" |
|
if self.etype == "custom": |
|
return self.encoder.conv_subsampling_factor * int( |
|
numpy.prod(self.subsample) |
|
) |
|
else: |
|
return self.enc.conv_subsampling_factor * int(numpy.prod(self.subsample)) |
|
|
|
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, training=True): |
|
"""Construct an E2E object for transducer model.""" |
|
torch.nn.Module.__init__(self) |
|
|
|
args = fill_missing_args(args, self.add_arguments) |
|
|
|
self.is_rnnt = True |
|
self.transducer_weight = args.transducer_weight |
|
|
|
self.use_aux_task = ( |
|
True if (args.aux_task_type is not None and training) else False |
|
) |
|
|
|
self.use_aux_ctc = args.aux_ctc and training |
|
self.aux_ctc_weight = args.aux_ctc_weight |
|
|
|
self.use_aux_cross_entropy = args.aux_cross_entropy and training |
|
self.aux_cross_entropy_weight = args.aux_cross_entropy_weight |
|
|
|
if self.use_aux_task: |
|
n_layers = ( |
|
(len(args.enc_block_arch) * args.enc_block_repeat - 1) |
|
if args.enc_block_arch is not None |
|
else (args.elayers - 1) |
|
) |
|
|
|
aux_task_layer_list = valid_aux_task_layer_list( |
|
args.aux_task_layer_list, |
|
n_layers, |
|
) |
|
else: |
|
aux_task_layer_list = [] |
|
|
|
if "custom" in args.etype: |
|
if args.enc_block_arch is None: |
|
raise ValueError( |
|
"When specifying custom encoder type, --enc-block-arch" |
|
"should also be specified in training config. See" |
|
"egs/vivos/asr1/conf/transducer/train_*.yaml for more info." |
|
) |
|
|
|
self.subsample = get_subsample(args, mode="asr", arch="transformer") |
|
|
|
self.encoder = CustomEncoder( |
|
idim, |
|
args.enc_block_arch, |
|
input_layer=args.custom_enc_input_layer, |
|
repeat_block=args.enc_block_repeat, |
|
self_attn_type=args.custom_enc_self_attn_type, |
|
positional_encoding_type=args.custom_enc_positional_encoding_type, |
|
positionwise_activation_type=args.custom_enc_pw_activation_type, |
|
conv_mod_activation_type=args.custom_enc_conv_mod_activation_type, |
|
aux_task_layer_list=aux_task_layer_list, |
|
) |
|
encoder_out = self.encoder.enc_out |
|
|
|
self.most_dom_list = args.enc_block_arch[:] |
|
else: |
|
self.subsample = get_subsample(args, mode="asr", arch="rnn-t") |
|
|
|
self.enc = encoder_for( |
|
args, |
|
idim, |
|
self.subsample, |
|
aux_task_layer_list=aux_task_layer_list, |
|
) |
|
encoder_out = args.eprojs |
|
|
|
if "custom" in args.dtype: |
|
if args.dec_block_arch is None: |
|
raise ValueError( |
|
"When specifying custom decoder type, --dec-block-arch" |
|
"should also be specified in training config. See" |
|
"egs/vivos/asr1/conf/transducer/train_*.yaml for more info." |
|
) |
|
|
|
self.decoder = CustomDecoder( |
|
odim, |
|
args.dec_block_arch, |
|
input_layer=args.custom_dec_input_layer, |
|
repeat_block=args.dec_block_repeat, |
|
positionwise_activation_type=args.custom_dec_pw_activation_type, |
|
dropout_rate_embed=args.dropout_rate_embed_decoder, |
|
) |
|
decoder_out = self.decoder.dunits |
|
|
|
if "custom" in args.etype: |
|
self.most_dom_list += args.dec_block_arch[:] |
|
else: |
|
self.most_dom_list = args.dec_block_arch[:] |
|
else: |
|
self.dec = DecoderRNNT( |
|
odim, |
|
args.dtype, |
|
args.dlayers, |
|
args.dunits, |
|
blank_id, |
|
args.dec_embed_dim, |
|
args.dropout_rate_decoder, |
|
args.dropout_rate_embed_decoder, |
|
) |
|
decoder_out = args.dunits |
|
|
|
self.joint_network = JointNetwork( |
|
odim, encoder_out, decoder_out, args.joint_dim, args.joint_activation_type |
|
) |
|
|
|
if hasattr(self, "most_dom_list"): |
|
self.most_dom_dim = sorted( |
|
Counter( |
|
d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d |
|
).most_common(), |
|
key=lambda x: x[0], |
|
reverse=True, |
|
)[0][0] |
|
|
|
self.etype = args.etype |
|
self.dtype = args.dtype |
|
|
|
self.sos = odim - 1 |
|
self.eos = odim - 1 |
|
self.blank_id = blank_id |
|
self.ignore_id = ignore_id |
|
|
|
self.space = args.sym_space |
|
self.blank = args.sym_blank |
|
|
|
self.odim = odim |
|
|
|
self.reporter = Reporter() |
|
|
|
self.error_calculator = None |
|
|
|
self.default_parameters(args) |
|
|
|
if training: |
|
self.criterion = TransLoss(args.trans_type, self.blank_id) |
|
|
|
decoder = self.decoder if self.dtype == "custom" else self.dec |
|
|
|
if args.report_cer or args.report_wer: |
|
self.error_calculator = ErrorCalculator( |
|
decoder, |
|
self.joint_network, |
|
args.char_list, |
|
args.sym_space, |
|
args.sym_blank, |
|
args.report_cer, |
|
args.report_wer, |
|
) |
|
|
|
if self.use_aux_task: |
|
self.auxiliary_task = AuxiliaryTask( |
|
decoder, |
|
self.joint_network, |
|
self.criterion, |
|
args.aux_task_type, |
|
args.aux_task_weight, |
|
encoder_out, |
|
args.joint_dim, |
|
) |
|
|
|
if self.use_aux_ctc: |
|
self.aux_ctc = ctc_for( |
|
Namespace( |
|
num_encs=1, |
|
eprojs=encoder_out, |
|
dropout_rate=args.aux_ctc_dropout_rate, |
|
ctc_type="warpctc", |
|
), |
|
odim, |
|
) |
|
|
|
if self.use_aux_cross_entropy: |
|
self.aux_decoder_output = torch.nn.Linear(decoder_out, odim) |
|
|
|
self.aux_cross_entropy = LabelSmoothingLoss( |
|
odim, ignore_id, args.aux_cross_entropy_smoothing |
|
) |
|
|
|
self.loss = None |
|
self.rnnlm = None |
|
|
|
def default_parameters(self, args): |
|
"""Initialize/reset parameters for transducer. |
|
|
|
Args: |
|
args (Namespace): argument Namespace containing options |
|
|
|
""" |
|
initializer(self, args) |
|
|
|
def forward(self, xs_pad, ilens, ys_pad): |
|
"""E2E forward. |
|
|
|
Args: |
|
xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) |
|
ilens (torch.Tensor): batch of lengths of input sequences (B) |
|
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) |
|
|
|
Returns: |
|
loss (torch.Tensor): transducer loss value |
|
|
|
""" |
|
|
|
xs_pad = xs_pad[:, : max(ilens)] |
|
|
|
if "custom" in self.etype: |
|
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) |
|
|
|
_hs_pad, hs_mask = self.encoder(xs_pad, src_mask) |
|
else: |
|
_hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) |
|
|
|
if self.use_aux_task: |
|
hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1] |
|
else: |
|
hs_pad, aux_hs_pad = _hs_pad, None |
|
|
|
|
|
ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs( |
|
ys_pad, hs_mask |
|
) |
|
|
|
|
|
if "custom" in self.dtype: |
|
ys_mask = target_mask(ys_in_pad, self.blank_id) |
|
pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) |
|
else: |
|
pred_pad = self.dec(hs_pad, ys_in_pad) |
|
|
|
z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) |
|
|
|
|
|
loss_trans = self.criterion(z, target, pred_len, target_len) |
|
|
|
if self.use_aux_task and aux_hs_pad is not None: |
|
loss_aux_trans, loss_aux_symm_kl = self.auxiliary_task( |
|
aux_hs_pad, pred_pad, z, target, pred_len, target_len |
|
) |
|
else: |
|
loss_aux_trans, loss_aux_symm_kl = 0.0, 0.0 |
|
|
|
if self.use_aux_ctc: |
|
if "custom" in self.etype: |
|
hs_mask = torch.IntTensor( |
|
[h.size(1) for h in hs_mask], |
|
).to(hs_mask.device) |
|
|
|
loss_ctc = self.aux_ctc_weight * self.aux_ctc(hs_pad, hs_mask, ys_pad) |
|
else: |
|
loss_ctc = 0.0 |
|
|
|
if self.use_aux_cross_entropy: |
|
loss_lm = self.aux_cross_entropy_weight * self.aux_cross_entropy( |
|
self.aux_decoder_output(pred_pad), ys_out_pad |
|
) |
|
else: |
|
loss_lm = 0.0 |
|
|
|
loss = ( |
|
loss_trans |
|
+ self.transducer_weight * (loss_aux_trans + loss_aux_symm_kl) |
|
+ loss_ctc |
|
+ loss_lm |
|
) |
|
|
|
self.loss = loss |
|
loss_data = float(loss) |
|
|
|
|
|
if self.training or self.error_calculator is None: |
|
cer, wer = None, None |
|
else: |
|
cer, wer = self.error_calculator(hs_pad, ys_pad) |
|
|
|
if not math.isnan(loss_data): |
|
self.reporter.report( |
|
loss_data, |
|
float(loss_trans), |
|
float(loss_ctc), |
|
float(loss_lm), |
|
float(loss_aux_trans), |
|
float(loss_aux_symm_kl), |
|
cer, |
|
wer, |
|
) |
|
else: |
|
logging.warning("loss (=%f) is not correct", loss_data) |
|
|
|
return self.loss |
|
|
|
def encode_custom(self, x): |
|
"""Encode acoustic features. |
|
|
|
Args: |
|
x (ndarray): input acoustic feature (T, D) |
|
|
|
Returns: |
|
x (torch.Tensor): encoded features (T, D_enc) |
|
|
|
""" |
|
x = torch.as_tensor(x).unsqueeze(0) |
|
enc_output, _ = self.encoder(x, None) |
|
|
|
return enc_output.squeeze(0) |
|
|
|
def encode_rnn(self, x): |
|
"""Encode acoustic features. |
|
|
|
Args: |
|
x (ndarray): input acoustic feature (T, D) |
|
|
|
Returns: |
|
x (torch.Tensor): encoded features (T, D_enc) |
|
|
|
""" |
|
p = next(self.parameters()) |
|
|
|
ilens = [x.shape[0]] |
|
x = x[:: self.subsample[0], :] |
|
|
|
h = torch.as_tensor(x, device=p.device, dtype=p.dtype) |
|
hs = h.contiguous().unsqueeze(0) |
|
|
|
hs, _, _ = self.enc(hs, ilens) |
|
|
|
return hs.squeeze(0) |
|
|
|
def recognize(self, x, beam_search): |
|
"""Recognize input features. |
|
|
|
Args: |
|
x (ndarray): input acoustic feature (T, D) |
|
beam_search (class): beam search class |
|
|
|
Returns: |
|
nbest_hyps (list): n-best decoding results |
|
|
|
""" |
|
self.eval() |
|
|
|
if "custom" in self.etype: |
|
h = self.encode_custom(x) |
|
else: |
|
h = self.encode_rnn(x) |
|
|
|
nbest_hyps = beam_search(h) |
|
|
|
return [asdict(n) for n in nbest_hyps] |
|
|
|
def calculate_all_attentions(self, xs_pad, ilens, ys_pad): |
|
"""E2E attention calculation. |
|
|
|
Args: |
|
xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) |
|
ilens (torch.Tensor): batch of lengths of input sequences (B) |
|
ys_pad (torch.Tensor): |
|
batch of padded character id sequence tensor (B, Lmax) |
|
|
|
Returns: |
|
ret (ndarray): attention weights with the following shape, |
|
1) multi-head case => attention weights (B, H, Lmax, Tmax), |
|
2) other case => attention weights (B, Lmax, Tmax). |
|
|
|
""" |
|
self.eval() |
|
|
|
if "custom" not in self.etype and "custom" not in self.dtype: |
|
return [] |
|
else: |
|
with torch.no_grad(): |
|
self.forward(xs_pad, ilens, ys_pad) |
|
|
|
ret = dict() |
|
for name, m in self.named_modules(): |
|
if isinstance(m, MultiHeadedAttention) or isinstance( |
|
m, RelPositionMultiHeadedAttention |
|
): |
|
ret[name] = m.attn.cpu().numpy() |
|
|
|
self.train() |
|
|
|
return ret |
|
|