GenerSpeech / tasks /vocoder /dataset_utils.py
Rongjiehuang's picture
update
222619b
import glob
import importlib
import os
from resemblyzer import VoiceEncoder
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
import utils
from tasks.base_task import BaseDataset
from utils.hparams import hparams
from utils.indexed_datasets import IndexedDataset
from tqdm import tqdm
class EndlessDistributedSampler(DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.shuffle = shuffle
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = [i for _ in range(1000) for i in torch.randperm(
len(self.dataset), generator=g).tolist()]
else:
indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))]
indices = indices[:len(indices) // self.num_replicas * self.num_replicas]
indices = indices[self.rank::self.num_replicas]
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class VocoderDataset(BaseDataset):
def __init__(self, prefix, shuffle=False):
super().__init__(shuffle)
self.hparams = hparams
self.prefix = prefix
self.data_dir = hparams['binary_data_dir']
self.is_infer = prefix == 'test'
self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size']
self.aux_context_window = hparams['aux_context_window']
self.hop_size = hparams['hop_size']
if self.is_infer and hparams['test_input_dir'] != '':
self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
elif self.is_infer and hparams['test_mel_dir'] != '':
self.indexed_ds, self.sizes = self.load_mel_inputs(hparams['test_mel_dir'])
self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
else:
self.indexed_ds = None
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if
s - 2 * self.aux_context_window > self.batch_max_frames]
print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.")
self.sizes = [s for idx, s in enumerate(self.sizes) if
s - 2 * self.aux_context_window > self.batch_max_frames]
def _get_item(self, index):
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
item = self.indexed_ds[index]
return item
def __getitem__(self, index):
index = self.avail_idxs[index]
item = self._get_item(index)
sample = {
"id": index,
"item_name": item['item_name'],
"mel": torch.FloatTensor(item['mel']),
"wav": torch.FloatTensor(item['wav'].astype(np.float32)),
}
if 'pitch' in item:
sample['pitch'] = torch.LongTensor(item['pitch'])
sample['f0'] = torch.FloatTensor(item['f0'])
if hparams.get('use_spk_embed', False):
sample["spk_embed"] = torch.Tensor(item['spk_embed'])
if hparams.get('use_emo_embed', False):
sample["emo_embed"] = torch.Tensor(item['emo_embed'])
return sample
def collater(self, batch):
if len(batch) == 0:
return {}
y_batch, c_batch, p_batch, f0_batch = [], [], [], []
item_name = []
have_pitch = 'pitch' in batch[0]
for idx in range(len(batch)):
item_name.append(batch[idx]['item_name'])
x, c = batch[idx]['wav'] if self.hparams['use_wav'] else None, batch[idx]['mel'].squeeze(0)
if have_pitch:
p = batch[idx]['pitch']
f0 = batch[idx]['f0']
if self.hparams['use_wav']:self._assert_ready_for_upsampling(x, c, self.hop_size, 0)
if len(c) - 2 * self.aux_context_window > self.batch_max_frames:
# randomly pickup with the batch_max_steps length of the part
batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len(
c) - 2 * self.aux_context_window - 1
batch_max_steps = batch_max_frames * self.hop_size
interval_start = self.aux_context_window
interval_end = len(c) - batch_max_frames - self.aux_context_window
start_frame = np.random.randint(interval_start, interval_end)
start_step = start_frame * self.hop_size
if self.hparams['use_wav']:y = x[start_step: start_step + batch_max_steps]
c = c[start_frame - self.aux_context_window:
start_frame + self.aux_context_window + batch_max_frames]
if have_pitch:
p = p[start_frame - self.aux_context_window:
start_frame + self.aux_context_window + batch_max_frames]
f0 = f0[start_frame - self.aux_context_window:
start_frame + self.aux_context_window + batch_max_frames]
if self.hparams['use_wav']:self._assert_ready_for_upsampling(y, c, self.hop_size, self.aux_context_window)
else:
print(f"Removed short sample from batch (length={len(x)}).")
continue
if self.hparams['use_wav']:y_batch += [y.reshape(-1, 1)] # [(T, 1), (T, 1), ...]
c_batch += [c] # [(T' C), (T' C), ...]
if have_pitch:
p_batch += [p] # [(T' C), (T' C), ...]
f0_batch += [f0] # [(T' C), (T' C), ...]
# convert each batch to tensor, asuume that each item in batch has the same length
if self.hparams['use_wav']:y_batch = utils.collate_2d(y_batch, 0).transpose(2, 1) # (B, 1, T)
c_batch = utils.collate_2d(c_batch, 0).transpose(2, 1) # (B, C, T')
if have_pitch:
p_batch = utils.collate_1d(p_batch, 0) # (B, T')
f0_batch = utils.collate_1d(f0_batch, 0) # (B, T')
else:
p_batch, f0_batch = None, None
# make input noise signal batch tensor
if self.hparams['use_wav']: z_batch = torch.randn(y_batch.size()) # (B, 1, T)
else: z_batch=[]
return {
'z': z_batch,
'mels': c_batch,
'wavs': y_batch,
'pitches': p_batch,
'f0': f0_batch,
'item_name': item_name
}
@staticmethod
def _assert_ready_for_upsampling(x, c, hop_size, context_window):
"""Assert the audio and feature lengths are correctly adjusted for upsamping."""
assert len(x) == (len(c) - 2 * context_window) * hop_size
def load_test_inputs(self, test_input_dir, spk_id=0):
inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/**/*.mp3'))
sizes = []
items = []
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
pkg = ".".join(binarizer_cls.split(".")[:-1])
cls_name = binarizer_cls.split(".")[-1]
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
binarization_args = hparams['binarization_args']
for wav_fn in inp_wav_paths:
item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_")
item = binarizer_cls.process_item(
item_name, wav_fn, binarization_args)
items.append(item)
sizes.append(item['len'])
return items, sizes
def load_mel_inputs(self, test_input_dir, spk_id=0):
inp_mel_paths = sorted(glob.glob(f'{test_input_dir}/*.npy'))
sizes = []
items = []
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
pkg = ".".join(binarizer_cls.split(".")[:-1])
cls_name = binarizer_cls.split(".")[-1]
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
binarization_args = hparams['binarization_args']
for mel in inp_mel_paths:
mel_input = np.load(mel)
mel_input = torch.FloatTensor(mel_input)
item_name = mel[len(test_input_dir) + 1:].replace("/", "_")
item = binarizer_cls.process_mel_item(item_name, mel_input, None, binarization_args)
items.append(item)
sizes.append(item['len'])
return items, sizes