|
import functools
|
|
import random
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from TTS.config.shared_configs import BaseDatasetConfig
|
|
from TTS.tts.datasets import load_tts_samples
|
|
from TTS.tts.utils.data import get_length_balancer_weights
|
|
from TTS.tts.utils.languages import get_language_balancer_weights
|
|
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
|
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
|
|
|
|
|
torch.manual_seed(0)
|
|
|
|
dataset_config_en = BaseDatasetConfig(
|
|
formatter="ljspeech",
|
|
meta_file_train="metadata.csv",
|
|
meta_file_val="metadata.csv",
|
|
path="tests/data/ljspeech",
|
|
language="en",
|
|
)
|
|
|
|
dataset_config_pt = BaseDatasetConfig(
|
|
formatter="ljspeech",
|
|
meta_file_train="metadata.csv",
|
|
meta_file_val="metadata.csv",
|
|
path="tests/data/ljspeech",
|
|
language="pt-br",
|
|
)
|
|
|
|
|
|
train_samples, eval_samples = load_tts_samples(
|
|
[dataset_config_en, dataset_config_en, dataset_config_pt], eval_split=True
|
|
)
|
|
|
|
|
|
for i, sample in enumerate(train_samples):
|
|
if i < 5:
|
|
sample["speaker_name"] = "ljspeech-0"
|
|
else:
|
|
sample["speaker_name"] = "ljspeech-1"
|
|
|
|
|
|
def is_balanced(lang_1, lang_2):
|
|
return 0.85 < lang_1 / lang_2 < 1.2
|
|
|
|
|
|
class TestSamplers(unittest.TestCase):
|
|
def test_language_random_sampler(self):
|
|
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
|
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
|
en, pt = 0, 0
|
|
for index in ids:
|
|
if train_samples[index]["language"] == "en":
|
|
en += 1
|
|
else:
|
|
pt += 1
|
|
|
|
assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
|
|
|
def test_language_weighted_random_sampler(self):
|
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
|
|
get_language_balancer_weights(train_samples), len(train_samples)
|
|
)
|
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
|
en, pt = 0, 0
|
|
for index in ids:
|
|
if train_samples[index]["language"] == "en":
|
|
en += 1
|
|
else:
|
|
pt += 1
|
|
|
|
assert is_balanced(en, pt), "Language Weighted sampler is supposed to be balanced"
|
|
|
|
def test_speaker_weighted_random_sampler(self):
|
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
|
|
get_speaker_balancer_weights(train_samples), len(train_samples)
|
|
)
|
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
|
spk1, spk2 = 0, 0
|
|
for index in ids:
|
|
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
|
spk1 += 1
|
|
else:
|
|
spk2 += 1
|
|
|
|
assert is_balanced(spk1, spk2), "Speaker Weighted sampler is supposed to be balanced"
|
|
|
|
def test_perfect_sampler(self):
|
|
classes = set()
|
|
for item in train_samples:
|
|
classes.add(item["speaker_name"])
|
|
|
|
sampler = PerfectBatchSampler(
|
|
train_samples,
|
|
classes,
|
|
batch_size=2 * 3,
|
|
num_classes_in_batch=2,
|
|
label_key="speaker_name",
|
|
shuffle=False,
|
|
drop_last=True,
|
|
)
|
|
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
|
for batch in batchs:
|
|
spk1, spk2 = 0, 0
|
|
|
|
for index in batch:
|
|
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
|
spk1 += 1
|
|
else:
|
|
spk2 += 1
|
|
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
|
|
|
def test_perfect_sampler_shuffle(self):
|
|
classes = set()
|
|
for item in train_samples:
|
|
classes.add(item["speaker_name"])
|
|
|
|
sampler = PerfectBatchSampler(
|
|
train_samples,
|
|
classes,
|
|
batch_size=2 * 3,
|
|
num_classes_in_batch=2,
|
|
label_key="speaker_name",
|
|
shuffle=True,
|
|
drop_last=False,
|
|
)
|
|
batchs = functools.reduce(lambda a, b: a + b, [list(sampler) for i in range(100)])
|
|
for batch in batchs:
|
|
spk1, spk2 = 0, 0
|
|
|
|
for index in batch:
|
|
if train_samples[index]["speaker_name"] == "ljspeech-0":
|
|
spk1 += 1
|
|
else:
|
|
spk2 += 1
|
|
assert spk1 == spk2, "PerfectBatchSampler is supposed to be perfectly balanced"
|
|
|
|
def test_length_weighted_random_sampler(self):
|
|
for _ in range(1000):
|
|
|
|
min_audio = random.randrange(1, 22050)
|
|
max_audio = random.randrange(44100, 220500)
|
|
for idx, item in enumerate(train_samples):
|
|
|
|
random_increase = random.randrange(100, 1000)
|
|
if idx < 5:
|
|
item["audio_length"] = min_audio + random_increase
|
|
else:
|
|
item["audio_length"] = max_audio + random_increase
|
|
|
|
weighted_sampler = torch.utils.data.sampler.WeightedRandomSampler(
|
|
get_length_balancer_weights(train_samples, num_buckets=2), len(train_samples)
|
|
)
|
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
|
len1, len2 = 0, 0
|
|
for index in ids:
|
|
if train_samples[index]["audio_length"] < max_audio:
|
|
len1 += 1
|
|
else:
|
|
len2 += 1
|
|
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
|
|
|
def test_bucket_batch_sampler(self):
|
|
bucket_size_multiplier = 2
|
|
sampler = range(len(train_samples))
|
|
sampler = BucketBatchSampler(
|
|
sampler,
|
|
data=train_samples,
|
|
batch_size=7,
|
|
drop_last=True,
|
|
sort_key=lambda x: len(x["text"]),
|
|
bucket_size_multiplier=bucket_size_multiplier,
|
|
)
|
|
|
|
|
|
min_text_len_in_bucket = 0
|
|
bucket_items = []
|
|
for batch_idx, batch in enumerate(list(sampler)):
|
|
if (batch_idx + 1) % bucket_size_multiplier == 0:
|
|
for bucket_item in bucket_items:
|
|
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
|
|
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
|
|
min_text_len_in_bucket = 0
|
|
bucket_items = []
|
|
else:
|
|
bucket_items += batch
|
|
|
|
|
|
self.assertEqual(len(sampler), len(train_samples) // 7)
|
|
|