Update custom_interface.py
Browse files- custom_interface.py +6 -3
custom_interface.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import torch
|
2 |
from speechbrain.inference.interfaces import Pretrained
|
|
|
3 |
|
4 |
|
5 |
class ASR(Pretrained):
|
@@ -20,13 +21,15 @@ class ASR(Pretrained):
|
|
20 |
# Output layer for seq2seq log-probabilities
|
21 |
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
|
22 |
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
|
23 |
-
print(predicted_words)
|
24 |
|
25 |
return predicted_words
|
26 |
|
27 |
|
28 |
def classify_file(self, path):
|
29 |
-
waveform = self.load_audio(path)
|
|
|
|
|
|
|
30 |
# Fake a batch:
|
31 |
batch = waveform.unsqueeze(0)
|
32 |
rel_length = torch.tensor([1.0])
|
@@ -35,4 +38,4 @@ class ASR(Pretrained):
|
|
35 |
return outputs
|
36 |
|
37 |
# def forward(self, wavs, wav_lens=None):
|
38 |
-
# return self.encode_batch(wavs=wavs, wav_lens=wav_lens)
|
|
|
1 |
import torch
|
2 |
from speechbrain.inference.interfaces import Pretrained
|
3 |
+
import librosa
|
4 |
|
5 |
|
6 |
class ASR(Pretrained):
|
|
|
21 |
# Output layer for seq2seq log-probabilities
|
22 |
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
|
23 |
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
|
|
|
24 |
|
25 |
return predicted_words
|
26 |
|
27 |
|
28 |
def classify_file(self, path):
|
29 |
+
# waveform = self.load_audio(path)
|
30 |
+
waveform, sr = librosa.load(path, sr=16000)
|
31 |
+
waveform = torch.tensor(waveform).unsqueeze(0)
|
32 |
+
|
33 |
# Fake a batch:
|
34 |
batch = waveform.unsqueeze(0)
|
35 |
rel_length = torch.tensor([1.0])
|
|
|
38 |
return outputs
|
39 |
|
40 |
# def forward(self, wavs, wav_lens=None):
|
41 |
+
# return self.encode_batch(wavs=wavs, wav_lens=wav_lens)
|