File size: 3,071 Bytes
24a35c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import numpy as np
import torchaudio
import torch
from torch import nn
from speechbrain.lobes.models.Xvector import Xvector
from speechbrain.lobes.features import Fbank
from speechbrain.processing.features import InputNormalization
class Extractor(nn.Module):
model_dict = [
"mean_var_norm",
"compute_features",
"embedding_model",
"mean_var_norm_emb",
]
def __init__(self, model_path, n_mels=24, device="cpu"):
super().__init__()
self.device = device
self.compute_features = Fbank(n_mels=n_mels)
self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False)
self.embedding_model = Xvector(
in_channels = n_mels,
activation = torch.nn.LeakyReLU,
tdnn_blocks = 5,
tdnn_channels = [512, 512, 512, 512, 1500],
tdnn_kernel_sizes = [5, 3, 3, 1, 1],
tdnn_dilations = [1, 2, 3, 1, 1],
lin_neurons = 512,
)
self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False)
for mod_name in self.model_dict:
filename = os.path.join(model_path, f"{mod_name}.ckpt")
module = getattr(self, mod_name)
if os.path.exists(filename):
if hasattr(module, "_load"):
print(f"Load: {filename}")
module._load(filename)
else:
print(f"Load State Dict: {filename}")
module.load_state_dict(torch.load(filename))
module.to(self.device)
@torch.no_grad()
def forward(self, wavs, wav_lens = None, normalize=False):
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
feats = self.compute_features(wavs)
feats = self.mean_var_norm(feats, wav_lens)
embeddings = self.embedding_model(feats, wav_lens)
if normalize:
embeddings = self.mean_var_norm_emb(
embeddings, torch.ones(embeddings.shape[0], device=self.device)
)
return embeddings
MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb"
signal, fs = torchaudio.load('audio.wav')
device = "cuda"
extractor = Extractor(MODEL_PATH, device=device)
for k, p in extractor.named_parameters():
p.requires_grad = False
extractor.eval()
embeddings_x = extractor(signal).cpu().squeeze()
# Tracing
traced_model = torch.jit.trace(extractor, signal)
torch.jit.save(traced_model, f"model_{device}.pt")
embeddings_t = traced_model(signal).squeeze()
print(embeddings_t)
model = torch.jit.load(f"model_{device}.pt")
emb_m = model(signal).squeeze()
print(emb_m)
|