KdaiP's picture
Upload 238 files
d358e26 verified
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6,7'
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import itertools
from models.model import Vocos
from dataset import VocosDataset
from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from models.loss import feature_loss, generator_loss, discriminator_loss
from utils.audio import LogMelSpectrogram
from config import MelConfig, VocosConfig, TrainConfig
from utils.scheduler import get_cosine_schedule_with_warmup
from utils.load import continue_training
torch.backends.cudnn.benchmark = True
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig):
if vocos_config.input_channels != mel_config.n_mels:
raise ValueError("input_channels and n_mels must be equal.")
if not os.path.exists(train_config.model_save_path):
print(f'Creating {train_config.model_save_path}')
os.makedirs(train_config.model_save_path, exist_ok=True)
def train(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
vocos_config = VocosConfig()
mel_config = MelConfig()
train_config = TrainConfig()
_init_config(vocos_config, mel_config, train_config)
generator = Vocos(vocos_config, mel_config).to(rank)
mpd = MultiPeriodDiscriminator().to(rank)
mrd = MultiResolutionDiscriminator().to(rank)
mel_extractor = LogMelSpectrogram(mel_config).to(rank)
generator = DDP(generator, device_ids=[rank])
mpd = DDP(mpd, device_ids=[rank])
mrd = DDP(mrd, device_ids=[rank])
train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config)
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False)
if rank == 0:
writer = SummaryWriter(train_config.log_dir)
optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate)
optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate)
scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
# load latest checkpoints if possible
current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g)
generator.train()
mpd.train()
mrd.train()
for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times
train_dataloader.sampler.set_epoch(epoch)
if rank == 0:
dataloader = tqdm(train_dataloader)
else:
dataloader = train_dataloader
for batch_idx, datas in enumerate(dataloader):
datas = [data.to(rank, non_blocking=True) for data in datas]
audios, mels = datas
audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size]
mels_fake = mel_extractor(audios_fake).squeeze(1) # shape: [batch_size, n_mels, segment_size // hop_length]
optimizer_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000)
optimizer_d.step()
scheduler_d.step()
# generator
optimizer_g.zero_grad()
loss_mel = torch.nn.functional.l1_loss(mels, mels_fake) * 45
# MPD loss
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
# MRD loss
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000)
optimizer_g.step()
scheduler_g.step()
if rank == 0 and batch_idx % train_config.log_interval == 0:
steps = epoch * len(dataloader) + batch_idx
writer.add_scalar("training/gen_loss_total", loss_gen_all, steps)
writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
writer.add_scalar("training/mel_loss", loss_mel.item(), steps)
writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps)
writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps)
writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps)
writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
if rank == 0:
torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt'))
torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt'))
torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt'))
torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt'))
torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt'))
print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}")
cleanup()
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)