Spaces:
Running
Running
# coding: utf-8 | |
import os | |
import os.path as osp | |
import time | |
import random | |
import numpy as np | |
import random | |
import soundfile as sf | |
import librosa | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchaudio | |
from torch.utils.data import DataLoader | |
import logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
import pandas as pd | |
_pad = "$" | |
_punctuation = ';:,.!?¡¿—…"«»“” ' | |
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" | |
# Export all symbols: | |
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) | |
dicts = {} | |
for i in range(len((symbols))): | |
dicts[symbols[i]] = i | |
class TextCleaner: | |
def __init__(self, dummy=None): | |
self.word_index_dictionary = dicts | |
def __call__(self, text): | |
indexes = [] | |
for char in text: | |
try: | |
indexes.append(self.word_index_dictionary[char]) | |
except KeyError: | |
print(text) | |
return indexes | |
np.random.seed(1) | |
random.seed(1) | |
SPECT_PARAMS = {"n_fft": 2048, "win_length": 1200, "hop_length": 300} | |
MEL_PARAMS = { | |
"n_mels": 80, | |
} | |
to_mel = torchaudio.transforms.MelSpectrogram( | |
n_mels=80, n_fft=2048, win_length=1200, hop_length=300 | |
) | |
mean, std = -4, 4 | |
def preprocess(wave): | |
wave_tensor = torch.from_numpy(wave).float() | |
mel_tensor = to_mel(wave_tensor) | |
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std | |
return mel_tensor | |
class FilePathDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
data_list, | |
root_path, | |
sr=24000, | |
data_augmentation=False, | |
validation=False, | |
OOD_data="Data/OOD_texts.txt", | |
min_length=50, | |
): | |
spect_params = SPECT_PARAMS | |
mel_params = MEL_PARAMS | |
_data_list = [l[:-1].split("|") for l in data_list] | |
self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list] | |
self.text_cleaner = TextCleaner() | |
self.sr = sr | |
self.df = pd.DataFrame(self.data_list) | |
self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS) | |
self.mean, self.std = -4, 4 | |
self.data_augmentation = data_augmentation and (not validation) | |
self.max_mel_length = 192 | |
self.min_length = min_length | |
with open(OOD_data, "r") as f: | |
tl = f.readlines() | |
idx = 1 if ".wav" in tl[0].split("|")[0] else 0 | |
self.ptexts = [t.split("|")[idx] for t in tl] | |
self.root_path = root_path | |
def __len__(self): | |
return len(self.data_list) | |
def __getitem__(self, idx): | |
data = self.data_list[idx] | |
path = data[0] | |
wave, text_tensor, speaker_id = self._load_tensor(data) | |
mel_tensor = preprocess(wave).squeeze() | |
acoustic_feature = mel_tensor.squeeze() | |
length_feature = acoustic_feature.size(1) | |
acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)] | |
# get reference sample | |
ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist() | |
ref_mel_tensor, ref_label = self._load_data(ref_data[:3]) | |
# get OOD text | |
ps = "" | |
while len(ps) < self.min_length: | |
rand_idx = np.random.randint(0, len(self.ptexts) - 1) | |
ps = self.ptexts[rand_idx] | |
text = self.text_cleaner(ps) | |
text.insert(0, 0) | |
text.append(0) | |
ref_text = torch.LongTensor(text) | |
return ( | |
speaker_id, | |
acoustic_feature, | |
text_tensor, | |
ref_text, | |
ref_mel_tensor, | |
ref_label, | |
path, | |
wave, | |
) | |
def _load_tensor(self, data): | |
wave_path, text, speaker_id = data | |
speaker_id = int(speaker_id) | |
wave, sr = sf.read(osp.join(self.root_path, wave_path)) | |
if wave.shape[-1] == 2: | |
wave = wave[:, 0].squeeze() | |
if sr != 24000: | |
wave = librosa.resample(wave, orig_sr=sr, target_sr=24000) | |
print(wave_path, sr) | |
wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) | |
text = self.text_cleaner(text) | |
text.insert(0, 0) | |
text.append(0) | |
text = torch.LongTensor(text) | |
return wave, text, speaker_id | |
def _load_data(self, data): | |
wave, text_tensor, speaker_id = self._load_tensor(data) | |
mel_tensor = preprocess(wave).squeeze() | |
mel_length = mel_tensor.size(1) | |
if mel_length > self.max_mel_length: | |
random_start = np.random.randint(0, mel_length - self.max_mel_length) | |
mel_tensor = mel_tensor[ | |
:, random_start : random_start + self.max_mel_length | |
] | |
return mel_tensor, speaker_id | |
class Collater(object): | |
""" | |
Args: | |
adaptive_batch_size (bool): if true, decrease batch size when long data comes. | |
""" | |
def __init__(self, return_wave=False): | |
self.text_pad_index = 0 | |
self.min_mel_length = 192 | |
self.max_mel_length = 192 | |
self.return_wave = return_wave | |
def __call__(self, batch): | |
# batch[0] = wave, mel, text, f0, speakerid | |
batch_size = len(batch) | |
# sort by mel length | |
lengths = [b[1].shape[1] for b in batch] | |
batch_indexes = np.argsort(lengths)[::-1] | |
batch = [batch[bid] for bid in batch_indexes] | |
nmels = batch[0][1].size(0) | |
max_mel_length = max([b[1].shape[1] for b in batch]) | |
max_text_length = max([b[2].shape[0] for b in batch]) | |
max_rtext_length = max([b[3].shape[0] for b in batch]) | |
labels = torch.zeros((batch_size)).long() | |
mels = torch.zeros((batch_size, nmels, max_mel_length)).float() | |
texts = torch.zeros((batch_size, max_text_length)).long() | |
ref_texts = torch.zeros((batch_size, max_rtext_length)).long() | |
input_lengths = torch.zeros(batch_size).long() | |
ref_lengths = torch.zeros(batch_size).long() | |
output_lengths = torch.zeros(batch_size).long() | |
ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() | |
ref_labels = torch.zeros((batch_size)).long() | |
paths = ["" for _ in range(batch_size)] | |
waves = [None for _ in range(batch_size)] | |
for bid, ( | |
label, | |
mel, | |
text, | |
ref_text, | |
ref_mel, | |
ref_label, | |
path, | |
wave, | |
) in enumerate(batch): | |
mel_size = mel.size(1) | |
text_size = text.size(0) | |
rtext_size = ref_text.size(0) | |
labels[bid] = label | |
mels[bid, :, :mel_size] = mel | |
texts[bid, :text_size] = text | |
ref_texts[bid, :rtext_size] = ref_text | |
input_lengths[bid] = text_size | |
ref_lengths[bid] = rtext_size | |
output_lengths[bid] = mel_size | |
paths[bid] = path | |
ref_mel_size = ref_mel.size(1) | |
ref_mels[bid, :, :ref_mel_size] = ref_mel | |
ref_labels[bid] = ref_label | |
waves[bid] = wave | |
return ( | |
waves, | |
texts, | |
input_lengths, | |
ref_texts, | |
ref_lengths, | |
mels, | |
output_lengths, | |
ref_mels, | |
) | |
def build_dataloader( | |
path_list, | |
root_path, | |
validation=False, | |
OOD_data="Data/OOD_texts.txt", | |
min_length=50, | |
batch_size=4, | |
num_workers=1, | |
device="cpu", | |
collate_config={}, | |
dataset_config={}, | |
): | |
dataset = FilePathDataset( | |
path_list, | |
root_path, | |
OOD_data=OOD_data, | |
min_length=min_length, | |
validation=validation, | |
**dataset_config | |
) | |
collate_fn = Collater(**collate_config) | |
data_loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=(not validation), | |
num_workers=num_workers, | |
drop_last=(not validation), | |
collate_fn=collate_fn, | |
pin_memory=(device != "cpu"), | |
) | |
return data_loader | |