import os os.environ['CUDA_VISIBLE_DEVICES'] = '6,7' import torch import torch.nn as nn import torch.optim as optim import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import itertools from models.model import Vocos from dataset import VocosDataset from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator from models.loss import feature_loss, generator_loss, discriminator_loss from utils.audio import LogMelSpectrogram from config import MelConfig, VocosConfig, TrainConfig from utils.scheduler import get_cosine_schedule_with_warmup from utils.load import continue_training torch.backends.cudnn.benchmark = True def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig): if vocos_config.input_channels != mel_config.n_mels: raise ValueError("input_channels and n_mels must be equal.") if not os.path.exists(train_config.model_save_path): print(f'Creating {train_config.model_save_path}') os.makedirs(train_config.model_save_path, exist_ok=True) def train(rank, world_size): setup(rank, world_size) torch.cuda.set_device(rank) vocos_config = VocosConfig() mel_config = MelConfig() train_config = TrainConfig() _init_config(vocos_config, mel_config, train_config) generator = Vocos(vocos_config, mel_config).to(rank) mpd = MultiPeriodDiscriminator().to(rank) mrd = MultiResolutionDiscriminator().to(rank) mel_extractor = LogMelSpectrogram(mel_config).to(rank) generator = DDP(generator, device_ids=[rank]) mpd = DDP(mpd, device_ids=[rank]) mrd = DDP(mrd, device_ids=[rank]) train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config) train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False) if rank == 0: writer = SummaryWriter(train_config.log_dir) optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate) optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate) scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader)) scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader)) # load latest checkpoints if possible current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g) generator.train() mpd.train() mrd.train() for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times train_dataloader.sampler.set_epoch(epoch) if rank == 0: dataloader = tqdm(train_dataloader) else: dataloader = train_dataloader for batch_idx, datas in enumerate(dataloader): datas = [data.to(rank, non_blocking=True) for data in datas] audios, mels = datas audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size] mels_fake = mel_extractor(audios_fake).squeeze(1) # shape: [batch_size, n_mels, segment_size // hop_length] optimizer_d.zero_grad() # MPD y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) # MRD y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f loss_disc_all.backward() grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000) grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000) optimizer_d.step() scheduler_d.step() # generator optimizer_g.zero_grad() loss_mel = torch.nn.functional.l1_loss(mels, mels_fake) * 45 # MPD loss y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) # MRD loss y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel loss_gen_all.backward() grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000) optimizer_g.step() scheduler_g.step() if rank == 0 and batch_idx % train_config.log_interval == 0: steps = epoch * len(dataloader) + batch_idx writer.add_scalar("training/gen_loss_total", loss_gen_all, steps) writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps) writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps) writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps) writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps) writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps) writer.add_scalar("training/mel_loss", loss_mel.item(), steps) writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps) writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps) writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps) writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps) writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps) if rank == 0: torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt')) torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt')) torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt')) torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt')) torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt')) print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}") cleanup() torch.set_num_threads(1) torch.set_num_interop_threads(1) if __name__ == "__main__": world_size = torch.cuda.device_count() torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)