styletts2 / meldataset.py
jonluca's picture
add precomputed voices, reformat code, remove unused code
1373f78 unverified
# coding: utf-8
import os.path as osp
import random
import numpy as np
import random
import soundfile as sf
import librosa
import torch
import torchaudio
from torch.utils.data import DataLoader
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import pandas as pd
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
class TextCleaner:
def __init__(self, dummy=None):
self.word_index_dictionary = dicts
def __call__(self, text):
indexes = []
for char in text:
try:
indexes.append(self.word_index_dictionary[char])
except KeyError:
print(text)
return indexes
np.random.seed(1)
random.seed(1)
SPECT_PARAMS = {"n_fft": 2048, "win_length": 1200, "hop_length": 300}
MEL_PARAMS = {
"n_mels": 80,
}
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
)
mean, std = -4, 4
def preprocess(wave):
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
class FilePathDataset(torch.utils.data.Dataset):
def __init__(
self,
data_list,
root_path,
sr=24000,
data_augmentation=False,
validation=False,
OOD_data="Data/OOD_texts.txt",
min_length=50,
):
_data_list = [l[:-1].split("|") for l in data_list]
self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
self.text_cleaner = TextCleaner()
self.sr = sr
self.df = pd.DataFrame(self.data_list)
self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
self.mean, self.std = -4, 4
self.data_augmentation = data_augmentation and (not validation)
self.max_mel_length = 192
self.min_length = min_length
with open(OOD_data, "r") as f:
tl = f.readlines()
idx = 1 if ".wav" in tl[0].split("|")[0] else 0
self.ptexts = [t.split("|")[idx] for t in tl]
self.root_path = root_path
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
data = self.data_list[idx]
path = data[0]
wave, text_tensor, speaker_id = self._load_tensor(data)
mel_tensor = preprocess(wave).squeeze()
acoustic_feature = mel_tensor.squeeze()
length_feature = acoustic_feature.size(1)
acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)]
# get reference sample
ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
# get OOD text
ps = ""
while len(ps) < self.min_length:
rand_idx = np.random.randint(0, len(self.ptexts) - 1)
ps = self.ptexts[rand_idx]
text = self.text_cleaner(ps)
text.insert(0, 0)
text.append(0)
ref_text = torch.LongTensor(text)
return (
speaker_id,
acoustic_feature,
text_tensor,
ref_text,
ref_mel_tensor,
ref_label,
path,
wave,
)
def _load_tensor(self, data):
wave_path, text, speaker_id = data
speaker_id = int(speaker_id)
wave, sr = sf.read(osp.join(self.root_path, wave_path))
if wave.shape[-1] == 2:
wave = wave[:, 0].squeeze()
if sr != 24000:
wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
print(wave_path, sr)
wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
text = self.text_cleaner(text)
text.insert(0, 0)
text.append(0)
text = torch.LongTensor(text)
return wave, text, speaker_id
def _load_data(self, data):
wave, text_tensor, speaker_id = self._load_tensor(data)
mel_tensor = preprocess(wave).squeeze()
mel_length = mel_tensor.size(1)
if mel_length > self.max_mel_length:
random_start = np.random.randint(0, mel_length - self.max_mel_length)
mel_tensor = mel_tensor[
:, random_start : random_start + self.max_mel_length
]
return mel_tensor, speaker_id
class Collater(object):
"""
Args:
adaptive_batch_size (bool): if true, decrease batch size when long data comes.
"""
def __init__(self, return_wave=False):
self.text_pad_index = 0
self.min_mel_length = 192
self.max_mel_length = 192
self.return_wave = return_wave
def __call__(self, batch):
# batch[0] = wave, mel, text, f0, speakerid
batch_size = len(batch)
# sort by mel length
lengths = [b[1].shape[1] for b in batch]
batch_indexes = np.argsort(lengths)[::-1]
batch = [batch[bid] for bid in batch_indexes]
nmels = batch[0][1].size(0)
max_mel_length = max([b[1].shape[1] for b in batch])
max_text_length = max([b[2].shape[0] for b in batch])
max_rtext_length = max([b[3].shape[0] for b in batch])
labels = torch.zeros((batch_size)).long()
mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
texts = torch.zeros((batch_size, max_text_length)).long()
ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
input_lengths = torch.zeros(batch_size).long()
ref_lengths = torch.zeros(batch_size).long()
output_lengths = torch.zeros(batch_size).long()
ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
ref_labels = torch.zeros((batch_size)).long()
paths = ["" for _ in range(batch_size)]
waves = [None for _ in range(batch_size)]
for bid, (
label,
mel,
text,
ref_text,
ref_mel,
ref_label,
path,
wave,
) in enumerate(batch):
mel_size = mel.size(1)
text_size = text.size(0)
rtext_size = ref_text.size(0)
labels[bid] = label
mels[bid, :, :mel_size] = mel
texts[bid, :text_size] = text
ref_texts[bid, :rtext_size] = ref_text
input_lengths[bid] = text_size
ref_lengths[bid] = rtext_size
output_lengths[bid] = mel_size
paths[bid] = path
ref_mel_size = ref_mel.size(1)
ref_mels[bid, :, :ref_mel_size] = ref_mel
ref_labels[bid] = ref_label
waves[bid] = wave
return (
waves,
texts,
input_lengths,
ref_texts,
ref_lengths,
mels,
output_lengths,
ref_mels,
)
def build_dataloader(
path_list,
root_path,
validation=False,
OOD_data="Data/OOD_texts.txt",
min_length=50,
batch_size=4,
num_workers=1,
device="cpu",
collate_config={},
dataset_config={},
):
dataset = FilePathDataset(
path_list,
root_path,
OOD_data=OOD_data,
min_length=min_length,
validation=validation,
**dataset_config
)
collate_fn = Collater(**collate_config)
data_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=(not validation),
num_workers=num_workers,
drop_last=(not validation),
collate_fn=collate_fn,
pin_memory=(device != "cpu"),
)
return data_loader