Spaces:
Runtime error
Runtime error
# Copyright 2023 (authors: Feiteng Li) | |
# | |
# See ../../../../LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import argparse | |
import inspect | |
import logging | |
from functools import lru_cache | |
from pathlib import Path | |
from typing import Any, Dict, Optional | |
import torch | |
# from icefall.utils import str2bool | |
# from lhotse import CutSet, load_manifest_lazy | |
# from lhotse.dataset import ( | |
# CutConcatenate, | |
# DynamicBucketingSampler, | |
# PrecomputedFeatures, | |
# SingleCutSampler, | |
# SpecAugment, | |
# ) | |
# from lhotse.dataset.input_strategies import OnTheFlyFeatures | |
# from lhotse.utils import fix_random_seed | |
from torch.utils.data import DataLoader | |
from data.collation import get_text_token_collater | |
# from data.dataset import SpeechSynthesisDataset | |
from data.fbank import get_fbank_extractor | |
from data.input_strategies import PromptedPrecomputedFeatures | |
# PrecomputedFeatures = PrecomputedFeatures | |
class _SeedWorkers: | |
def __init__(self, seed: int): | |
self.seed = seed | |
def __call__(self, worker_id: int): | |
fix_random_seed(self.seed + worker_id) | |
def _get_input_strategy(input_strategy, dataset, cuts): | |
if input_strategy == "PromptedPrecomputedFeatures": | |
return PromptedPrecomputedFeatures(dataset, cuts) | |
return eval(input_strategy)() | |
class TtsDataModule: | |
""" | |
DataModule for VALL-E TTS experiments. | |
It assumes there is always one train and valid dataloader. | |
It contains all the common data pipeline modules used in TTS | |
experiments, e.g.: | |
- dynamic batch size, | |
- bucketing samplers, | |
- cut concatenation[not used & tested yet], | |
- augmentation[not used & tested yet], | |
- on-the-fly feature extraction[not used & tested yet] | |
This class should be derived for specific corpora used in TTS tasks. | |
""" | |
def __init__(self, args: argparse.Namespace): | |
self.args = args | |
def add_arguments(cls, parser: argparse.ArgumentParser): | |
group = parser.add_argument_group( | |
title="TTS data related options", | |
description="These options are used for the preparation of " | |
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " | |
"effective batch sizes, sampling strategies, applied data " | |
"augmentations, etc.", | |
) | |
group.add_argument( | |
"--manifest-dir", | |
type=Path, | |
default=Path("data/tokenized"), | |
help="Path to directory with train/valid/test cuts.", | |
) | |
group.add_argument( | |
"--max-duration", | |
type=int, | |
default=40.0, | |
help="Maximum pooled recordings duration (seconds) in a " | |
"single batch. You can reduce it if it causes CUDA OOM.", | |
) | |
group.add_argument( | |
"--bucketing-sampler", | |
type=str2bool, | |
default=True, | |
help="When enabled, the batches will come from buckets of " | |
"similar duration (saves padding frames).", | |
) | |
group.add_argument( | |
"--num-buckets", | |
type=int, | |
default=10, | |
help="The number of buckets for the DynamicBucketingSampler" | |
"(you might want to increase it for larger datasets).", | |
) | |
group.add_argument( | |
"--concatenate-cuts", | |
type=str2bool, | |
default=False, | |
help="When enabled, utterances (cuts) will be concatenated " | |
"to minimize the amount of padding.", | |
) | |
group.add_argument( | |
"--duration-factor", | |
type=float, | |
default=1.0, | |
help="Determines the maximum duration of a concatenated cut " | |
"relative to the duration of the longest cut in a batch.", | |
) | |
group.add_argument( | |
"--gap", | |
type=float, | |
default=0.1, | |
help="The amount of padding (in seconds) inserted between " | |
"concatenated cuts. This padding is filled with noise when " | |
"noise augmentation is used.", | |
) | |
group.add_argument( | |
"--on-the-fly-feats", | |
type=str2bool, | |
default=False, | |
help="When enabled, use on-the-fly cut mixing and feature " | |
"extraction. Will drop existing precomputed feature manifests " | |
"if available.", | |
) | |
group.add_argument( | |
"--shuffle", | |
type=str2bool, | |
default=True, | |
help="When enabled (=default), the examples will be " | |
"shuffled for each epoch.", | |
) | |
group.add_argument( | |
"--drop-last", | |
type=str2bool, | |
default=False, | |
help="Whether to drop last batch. Used by sampler.", | |
) | |
group.add_argument( | |
"--return-cuts", | |
type=str2bool, | |
default=True, | |
help="When enabled, each batch will have the " | |
"field: batch['supervisions']['cut'] with the cuts that " | |
"were used to construct it.", | |
) | |
group.add_argument( | |
"--num-workers", | |
type=int, | |
default=8, | |
help="The number of training dataloader workers that " | |
"collect the batches.", | |
) | |
group.add_argument( | |
"--enable-spec-aug", | |
type=str2bool, | |
default=False, | |
help="When enabled, use SpecAugment for training dataset.", | |
) | |
group.add_argument( | |
"--spec-aug-time-warp-factor", | |
type=int, | |
default=80, | |
help="Used only when --enable-spec-aug is True. " | |
"It specifies the factor for time warping in SpecAugment. " | |
"Larger values mean more warping. " | |
"A value less than 1 means to disable time warp.", | |
) | |
group.add_argument( | |
"--input-strategy", | |
type=str, | |
default="PrecomputedFeatures", | |
help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures", | |
) | |
group.add_argument( | |
"--dataset", | |
type=str, | |
default="ljspeech", | |
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", | |
) | |
parser.add_argument( | |
"--text-tokens", | |
type=str, | |
default="data/tokenized/unique_text_tokens.k2symbols", | |
help="Path to the unique text tokens file", | |
) | |
parser.add_argument( | |
"--sampling-rate", | |
type=int, | |
default=24000, | |
help="""Audio sampling rate.""", | |
) | |
def train_dataloaders( | |
self, | |
cuts_train: CutSet, | |
sampler_state_dict: Optional[Dict[str, Any]] = None, | |
) -> DataLoader: | |
""" | |
Args: | |
cuts_train: | |
CutSet for training. | |
sampler_state_dict: | |
The state dict for the training sampler. | |
""" | |
transforms = [] | |
if self.args.concatenate_cuts: | |
logging.info( | |
f"Using cut concatenation with duration factor " | |
f"{self.args.duration_factor} and gap {self.args.gap}." | |
) | |
# Cut concatenation should be the first transform in the list, | |
# so that if we e.g. mix noise in, it will fill the gaps between | |
# different utterances. | |
transforms = [ | |
CutConcatenate( | |
duration_factor=self.args.duration_factor, gap=self.args.gap | |
) | |
] + transforms | |
input_transforms = [] | |
if self.args.enable_spec_aug: | |
logging.info("Enable SpecAugment") | |
logging.info( | |
f"Time warp factor: {self.args.spec_aug_time_warp_factor}" | |
) | |
# Set the value of num_frame_masks according to Lhotse's version. | |
# In different Lhotse's versions, the default of num_frame_masks is | |
# different. | |
num_frame_masks = 10 | |
num_frame_masks_parameter = inspect.signature( | |
SpecAugment.__init__ | |
).parameters["num_frame_masks"] | |
if num_frame_masks_parameter.default == 1: | |
num_frame_masks = 2 | |
logging.info(f"Num frame mask: {num_frame_masks}") | |
input_transforms.append( | |
SpecAugment( | |
time_warp_factor=self.args.spec_aug_time_warp_factor, | |
num_frame_masks=num_frame_masks, | |
features_mask_size=27, | |
num_feature_masks=2, | |
frames_mask_size=100, | |
) | |
) | |
else: | |
logging.info("Disable SpecAugment") | |
logging.info("About to create train dataset") | |
if self.args.on_the_fly_feats: | |
# NOTE: the PerturbSpeed transform should be added only if we | |
# remove it from data prep stage. | |
# Add on-the-fly speed perturbation; since originally it would | |
# have increased epoch size by 3, we will apply prob 2/3 and use | |
# 3x more epochs. | |
# Speed perturbation probably should come first before | |
# concatenation, but in principle the transforms order doesn't have | |
# to be strict (e.g. could be randomized) | |
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa | |
# Drop feats to be on the safe side. | |
train = SpeechSynthesisDataset( | |
get_text_token_collater(self.args.text_tokens), | |
cut_transforms=transforms, | |
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), | |
feature_transforms=input_transforms, | |
) | |
else: | |
train = SpeechSynthesisDataset( | |
get_text_token_collater(self.args.text_tokens), | |
feature_input_strategy=_get_input_strategy( | |
self.args.input_strategy, self.args.dataset, cuts_train | |
), | |
cut_transforms=transforms, | |
feature_transforms=input_transforms, | |
) | |
if self.args.bucketing_sampler: | |
logging.info("Using DynamicBucketingSampler") | |
train_sampler = DynamicBucketingSampler( | |
cuts_train, | |
max_duration=self.args.max_duration, | |
shuffle=self.args.shuffle, | |
num_buckets=self.args.num_buckets, | |
drop_last=self.args.drop_last, | |
) | |
else: | |
logging.info( | |
"Using SingleCutSampler and sort by duraton(ascending=True)." | |
) | |
cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True) | |
train_sampler = SingleCutSampler( | |
cuts_train, | |
max_duration=self.args.max_duration, | |
shuffle=self.args.shuffle, | |
) | |
logging.info("About to create train dataloader") | |
if sampler_state_dict is not None: | |
logging.info("Loading sampler state dict") | |
train_sampler.load_state_dict(sampler_state_dict) | |
# 'seed' is derived from the current random state, which will have | |
# previously been set in the main process. | |
seed = torch.randint(0, 100000, ()).item() | |
worker_init_fn = _SeedWorkers(seed) | |
train_dl = DataLoader( | |
train, | |
sampler=train_sampler, | |
batch_size=None, | |
num_workers=self.args.num_workers, | |
persistent_workers=False, | |
worker_init_fn=worker_init_fn, | |
) | |
return train_dl | |
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: | |
logging.info("About to create dev dataset") | |
if self.args.on_the_fly_feats: | |
validate = SpeechSynthesisDataset( | |
get_text_token_collater(self.args.text_tokens), | |
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), | |
cut_transforms=[], | |
) | |
else: | |
validate = SpeechSynthesisDataset( | |
get_text_token_collater(self.args.text_tokens), | |
feature_input_strategy=_get_input_strategy( | |
self.args.input_strategy, self.args.dataset, cuts_valid | |
), | |
cut_transforms=[], | |
) | |
valid_sampler = DynamicBucketingSampler( | |
cuts_valid, | |
max_duration=self.args.max_duration, | |
shuffle=False, | |
) | |
logging.info("About to create dev dataloader") | |
valid_dl = DataLoader( | |
validate, | |
sampler=valid_sampler, | |
batch_size=None, | |
num_workers=4, | |
persistent_workers=False, | |
) | |
return valid_dl | |
def test_dataloaders(self, cuts: CutSet) -> DataLoader: | |
logging.debug("About to create test dataset") | |
test = SpeechSynthesisDataset( | |
get_text_token_collater(self.args.text_tokens), | |
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()) | |
if self.args.on_the_fly_feats | |
else _get_input_strategy( | |
self.args.input_strategy, self.args.dataset, cuts | |
), | |
cut_transforms=[], | |
) | |
sampler = DynamicBucketingSampler( | |
cuts, | |
max_duration=self.args.max_duration, | |
shuffle=False, | |
) | |
logging.debug("About to create test dataloader") | |
test_dl = DataLoader( | |
test, | |
batch_size=None, | |
sampler=sampler, | |
num_workers=self.args.num_workers, | |
) | |
return test_dl | |
def train_cuts(self) -> CutSet: | |
logging.info("About to get train cuts") | |
return load_manifest_lazy( | |
self.args.manifest_dir / "cuts_train.jsonl.gz" | |
) | |
def dev_cuts(self) -> CutSet: | |
logging.info("About to get dev cuts") | |
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") | |
def test_cuts(self) -> CutSet: | |
logging.info("About to get test cuts") | |
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") | |