Spaces:
Running
Running
import os | |
import time | |
import numpy as np | |
import torch | |
import librosa | |
from diffusion.logger.saver import Saver | |
from diffusion.logger import utils | |
from torch import autocast | |
from torch.cuda.amp import GradScaler | |
def test(args, model, vocoder, loader_test, saver): | |
print(' [*] testing...') | |
model.eval() | |
# losses | |
test_loss = 0. | |
# intialization | |
num_batches = len(loader_test) | |
rtf_all = [] | |
# run | |
with torch.no_grad(): | |
for bidx, data in enumerate(loader_test): | |
fn = data['name'][0].split("/")[-1] | |
speaker = data['name'][0].split("/")[-2] | |
print('--------') | |
print('{}/{} - {}'.format(bidx, num_batches, fn)) | |
# unpack data | |
for k in data.keys(): | |
if not k.startswith('name'): | |
data[k] = data[k].to(args.device) | |
print('>>', data['name'][0]) | |
# forward | |
st_time = time.time() | |
mel = model( | |
data['units'], | |
data['f0'], | |
data['volume'], | |
data['spk_id'], | |
gt_spec=None, | |
infer=True, | |
infer_speedup=args.infer.speedup, | |
method=args.infer.method) | |
signal = vocoder.infer(mel, data['f0']) | |
ed_time = time.time() | |
# RTF | |
run_time = ed_time - st_time | |
song_time = signal.shape[-1] / args.data.sampling_rate | |
rtf = run_time / song_time | |
print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) | |
rtf_all.append(rtf) | |
# loss | |
for i in range(args.train.batch_size): | |
loss = model( | |
data['units'], | |
data['f0'], | |
data['volume'], | |
data['spk_id'], | |
gt_spec=data['mel'], | |
infer=False) | |
test_loss += loss.item() | |
# log mel | |
saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel) | |
# log audi | |
path_audio = data['name_ext'][0] | |
audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate) | |
if len(audio.shape) > 1: | |
audio = librosa.to_mono(audio) | |
audio = torch.from_numpy(audio).unsqueeze(0).to(signal) | |
saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal}) | |
# report | |
test_loss /= args.train.batch_size | |
test_loss /= num_batches | |
# check | |
print(' [test_loss] test_loss:', test_loss) | |
print(' Real Time Factor', np.mean(rtf_all)) | |
return test_loss | |
def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test): | |
# saver | |
saver = Saver(args, initial_global_step=initial_global_step) | |
# model size | |
params_count = utils.get_network_paras_amount({'model': model}) | |
saver.log_info('--- model size ---') | |
saver.log_info(params_count) | |
# run | |
num_batches = len(loader_train) | |
model.train() | |
saver.log_info('======= start training =======') | |
scaler = GradScaler() | |
if args.train.amp_dtype == 'fp32': | |
dtype = torch.float32 | |
elif args.train.amp_dtype == 'fp16': | |
dtype = torch.float16 | |
elif args.train.amp_dtype == 'bf16': | |
dtype = torch.bfloat16 | |
else: | |
raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) | |
saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step") | |
for epoch in range(args.train.epochs): | |
for batch_idx, data in enumerate(loader_train): | |
saver.global_step_increment() | |
optimizer.zero_grad() | |
# unpack data | |
for k in data.keys(): | |
if not k.startswith('name'): | |
data[k] = data[k].to(args.device) | |
# forward | |
if dtype == torch.float32: | |
loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'], | |
aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False) | |
else: | |
with autocast(device_type=args.device, dtype=dtype): | |
loss = model(data['units'], data['f0'], data['volume'], data['spk_id'], | |
aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False) | |
# handle nan loss | |
if torch.isnan(loss): | |
raise ValueError(' [x] nan loss ') | |
else: | |
# backpropagate | |
if dtype == torch.float32: | |
loss.backward() | |
optimizer.step() | |
else: | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
scheduler.step() | |
# log loss | |
if saver.global_step % args.train.interval_log == 0: | |
current_lr = optimizer.param_groups[0]['lr'] | |
saver.log_info( | |
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( | |
epoch, | |
batch_idx, | |
num_batches, | |
args.env.expdir, | |
args.train.interval_log/saver.get_interval_time(), | |
current_lr, | |
loss.item(), | |
saver.get_total_time(), | |
saver.global_step | |
) | |
) | |
saver.log_value({ | |
'train/loss': loss.item() | |
}) | |
saver.log_value({ | |
'train/lr': current_lr | |
}) | |
# validation | |
if saver.global_step % args.train.interval_val == 0: | |
optimizer_save = optimizer if args.train.save_opt else None | |
# save latest | |
saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}') | |
last_val_step = saver.global_step - args.train.interval_val | |
if last_val_step % args.train.interval_force_save != 0: | |
saver.delete_model(postfix=f'{last_val_step}') | |
# run testing set | |
test_loss = test(args, model, vocoder, loader_test, saver) | |
# log loss | |
saver.log_info( | |
' --- <validation> --- \nloss: {:.3f}. '.format( | |
test_loss, | |
) | |
) | |
saver.log_value({ | |
'validation/loss': test_loss | |
}) | |
model.train() | |