Spaces:
Runtime error
Runtime error
from distutils.version import LooseVersion | |
import logging | |
import numpy as np | |
import six | |
import torch | |
import torch.nn.functional as F | |
from espnet.nets.pytorch_backend.nets_utils import to_device | |
class CTC(torch.nn.Module): | |
"""CTC module | |
:param int odim: dimension of outputs | |
:param int eprojs: number of encoder projection units | |
:param float dropout_rate: dropout rate (0.0 ~ 1.0) | |
:param str ctc_type: builtin or warpctc | |
:param bool reduce: reduce the CTC loss into a scalar | |
""" | |
def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True): | |
super().__init__() | |
self.dropout_rate = dropout_rate | |
self.loss = None | |
self.ctc_lo = torch.nn.Linear(eprojs, odim) | |
self.dropout = torch.nn.Dropout(dropout_rate) | |
self.probs = None # for visualization | |
# In case of Pytorch >= 1.7.0, CTC will be always builtin | |
self.ctc_type = ( | |
ctc_type | |
if LooseVersion(torch.__version__) < LooseVersion("1.7.0") | |
else "builtin" | |
) | |
if self.ctc_type == "builtin": | |
reduction_type = "sum" if reduce else "none" | |
self.ctc_loss = torch.nn.CTCLoss( | |
reduction=reduction_type, zero_infinity=True | |
) | |
elif self.ctc_type == "cudnnctc": | |
reduction_type = "sum" if reduce else "none" | |
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) | |
elif self.ctc_type == "warpctc": | |
import warpctc_pytorch as warp_ctc | |
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) | |
elif self.ctc_type == "gtnctc": | |
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction | |
self.ctc_loss = GTNCTCLossFunction.apply | |
else: | |
raise ValueError( | |
'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type) | |
) | |
self.ignore_id = -1 | |
self.reduce = reduce | |
def loss_fn(self, th_pred, th_target, th_ilen, th_olen): | |
if self.ctc_type in ["builtin", "cudnnctc"]: | |
th_pred = th_pred.log_softmax(2) | |
# Use the deterministic CuDNN implementation of CTC loss to avoid | |
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798) | |
with torch.backends.cudnn.flags(deterministic=True): | |
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) | |
# Batch-size average | |
loss = loss / th_pred.size(1) | |
return loss | |
elif self.ctc_type == "warpctc": | |
return self.ctc_loss(th_pred, th_target, th_ilen, th_olen) | |
elif self.ctc_type == "gtnctc": | |
targets = [t.tolist() for t in th_target] | |
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2) | |
return self.ctc_loss(log_probs, targets, th_ilen, 0, "none") | |
else: | |
raise NotImplementedError | |
def forward(self, hs_pad, hlens, ys_pad): | |
"""CTC forward | |
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) | |
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B) | |
:param torch.Tensor ys_pad: | |
batch of padded character id sequence tensor (B, Lmax) | |
:return: ctc loss value | |
:rtype: torch.Tensor | |
""" | |
# TODO(kan-bayashi): need to make more smart way | |
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys | |
# zero padding for hs | |
ys_hat = self.ctc_lo(self.dropout(hs_pad)) | |
if self.ctc_type != "gtnctc": | |
ys_hat = ys_hat.transpose(0, 1) | |
if self.ctc_type == "builtin": | |
olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys])) | |
hlens = hlens.long() | |
ys_pad = torch.cat(ys) # without this the code breaks for asr_mix | |
self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens) | |
else: | |
self.loss = None | |
hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32)) | |
olens = torch.from_numpy( | |
np.fromiter((x.size(0) for x in ys), dtype=np.int32) | |
) | |
# zero padding for ys | |
ys_true = torch.cat(ys).cpu().int() # batch x olen | |
# get ctc loss | |
# expected shape of seqLength x batchSize x alphabet_size | |
dtype = ys_hat.dtype | |
if self.ctc_type == "warpctc" or dtype == torch.float16: | |
# warpctc only supports float32 | |
# torch.ctc does not support float16 (#1751) | |
ys_hat = ys_hat.to(dtype=torch.float32) | |
if self.ctc_type == "cudnnctc": | |
# use GPU when using the cuDNN implementation | |
ys_true = to_device(hs_pad, ys_true) | |
if self.ctc_type == "gtnctc": | |
# keep as list for gtn | |
ys_true = ys | |
self.loss = to_device( | |
hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens) | |
).to(dtype=dtype) | |
# get length info | |
logging.info( | |
self.__class__.__name__ | |
+ " input lengths: " | |
+ "".join(str(hlens).split("\n")) | |
) | |
logging.info( | |
self.__class__.__name__ | |
+ " output lengths: " | |
+ "".join(str(olens).split("\n")) | |
) | |
if self.reduce: | |
# NOTE: sum() is needed to keep consistency | |
# since warpctc return as tensor w/ shape (1,) | |
# but builtin return as tensor w/o shape (scalar). | |
self.loss = self.loss.sum() | |
logging.info("ctc loss:" + str(float(self.loss))) | |
return self.loss | |
def softmax(self, hs_pad): | |
"""softmax of frame activations | |
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | |
:return: log softmax applied 3d tensor (B, Tmax, odim) | |
:rtype: torch.Tensor | |
""" | |
self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2) | |
return self.probs | |
def log_softmax(self, hs_pad): | |
"""log_softmax of frame activations | |
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | |
:return: log softmax applied 3d tensor (B, Tmax, odim) | |
:rtype: torch.Tensor | |
""" | |
return F.log_softmax(self.ctc_lo(hs_pad), dim=2) | |
def argmax(self, hs_pad): | |
"""argmax of frame activations | |
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | |
:return: argmax applied 2d tensor (B, Tmax) | |
:rtype: torch.Tensor | |
""" | |
return torch.argmax(self.ctc_lo(hs_pad), dim=2) | |
def forced_align(self, h, y, blank_id=0): | |
"""forced alignment. | |
:param torch.Tensor h: hidden state sequence, 2d tensor (T, D) | |
:param torch.Tensor y: id sequence tensor 1d tensor (L) | |
:param int y: blank symbol index | |
:return: best alignment results | |
:rtype: list | |
""" | |
def interpolate_blank(label, blank_id=0): | |
"""Insert blank token between every two label token.""" | |
label = np.expand_dims(label, 1) | |
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id | |
label = np.concatenate([blanks, label], axis=1) | |
label = label.reshape(-1) | |
label = np.append(label, label[0]) | |
return label | |
lpz = self.log_softmax(h) | |
lpz = lpz.squeeze(0) | |
y_int = interpolate_blank(y, blank_id) | |
logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero | |
state_path = ( | |
np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1 | |
) # state path | |
logdelta[0, 0] = lpz[0][y_int[0]] | |
logdelta[0, 1] = lpz[0][y_int[1]] | |
for t in six.moves.range(1, lpz.size(0)): | |
for s in six.moves.range(len(y_int)): | |
if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]: | |
candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]]) | |
prev_state = [s, s - 1] | |
else: | |
candidates = np.array( | |
[ | |
logdelta[t - 1, s], | |
logdelta[t - 1, s - 1], | |
logdelta[t - 1, s - 2], | |
] | |
) | |
prev_state = [s, s - 1, s - 2] | |
logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]] | |
state_path[t, s] = prev_state[np.argmax(candidates)] | |
state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16) | |
candidates = np.array( | |
[logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]] | |
) | |
prev_state = [len(y_int) - 1, len(y_int) - 2] | |
state_seq[-1] = prev_state[np.argmax(candidates)] | |
for t in six.moves.range(lpz.size(0) - 2, -1, -1): | |
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] | |
output_state_seq = [] | |
for t in six.moves.range(0, lpz.size(0)): | |
output_state_seq.append(y_int[state_seq[t, 0]]) | |
return output_state_seq | |
def ctc_for(args, odim, reduce=True): | |
"""Returns the CTC module for the given args and output dimension | |
:param Namespace args: the program args | |
:param int odim : The output dimension | |
:param bool reduce : return the CTC loss in a scalar | |
:return: the corresponding CTC module | |
""" | |
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility | |
if num_encs == 1: | |
# compatible with single encoder asr mode | |
return CTC( | |
odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce | |
) | |
elif num_encs >= 1: | |
ctcs_list = torch.nn.ModuleList() | |
if args.share_ctc: | |
# use dropout_rate of the first encoder | |
ctc = CTC( | |
odim, | |
args.eprojs, | |
args.dropout_rate[0], | |
ctc_type=args.ctc_type, | |
reduce=reduce, | |
) | |
ctcs_list.append(ctc) | |
else: | |
for idx in range(num_encs): | |
ctc = CTC( | |
odim, | |
args.eprojs, | |
args.dropout_rate[idx], | |
ctc_type=args.ctc_type, | |
reduce=reduce, | |
) | |
ctcs_list.append(ctc) | |
return ctcs_list | |
else: | |
raise ValueError( | |
"Number of encoders needs to be more than one. {}".format(num_encs) | |
) | |