wav2vec2-aed-macedonian-asr / custom_interface.py
Porjaz's picture
Update custom_interface.py
b1c4d32 verified
raw
history blame
1.37 kB
import torch
from speechbrain.inference.interfaces import Pretrained
class ASR(Pretrained):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
wavs = wavs.to(self.device)
self.wav_lens = wav_lens.to(self.device)
# Forward pass
encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
# append
tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
embedded_tokens = self.mods.embedding(tokens_bos)
decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, self.wav_lens)
# Output layer for seq2seq log-probabilities
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
print(predicted_words)
return predicted_words
def classify_file(self, path):
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
outputs = self.encode_batch(batch, rel_length)
return outputs
# def forward(self, wavs, wav_lens=None):
# return self.encode_batch(wavs=wavs, wav_lens=wav_lens)