tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
2.5 kB
#!/usr/bin/env python3
"""Transducer loss module."""
import torch
class TransLoss(torch.nn.Module):
"""Transducer loss module.
Args:
trans_type (str): type of transducer implementation to calculate loss.
blank_id (int): blank symbol id
"""
def __init__(self, trans_type, blank_id):
"""Construct an TransLoss object."""
super().__init__()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if trans_type == "warp-transducer":
from warprnnt_pytorch import RNNTLoss
self.trans_loss = RNNTLoss(blank=blank_id)
elif trans_type == "warp-rnnt":
if device.type == "cuda":
try:
from warp_rnnt import rnnt_loss
self.trans_loss = rnnt_loss
except ImportError:
raise ImportError(
"warp-rnnt is not installed. Please re-setup"
" espnet or use 'warp-transducer'"
)
else:
raise ValueError("warp-rnnt is not supported in CPU mode")
self.trans_type = trans_type
self.blank_id = blank_id
def forward(self, pred_pad, target, pred_len, target_len):
"""Compute path-aware regularization transducer loss.
Args:
pred_pad (torch.Tensor): Batch of predicted sequences
(batch, maxlen_in, maxlen_out+1, odim)
target (torch.Tensor): Batch of target sequences (batch, maxlen_out)
pred_len (torch.Tensor): batch of lengths of predicted sequences (batch)
target_len (torch.tensor): batch of lengths of target sequences (batch)
Returns:
loss (torch.Tensor): transducer loss
"""
dtype = pred_pad.dtype
if dtype != torch.float32:
# warp-transducer and warp-rnnt only support float32
pred_pad = pred_pad.to(dtype=torch.float32)
if self.trans_type == "warp-rnnt":
log_probs = torch.log_softmax(pred_pad, dim=-1)
loss = self.trans_loss(
log_probs,
target,
pred_len,
target_len,
reduction="mean",
blank=self.blank_id,
gather=True,
)
else:
loss = self.trans_loss(pred_pad, target, pred_len, target_len)
loss = loss.to(dtype=dtype)
return loss