import torch from speechbrain.inference.interfaces import Pretrained class ASR(Pretrained): """A ready-to-use class for utterance-level classification (e.g, speaker-id, language-id, emotion recognition, keyword spotting, etc). The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model are defined in the yaml file. If you want to convert the predicted index into a corresponding text label, please provide the path of the label_encoder in a variable called 'lab_encoder_file' within the yaml. The class can be used either to run only the encoder (encode_batch()) to extract embeddings or to run a classification step (classify_batch()). ``` Example ------- >>> import torchaudio >>> from speechbrain.pretrained import EncoderClassifier >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> classifier = EncoderClassifier.from_hparams( ... source="speechbrain/spkrec-ecapa-voxceleb", ... savedir=tmpdir, ... ) >>> # Compute embeddings >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") >>> embeddings = classifier.encode_batch(signal) >>> # Classification >>> prediction = classifier .classify_batch(signal) """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def encode_batch(self, wavs, wav_lens=None, normalize=False): """Encodes the input audio into a single vector embedding. The waveforms should already be in the model's desired format. You can call: ``normalized = .normalizer(signal, sample_rate)`` to get a correctly converted signal in most cases. Arguments --------- wavs : torch.tensor Batch of waveforms [batch, time, channels] or [batch, time] depending on the model. Make sure the sample rate is fs=16000 Hz. wav_lens : torch.tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. normalize : bool If True, it normalizes the embeddings with the statistics contained in mean_var_norm_emb. Returns ------- torch.tensor The encoded batch """ batch = batch.to(self.device) sig, self.sig_lens = batch.sig tokens_bos, _ = batch.tokens_bos sig, self.sig_lens = sig.to(self.device), self.sig_lens.to(self.device) # Forward pass encoded_outputs = self.modules.encoder_w2v2(sig.detach()) embedded_tokens = self.modules.embedding(tokens_bos) decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens) # Output layer for seq2seq log-probabilities logits = self.modules.seq_lin(decoder_outputs) predictions = {"seq_logprobs": self.hparams.log_softmax(logits)} predictions["tokens"], _, _, _ = self.hparams.test_search(encoded_outputs, self.sig_lens) return predictions def classify_file(self, path): """Classifies the given audiofile into the given set of labels. Arguments --------- path : str Path to audio file to classify. Returns ------- out_prob The log posterior probabilities of each class ([batch, N_class]) score: It is the value of the log-posterior for the best class ([batch,]) index The indexes of the best class ([batch,]) text_lab: List with the text labels corresponding to the indexes. (label encoder should be provided). """ waveform = self.load_audio(path) # Fake a batch: batch = waveform.unsqueeze(0) rel_length = torch.tensor([1.0]) outputs = self.encode_batch(batch, rel_length)["tokens"] return outputs def forward(self, wavs, wav_lens=None): return self.encode_batch(wavs=wavs, wav_lens=wav_lens)