esb
/

NeMo
English
esb
sanchit-gandhi's picture
Add training scripts and weights
320c039
raw
history blame
6.13 kB
from dataclasses import dataclass
from typing import Optional
import torch
from nemo.collections.asr.models import EncDecRNNTBPEModel
from omegaconf import DictConfig
from transformers.utils import ModelOutput
@dataclass
class RNNTOutput(ModelOutput):
"""
Base class for RNNT outputs.
"""
loss: Optional[torch.FloatTensor] = None
wer: Optional[float] = None
wer_num: Optional[float] = None
wer_denom: Optional[float] = None
# Adapted from https://github.com/NVIDIA/NeMo/blob/66c7677cd4a68d78965d4905dd1febbf5385dff3/nemo/collections/asr/models/rnnt_bpe_models.py#L33
class RNNTBPEModel(EncDecRNNTBPEModel):
def __init__(self, cfg: DictConfig):
super().__init__(cfg=cfg, trainer=None)
def encoding(
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
):
"""
Forward pass of the acoustic model. Note that for RNNT Models, the forward pass of the model is a 3 step process,
and this method only performs the first step - forward of the acoustic model.
Please refer to the `forward` in order to see the full `forward` step for training - which
performs the forward of the acoustic model, the prediction network and then the joint network.
Finally, it computes the loss and possibly compute the detokenized text via the `decoding` step.
Please refer to the `validation_step` in order to see the full `forward` step for inference - which
performs the forward of the acoustic model, the prediction network and then the joint network.
Finally, it computes the decoded tokens via the `decoding` step and possibly compute the batch metrics.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
processed_signal: Tensor that represents a batch of processed audio signals,
of shape (B, D, T) that has undergone processing via some DALI preprocessor.
processed_signal_length: Vector of length B, that contains the individual lengths of the
processed audio sequences.
Returns:
A tuple of 2 elements -
1) The log probabilities tensor of shape [B, T, D].
2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
"""
has_input_signal = input_signal is not None and input_signal_length is not None
has_processed_signal = processed_signal is not None and processed_signal_length is not None
if (has_input_signal ^ has_processed_signal) is False:
raise ValueError(
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
" with ``processed_signal`` and ``processed_signal_len`` arguments."
)
if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
)
# Spec augment is not applied during evaluation/testing
if self.spec_augmentation is not None and self.training:
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)
encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
return encoded, encoded_len
def forward(self, input_ids, input_lengths=None, labels=None, label_lengths=None):
# encoding() only performs encoder forward
encoded, encoded_len = self.encoding(input_signal=input_ids, input_signal_length=input_lengths)
del input_ids
# During training, loss must be computed, so decoder forward is necessary
decoder, target_length, states = self.decoder(targets=labels, target_length=label_lengths)
# If experimental fused Joint-Loss-WER is not used
if not self.joint.fuse_loss_wer:
# Compute full joint and loss
joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)
loss_value = self.loss(
log_probs=joint, targets=labels, input_lengths=encoded_len, target_lengths=target_length
)
# Add auxiliary losses, if registered
loss_value = self.add_auxiliary_losses(loss_value)
wer = wer_num = wer_denom = None
if not self.training:
self.wer.update(encoded, encoded_len, labels, target_length)
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
else:
# If experimental fused Joint-Loss-WER is used
# Fused joint step
loss_value, wer, wer_num, wer_denom = self.joint(
encoder_outputs=encoded,
decoder_outputs=decoder,
encoder_lengths=encoded_len,
transcripts=labels,
transcript_lengths=label_lengths,
compute_wer=not self.training,
)
# Add auxiliary losses, if registered
loss_value = self.add_auxiliary_losses(loss_value)
return RNNTOutput(loss=loss_value, wer=wer, wer_num=wer_num, wer_denom=wer_denom)
def transcribe(self, input_ids, input_lengths=None, labels=None, label_lengths=None, return_hypotheses: bool = False, partial_hypothesis: Optional = None):
encoded, encoded_len = self.encoding(input_signal=input_ids, input_signal_length=input_lengths)
del input_ids
best_hyp, all_hyp = self.decoding.rnnt_decoder_predictions_tensor(
encoded,
encoded_len,
return_hypotheses=return_hypotheses,
partial_hypotheses=partial_hypothesis,
)
return best_hyp, all_hyp