Spaces:
Runtime error
Runtime error
kevinwang676
commited on
Commit
•
dd6f15d
1
Parent(s):
bf21026
Upload 20 files
Browse files- speaker_encoder/__init__.py +1 -0
- speaker_encoder/audio.py +107 -0
- speaker_encoder/ckpt/__init__.py +1 -0
- speaker_encoder/compute_embed.py +40 -0
- speaker_encoder/config.py +45 -0
- speaker_encoder/data_objects/__init__.py +2 -0
- speaker_encoder/data_objects/random_cycler.py +37 -0
- speaker_encoder/data_objects/speaker.py +40 -0
- speaker_encoder/data_objects/speaker_batch.py +12 -0
- speaker_encoder/data_objects/speaker_verification_dataset.py +56 -0
- speaker_encoder/data_objects/utterance.py +26 -0
- speaker_encoder/hparams.py +31 -0
- speaker_encoder/inference.py +177 -0
- speaker_encoder/model.py +135 -0
- speaker_encoder/params_data.py +29 -0
- speaker_encoder/params_model.py +11 -0
- speaker_encoder/preprocess.py +285 -0
- speaker_encoder/train.py +125 -0
- speaker_encoder/visualizations.py +178 -0
- speaker_encoder/voice_encoder.py +173 -0
speaker_encoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
speaker_encoder/audio.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage.morphology import binary_dilation
|
2 |
+
from speaker_encoder.params_data import *
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
import numpy as np
|
6 |
+
import webrtcvad
|
7 |
+
import librosa
|
8 |
+
import struct
|
9 |
+
|
10 |
+
int16_max = (2 ** 15) - 1
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
14 |
+
source_sr: Optional[int] = None):
|
15 |
+
"""
|
16 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
17 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
18 |
+
|
19 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
20 |
+
just .wav), either the waveform as a numpy array of floats.
|
21 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
22 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
23 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
24 |
+
this argument will be ignored.
|
25 |
+
"""
|
26 |
+
# Load the wav from disk if needed
|
27 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
28 |
+
wav, source_sr = librosa.load(fpath_or_wav, sr=None)
|
29 |
+
else:
|
30 |
+
wav = fpath_or_wav
|
31 |
+
|
32 |
+
# Resample the wav if needed
|
33 |
+
if source_sr is not None and source_sr != sampling_rate:
|
34 |
+
wav = librosa.resample(wav, source_sr, sampling_rate)
|
35 |
+
|
36 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
37 |
+
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
38 |
+
wav = trim_long_silences(wav)
|
39 |
+
|
40 |
+
return wav
|
41 |
+
|
42 |
+
|
43 |
+
def wav_to_mel_spectrogram(wav):
|
44 |
+
"""
|
45 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
46 |
+
Note: this not a log-mel spectrogram.
|
47 |
+
"""
|
48 |
+
frames = librosa.feature.melspectrogram(
|
49 |
+
y=wav,
|
50 |
+
sr=sampling_rate,
|
51 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
52 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
53 |
+
n_mels=mel_n_channels
|
54 |
+
)
|
55 |
+
return frames.astype(np.float32).T
|
56 |
+
|
57 |
+
|
58 |
+
def trim_long_silences(wav):
|
59 |
+
"""
|
60 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
61 |
+
threshold determined by the VAD parameters in params.py.
|
62 |
+
|
63 |
+
:param wav: the raw waveform as a numpy array of floats
|
64 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
65 |
+
"""
|
66 |
+
# Compute the voice detection window size
|
67 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
68 |
+
|
69 |
+
# Trim the end of the audio to have a multiple of the window size
|
70 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
71 |
+
|
72 |
+
# Convert the float waveform to 16-bit mono PCM
|
73 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
74 |
+
|
75 |
+
# Perform voice activation detection
|
76 |
+
voice_flags = []
|
77 |
+
vad = webrtcvad.Vad(mode=3)
|
78 |
+
for window_start in range(0, len(wav), samples_per_window):
|
79 |
+
window_end = window_start + samples_per_window
|
80 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
81 |
+
sample_rate=sampling_rate))
|
82 |
+
voice_flags = np.array(voice_flags)
|
83 |
+
|
84 |
+
# Smooth the voice detection with a moving average
|
85 |
+
def moving_average(array, width):
|
86 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
87 |
+
ret = np.cumsum(array_padded, dtype=float)
|
88 |
+
ret[width:] = ret[width:] - ret[:-width]
|
89 |
+
return ret[width - 1:] / width
|
90 |
+
|
91 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
92 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
93 |
+
|
94 |
+
# Dilate the voiced regions
|
95 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
96 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
97 |
+
|
98 |
+
return wav[audio_mask == True]
|
99 |
+
|
100 |
+
|
101 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
102 |
+
if increase_only and decrease_only:
|
103 |
+
raise ValueError("Both increase only and decrease only are set")
|
104 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
105 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
106 |
+
return wav
|
107 |
+
return wav * (10 ** (dBFS_change / 20))
|
speaker_encoder/ckpt/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
speaker_encoder/compute_embed.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder import inference as encoder
|
2 |
+
from multiprocessing.pool import Pool
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
# from utils import logmmse
|
6 |
+
# from tqdm import tqdm
|
7 |
+
# import numpy as np
|
8 |
+
# import librosa
|
9 |
+
|
10 |
+
|
11 |
+
def embed_utterance(fpaths, encoder_model_fpath):
|
12 |
+
if not encoder.is_loaded():
|
13 |
+
encoder.load_model(encoder_model_fpath)
|
14 |
+
|
15 |
+
# Compute the speaker embedding of the utterance
|
16 |
+
wav_fpath, embed_fpath = fpaths
|
17 |
+
wav = np.load(wav_fpath)
|
18 |
+
wav = encoder.preprocess_wav(wav)
|
19 |
+
embed = encoder.embed_utterance(wav)
|
20 |
+
np.save(embed_fpath, embed, allow_pickle=False)
|
21 |
+
|
22 |
+
|
23 |
+
def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
|
24 |
+
|
25 |
+
wav_dir = outdir_root.joinpath("audio")
|
26 |
+
metadata_fpath = synthesizer_root.joinpath("train.txt")
|
27 |
+
assert wav_dir.exists() and metadata_fpath.exists()
|
28 |
+
embed_dir = synthesizer_root.joinpath("embeds")
|
29 |
+
embed_dir.mkdir(exist_ok=True)
|
30 |
+
|
31 |
+
# Gather the input wave filepath and the target output embed filepath
|
32 |
+
with metadata_fpath.open("r") as metadata_file:
|
33 |
+
metadata = [line.split("|") for line in metadata_file]
|
34 |
+
fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
|
35 |
+
|
36 |
+
# TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
|
37 |
+
# Embed the utterances in separate threads
|
38 |
+
func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
|
39 |
+
job = Pool(n_processes).imap(func, fpaths)
|
40 |
+
list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
|
speaker_encoder/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librispeech_datasets = {
|
2 |
+
"train": {
|
3 |
+
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
4 |
+
"other": ["LibriSpeech/train-other-500"]
|
5 |
+
},
|
6 |
+
"test": {
|
7 |
+
"clean": ["LibriSpeech/test-clean"],
|
8 |
+
"other": ["LibriSpeech/test-other"]
|
9 |
+
},
|
10 |
+
"dev": {
|
11 |
+
"clean": ["LibriSpeech/dev-clean"],
|
12 |
+
"other": ["LibriSpeech/dev-other"]
|
13 |
+
},
|
14 |
+
}
|
15 |
+
libritts_datasets = {
|
16 |
+
"train": {
|
17 |
+
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
18 |
+
"other": ["LibriTTS/train-other-500"]
|
19 |
+
},
|
20 |
+
"test": {
|
21 |
+
"clean": ["LibriTTS/test-clean"],
|
22 |
+
"other": ["LibriTTS/test-other"]
|
23 |
+
},
|
24 |
+
"dev": {
|
25 |
+
"clean": ["LibriTTS/dev-clean"],
|
26 |
+
"other": ["LibriTTS/dev-other"]
|
27 |
+
},
|
28 |
+
}
|
29 |
+
voxceleb_datasets = {
|
30 |
+
"voxceleb1" : {
|
31 |
+
"train": ["VoxCeleb1/wav"],
|
32 |
+
"test": ["VoxCeleb1/test_wav"]
|
33 |
+
},
|
34 |
+
"voxceleb2" : {
|
35 |
+
"train": ["VoxCeleb2/dev/aac"],
|
36 |
+
"test": ["VoxCeleb2/test_wav"]
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
other_datasets = [
|
41 |
+
"LJSpeech-1.1",
|
42 |
+
"VCTK-Corpus/wav48",
|
43 |
+
]
|
44 |
+
|
45 |
+
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
speaker_encoder/data_objects/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
speaker_encoder/data_objects/random_cycler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
class RandomCycler:
|
4 |
+
"""
|
5 |
+
Creates an internal copy of a sequence and allows access to its items in a constrained random
|
6 |
+
order. For a source sequence of n items and one or several consecutive queries of a total
|
7 |
+
of m items, the following guarantees hold (one implies the other):
|
8 |
+
- Each item will be returned between m // n and ((m - 1) // n) + 1 times.
|
9 |
+
- Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, source):
|
13 |
+
if len(source) == 0:
|
14 |
+
raise Exception("Can't create RandomCycler from an empty collection")
|
15 |
+
self.all_items = list(source)
|
16 |
+
self.next_items = []
|
17 |
+
|
18 |
+
def sample(self, count: int):
|
19 |
+
shuffle = lambda l: random.sample(l, len(l))
|
20 |
+
|
21 |
+
out = []
|
22 |
+
while count > 0:
|
23 |
+
if count >= len(self.all_items):
|
24 |
+
out.extend(shuffle(list(self.all_items)))
|
25 |
+
count -= len(self.all_items)
|
26 |
+
continue
|
27 |
+
n = min(count, len(self.next_items))
|
28 |
+
out.extend(self.next_items[:n])
|
29 |
+
count -= n
|
30 |
+
self.next_items = self.next_items[n:]
|
31 |
+
if len(self.next_items) == 0:
|
32 |
+
self.next_items = shuffle(list(self.all_items))
|
33 |
+
return out
|
34 |
+
|
35 |
+
def __next__(self):
|
36 |
+
return self.sample(1)[0]
|
37 |
+
|
speaker_encoder/data_objects/speaker.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from speaker_encoder.data_objects.utterance import Utterance
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Contains the set of utterances of a single speaker
|
6 |
+
class Speaker:
|
7 |
+
def __init__(self, root: Path):
|
8 |
+
self.root = root
|
9 |
+
self.name = root.name
|
10 |
+
self.utterances = None
|
11 |
+
self.utterance_cycler = None
|
12 |
+
|
13 |
+
def _load_utterances(self):
|
14 |
+
with self.root.joinpath("_sources.txt").open("r") as sources_file:
|
15 |
+
sources = [l.split(",") for l in sources_file]
|
16 |
+
sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
|
17 |
+
self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
|
18 |
+
self.utterance_cycler = RandomCycler(self.utterances)
|
19 |
+
|
20 |
+
def random_partial(self, count, n_frames):
|
21 |
+
"""
|
22 |
+
Samples a batch of <count> unique partial utterances from the disk in a way that all
|
23 |
+
utterances come up at least once every two cycles and in a random order every time.
|
24 |
+
|
25 |
+
:param count: The number of partial utterances to sample from the set of utterances from
|
26 |
+
that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
|
27 |
+
the number of utterances available.
|
28 |
+
:param n_frames: The number of frames in the partial utterance.
|
29 |
+
:return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
|
30 |
+
frames are the frames of the partial utterances and range is the range of the partial
|
31 |
+
utterance with regard to the complete utterance.
|
32 |
+
"""
|
33 |
+
if self.utterances is None:
|
34 |
+
self._load_utterances()
|
35 |
+
|
36 |
+
utterances = self.utterance_cycler.sample(count)
|
37 |
+
|
38 |
+
a = [(u,) + u.random_partial(n_frames) for u in utterances]
|
39 |
+
|
40 |
+
return a
|
speaker_encoder/data_objects/speaker_batch.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List
|
3 |
+
from speaker_encoder.data_objects.speaker import Speaker
|
4 |
+
|
5 |
+
class SpeakerBatch:
|
6 |
+
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
7 |
+
self.speakers = speakers
|
8 |
+
self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
|
9 |
+
|
10 |
+
# Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
|
11 |
+
# 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
|
12 |
+
self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
|
speaker_encoder/data_objects/speaker_verification_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
|
3 |
+
from speaker_encoder.data_objects.speaker import Speaker
|
4 |
+
from speaker_encoder.params_data import partials_n_frames
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# TODO: improve with a pool of speakers for data efficiency
|
9 |
+
|
10 |
+
class SpeakerVerificationDataset(Dataset):
|
11 |
+
def __init__(self, datasets_root: Path):
|
12 |
+
self.root = datasets_root
|
13 |
+
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
14 |
+
if len(speaker_dirs) == 0:
|
15 |
+
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
16 |
+
"containing all preprocessed speaker directories.")
|
17 |
+
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
18 |
+
self.speaker_cycler = RandomCycler(self.speakers)
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return int(1e10)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return next(self.speaker_cycler)
|
25 |
+
|
26 |
+
def get_logs(self):
|
27 |
+
log_string = ""
|
28 |
+
for log_fpath in self.root.glob("*.txt"):
|
29 |
+
with log_fpath.open("r") as log_file:
|
30 |
+
log_string += "".join(log_file.readlines())
|
31 |
+
return log_string
|
32 |
+
|
33 |
+
|
34 |
+
class SpeakerVerificationDataLoader(DataLoader):
|
35 |
+
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
|
36 |
+
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
37 |
+
worker_init_fn=None):
|
38 |
+
self.utterances_per_speaker = utterances_per_speaker
|
39 |
+
|
40 |
+
super().__init__(
|
41 |
+
dataset=dataset,
|
42 |
+
batch_size=speakers_per_batch,
|
43 |
+
shuffle=False,
|
44 |
+
sampler=sampler,
|
45 |
+
batch_sampler=batch_sampler,
|
46 |
+
num_workers=num_workers,
|
47 |
+
collate_fn=self.collate,
|
48 |
+
pin_memory=pin_memory,
|
49 |
+
drop_last=False,
|
50 |
+
timeout=timeout,
|
51 |
+
worker_init_fn=worker_init_fn
|
52 |
+
)
|
53 |
+
|
54 |
+
def collate(self, speakers):
|
55 |
+
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
56 |
+
|
speaker_encoder/data_objects/utterance.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Utterance:
|
5 |
+
def __init__(self, frames_fpath, wave_fpath):
|
6 |
+
self.frames_fpath = frames_fpath
|
7 |
+
self.wave_fpath = wave_fpath
|
8 |
+
|
9 |
+
def get_frames(self):
|
10 |
+
return np.load(self.frames_fpath)
|
11 |
+
|
12 |
+
def random_partial(self, n_frames):
|
13 |
+
"""
|
14 |
+
Crops the frames into a partial utterance of n_frames
|
15 |
+
|
16 |
+
:param n_frames: The number of frames of the partial utterance
|
17 |
+
:return: the partial utterance frames and a tuple indicating the start and end of the
|
18 |
+
partial utterance in the complete utterance.
|
19 |
+
"""
|
20 |
+
frames = self.get_frames()
|
21 |
+
if frames.shape[0] == n_frames:
|
22 |
+
start = 0
|
23 |
+
else:
|
24 |
+
start = np.random.randint(0, frames.shape[0] - n_frames)
|
25 |
+
end = start + n_frames
|
26 |
+
return frames[start:end], (start, end)
|
speaker_encoder/hparams.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Mel-filterbank
|
2 |
+
mel_window_length = 25 # In milliseconds
|
3 |
+
mel_window_step = 10 # In milliseconds
|
4 |
+
mel_n_channels = 40
|
5 |
+
|
6 |
+
|
7 |
+
## Audio
|
8 |
+
sampling_rate = 16000
|
9 |
+
# Number of spectrogram frames in a partial utterance
|
10 |
+
partials_n_frames = 160 # 1600 ms
|
11 |
+
|
12 |
+
|
13 |
+
## Voice Activation Detection
|
14 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
15 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
16 |
+
vad_window_length = 30 # In milliseconds
|
17 |
+
# Number of frames to average together when performing the moving average smoothing.
|
18 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
19 |
+
vad_moving_average_width = 8
|
20 |
+
# Maximum number of consecutive silent frames a segment can have.
|
21 |
+
vad_max_silence_length = 6
|
22 |
+
|
23 |
+
|
24 |
+
## Audio volume normalization
|
25 |
+
audio_norm_target_dBFS = -30
|
26 |
+
|
27 |
+
|
28 |
+
## Model parameters
|
29 |
+
model_hidden_size = 256
|
30 |
+
model_embedding_size = 256
|
31 |
+
model_num_layers = 3
|
speaker_encoder/inference.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.params_data import *
|
2 |
+
from speaker_encoder.model import SpeakerEncoder
|
3 |
+
from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
|
4 |
+
from matplotlib import cm
|
5 |
+
from speaker_encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
_model = None # type: SpeakerEncoder
|
12 |
+
_device = None # type: torch.device
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(weights_fpath: Path, device=None):
|
16 |
+
"""
|
17 |
+
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
18 |
+
first call to embed_frames() with the default weights file.
|
19 |
+
|
20 |
+
:param weights_fpath: the path to saved model weights.
|
21 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
22 |
+
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
23 |
+
If None, will default to your GPU if it"s available, otherwise your CPU.
|
24 |
+
"""
|
25 |
+
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
26 |
+
# was saved on. Worth investigating.
|
27 |
+
global _model, _device
|
28 |
+
if device is None:
|
29 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
elif isinstance(device, str):
|
31 |
+
_device = torch.device(device)
|
32 |
+
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
33 |
+
checkpoint = torch.load(weights_fpath)
|
34 |
+
_model.load_state_dict(checkpoint["model_state"])
|
35 |
+
_model.eval()
|
36 |
+
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
37 |
+
|
38 |
+
|
39 |
+
def is_loaded():
|
40 |
+
return _model is not None
|
41 |
+
|
42 |
+
|
43 |
+
def embed_frames_batch(frames_batch):
|
44 |
+
"""
|
45 |
+
Computes embeddings for a batch of mel spectrogram.
|
46 |
+
|
47 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
48 |
+
(batch_size, n_frames, n_channels)
|
49 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
50 |
+
"""
|
51 |
+
if _model is None:
|
52 |
+
raise Exception("Model was not loaded. Call load_model() before inference.")
|
53 |
+
|
54 |
+
frames = torch.from_numpy(frames_batch).to(_device)
|
55 |
+
embed = _model.forward(frames).detach().cpu().numpy()
|
56 |
+
return embed
|
57 |
+
|
58 |
+
|
59 |
+
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
60 |
+
min_pad_coverage=0.75, overlap=0.5):
|
61 |
+
"""
|
62 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
63 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
64 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
65 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
66 |
+
defined in params_data.py.
|
67 |
+
|
68 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
69 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
70 |
+
|
71 |
+
:param n_samples: the number of samples in the waveform
|
72 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
73 |
+
utterance
|
74 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
75 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
76 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
77 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
78 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
79 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
80 |
+
utterances are entirely disjoint.
|
81 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
82 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
83 |
+
utterances.
|
84 |
+
"""
|
85 |
+
assert 0 <= overlap < 1
|
86 |
+
assert 0 < min_pad_coverage <= 1
|
87 |
+
|
88 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
89 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
90 |
+
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
91 |
+
|
92 |
+
# Compute the slices
|
93 |
+
wav_slices, mel_slices = [], []
|
94 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
95 |
+
for i in range(0, steps, frame_step):
|
96 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
97 |
+
wav_range = mel_range * samples_per_frame
|
98 |
+
mel_slices.append(slice(*mel_range))
|
99 |
+
wav_slices.append(slice(*wav_range))
|
100 |
+
|
101 |
+
# Evaluate whether extra padding is warranted or not
|
102 |
+
last_wav_range = wav_slices[-1]
|
103 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
104 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
105 |
+
mel_slices = mel_slices[:-1]
|
106 |
+
wav_slices = wav_slices[:-1]
|
107 |
+
|
108 |
+
return wav_slices, mel_slices
|
109 |
+
|
110 |
+
|
111 |
+
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
112 |
+
"""
|
113 |
+
Computes an embedding for a single utterance.
|
114 |
+
|
115 |
+
# TODO: handle multiple wavs to benefit from batching on GPU
|
116 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
117 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
118 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
119 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
120 |
+
spectogram to the network.
|
121 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
122 |
+
wav slices that correspond to the partial embeddings.
|
123 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
124 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
125 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
126 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
127 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
128 |
+
instead.
|
129 |
+
"""
|
130 |
+
# Process the entire utterance if not using partials
|
131 |
+
if not using_partials:
|
132 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
133 |
+
embed = embed_frames_batch(frames[None, ...])[0]
|
134 |
+
if return_partials:
|
135 |
+
return embed, None, None
|
136 |
+
return embed
|
137 |
+
|
138 |
+
# Compute where to split the utterance into partials and pad if necessary
|
139 |
+
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
140 |
+
max_wave_length = wave_slices[-1].stop
|
141 |
+
if max_wave_length >= len(wav):
|
142 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
143 |
+
|
144 |
+
# Split the utterance into partials
|
145 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
146 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
147 |
+
partial_embeds = embed_frames_batch(frames_batch)
|
148 |
+
|
149 |
+
# Compute the utterance embedding from the partial embeddings
|
150 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
151 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
152 |
+
|
153 |
+
if return_partials:
|
154 |
+
return embed, partial_embeds, wave_slices
|
155 |
+
return embed
|
156 |
+
|
157 |
+
|
158 |
+
def embed_speaker(wavs, **kwargs):
|
159 |
+
raise NotImplemented()
|
160 |
+
|
161 |
+
|
162 |
+
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
163 |
+
if ax is None:
|
164 |
+
ax = plt.gca()
|
165 |
+
|
166 |
+
if shape is None:
|
167 |
+
height = int(np.sqrt(len(embed)))
|
168 |
+
shape = (height, -1)
|
169 |
+
embed = embed.reshape(shape)
|
170 |
+
|
171 |
+
cmap = cm.get_cmap()
|
172 |
+
mappable = ax.imshow(embed, cmap=cmap)
|
173 |
+
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
174 |
+
cbar.set_clim(*color_range)
|
175 |
+
|
176 |
+
ax.set_xticks([]), ax.set_yticks([])
|
177 |
+
ax.set_title(title)
|
speaker_encoder/model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.params_model import *
|
2 |
+
from speaker_encoder.params_data import *
|
3 |
+
from scipy.interpolate import interp1d
|
4 |
+
from sklearn.metrics import roc_curve
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
from scipy.optimize import brentq
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class SpeakerEncoder(nn.Module):
|
13 |
+
def __init__(self, device, loss_device):
|
14 |
+
super().__init__()
|
15 |
+
self.loss_device = loss_device
|
16 |
+
|
17 |
+
# Network defition
|
18 |
+
self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
|
19 |
+
hidden_size=model_hidden_size, # 256
|
20 |
+
num_layers=model_num_layers, # 3
|
21 |
+
batch_first=True).to(device)
|
22 |
+
self.linear = nn.Linear(in_features=model_hidden_size,
|
23 |
+
out_features=model_embedding_size).to(device)
|
24 |
+
self.relu = torch.nn.ReLU().to(device)
|
25 |
+
|
26 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
27 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
28 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
29 |
+
|
30 |
+
# Loss
|
31 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
32 |
+
|
33 |
+
def do_gradient_ops(self):
|
34 |
+
# Gradient scale
|
35 |
+
self.similarity_weight.grad *= 0.01
|
36 |
+
self.similarity_bias.grad *= 0.01
|
37 |
+
|
38 |
+
# Gradient clipping
|
39 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
40 |
+
|
41 |
+
def forward(self, utterances, hidden_init=None):
|
42 |
+
"""
|
43 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
44 |
+
|
45 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
46 |
+
(batch_size, n_frames, n_channels)
|
47 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
48 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
49 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
50 |
+
"""
|
51 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
52 |
+
# and the final cell state.
|
53 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
54 |
+
|
55 |
+
# We take only the hidden state of the last layer
|
56 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
57 |
+
|
58 |
+
# L2-normalize it
|
59 |
+
embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
60 |
+
|
61 |
+
return embeds
|
62 |
+
|
63 |
+
def similarity_matrix(self, embeds):
|
64 |
+
"""
|
65 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
66 |
+
|
67 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
68 |
+
utterances_per_speaker, embedding_size)
|
69 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
70 |
+
utterances_per_speaker, speakers_per_batch)
|
71 |
+
"""
|
72 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
73 |
+
|
74 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
75 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
76 |
+
centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
|
77 |
+
|
78 |
+
# Exclusive centroids (1 per utterance)
|
79 |
+
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
80 |
+
centroids_excl /= (utterances_per_speaker - 1)
|
81 |
+
centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
|
82 |
+
|
83 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
84 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
85 |
+
# We vectorize the computation for efficiency.
|
86 |
+
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
87 |
+
speakers_per_batch).to(self.loss_device)
|
88 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
89 |
+
for j in range(speakers_per_batch):
|
90 |
+
mask = np.where(mask_matrix[j])[0]
|
91 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
92 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
93 |
+
|
94 |
+
## Even more vectorized version (slower maybe because of transpose)
|
95 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
96 |
+
# ).to(self.loss_device)
|
97 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
98 |
+
# mask = np.where(1 - eye)
|
99 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
100 |
+
# mask = np.where(eye)
|
101 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
102 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
103 |
+
|
104 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
105 |
+
return sim_matrix
|
106 |
+
|
107 |
+
def loss(self, embeds):
|
108 |
+
"""
|
109 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
110 |
+
|
111 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
112 |
+
utterances_per_speaker, embedding_size)
|
113 |
+
:return: the loss and the EER for this batch of embeddings.
|
114 |
+
"""
|
115 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
116 |
+
|
117 |
+
# Loss
|
118 |
+
sim_matrix = self.similarity_matrix(embeds)
|
119 |
+
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
120 |
+
speakers_per_batch))
|
121 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
122 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
123 |
+
loss = self.loss_fn(sim_matrix, target)
|
124 |
+
|
125 |
+
# EER (not backpropagated)
|
126 |
+
with torch.no_grad():
|
127 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
128 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
129 |
+
preds = sim_matrix.detach().cpu().numpy()
|
130 |
+
|
131 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
132 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
133 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
134 |
+
|
135 |
+
return loss, eer
|
speaker_encoder/params_data.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Mel-filterbank
|
3 |
+
mel_window_length = 25 # In milliseconds
|
4 |
+
mel_window_step = 10 # In milliseconds
|
5 |
+
mel_n_channels = 40
|
6 |
+
|
7 |
+
|
8 |
+
## Audio
|
9 |
+
sampling_rate = 16000
|
10 |
+
# Number of spectrogram frames in a partial utterance
|
11 |
+
partials_n_frames = 160 # 1600 ms
|
12 |
+
# Number of spectrogram frames at inference
|
13 |
+
inference_n_frames = 80 # 800 ms
|
14 |
+
|
15 |
+
|
16 |
+
## Voice Activation Detection
|
17 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
18 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
19 |
+
vad_window_length = 30 # In milliseconds
|
20 |
+
# Number of frames to average together when performing the moving average smoothing.
|
21 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
22 |
+
vad_moving_average_width = 8
|
23 |
+
# Maximum number of consecutive silent frames a segment can have.
|
24 |
+
vad_max_silence_length = 6
|
25 |
+
|
26 |
+
|
27 |
+
## Audio volume normalization
|
28 |
+
audio_norm_target_dBFS = -30
|
29 |
+
|
speaker_encoder/params_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Model parameters
|
3 |
+
model_hidden_size = 256
|
4 |
+
model_embedding_size = 256
|
5 |
+
model_num_layers = 3
|
6 |
+
|
7 |
+
|
8 |
+
## Training parameters
|
9 |
+
learning_rate_init = 1e-4
|
10 |
+
speakers_per_batch = 64
|
11 |
+
utterances_per_speaker = 10
|
speaker_encoder/preprocess.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocess.pool import ThreadPool
|
2 |
+
from speaker_encoder.params_data import *
|
3 |
+
from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
|
4 |
+
from datetime import datetime
|
5 |
+
from speaker_encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetLog:
|
12 |
+
"""
|
13 |
+
Registers metadata about the dataset in a text file.
|
14 |
+
"""
|
15 |
+
def __init__(self, root, name):
|
16 |
+
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
17 |
+
self.sample_data = dict()
|
18 |
+
|
19 |
+
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
20 |
+
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
21 |
+
self.write_line("-----")
|
22 |
+
self._log_params()
|
23 |
+
|
24 |
+
def _log_params(self):
|
25 |
+
from speaker_encoder import params_data
|
26 |
+
self.write_line("Parameter values:")
|
27 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
28 |
+
value = getattr(params_data, param_name)
|
29 |
+
self.write_line("\t%s: %s" % (param_name, value))
|
30 |
+
self.write_line("-----")
|
31 |
+
|
32 |
+
def write_line(self, line):
|
33 |
+
self.text_file.write("%s\n" % line)
|
34 |
+
|
35 |
+
def add_sample(self, **kwargs):
|
36 |
+
for param_name, value in kwargs.items():
|
37 |
+
if not param_name in self.sample_data:
|
38 |
+
self.sample_data[param_name] = []
|
39 |
+
self.sample_data[param_name].append(value)
|
40 |
+
|
41 |
+
def finalize(self):
|
42 |
+
self.write_line("Statistics:")
|
43 |
+
for param_name, values in self.sample_data.items():
|
44 |
+
self.write_line("\t%s:" % param_name)
|
45 |
+
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
46 |
+
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
47 |
+
self.write_line("-----")
|
48 |
+
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
49 |
+
self.write_line("Finished on %s" % end_time)
|
50 |
+
self.text_file.close()
|
51 |
+
|
52 |
+
|
53 |
+
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
54 |
+
dataset_root = datasets_root.joinpath(dataset_name)
|
55 |
+
if not dataset_root.exists():
|
56 |
+
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
57 |
+
return None, None
|
58 |
+
return dataset_root, DatasetLog(out_dir, dataset_name)
|
59 |
+
|
60 |
+
|
61 |
+
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
62 |
+
skip_existing, logger):
|
63 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
64 |
+
|
65 |
+
# Function to preprocess utterances for one speaker
|
66 |
+
def preprocess_speaker(speaker_dir: Path):
|
67 |
+
# Give a name to the speaker that includes its dataset
|
68 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
69 |
+
|
70 |
+
# Create an output directory with that name, as well as a txt file containing a
|
71 |
+
# reference to each source file.
|
72 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
73 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
74 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
75 |
+
|
76 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
77 |
+
# there already is a sources file.
|
78 |
+
if sources_fpath.exists():
|
79 |
+
try:
|
80 |
+
with sources_fpath.open("r") as sources_file:
|
81 |
+
existing_fnames = {line.split(",")[0] for line in sources_file}
|
82 |
+
except:
|
83 |
+
existing_fnames = {}
|
84 |
+
else:
|
85 |
+
existing_fnames = {}
|
86 |
+
|
87 |
+
# Gather all audio files for that speaker recursively
|
88 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
89 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
90 |
+
# Check if the target output file already exists
|
91 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
92 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
93 |
+
if skip_existing and out_fname in existing_fnames:
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Load and preprocess the waveform
|
97 |
+
wav = audio.preprocess_wav(in_fpath)
|
98 |
+
if len(wav) == 0:
|
99 |
+
continue
|
100 |
+
|
101 |
+
# Create the mel spectrogram, discard those that are too short
|
102 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
103 |
+
if len(frames) < partials_n_frames:
|
104 |
+
continue
|
105 |
+
|
106 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
107 |
+
np.save(out_fpath, frames)
|
108 |
+
logger.add_sample(duration=len(wav) / sampling_rate)
|
109 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
110 |
+
|
111 |
+
sources_file.close()
|
112 |
+
|
113 |
+
# Process the utterances for each speaker
|
114 |
+
with ThreadPool(8) as pool:
|
115 |
+
list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
116 |
+
unit="speakers"))
|
117 |
+
logger.finalize()
|
118 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
119 |
+
|
120 |
+
|
121 |
+
# Function to preprocess utterances for one speaker
|
122 |
+
def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
|
123 |
+
# Give a name to the speaker that includes its dataset
|
124 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
125 |
+
|
126 |
+
# Create an output directory with that name, as well as a txt file containing a
|
127 |
+
# reference to each source file.
|
128 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
129 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
130 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
131 |
+
|
132 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
133 |
+
# there already is a sources file.
|
134 |
+
# if sources_fpath.exists():
|
135 |
+
# try:
|
136 |
+
# with sources_fpath.open("r") as sources_file:
|
137 |
+
# existing_fnames = {line.split(",")[0] for line in sources_file}
|
138 |
+
# except:
|
139 |
+
# existing_fnames = {}
|
140 |
+
# else:
|
141 |
+
# existing_fnames = {}
|
142 |
+
existing_fnames = {}
|
143 |
+
# Gather all audio files for that speaker recursively
|
144 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
145 |
+
|
146 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
147 |
+
# Check if the target output file already exists
|
148 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
149 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
150 |
+
if skip_existing and out_fname in existing_fnames:
|
151 |
+
continue
|
152 |
+
|
153 |
+
# Load and preprocess the waveform
|
154 |
+
wav = audio.preprocess_wav(in_fpath)
|
155 |
+
if len(wav) == 0:
|
156 |
+
continue
|
157 |
+
|
158 |
+
# Create the mel spectrogram, discard those that are too short
|
159 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
160 |
+
if len(frames) < partials_n_frames:
|
161 |
+
continue
|
162 |
+
|
163 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
164 |
+
np.save(out_fpath, frames)
|
165 |
+
# logger.add_sample(duration=len(wav) / sampling_rate)
|
166 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
167 |
+
|
168 |
+
sources_file.close()
|
169 |
+
return len(wav)
|
170 |
+
|
171 |
+
def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
172 |
+
skip_existing, logger):
|
173 |
+
# from multiprocessing import Pool, cpu_count
|
174 |
+
from pathos.multiprocessing import ProcessingPool as Pool
|
175 |
+
# Function to preprocess utterances for one speaker
|
176 |
+
def __preprocess_speaker(speaker_dir: Path):
|
177 |
+
# Give a name to the speaker that includes its dataset
|
178 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
179 |
+
|
180 |
+
# Create an output directory with that name, as well as a txt file containing a
|
181 |
+
# reference to each source file.
|
182 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
183 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
184 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
185 |
+
|
186 |
+
existing_fnames = {}
|
187 |
+
# Gather all audio files for that speaker recursively
|
188 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
189 |
+
wav_lens = []
|
190 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
191 |
+
# Check if the target output file already exists
|
192 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
193 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
194 |
+
if skip_existing and out_fname in existing_fnames:
|
195 |
+
continue
|
196 |
+
|
197 |
+
# Load and preprocess the waveform
|
198 |
+
wav = audio.preprocess_wav(in_fpath)
|
199 |
+
if len(wav) == 0:
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Create the mel spectrogram, discard those that are too short
|
203 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
204 |
+
if len(frames) < partials_n_frames:
|
205 |
+
continue
|
206 |
+
|
207 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
208 |
+
np.save(out_fpath, frames)
|
209 |
+
# logger.add_sample(duration=len(wav) / sampling_rate)
|
210 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
211 |
+
wav_lens.append(len(wav))
|
212 |
+
sources_file.close()
|
213 |
+
return wav_lens
|
214 |
+
|
215 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
216 |
+
# Process the utterances for each speaker
|
217 |
+
# with ThreadPool(8) as pool:
|
218 |
+
# list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
219 |
+
# unit="speakers"))
|
220 |
+
pool = Pool(processes=20)
|
221 |
+
for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
|
222 |
+
for wav_len in wav_lens:
|
223 |
+
logger.add_sample(duration=wav_len / sampling_rate)
|
224 |
+
print(f'{i}/{len(speaker_dirs)} \r')
|
225 |
+
|
226 |
+
logger.finalize()
|
227 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
228 |
+
|
229 |
+
|
230 |
+
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
231 |
+
for dataset_name in librispeech_datasets["train"]["other"]:
|
232 |
+
# Initialize the preprocessing
|
233 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
234 |
+
if not dataset_root:
|
235 |
+
return
|
236 |
+
|
237 |
+
# Preprocess all speakers
|
238 |
+
speaker_dirs = list(dataset_root.glob("*"))
|
239 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
|
240 |
+
skip_existing, logger)
|
241 |
+
|
242 |
+
|
243 |
+
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
244 |
+
# Initialize the preprocessing
|
245 |
+
dataset_name = "VoxCeleb1"
|
246 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
247 |
+
if not dataset_root:
|
248 |
+
return
|
249 |
+
|
250 |
+
# Get the contents of the meta file
|
251 |
+
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
252 |
+
metadata = [line.split("\t") for line in metafile][1:]
|
253 |
+
|
254 |
+
# Select the ID and the nationality, filter out non-anglophone speakers
|
255 |
+
nationalities = {line[0]: line[3] for line in metadata}
|
256 |
+
# keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
257 |
+
# nationality.lower() in anglophone_nationalites]
|
258 |
+
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
|
259 |
+
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
260 |
+
(len(keep_speaker_ids), len(nationalities)))
|
261 |
+
|
262 |
+
# Get the speaker directories for anglophone speakers only
|
263 |
+
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
264 |
+
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
265 |
+
speaker_dir.name in keep_speaker_ids]
|
266 |
+
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
267 |
+
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
268 |
+
|
269 |
+
# Preprocess all speakers
|
270 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
271 |
+
skip_existing, logger)
|
272 |
+
|
273 |
+
|
274 |
+
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
275 |
+
# Initialize the preprocessing
|
276 |
+
dataset_name = "VoxCeleb2"
|
277 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
278 |
+
if not dataset_root:
|
279 |
+
return
|
280 |
+
|
281 |
+
# Get the speaker directories
|
282 |
+
# Preprocess all speakers
|
283 |
+
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
284 |
+
_preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
|
285 |
+
skip_existing, logger)
|
speaker_encoder/train.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.visualizations import Visualizations
|
2 |
+
from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
3 |
+
from speaker_encoder.params_model import *
|
4 |
+
from speaker_encoder.model import SpeakerEncoder
|
5 |
+
from utils.profiler import Profiler
|
6 |
+
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def sync(device: torch.device):
|
10 |
+
# FIXME
|
11 |
+
return
|
12 |
+
# For correct profiling (cuda operations are async)
|
13 |
+
if device.type == "cuda":
|
14 |
+
torch.cuda.synchronize(device)
|
15 |
+
|
16 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
17 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
18 |
+
no_visdom: bool):
|
19 |
+
# Create a dataset and a dataloader
|
20 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
21 |
+
loader = SpeakerVerificationDataLoader(
|
22 |
+
dataset,
|
23 |
+
speakers_per_batch, # 64
|
24 |
+
utterances_per_speaker, # 10
|
25 |
+
num_workers=8,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
29 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
30 |
+
# hyperparameters) faster on the CPU.
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
33 |
+
loss_device = torch.device("cpu")
|
34 |
+
|
35 |
+
# Create the model and the optimizer
|
36 |
+
model = SpeakerEncoder(device, loss_device)
|
37 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
38 |
+
init_step = 1
|
39 |
+
|
40 |
+
# Configure file path for the model
|
41 |
+
state_fpath = models_dir.joinpath(run_id + ".pt")
|
42 |
+
backup_dir = models_dir.joinpath(run_id + "_backups")
|
43 |
+
|
44 |
+
# Load any existing model
|
45 |
+
if not force_restart:
|
46 |
+
if state_fpath.exists():
|
47 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
48 |
+
checkpoint = torch.load(state_fpath)
|
49 |
+
init_step = checkpoint["step"]
|
50 |
+
model.load_state_dict(checkpoint["model_state"])
|
51 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
52 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
53 |
+
else:
|
54 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
55 |
+
else:
|
56 |
+
print("Starting the training from scratch.")
|
57 |
+
model.train()
|
58 |
+
|
59 |
+
# Initialize the visualization environment
|
60 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
61 |
+
vis.log_dataset(dataset)
|
62 |
+
vis.log_params()
|
63 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
64 |
+
vis.log_implementation({"Device": device_name})
|
65 |
+
|
66 |
+
# Training loop
|
67 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
68 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
69 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
70 |
+
|
71 |
+
# Forward pass
|
72 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
73 |
+
sync(device)
|
74 |
+
profiler.tick("Data to %s" % device)
|
75 |
+
embeds = model(inputs)
|
76 |
+
sync(device)
|
77 |
+
profiler.tick("Forward pass")
|
78 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
79 |
+
loss, eer = model.loss(embeds_loss)
|
80 |
+
sync(loss_device)
|
81 |
+
profiler.tick("Loss")
|
82 |
+
|
83 |
+
# Backward pass
|
84 |
+
model.zero_grad()
|
85 |
+
loss.backward()
|
86 |
+
profiler.tick("Backward pass")
|
87 |
+
model.do_gradient_ops()
|
88 |
+
optimizer.step()
|
89 |
+
profiler.tick("Parameter update")
|
90 |
+
|
91 |
+
# Update visualizations
|
92 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
93 |
+
vis.update(loss.item(), eer, step)
|
94 |
+
|
95 |
+
# Draw projections and save them to the backup folder
|
96 |
+
if umap_every != 0 and step % umap_every == 0:
|
97 |
+
print("Drawing and saving projections (step %d)" % step)
|
98 |
+
backup_dir.mkdir(exist_ok=True)
|
99 |
+
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
|
100 |
+
embeds = embeds.detach().cpu().numpy()
|
101 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
102 |
+
vis.save()
|
103 |
+
|
104 |
+
# Overwrite the latest version of the model
|
105 |
+
if save_every != 0 and step % save_every == 0:
|
106 |
+
print("Saving the model (step %d)" % step)
|
107 |
+
torch.save({
|
108 |
+
"step": step + 1,
|
109 |
+
"model_state": model.state_dict(),
|
110 |
+
"optimizer_state": optimizer.state_dict(),
|
111 |
+
}, state_fpath)
|
112 |
+
|
113 |
+
# Make a backup
|
114 |
+
if backup_every != 0 and step % backup_every == 0:
|
115 |
+
print("Making a backup (step %d)" % step)
|
116 |
+
backup_dir.mkdir(exist_ok=True)
|
117 |
+
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
|
118 |
+
torch.save({
|
119 |
+
"step": step + 1,
|
120 |
+
"model_state": model.state_dict(),
|
121 |
+
"optimizer_state": optimizer.state_dict(),
|
122 |
+
}, backup_fpath)
|
123 |
+
|
124 |
+
profiler.tick("Extras (visualizations, saving)")
|
125 |
+
|
speaker_encoder/visualizations.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from datetime import datetime
|
3 |
+
from time import perf_counter as timer
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
# import webbrowser
|
7 |
+
import visdom
|
8 |
+
import umap
|
9 |
+
|
10 |
+
colormap = np.array([
|
11 |
+
[76, 255, 0],
|
12 |
+
[0, 127, 70],
|
13 |
+
[255, 0, 0],
|
14 |
+
[255, 217, 38],
|
15 |
+
[0, 135, 255],
|
16 |
+
[165, 0, 165],
|
17 |
+
[255, 167, 255],
|
18 |
+
[0, 255, 255],
|
19 |
+
[255, 96, 38],
|
20 |
+
[142, 76, 0],
|
21 |
+
[33, 0, 127],
|
22 |
+
[0, 0, 0],
|
23 |
+
[183, 183, 183],
|
24 |
+
], dtype=np.float) / 255
|
25 |
+
|
26 |
+
|
27 |
+
class Visualizations:
|
28 |
+
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
29 |
+
# Tracking data
|
30 |
+
self.last_update_timestamp = timer()
|
31 |
+
self.update_every = update_every
|
32 |
+
self.step_times = []
|
33 |
+
self.losses = []
|
34 |
+
self.eers = []
|
35 |
+
print("Updating the visualizations every %d steps." % update_every)
|
36 |
+
|
37 |
+
# If visdom is disabled TODO: use a better paradigm for that
|
38 |
+
self.disabled = disabled
|
39 |
+
if self.disabled:
|
40 |
+
return
|
41 |
+
|
42 |
+
# Set the environment name
|
43 |
+
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
44 |
+
if env_name is None:
|
45 |
+
self.env_name = now
|
46 |
+
else:
|
47 |
+
self.env_name = "%s (%s)" % (env_name, now)
|
48 |
+
|
49 |
+
# Connect to visdom and open the corresponding window in the browser
|
50 |
+
try:
|
51 |
+
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
52 |
+
except ConnectionError:
|
53 |
+
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
54 |
+
"start it.")
|
55 |
+
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
56 |
+
|
57 |
+
# Create the windows
|
58 |
+
self.loss_win = None
|
59 |
+
self.eer_win = None
|
60 |
+
# self.lr_win = None
|
61 |
+
self.implementation_win = None
|
62 |
+
self.projection_win = None
|
63 |
+
self.implementation_string = ""
|
64 |
+
|
65 |
+
def log_params(self):
|
66 |
+
if self.disabled:
|
67 |
+
return
|
68 |
+
from speaker_encoder import params_data
|
69 |
+
from speaker_encoder import params_model
|
70 |
+
param_string = "<b>Model parameters</b>:<br>"
|
71 |
+
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
72 |
+
value = getattr(params_model, param_name)
|
73 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
74 |
+
param_string += "<b>Data parameters</b>:<br>"
|
75 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
76 |
+
value = getattr(params_data, param_name)
|
77 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
78 |
+
self.vis.text(param_string, opts={"title": "Parameters"})
|
79 |
+
|
80 |
+
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
81 |
+
if self.disabled:
|
82 |
+
return
|
83 |
+
dataset_string = ""
|
84 |
+
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
85 |
+
dataset_string += "\n" + dataset.get_logs()
|
86 |
+
dataset_string = dataset_string.replace("\n", "<br>")
|
87 |
+
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
88 |
+
|
89 |
+
def log_implementation(self, params):
|
90 |
+
if self.disabled:
|
91 |
+
return
|
92 |
+
implementation_string = ""
|
93 |
+
for param, value in params.items():
|
94 |
+
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
95 |
+
implementation_string = implementation_string.replace("\n", "<br>")
|
96 |
+
self.implementation_string = implementation_string
|
97 |
+
self.implementation_win = self.vis.text(
|
98 |
+
implementation_string,
|
99 |
+
opts={"title": "Training implementation"}
|
100 |
+
)
|
101 |
+
|
102 |
+
def update(self, loss, eer, step):
|
103 |
+
# Update the tracking data
|
104 |
+
now = timer()
|
105 |
+
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
106 |
+
self.last_update_timestamp = now
|
107 |
+
self.losses.append(loss)
|
108 |
+
self.eers.append(eer)
|
109 |
+
print(".", end="")
|
110 |
+
|
111 |
+
# Update the plots every <update_every> steps
|
112 |
+
if step % self.update_every != 0:
|
113 |
+
return
|
114 |
+
time_string = "Step time: mean: %5dms std: %5dms" % \
|
115 |
+
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
116 |
+
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
117 |
+
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
118 |
+
if not self.disabled:
|
119 |
+
self.loss_win = self.vis.line(
|
120 |
+
[np.mean(self.losses)],
|
121 |
+
[step],
|
122 |
+
win=self.loss_win,
|
123 |
+
update="append" if self.loss_win else None,
|
124 |
+
opts=dict(
|
125 |
+
legend=["Avg. loss"],
|
126 |
+
xlabel="Step",
|
127 |
+
ylabel="Loss",
|
128 |
+
title="Loss",
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.eer_win = self.vis.line(
|
132 |
+
[np.mean(self.eers)],
|
133 |
+
[step],
|
134 |
+
win=self.eer_win,
|
135 |
+
update="append" if self.eer_win else None,
|
136 |
+
opts=dict(
|
137 |
+
legend=["Avg. EER"],
|
138 |
+
xlabel="Step",
|
139 |
+
ylabel="EER",
|
140 |
+
title="Equal error rate"
|
141 |
+
)
|
142 |
+
)
|
143 |
+
if self.implementation_win is not None:
|
144 |
+
self.vis.text(
|
145 |
+
self.implementation_string + ("<b>%s</b>" % time_string),
|
146 |
+
win=self.implementation_win,
|
147 |
+
opts={"title": "Training implementation"},
|
148 |
+
)
|
149 |
+
|
150 |
+
# Reset the tracking
|
151 |
+
self.losses.clear()
|
152 |
+
self.eers.clear()
|
153 |
+
self.step_times.clear()
|
154 |
+
|
155 |
+
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
|
156 |
+
max_speakers=10):
|
157 |
+
max_speakers = min(max_speakers, len(colormap))
|
158 |
+
embeds = embeds[:max_speakers * utterances_per_speaker]
|
159 |
+
|
160 |
+
n_speakers = len(embeds) // utterances_per_speaker
|
161 |
+
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
162 |
+
colors = [colormap[i] for i in ground_truth]
|
163 |
+
|
164 |
+
reducer = umap.UMAP()
|
165 |
+
projected = reducer.fit_transform(embeds)
|
166 |
+
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
167 |
+
plt.gca().set_aspect("equal", "datalim")
|
168 |
+
plt.title("UMAP projection (step %d)" % step)
|
169 |
+
if not self.disabled:
|
170 |
+
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
171 |
+
if out_fpath is not None:
|
172 |
+
plt.savefig(out_fpath)
|
173 |
+
plt.clf()
|
174 |
+
|
175 |
+
def save(self):
|
176 |
+
if not self.disabled:
|
177 |
+
self.vis.save([self.env_name])
|
178 |
+
|
speaker_encoder/voice_encoder.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.hparams import *
|
2 |
+
from speaker_encoder import audio
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union, List
|
5 |
+
from torch import nn
|
6 |
+
from time import perf_counter as timer
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class SpeakerEncoder(nn.Module):
|
12 |
+
def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
|
13 |
+
"""
|
14 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
|
15 |
+
If None, defaults to cuda if it is available on your machine, otherwise the model will
|
16 |
+
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# Define the network
|
21 |
+
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
22 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
23 |
+
self.relu = nn.ReLU()
|
24 |
+
|
25 |
+
# Get the target device
|
26 |
+
if device is None:
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
elif isinstance(device, str):
|
29 |
+
device = torch.device(device)
|
30 |
+
self.device = device
|
31 |
+
|
32 |
+
# Load the pretrained model'speaker weights
|
33 |
+
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
|
34 |
+
# if not weights_fpath.exists():
|
35 |
+
# raise Exception("Couldn't find the voice encoder pretrained model at %s." %
|
36 |
+
# weights_fpath)
|
37 |
+
|
38 |
+
start = timer()
|
39 |
+
checkpoint = torch.load(weights_fpath, map_location="cpu")
|
40 |
+
|
41 |
+
self.load_state_dict(checkpoint["model_state"], strict=False)
|
42 |
+
self.to(device)
|
43 |
+
|
44 |
+
if verbose:
|
45 |
+
print("Loaded the voice encoder model on %s in %.2f seconds." %
|
46 |
+
(device.type, timer() - start))
|
47 |
+
|
48 |
+
def forward(self, mels: torch.FloatTensor):
|
49 |
+
"""
|
50 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
51 |
+
:param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
|
52 |
+
(batch_size, n_frames, n_channels)
|
53 |
+
:return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
|
54 |
+
Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
|
55 |
+
"""
|
56 |
+
# Pass the input through the LSTM layers and retrieve the final hidden state of the last
|
57 |
+
# layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
|
58 |
+
_, (hidden, _) = self.lstm(mels)
|
59 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
60 |
+
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def compute_partial_slices(n_samples: int, rate, min_coverage):
|
64 |
+
"""
|
65 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to
|
66 |
+
obtain partial utterances of <partials_n_frames> each. Both the waveform and the
|
67 |
+
mel spectrogram slices are returned, so as to make each partial utterance waveform
|
68 |
+
correspond to its spectrogram.
|
69 |
+
|
70 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
71 |
+
recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
|
72 |
+
|
73 |
+
:param n_samples: the number of samples in the waveform
|
74 |
+
:param rate: how many partial utterances should occur per second. Partial utterances must
|
75 |
+
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
76 |
+
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
77 |
+
the minimum rate is thus 0.625.
|
78 |
+
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
79 |
+
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
80 |
+
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
81 |
+
it will be discarded. If there aren't enough frames for one partial utterance,
|
82 |
+
this parameter is ignored so that the function always returns at least one slice.
|
83 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
84 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
85 |
+
utterances.
|
86 |
+
"""
|
87 |
+
assert 0 < min_coverage <= 1
|
88 |
+
|
89 |
+
# Compute how many frames separate two partial utterances
|
90 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
91 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
92 |
+
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
93 |
+
assert 0 < frame_step, "The rate is too high"
|
94 |
+
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
95 |
+
(sampling_rate / (samples_per_frame * partials_n_frames))
|
96 |
+
|
97 |
+
# Compute the slices
|
98 |
+
wav_slices, mel_slices = [], []
|
99 |
+
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
|
100 |
+
for i in range(0, steps, frame_step):
|
101 |
+
mel_range = np.array([i, i + partials_n_frames])
|
102 |
+
wav_range = mel_range * samples_per_frame
|
103 |
+
mel_slices.append(slice(*mel_range))
|
104 |
+
wav_slices.append(slice(*wav_range))
|
105 |
+
|
106 |
+
# Evaluate whether extra padding is warranted or not
|
107 |
+
last_wav_range = wav_slices[-1]
|
108 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
109 |
+
if coverage < min_coverage and len(mel_slices) > 1:
|
110 |
+
mel_slices = mel_slices[:-1]
|
111 |
+
wav_slices = wav_slices[:-1]
|
112 |
+
|
113 |
+
return wav_slices, mel_slices
|
114 |
+
|
115 |
+
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
|
116 |
+
"""
|
117 |
+
Computes an embedding for a single utterance. The utterance is divided in partial
|
118 |
+
utterances and an embedding is computed for each. The complete utterance embedding is the
|
119 |
+
L2-normed average embedding of the partial utterances.
|
120 |
+
|
121 |
+
TODO: independent batched version of this function
|
122 |
+
|
123 |
+
:param wav: a preprocessed utterance waveform as a numpy array of float32
|
124 |
+
:param return_partials: if True, the partial embeddings will also be returned along with
|
125 |
+
the wav slices corresponding to each partial utterance.
|
126 |
+
:param rate: how many partial utterances should occur per second. Partial utterances must
|
127 |
+
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
128 |
+
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
129 |
+
the minimum rate is thus 0.625.
|
130 |
+
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
131 |
+
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
132 |
+
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
133 |
+
it will be discarded. If there aren't enough frames for one partial utterance,
|
134 |
+
this parameter is ignored so that the function always returns at least one slice.
|
135 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
136 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
137 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
138 |
+
returned.
|
139 |
+
"""
|
140 |
+
# Compute where to split the utterance into partials and pad the waveform with zeros if
|
141 |
+
# the partial utterances cover a larger range.
|
142 |
+
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
|
143 |
+
max_wave_length = wav_slices[-1].stop
|
144 |
+
if max_wave_length >= len(wav):
|
145 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
146 |
+
|
147 |
+
# Split the utterance into partials and forward them through the model
|
148 |
+
mel = audio.wav_to_mel_spectrogram(wav)
|
149 |
+
mels = np.array([mel[s] for s in mel_slices])
|
150 |
+
with torch.no_grad():
|
151 |
+
mels = torch.from_numpy(mels).to(self.device)
|
152 |
+
partial_embeds = self(mels).cpu().numpy()
|
153 |
+
|
154 |
+
# Compute the utterance embedding from the partial embeddings
|
155 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
156 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
157 |
+
|
158 |
+
if return_partials:
|
159 |
+
return embed, partial_embeds, wav_slices
|
160 |
+
return embed
|
161 |
+
|
162 |
+
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
|
163 |
+
"""
|
164 |
+
Compute the embedding of a collection of wavs (presumably from the same speaker) by
|
165 |
+
averaging their embedding and L2-normalizing it.
|
166 |
+
|
167 |
+
:param wavs: list of wavs a numpy arrays of float32.
|
168 |
+
:param kwargs: extra arguments to embed_utterance()
|
169 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
|
170 |
+
"""
|
171 |
+
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
|
172 |
+
for wav in wavs], axis=0)
|
173 |
+
return raw_embed / np.linalg.norm(raw_embed, 2)
|