FcF-Inpainting / training /training_loop.py
praeclarumjj3's picture
:zap: Build App
9eae6e7
import os
import time
import copy
import json
import pickle
import psutil
import PIL.Image
import numpy as np
import torch
import dnnlib
from torch_utils import misc
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix
import legacy
import warnings
warnings.filterwarnings("ignore")
from colorama import init
from colorama import Fore, Style
from icecream import ic
init(autoreset=True)
from etaprogress.progress import ProgressBar
import sys
import matplotlib.pyplot as plt
from evaluate import save_gen, create_folders
from metrics.evaluation.data import PrecomputedInpaintingResultsDataset
from metrics.evaluation.evaluator import InpaintingEvaluator
from metrics.evaluation.losses.base_loss import FIDScore
from metrics.evaluation.utils import load_yaml
#----------------------------------------------------------------------------
def setup_snapshot_image_grid(training_set, random_seed=0):
rnd = np.random.RandomState(random_seed)
gw = np.clip(5120 // training_set.image_shape[2], 0, 1)
gh = np.clip(5120 // training_set.image_shape[1], 10, 30)
# No labels => show random subset of training samples.
if not training_set.has_labels:
all_indices = list(range(len(training_set)))
rnd.shuffle(all_indices)
grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
else:
# Group training samples by label.
label_groups = dict() # label => [idx, ...]
for idx in range(len(training_set)):
label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
if label not in label_groups:
label_groups[label] = []
label_groups[label].append(idx)
# Reorder.
label_order = sorted(label_groups.keys())
for label in label_order:
rnd.shuffle(label_groups[label])
# Organize into grid.
grid_indices = []
for y in range(gh):
label = label_order[y % len(label_order)]
indices = label_groups[label]
grid_indices += [indices[x % len(indices)] for x in range(gw)]
label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
# Load data.
images, masks, labels = zip(*[training_set[i] for i in grid_indices])
return (gw, gh), np.stack(images), np.stack(masks), np.stack(labels)
#----------------------------------------------------------------------------
def save_image_grid(img, erased_img, inv_mask, pred_img, fname, drange, grid_size):
lo, hi = (0, 255)
model_lo, model_hi = drange
img = np.asarray(img, dtype=np.float32)
img = (img - lo) * (255 / (hi - lo))
img = np.rint(img).clip(0, 255).astype(np.uint8)
inv_mask = np.squeeze(np.stack([inv_mask]*3, axis=1))
inv_mask = np.asarray(inv_mask, dtype=np.float32)
inv_mask = np.rint(inv_mask).clip(0, 1).astype(np.uint8)
erased_img = np.asarray(erased_img, dtype=np.float32)
erased_img = (erased_img - lo) * (255 / (hi - lo))
erased_img = np.rint(erased_img).clip(0, 255).astype(np.uint8)
pred_img = np.asarray(pred_img, dtype=np.float32)
pred_img = (pred_img - model_lo) * (255 / (model_hi - model_lo))
pred_img = np.rint(pred_img).clip(0, 255).astype(np.uint8)
comp_img = img * (1 - inv_mask) + pred_img * inv_mask
f_img = np.concatenate((img, inv_mask * 255, erased_img, pred_img, comp_img), axis=1)
gw, gh = grid_size
gw *= f_img.shape[1] // 3
_N, C, H, W = img.shape
f_img = f_img.reshape(gh, gw, C, H, W)
f_img = f_img.transpose(0, 3, 1, 4, 2)
f_img = f_img.reshape(gh * H, gw * W, C)
assert C in [1, 3]
if C == 1:
PIL.Image.fromarray(f_img[:, :, 0], 'L').save(fname + '.png')
if C == 3:
PIL.Image.fromarray(f_img, 'RGB').save(fname + '.png')
#----------------------------------------------------------------------------
def training_loop(
run_dir = '.', # Output directory.
eval_img_data = None, # Evaluation Image data
resolution = 256, # Resolution of evaluation image
training_set_kwargs = {}, # Options for training set.
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
G_kwargs = {}, # Options for generator network.
D_kwargs = {}, # Options for discriminator network.
G_opt_kwargs = {}, # Options for generator optimizer.
D_opt_kwargs = {}, # Options for discriminator optimizer.
augment_kwargs = None, # Options for augmentation pipeline. None = disable.
loss_kwargs = {}, # Options for loss function.
metrics = [], # Metrics to evaluate during training.
random_seed = 0, # Global random seed.
num_gpus = 1, # Number of GPUs participating in the training.
rank = 0, # Rank of the current process in [0, num_gpus[.
batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
batch_gpu = 4, # Number of samples processed at a time by one GPU.
ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
ema_rampup = None, # EMA ramp-up coefficient.
G_reg_interval = None, # How often to perform regularization for G? None = disable lazy regularization.
D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
augment_p = 0, # Initial value of augmentation probability.
ada_target = None, # ADA target value. None = fixed p.
ada_interval = 4, # How often to perform ADA adjustment?
ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
total_kimg = 25000, # Total length of the training, measured in thousands of real images.
kimg_per_tick = 4, # Progress snapshot interval.
image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
resume_pkl = None, # Network pickle to resume training from.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
):
# Initialize.
start_time = time.time()
device = torch.device('cuda', rank)
np.random.seed(random_seed * num_gpus + rank)
torch.manual_seed(random_seed * num_gpus + rank)
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
conv2d_gradfix.enabled = True # Improves training speed.
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
eval_config = load_yaml('metrics/configs/eval2_gpu.yaml')
# Load training set.
if rank == 0:
print(Fore.GREEN + 'Loading training set...')
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
training_loader = torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)
training_set_iterator = iter(training_loader)
if rank == 0:
print()
print(Fore.GREEN + 'Num images: ', len(training_set))
print(Fore.GREEN + 'Image shape:', training_set.image_shape)
print(Fore.GREEN + 'Label shape:', training_set.label_shape)
print()
# Construct networks.
if rank == 0:
print('Constructing networks...')
common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Modul
G_ema = copy.deepcopy(G).eval()
# Resume from existing pickle.
if (resume_pkl is not None) and (rank == 0):
print(f'Resuming from "{resume_pkl}"')
with dnnlib.util.open_url(resume_pkl) as f:
resume_data = legacy.load_network_pkl(f)
for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
# Print network parameters
if rank == 0:
netG_params = sum(p.numel() for p in G.parameters())
print(Fore.GREEN +"Generator Params: {} M".format(netG_params/1e6))
netD_params = sum(p.numel() for p in D.parameters())
print(Fore.GREEN +"Discriminator Params: {} M".format(netD_params/1e6))
# Setup augmentation.
if rank == 0:
print(Fore.YELLOW + 'Setting up augmentation...')
augment_pipe = None
ada_stats = None
if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
augment_pipe.p.copy_(torch.as_tensor(augment_p))
if ada_target is not None:
ada_stats = training_stats.Collector(regex='Loss/signs/real')
# Distribute across GPUs.
if rank == 0:
print(Fore.CYAN + f'Distributing across {num_gpus} GPUs...')
ddp_modules = dict()
for name, module in [('G_encoder', G.encoder), ('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
module.requires_grad_(True)
module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False, find_unused_parameters=True)
module.requires_grad_(False)
if name is not None:
ddp_modules[name] = module
# Setup training phases.
if rank == 0:
print('Setting up training phases...')
loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.losses.loss.Loss
phases = []
for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
if reg_interval is None:
opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
else: # Lazy regularization.
mb_ratio = reg_interval / (reg_interval + 1)
opt_kwargs = dnnlib.EasyDict(opt_kwargs)
opt_kwargs.lr = opt_kwargs.lr * mb_ratio
opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
for phase in phases:
phase.start_event = None
phase.end_event = None
if rank == 0:
phase.start_event = torch.cuda.Event(enable_timing=True)
phase.end_event = torch.cuda.Event(enable_timing=True)
# Export sample images.
grid_size = None
grid_c = None
if rank == 0:
print('Exporting sample images...')
grid_size, images, masks, labels = setup_snapshot_image_grid(training_set=training_set)
erased_images = images * (1 - masks)
grid_img = (torch.from_numpy(images).to(torch.float32) / 127.5 - 1).to(device)
grid_mask = torch.from_numpy(masks).to(torch.float32).to(device)
grid_erased_img = grid_img * (1 - grid_mask)
grid_img = grid_img.split(batch_gpu)
grid_mask = grid_mask.split(batch_gpu)
grid_erased_img = grid_erased_img.split(batch_gpu)
grid_c = torch.from_numpy(labels).to(torch.float32).to(device).split(batch_gpu)
pred_images = torch.cat([G_ema(img=torch.cat([0.5 - mask, erased_img], dim=1), c=c, noise_mode='const').cpu() for erased_img, mask, c in zip(grid_erased_img, grid_mask, grid_c)])
save_image_grid(images, erased_images, masks, pred_images.detach().numpy(), os.path.join(run_dir, 'run_init'), drange=[-1,1], grid_size=grid_size)
# Initialize logs.
if rank == 0:
print('Initializing logs...')
stats_collector = training_stats.Collector(regex='.*')
stats_metrics = dict()
stats_jsonl = None
stats_tfevents = None
if rank == 0:
stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
try:
import torch.utils.tensorboard as tensorboard
stats_tfevents = tensorboard.SummaryWriter(run_dir)
except ImportError as err:
print('Skipping tfevents export:', err)
# Train.
if rank == 0:
print(Fore.GREEN + Style.BRIGHT + f'Training for {total_kimg} kimg...')
print()
total = total_kimg * 1000
bar = ProgressBar(total, max_width=80)
cur_nimg = 0
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
batch_idx = 0
if progress_fn is not None:
progress_fn(0, total_kimg)
while True:
# Fetch training data.
with torch.autograd.profiler.record_function('data_fetch'):
phase_real_imgs, phase_masks, phase_real_cs = next(training_set_iterator)
# phase_erased_img = ((phase_real_imgs * (1 - phase_masks)).to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
phase_real_img = (phase_real_imgs.to(device).to(torch.float32) / 127.5 - 1)
phase_inv_mask = (phase_masks.to(device).to(torch.float32))
phase_erased_img = phase_real_img * (1 - phase_inv_mask)
phase_erased_img = phase_erased_img.split(batch_gpu)
phase_real_img = phase_real_img.split(batch_gpu)
phase_inv_mask = phase_inv_mask.split(batch_gpu)
phase_real_c = phase_real_cs.to(device).split(batch_gpu)
all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
# Execute training phases.
for phase, phase_gen_c in zip(phases, all_gen_c):
if batch_idx % phase.interval != 0:
continue
# Initialize gradient accumulation.
if phase.start_event is not None:
phase.start_event.record(torch.cuda.current_stream(device))
phase.opt.zero_grad(set_to_none=True)
phase.module.requires_grad_(True)
# Accumulate gradients over multiple rounds.
for round_idx, (erased_img, real_img, mask, real_c, gen_c) in enumerate(zip(phase_erased_img, phase_real_img, phase_inv_mask, phase_real_c, phase_gen_c)):
sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
gain = phase.interval
loss.accumulate_gradients(phase=phase.name, erased_img=erased_img, real_img=real_img, mask=mask, real_c=real_c, gen_c=gen_c, sync=sync, gain=gain)
# Update weights.
phase.module.requires_grad_(False)
with torch.autograd.profiler.record_function(phase.name + '_opt'):
for param in phase.module.parameters():
if param.grad is not None:
misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
phase.opt.step()
if phase.end_event is not None:
phase.end_event.record(torch.cuda.current_stream(device))
# Update G_ema.
with torch.autograd.profiler.record_function('Gema'):
ema_nimg = ema_kimg * 1000
if ema_rampup is not None:
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
for p_ema, p in zip(G_ema.parameters(), G.parameters()):
p_ema.copy_(p.lerp(p_ema, ema_beta))
for b_ema, b in zip(G_ema.buffers(), G.buffers()):
b_ema.copy_(b)
# Update state.
cur_nimg += batch_size
batch_idx += 1
if rank == 0:
bar.numerator = cur_nimg
print(bar, end='\r')
# Execute ADA heuristic.
if (ada_stats is not None) and (batch_idx % ada_interval == 0):
ada_stats.update()
adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
# Perform maintenance tasks once per tick.
done = (cur_nimg >= total_kimg * 1000)
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
continue
# Print status line, accumulating the same information in stats_collector.
tick_end_time = time.time()
fields = []
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
fields += [f"cpumem GB {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
fields += [f"gpumem GB {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
torch.cuda.reset_peak_memory_stats()
fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.4f}"]
training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
if rank == 0:
print(Fore.CYAN + Style.BRIGHT + ' '.join(fields))
# Check for abort.
if (not done) and (abort_fn is not None) and abort_fn():
done = True
if rank == 0:
print()
print(Fore.RED + 'Aborting...')
# Save network snapshot.
snapshot_pkl = None
snapshot_data = None
if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0) and cur_tick is not 0:
snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
if module is not None:
if num_gpus > 1:
misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
snapshot_data[name] = module
del module # conserve memory
snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
if rank == 0:
with open(snapshot_pkl, 'wb') as f:
pickle.dump(snapshot_data, f)
if (snapshot_data is not None) and metrics and (done or cur_tick % network_snapshot_ticks == 0) and cur_tick is not 0:
msk_type = eval_img_data.split('/')[-1]
if rank == 0:
create_folders(msk_type)
label = torch.zeros([1, snapshot_data['G_ema'].c_dim]).to(device)
save_gen(snapshot_data['G_ema'], rank, num_gpus, device, eval_img_data, resolution, label, 1, msk_type)
if rank == 0:
eval_dataset = PrecomputedInpaintingResultsDataset(eval_img_data, f'fid_gens/{msk_type}', **eval_config.dataset_kwargs)
metrics = {
'fid': FIDScore()
}
evaluator = InpaintingEvaluator(eval_dataset, scores=metrics, area_grouping=False,
integral_title='lpips_fid100_f1', integral_func=None,
**eval_config.evaluator_kwargs)
results = evaluator.dist_evaluate(device, num_gpus=1, rank=0)
fid_score = round(results[('fid', 'total')]['mean'], 5)
stats_metrics.update({'fid': fid_score})
print(Fore.GREEN + Style.BRIGHT + f' FID Score: {fid_score}')
del snapshot_data # conserve memory
# Save image snapshot.
if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
pred_images = torch.cat([G_ema(img=torch.cat([0.5 - mask, erased_img], dim=1), c=c, noise_mode='const').cpu() for erased_img, mask, c in zip(grid_erased_img, grid_mask, grid_c)])
save_image_grid(images, erased_images, masks, pred_images.detach().numpy(), os.path.join(run_dir, f'run_{cur_nimg//1000:06d}'), drange=[-1,1], grid_size=grid_size)
# Collect statistics.
for phase in phases:
value = []
if (phase.start_event is not None) and (phase.end_event is not None):
phase.end_event.synchronize()
value = phase.start_event.elapsed_time(phase.end_event)
training_stats.report0('Timing/' + phase.name, value)
stats_collector.update()
stats_dict = stats_collector.as_dict()
if rank == 0:
losses = []
for key in stats_dict.keys():
if 'Loss/D' in key or 'Loss/G' in key:
losses += [f"{key}: {(stats_dict[key]['mean']):<.4f}"]
print(Fore.MAGENTA + Style.BRIGHT + ' '.join(losses))
# Update logs.
timestamp = time.time()
if stats_jsonl is not None:
fields = dict(stats_dict, timestamp=timestamp)
stats_jsonl.write(json.dumps(fields) + '\n')
stats_jsonl.flush()
if stats_tfevents is not None:
global_step = int(cur_nimg / 1e3)
walltime = timestamp - start_time
for name, value in stats_dict.items():
stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
for name, value in stats_metrics.items():
stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
stats_tfevents.flush()
if progress_fn is not None:
progress_fn(cur_nimg // 1000, total_kimg)
# Update state.
cur_tick += 1
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - tick_end_time
if rank == 0:
sys.stdout.flush()
if done:
break
# Done.
if rank == 0:
print()
print(Fore.YELLOW + 'Exiting...')
#----------------------------------------------------------------------------