Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
#coding:utf-8 | |
import os | |
import os.path as osp | |
import re | |
import sys | |
import yaml | |
import shutil | |
import numpy as np | |
import paddle | |
import click | |
import warnings | |
warnings.simplefilter('ignore') | |
from functools import reduce | |
from munch import Munch | |
from starganv2vc_paddle.meldataset import build_dataloader | |
from starganv2vc_paddle.optimizers import build_optimizer | |
from starganv2vc_paddle.models import build_model | |
from starganv2vc_paddle.trainer import Trainer | |
from visualdl import LogWriter | |
from starganv2vc_paddle.Utils.ASR.models import ASRCNN | |
from starganv2vc_paddle.Utils.JDC.model import JDCNet | |
import logging | |
from logging import StreamHandler | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
handler = StreamHandler() | |
handler.setLevel(logging.DEBUG) | |
logger.addHandler(handler) | |
def main(config_path): | |
config = yaml.safe_load(open(config_path)) | |
log_dir = config['log_dir'] | |
if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True) | |
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) | |
writer = LogWriter(log_dir + "/visualdl") | |
# write logs | |
file_handler = logging.FileHandler(osp.join(log_dir, 'train.log')) | |
file_handler.setLevel(logging.DEBUG) | |
file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s')) | |
logger.addHandler(file_handler) | |
batch_size = config.get('batch_size', 10) | |
epochs = config.get('epochs', 1000) | |
save_freq = config.get('save_freq', 20) | |
train_path = config.get('train_data', None) | |
val_path = config.get('val_data', None) | |
stage = config.get('stage', 'star') | |
fp16_run = config.get('fp16_run', False) | |
# load data | |
train_list, val_list = get_data_path_list(train_path, val_path) | |
train_dataloader = build_dataloader(train_list, | |
batch_size=batch_size, | |
num_workers=4) | |
val_dataloader = build_dataloader(val_list, | |
batch_size=batch_size, | |
validation=True, | |
num_workers=2) | |
# load pretrained ASR model | |
ASR_config = config.get('ASR_config', False) | |
ASR_path = config.get('ASR_path', False) | |
with open(ASR_config) as f: | |
ASR_config = yaml.safe_load(f) | |
ASR_model_config = ASR_config['model_params'] | |
ASR_model = ASRCNN(**ASR_model_config) | |
params = paddle.load(ASR_path)['model'] | |
ASR_model.set_state_dict(params) | |
_ = ASR_model.eval() | |
# load pretrained F0 model | |
F0_path = config.get('F0_path', False) | |
F0_model = JDCNet(num_class=1, seq_len=192) | |
params = paddle.load(F0_path)['net'] | |
F0_model.set_state_dict(params) | |
# build model | |
model, model_ema = build_model(Munch(config['model_params']), F0_model, ASR_model) | |
scheduler_params = { | |
"max_lr": float(config['optimizer_params'].get('lr', 2e-4)), | |
"pct_start": float(config['optimizer_params'].get('pct_start', 0.0)), | |
"epochs": epochs, | |
"steps_per_epoch": len(train_dataloader), | |
} | |
scheduler_params_dict = {key: scheduler_params.copy() for key in model} | |
scheduler_params_dict['mapping_network']['max_lr'] = 2e-6 | |
optimizer = build_optimizer({key: model[key].parameters() for key in model}, | |
scheduler_params_dict=scheduler_params_dict) | |
trainer = Trainer(args=Munch(config['loss_params']), model=model, | |
model_ema=model_ema, | |
optimizer=optimizer, | |
train_dataloader=train_dataloader, | |
val_dataloader=val_dataloader, | |
logger=logger, | |
fp16_run=fp16_run) | |
if config.get('pretrained_model', '') != '': | |
trainer.load_checkpoint(config['pretrained_model'], | |
load_only_params=config.get('load_only_params', True)) | |
for _ in range(1, epochs+1): | |
epoch = trainer.epochs | |
train_results = trainer._train_epoch() | |
eval_results = trainer._eval_epoch() | |
results = train_results.copy() | |
results.update(eval_results) | |
logger.info('--- epoch %d ---' % epoch) | |
for key, value in results.items(): | |
if isinstance(value, float): | |
logger.info('%-15s: %.4f' % (key, value)) | |
writer.add_scalar(key, value, epoch) | |
else: | |
for v in value: | |
writer.add_histogram('eval_spec', v, epoch) | |
if (epoch % save_freq) == 0: | |
trainer.save_checkpoint(osp.join(log_dir, 'epoch_%05d.pd' % epoch)) | |
return 0 | |
def get_data_path_list(train_path=None, val_path=None): | |
if train_path is None: | |
train_path = "Data/train_list.txt" | |
if val_path is None: | |
val_path = "Data/val_list.txt" | |
with open(train_path, 'r') as f: | |
train_list = f.readlines() | |
with open(val_path, 'r') as f: | |
val_list = f.readlines() | |
return train_list, val_list | |
if __name__=="__main__": | |
main() | |