Spaces:
Build error
Build error
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
"""Main training script.""" | |
import os | |
import numpy as np | |
import tensorflow as tf | |
import dnnlib | |
import dnnlib.tflib as tflib | |
from dnnlib.tflib.autosummary import autosummary | |
import config | |
import train | |
from training import dataset | |
from training import misc | |
from metrics import metric_base | |
#---------------------------------------------------------------------------- | |
# Just-in-time processing of training images before feeding them to the networks. | |
def process_reals(x, lod, mirror_augment, drange_data, drange_net): | |
with tf.name_scope('ProcessReals'): | |
with tf.name_scope('DynamicRange'): | |
x = tf.cast(x, tf.float32) | |
x = misc.adjust_dynamic_range(x, drange_data, drange_net) | |
if mirror_augment: | |
with tf.name_scope('MirrorAugment'): | |
s = tf.shape(x) | |
mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) | |
mask = tf.tile(mask, [1, s[1], s[2], s[3]]) | |
x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) | |
with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail. | |
s = tf.shape(x) | |
y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) | |
y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) | |
y = tf.tile(y, [1, 1, 1, 2, 1, 2]) | |
y = tf.reshape(y, [-1, s[1], s[2], s[3]]) | |
x = tflib.lerp(x, y, lod - tf.floor(lod)) | |
with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks. | |
s = tf.shape(x) | |
factor = tf.cast(2 ** tf.floor(lod), tf.int32) | |
x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) | |
x = tf.tile(x, [1, 1, 1, factor, 1, factor]) | |
x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) | |
return x | |
#---------------------------------------------------------------------------- | |
# Evaluate time-varying training parameters. | |
def training_schedule( | |
cur_nimg, | |
training_set, | |
num_gpus, | |
lod_initial_resolution = 4, # Image resolution used at the beginning. | |
lod_training_kimg = 600, # Thousands of real images to show before doubling the resolution. | |
lod_transition_kimg = 600, # Thousands of real images to show when fading in new layers. | |
minibatch_base = 16, # Maximum minibatch size, divided evenly among GPUs. | |
minibatch_dict = {}, # Resolution-specific overrides. | |
max_minibatch_per_gpu = {}, # Resolution-specific maximum minibatch size per GPU. | |
G_lrate_base = 0.001, # Learning rate for the generator. | |
G_lrate_dict = {}, # Resolution-specific overrides. | |
D_lrate_base = 0.001, # Learning rate for the discriminator. | |
D_lrate_dict = {}, # Resolution-specific overrides. | |
lrate_rampup_kimg = 0, # Duration of learning rate ramp-up. | |
tick_kimg_base = 160, # Default interval of progress snapshots. | |
tick_kimg_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:30, 1024:20}): # Resolution-specific overrides. | |
# Initialize result dict. | |
s = dnnlib.EasyDict() | |
s.kimg = cur_nimg / 1000.0 | |
# Training phase. | |
phase_dur = lod_training_kimg + lod_transition_kimg | |
phase_idx = int(np.floor(s.kimg / phase_dur)) if phase_dur > 0 else 0 | |
phase_kimg = s.kimg - phase_idx * phase_dur | |
# Level-of-detail and resolution. | |
s.lod = training_set.resolution_log2 | |
s.lod -= np.floor(np.log2(lod_initial_resolution)) | |
s.lod -= phase_idx | |
if lod_transition_kimg > 0: | |
s.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg | |
s.lod = max(s.lod, 0.0) | |
s.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(s.lod))) | |
# Minibatch size. | |
s.minibatch = minibatch_dict.get(s.resolution, minibatch_base) | |
s.minibatch -= s.minibatch % num_gpus | |
if s.resolution in max_minibatch_per_gpu: | |
s.minibatch = min(s.minibatch, max_minibatch_per_gpu[s.resolution] * num_gpus) | |
# Learning rate. | |
s.G_lrate = G_lrate_dict.get(s.resolution, G_lrate_base) | |
s.D_lrate = D_lrate_dict.get(s.resolution, D_lrate_base) | |
if lrate_rampup_kimg > 0: | |
rampup = min(s.kimg / lrate_rampup_kimg, 1.0) | |
s.G_lrate *= rampup | |
s.D_lrate *= rampup | |
# Other parameters. | |
s.tick_kimg = tick_kimg_dict.get(s.resolution, tick_kimg_base) | |
return s | |
#---------------------------------------------------------------------------- | |
# Main training script. | |
def training_loop( | |
submit_config, | |
G_args = {}, # Options for generator network. | |
D_args = {}, # Options for discriminator network. | |
G_opt_args = {}, # Options for generator optimizer. | |
D_opt_args = {}, # Options for discriminator optimizer. | |
G_loss_args = {}, # Options for generator loss. | |
D_loss_args = {}, # Options for discriminator loss. | |
dataset_args = {}, # Options for dataset.load_dataset(). | |
sched_args = {}, # Options for train.TrainingSchedule. | |
grid_args = {}, # Options for train.setup_snapshot_image_grid(). | |
metric_arg_list = [], # Options for MetricGroup. | |
tf_config = {}, # Options for tflib.init_tf(). | |
G_smoothing_kimg = 10.0, # Half-life of the running average of generator weights. | |
D_repeats = 1, # How many times the discriminator is trained per G iteration. | |
minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters. | |
reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced? | |
total_kimg = 15000, # Total length of the training, measured in thousands of real images. | |
mirror_augment = False, # Enable mirror augment? | |
drange_net = [-1,1], # Dynamic range used when feeding image data to the networks. | |
image_snapshot_ticks = 1, # How often to export image snapshots? | |
network_snapshot_ticks = 10, # How often to export network snapshots? | |
save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file? | |
save_weight_histograms = False, # Include weight histograms in the tfevents file? | |
resume_run_id = None, # Run ID or network pkl to resume training from, None = start from scratch. | |
resume_snapshot = None, # Snapshot index to resume training from, None = autodetect. | |
resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule. | |
resume_time = 0.0): # Assumed wallclock time at the beginning. Affects reporting. | |
# Initialize dnnlib and TensorFlow. | |
ctx = dnnlib.RunContext(submit_config, train) | |
tflib.init_tf(tf_config) | |
# Load training set. | |
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) | |
# Construct networks. | |
with tf.device('/gpu:0'): | |
if resume_run_id is not None: | |
network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) | |
print('Loading networks from "%s"...' % network_pkl) | |
G, D, Gs = misc.load_pkl(network_pkl) | |
else: | |
print('Constructing networks...') | |
G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args) | |
D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args) | |
Gs = G.clone('Gs') | |
G.print_layers(); D.print_layers() | |
print('Building TensorFlow graph...') | |
with tf.name_scope('Inputs'), tf.device('/cpu:0'): | |
lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) | |
lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) | |
minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) | |
minibatch_split = minibatch_in // submit_config.num_gpus | |
Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 | |
G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args) | |
D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args) | |
for gpu in range(submit_config.num_gpus): | |
with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): | |
G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') | |
D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') | |
lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] | |
reals, labels = training_set.get_minibatch_tf() | |
reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net) | |
with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): | |
G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args) | |
with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): | |
D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args) | |
G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) | |
D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) | |
G_train_op = G_opt.apply_updates() | |
D_train_op = D_opt.apply_updates() | |
Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta) | |
with tf.device('/gpu:0'): | |
try: | |
peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() | |
except tf.errors.NotFoundError: | |
peak_gpu_mem_op = tf.constant(0) | |
print('Setting up snapshot image grid...') | |
grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args) | |
sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) | |
grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) | |
print('Setting up run dir...') | |
misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) | |
misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) | |
summary_log = tf.summary.FileWriter(submit_config.run_dir) | |
if save_tf_graph: | |
summary_log.add_graph(tf.get_default_graph()) | |
if save_weight_histograms: | |
G.setup_weight_histograms(); D.setup_weight_histograms() | |
metrics = metric_base.MetricGroup(metric_arg_list) | |
print('Training...\n') | |
ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) | |
maintenance_time = ctx.get_last_update_interval() | |
cur_nimg = int(resume_kimg * 1000) | |
cur_tick = 0 | |
tick_start_nimg = cur_nimg | |
prev_lod = -1.0 | |
while cur_nimg < total_kimg * 1000: | |
if ctx.should_stop(): break | |
# Choose training parameters and configure training ops. | |
sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) | |
training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) | |
if reset_opt_for_new_lod: | |
if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): | |
G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() | |
prev_lod = sched.lod | |
# Run training ops. | |
for _mb_repeat in range(minibatch_repeats): | |
for _D_repeat in range(D_repeats): | |
tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) | |
cur_nimg += sched.minibatch | |
tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) | |
# Perform maintenance tasks once per tick. | |
done = (cur_nimg >= total_kimg * 1000) | |
if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: | |
cur_tick += 1 | |
tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 | |
tick_start_nimg = cur_nimg | |
tick_time = ctx.get_time_since_last_update() | |
total_time = ctx.get_time_since_start() + resume_time | |
# Report progress. | |
print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( | |
autosummary('Progress/tick', cur_tick), | |
autosummary('Progress/kimg', cur_nimg / 1000.0), | |
autosummary('Progress/lod', sched.lod), | |
autosummary('Progress/minibatch', sched.minibatch), | |
dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), | |
autosummary('Timing/sec_per_tick', tick_time), | |
autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), | |
autosummary('Timing/maintenance_sec', maintenance_time), | |
autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) | |
autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) | |
autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) | |
# Save snapshots. | |
if cur_tick % image_snapshot_ticks == 0 or done: | |
grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) | |
misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) | |
if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: | |
pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) | |
misc.save_pkl((G, D, Gs), pkl) | |
metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) | |
# Update summaries and RunContext. | |
metrics.update_autosummaries() | |
tflib.autosummary.save_summaries(summary_log, cur_nimg) | |
ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) | |
maintenance_time = ctx.get_last_update_interval() - tick_time | |
# Write final results. | |
misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) | |
summary_log.close() | |
ctx.close() | |
#---------------------------------------------------------------------------- | |