Spaces:
Build error
Build error
File size: 9,199 Bytes
222619b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import glob
import importlib
import os
from resemblyzer import VoiceEncoder
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
import utils
from tasks.base_task import BaseDataset
from utils.hparams import hparams
from utils.indexed_datasets import IndexedDataset
from tqdm import tqdm
class EndlessDistributedSampler(DistributedSampler):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.shuffle = shuffle
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = [i for _ in range(1000) for i in torch.randperm(
len(self.dataset), generator=g).tolist()]
else:
indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))]
indices = indices[:len(indices) // self.num_replicas * self.num_replicas]
indices = indices[self.rank::self.num_replicas]
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
class VocoderDataset(BaseDataset):
def __init__(self, prefix, shuffle=False):
super().__init__(shuffle)
self.hparams = hparams
self.prefix = prefix
self.data_dir = hparams['binary_data_dir']
self.is_infer = prefix == 'test'
self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size']
self.aux_context_window = hparams['aux_context_window']
self.hop_size = hparams['hop_size']
if self.is_infer and hparams['test_input_dir'] != '':
self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
elif self.is_infer and hparams['test_mel_dir'] != '':
self.indexed_ds, self.sizes = self.load_mel_inputs(hparams['test_mel_dir'])
self.avail_idxs = [i for i, _ in enumerate(self.sizes)]
else:
self.indexed_ds = None
self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if
s - 2 * self.aux_context_window > self.batch_max_frames]
print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.")
self.sizes = [s for idx, s in enumerate(self.sizes) if
s - 2 * self.aux_context_window > self.batch_max_frames]
def _get_item(self, index):
if self.indexed_ds is None:
self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
item = self.indexed_ds[index]
return item
def __getitem__(self, index):
index = self.avail_idxs[index]
item = self._get_item(index)
sample = {
"id": index,
"item_name": item['item_name'],
"mel": torch.FloatTensor(item['mel']),
"wav": torch.FloatTensor(item['wav'].astype(np.float32)),
}
if 'pitch' in item:
sample['pitch'] = torch.LongTensor(item['pitch'])
sample['f0'] = torch.FloatTensor(item['f0'])
if hparams.get('use_spk_embed', False):
sample["spk_embed"] = torch.Tensor(item['spk_embed'])
if hparams.get('use_emo_embed', False):
sample["emo_embed"] = torch.Tensor(item['emo_embed'])
return sample
def collater(self, batch):
if len(batch) == 0:
return {}
y_batch, c_batch, p_batch, f0_batch = [], [], [], []
item_name = []
have_pitch = 'pitch' in batch[0]
for idx in range(len(batch)):
item_name.append(batch[idx]['item_name'])
x, c = batch[idx]['wav'] if self.hparams['use_wav'] else None, batch[idx]['mel'].squeeze(0)
if have_pitch:
p = batch[idx]['pitch']
f0 = batch[idx]['f0']
if self.hparams['use_wav']:self._assert_ready_for_upsampling(x, c, self.hop_size, 0)
if len(c) - 2 * self.aux_context_window > self.batch_max_frames:
# randomly pickup with the batch_max_steps length of the part
batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len(
c) - 2 * self.aux_context_window - 1
batch_max_steps = batch_max_frames * self.hop_size
interval_start = self.aux_context_window
interval_end = len(c) - batch_max_frames - self.aux_context_window
start_frame = np.random.randint(interval_start, interval_end)
start_step = start_frame * self.hop_size
if self.hparams['use_wav']:y = x[start_step: start_step + batch_max_steps]
c = c[start_frame - self.aux_context_window:
start_frame + self.aux_context_window + batch_max_frames]
if have_pitch:
p = p[start_frame - self.aux_context_window:
start_frame + self.aux_context_window + batch_max_frames]
f0 = f0[start_frame - self.aux_context_window:
start_frame + self.aux_context_window + batch_max_frames]
if self.hparams['use_wav']:self._assert_ready_for_upsampling(y, c, self.hop_size, self.aux_context_window)
else:
print(f"Removed short sample from batch (length={len(x)}).")
continue
if self.hparams['use_wav']:y_batch += [y.reshape(-1, 1)] # [(T, 1), (T, 1), ...]
c_batch += [c] # [(T' C), (T' C), ...]
if have_pitch:
p_batch += [p] # [(T' C), (T' C), ...]
f0_batch += [f0] # [(T' C), (T' C), ...]
# convert each batch to tensor, asuume that each item in batch has the same length
if self.hparams['use_wav']:y_batch = utils.collate_2d(y_batch, 0).transpose(2, 1) # (B, 1, T)
c_batch = utils.collate_2d(c_batch, 0).transpose(2, 1) # (B, C, T')
if have_pitch:
p_batch = utils.collate_1d(p_batch, 0) # (B, T')
f0_batch = utils.collate_1d(f0_batch, 0) # (B, T')
else:
p_batch, f0_batch = None, None
# make input noise signal batch tensor
if self.hparams['use_wav']: z_batch = torch.randn(y_batch.size()) # (B, 1, T)
else: z_batch=[]
return {
'z': z_batch,
'mels': c_batch,
'wavs': y_batch,
'pitches': p_batch,
'f0': f0_batch,
'item_name': item_name
}
@staticmethod
def _assert_ready_for_upsampling(x, c, hop_size, context_window):
"""Assert the audio and feature lengths are correctly adjusted for upsamping."""
assert len(x) == (len(c) - 2 * context_window) * hop_size
def load_test_inputs(self, test_input_dir, spk_id=0):
inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/**/*.mp3'))
sizes = []
items = []
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
pkg = ".".join(binarizer_cls.split(".")[:-1])
cls_name = binarizer_cls.split(".")[-1]
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
binarization_args = hparams['binarization_args']
for wav_fn in inp_wav_paths:
item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_")
item = binarizer_cls.process_item(
item_name, wav_fn, binarization_args)
items.append(item)
sizes.append(item['len'])
return items, sizes
def load_mel_inputs(self, test_input_dir, spk_id=0):
inp_mel_paths = sorted(glob.glob(f'{test_input_dir}/*.npy'))
sizes = []
items = []
binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer')
pkg = ".".join(binarizer_cls.split(".")[:-1])
cls_name = binarizer_cls.split(".")[-1]
binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
binarization_args = hparams['binarization_args']
for mel in inp_mel_paths:
mel_input = np.load(mel)
mel_input = torch.FloatTensor(mel_input)
item_name = mel[len(test_input_dir) + 1:].replace("/", "_")
item = binarizer_cls.process_mel_item(item_name, mel_input, None, binarization_args)
items.append(item)
sizes.append(item['len'])
return items, sizes
|