|
import torch |
|
from speechbrain.inference.interfaces import Pretrained |
|
|
|
|
|
class CustomSLUDecoder(Pretrained): |
|
"""A end-to-end SLU model using hubert self-supervised encoder. |
|
|
|
The class can be used either to run only the encoder (encode()) to extract |
|
features or to run the entire model (decode()) to map the speech to its semantics. |
|
|
|
Example |
|
------- |
|
>>> from speechbrain.pretrained.interfaces import foreign_class |
|
>>> slu_model = foreign_class(source="speechbrain/slu-timers-and-such-direct-librispeech-asr", |
|
pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier") |
|
>>> slu_model.decode_file("samples/audio_samples/example6.wav") |
|
"{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}" |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.tokenizer = self.hparams.tokenizer |
|
|
|
def decode_file(self, path): |
|
"""Maps the given audio file to a string representing the |
|
semantic dictionary for the utterance. |
|
|
|
Arguments |
|
--------- |
|
path : str |
|
Path to audio file to decode. |
|
|
|
Returns |
|
------- |
|
str |
|
The predicted semantics. |
|
""" |
|
waveform = self.load_audio(path) |
|
waveform = waveform.to(self.device) |
|
|
|
batch = waveform.unsqueeze(0) |
|
rel_length = torch.tensor([1.0]) |
|
predicted_words, predicted_tokens = self.decode_batch(batch, rel_length) |
|
return predicted_words[0] |
|
|
|
def encode_batch(self, wavs): |
|
"""Encodes the input audio into a sequence of hidden states |
|
|
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels] or [batch, time] |
|
depending on the model. |
|
|
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
wavs = wavs.float() |
|
wavs = wavs.to(self.device) |
|
encoder_out = self.mods.hubert(wavs.detach()) |
|
return encoder_out |
|
|
|
def decode_batch(self, wavs, wav_lens): |
|
"""Maps the input audio to its semantics |
|
|
|
Arguments |
|
--------- |
|
wavs : torch.tensor |
|
Batch of waveforms [batch, time, channels] or [batch, time] |
|
depending on the model. |
|
wav_lens : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
|
|
Returns |
|
------- |
|
list |
|
Each waveform in the batch decoded. |
|
tensor |
|
Each predicted token id. |
|
""" |
|
with torch.no_grad(): |
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
|
encoder_out = self.encode_batch(wavs) |
|
predicted_tokens, scores, _, _ = self.mods.beam_searcher( |
|
encoder_out, wav_lens |
|
) |
|
predicted_words = [ |
|
self.tokenizer.decode_ids(token_seq) |
|
for token_seq in predicted_tokens |
|
] |
|
return predicted_words, predicted_tokens |
|
|
|
def forward(self, wavs, wav_lens): |
|
"""Runs full decoding - note: no gradients through decoding""" |
|
return self.decode_batch(wavs, wav_lens) |