|
|
|
|
|
"""Transducer loss module.""" |
|
|
|
import torch |
|
|
|
|
|
class TransLoss(torch.nn.Module): |
|
"""Transducer loss module. |
|
|
|
Args: |
|
trans_type (str): type of transducer implementation to calculate loss. |
|
blank_id (int): blank symbol id |
|
""" |
|
|
|
def __init__(self, trans_type, blank_id): |
|
"""Construct an TransLoss object.""" |
|
super().__init__() |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if trans_type == "warp-transducer": |
|
from warprnnt_pytorch import RNNTLoss |
|
|
|
self.trans_loss = RNNTLoss(blank=blank_id) |
|
elif trans_type == "warp-rnnt": |
|
if device.type == "cuda": |
|
try: |
|
from warp_rnnt import rnnt_loss |
|
|
|
self.trans_loss = rnnt_loss |
|
except ImportError: |
|
raise ImportError( |
|
"warp-rnnt is not installed. Please re-setup" |
|
" espnet or use 'warp-transducer'" |
|
) |
|
else: |
|
raise ValueError("warp-rnnt is not supported in CPU mode") |
|
|
|
self.trans_type = trans_type |
|
self.blank_id = blank_id |
|
|
|
def forward(self, pred_pad, target, pred_len, target_len): |
|
"""Compute path-aware regularization transducer loss. |
|
|
|
Args: |
|
pred_pad (torch.Tensor): Batch of predicted sequences |
|
(batch, maxlen_in, maxlen_out+1, odim) |
|
target (torch.Tensor): Batch of target sequences (batch, maxlen_out) |
|
pred_len (torch.Tensor): batch of lengths of predicted sequences (batch) |
|
target_len (torch.tensor): batch of lengths of target sequences (batch) |
|
|
|
Returns: |
|
loss (torch.Tensor): transducer loss |
|
|
|
""" |
|
dtype = pred_pad.dtype |
|
if dtype != torch.float32: |
|
|
|
pred_pad = pred_pad.to(dtype=torch.float32) |
|
|
|
if self.trans_type == "warp-rnnt": |
|
log_probs = torch.log_softmax(pred_pad, dim=-1) |
|
|
|
loss = self.trans_loss( |
|
log_probs, |
|
target, |
|
pred_len, |
|
target_len, |
|
reduction="mean", |
|
blank=self.blank_id, |
|
gather=True, |
|
) |
|
else: |
|
loss = self.trans_loss(pred_pad, target, pred_len, target_len) |
|
loss = loss.to(dtype=dtype) |
|
|
|
return loss |
|
|