from distutils.version import LooseVersion import logging import numpy as np import six import torch import torch.nn.functional as F from espnet.nets.pytorch_backend.nets_utils import to_device class CTC(torch.nn.Module): """CTC module :param int odim: dimension of outputs :param int eprojs: number of encoder projection units :param float dropout_rate: dropout rate (0.0 ~ 1.0) :param str ctc_type: builtin or warpctc :param bool reduce: reduce the CTC loss into a scalar """ def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True): super().__init__() self.dropout_rate = dropout_rate self.loss = None self.ctc_lo = torch.nn.Linear(eprojs, odim) self.dropout = torch.nn.Dropout(dropout_rate) self.probs = None # for visualization # In case of Pytorch >= 1.7.0, CTC will be always builtin self.ctc_type = ( ctc_type if LooseVersion(torch.__version__) < LooseVersion("1.7.0") else "builtin" ) if self.ctc_type == "builtin": reduction_type = "sum" if reduce else "none" self.ctc_loss = torch.nn.CTCLoss( reduction=reduction_type, zero_infinity=True ) elif self.ctc_type == "cudnnctc": reduction_type = "sum" if reduce else "none" self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) elif self.ctc_type == "warpctc": import warpctc_pytorch as warp_ctc self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) elif self.ctc_type == "gtnctc": from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction self.ctc_loss = GTNCTCLossFunction.apply else: raise ValueError( 'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type) ) self.ignore_id = -1 self.reduce = reduce def loss_fn(self, th_pred, th_target, th_ilen, th_olen): if self.ctc_type in ["builtin", "cudnnctc"]: th_pred = th_pred.log_softmax(2) # Use the deterministic CuDNN implementation of CTC loss to avoid # [issue#17798](https://github.com/pytorch/pytorch/issues/17798) with torch.backends.cudnn.flags(deterministic=True): loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) # Batch-size average loss = loss / th_pred.size(1) return loss elif self.ctc_type == "warpctc": return self.ctc_loss(th_pred, th_target, th_ilen, th_olen) elif self.ctc_type == "gtnctc": targets = [t.tolist() for t in th_target] log_probs = torch.nn.functional.log_softmax(th_pred, dim=2) return self.ctc_loss(log_probs, targets, th_ilen, 0, "none") else: raise NotImplementedError def forward(self, hs_pad, hlens, ys_pad): """CTC forward :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: ctc loss value :rtype: torch.Tensor """ # TODO(kan-bayashi): need to make more smart way ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys # zero padding for hs ys_hat = self.ctc_lo(self.dropout(hs_pad)) if self.ctc_type != "gtnctc": ys_hat = ys_hat.transpose(0, 1) if self.ctc_type == "builtin": olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys])) hlens = hlens.long() ys_pad = torch.cat(ys) # without this the code breaks for asr_mix self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens) else: self.loss = None hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32)) olens = torch.from_numpy( np.fromiter((x.size(0) for x in ys), dtype=np.int32) ) # zero padding for ys ys_true = torch.cat(ys).cpu().int() # batch x olen # get ctc loss # expected shape of seqLength x batchSize x alphabet_size dtype = ys_hat.dtype if self.ctc_type == "warpctc" or dtype == torch.float16: # warpctc only supports float32 # torch.ctc does not support float16 (#1751) ys_hat = ys_hat.to(dtype=torch.float32) if self.ctc_type == "cudnnctc": # use GPU when using the cuDNN implementation ys_true = to_device(hs_pad, ys_true) if self.ctc_type == "gtnctc": # keep as list for gtn ys_true = ys self.loss = to_device( hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens) ).to(dtype=dtype) # get length info logging.info( self.__class__.__name__ + " input lengths: " + "".join(str(hlens).split("\n")) ) logging.info( self.__class__.__name__ + " output lengths: " + "".join(str(olens).split("\n")) ) if self.reduce: # NOTE: sum() is needed to keep consistency # since warpctc return as tensor w/ shape (1,) # but builtin return as tensor w/o shape (scalar). self.loss = self.loss.sum() logging.info("ctc loss:" + str(float(self.loss))) return self.loss def softmax(self, hs_pad): """softmax of frame activations :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) :return: log softmax applied 3d tensor (B, Tmax, odim) :rtype: torch.Tensor """ self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2) return self.probs def log_softmax(self, hs_pad): """log_softmax of frame activations :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) :return: log softmax applied 3d tensor (B, Tmax, odim) :rtype: torch.Tensor """ return F.log_softmax(self.ctc_lo(hs_pad), dim=2) def argmax(self, hs_pad): """argmax of frame activations :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) :return: argmax applied 2d tensor (B, Tmax) :rtype: torch.Tensor """ return torch.argmax(self.ctc_lo(hs_pad), dim=2) def forced_align(self, h, y, blank_id=0): """forced alignment. :param torch.Tensor h: hidden state sequence, 2d tensor (T, D) :param torch.Tensor y: id sequence tensor 1d tensor (L) :param int y: blank symbol index :return: best alignment results :rtype: list """ def interpolate_blank(label, blank_id=0): """Insert blank token between every two label token.""" label = np.expand_dims(label, 1) blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) label = label.reshape(-1) label = np.append(label, label[0]) return label lpz = self.log_softmax(h) lpz = lpz.squeeze(0) y_int = interpolate_blank(y, blank_id) logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero state_path = ( np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1 ) # state path logdelta[0, 0] = lpz[0][y_int[0]] logdelta[0, 1] = lpz[0][y_int[1]] for t in six.moves.range(1, lpz.size(0)): for s in six.moves.range(len(y_int)): if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]: candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]]) prev_state = [s, s - 1] else: candidates = np.array( [ logdelta[t - 1, s], logdelta[t - 1, s - 1], logdelta[t - 1, s - 2], ] ) prev_state = [s, s - 1, s - 2] logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]] state_path[t, s] = prev_state[np.argmax(candidates)] state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16) candidates = np.array( [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]] ) prev_state = [len(y_int) - 1, len(y_int) - 2] state_seq[-1] = prev_state[np.argmax(candidates)] for t in six.moves.range(lpz.size(0) - 2, -1, -1): state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] output_state_seq = [] for t in six.moves.range(0, lpz.size(0)): output_state_seq.append(y_int[state_seq[t, 0]]) return output_state_seq def ctc_for(args, odim, reduce=True): """Returns the CTC module for the given args and output dimension :param Namespace args: the program args :param int odim : The output dimension :param bool reduce : return the CTC loss in a scalar :return: the corresponding CTC module """ num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility if num_encs == 1: # compatible with single encoder asr mode return CTC( odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce ) elif num_encs >= 1: ctcs_list = torch.nn.ModuleList() if args.share_ctc: # use dropout_rate of the first encoder ctc = CTC( odim, args.eprojs, args.dropout_rate[0], ctc_type=args.ctc_type, reduce=reduce, ) ctcs_list.append(ctc) else: for idx in range(num_encs): ctc = CTC( odim, args.eprojs, args.dropout_rate[idx], ctc_type=args.ctc_type, reduce=reduce, ) ctcs_list.append(ctc) return ctcs_list else: raise ValueError( "Number of encoders needs to be more than one. {}".format(num_encs) )