Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.distributed as dist | |
from torch import nn | |
from torch.utils.data import DistributedSampler | |
from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler | |
from utils.audio.io import save_wav | |
from utils.commons.base_task import BaseTask | |
from utils.commons.dataset_utils import data_loader | |
from utils.commons.hparams import hparams | |
from utils.commons.tensor_utils import tensors_to_scalars | |
class VocoderBaseTask(BaseTask): | |
def __init__(self): | |
super(VocoderBaseTask, self).__init__() | |
self.max_sentences = hparams['max_sentences'] | |
self.max_valid_sentences = hparams['max_valid_sentences'] | |
if self.max_valid_sentences == -1: | |
hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences | |
self.dataset_cls = VocoderDataset | |
def train_dataloader(self): | |
train_dataset = self.dataset_cls('train', shuffle=True) | |
return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) | |
def val_dataloader(self): | |
valid_dataset = self.dataset_cls('test', shuffle=False) | |
return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) | |
def test_dataloader(self): | |
test_dataset = self.dataset_cls('test', shuffle=False) | |
return self.build_dataloader(test_dataset, False, self.max_valid_sentences) | |
def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): | |
world_size = 1 | |
rank = 0 | |
if dist.is_initialized(): | |
world_size = dist.get_world_size() | |
rank = dist.get_rank() | |
sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler | |
train_sampler = sampler_cls( | |
dataset=dataset, | |
num_replicas=world_size, | |
rank=rank, | |
shuffle=shuffle, | |
) | |
return torch.utils.data.DataLoader( | |
dataset=dataset, | |
shuffle=False, | |
collate_fn=dataset.collater, | |
batch_size=max_sentences, | |
num_workers=dataset.num_workers, | |
sampler=train_sampler, | |
pin_memory=True, | |
) | |
def build_optimizer(self, model): | |
optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'], | |
betas=[hparams['adam_b1'], hparams['adam_b2']]) | |
optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), lr=hparams['lr'], | |
betas=[hparams['adam_b1'], hparams['adam_b2']]) | |
return [optimizer_gen, optimizer_disc] | |
def build_scheduler(self, optimizer): | |
return { | |
"gen": torch.optim.lr_scheduler.StepLR( | |
optimizer=optimizer[0], | |
**hparams["generator_scheduler_params"]), | |
"disc": torch.optim.lr_scheduler.StepLR( | |
optimizer=optimizer[1], | |
**hparams["discriminator_scheduler_params"]), | |
} | |
def validation_step(self, sample, batch_idx): | |
outputs = {} | |
total_loss, loss_output = self._training_step(sample, batch_idx, 0) | |
outputs['losses'] = tensors_to_scalars(loss_output) | |
outputs['total_loss'] = tensors_to_scalars(total_loss) | |
if self.global_step % hparams['valid_infer_interval'] == 0 and \ | |
batch_idx < 10: | |
mels = sample['mels'] | |
y = sample['wavs'] | |
f0 = sample['f0'] | |
y_ = self.model_gen(mels, f0) | |
for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): | |
wav_pred = wav_pred / wav_pred.abs().max() | |
if self.global_step == 0: | |
wav_gt = wav_gt / wav_gt.abs().max() | |
self.logger.add_audio(f'wav_{batch_idx}_{idx}_gt', wav_gt, self.global_step, | |
hparams['audio_sample_rate']) | |
self.logger.add_audio(f'wav_{batch_idx}_{idx}_pred', wav_pred, self.global_step, | |
hparams['audio_sample_rate']) | |
return outputs | |
def test_start(self): | |
self.gen_dir = os.path.join(hparams['work_dir'], | |
f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') | |
os.makedirs(self.gen_dir, exist_ok=True) | |
def test_step(self, sample, batch_idx): | |
mels = sample['mels'] | |
y = sample['wavs'] | |
f0 = sample['f0'] | |
loss_output = {} | |
y_ = self.model_gen(mels, f0) | |
gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') | |
os.makedirs(gen_dir, exist_ok=True) | |
for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): | |
wav_gt = wav_gt.clamp(-1, 1) | |
wav_pred = wav_pred.clamp(-1, 1) | |
save_wav( | |
wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', | |
hparams['audio_sample_rate']) | |
save_wav( | |
wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', | |
hparams['audio_sample_rate']) | |
return loss_output | |
def test_end(self, outputs): | |
return {} | |
def on_before_optimization(self, opt_idx): | |
if opt_idx == 0: | |
nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm']) | |
else: | |
nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"]) | |
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): | |
if optimizer_idx == 0: | |
self.scheduler['gen'].step(self.global_step // hparams['accumulate_grad_batches']) | |
else: | |
self.scheduler['disc'].step(self.global_step // hparams['accumulate_grad_batches']) | |