Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
''' Utilities file | |
This file contains utility functions for bookkeeping, logging, and data loading. | |
Methods which directly affect training should either go in layers, the model, | |
or train_fns.py. | |
''' | |
from __future__ import print_function | |
import sys | |
import os | |
import numpy as np | |
import time | |
import datetime | |
import json | |
import pickle | |
from argparse import ArgumentParser | |
import animal_hash | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import DataLoader | |
import datasets as dset | |
def prepare_parser(): | |
usage = 'Parser for all scripts.' | |
parser = ArgumentParser(description=usage) | |
### Dataset/Dataloader stuff ### | |
parser.add_argument( | |
'--dataset', type=str, default='I128_hdf5', | |
help='Which Dataset to train on, out of I128, I256, C10, C100;' | |
'Append "_hdf5" to use the hdf5 version for ISLVRC ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--augment', action='store_true', default=False, | |
help='Augment with random crops and flips (default: %(default)s)') | |
parser.add_argument( | |
'--num_workers', type=int, default=8, | |
help='Number of dataloader workers; consider using less for HDF5 ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--no_pin_memory', action='store_false', dest='pin_memory', default=True, | |
help='Pin data into memory through dataloader? (default: %(default)s)') | |
parser.add_argument( | |
'--shuffle', action='store_true', default=False, | |
help='Shuffle the data (strongly recommended)? (default: %(default)s)') | |
parser.add_argument( | |
'--load_in_mem', action='store_true', default=False, | |
help='Load all data into memory? (default: %(default)s)') | |
parser.add_argument( | |
'--use_multiepoch_sampler', action='store_true', default=False, | |
help='Use the multi-epoch sampler for dataloader? (default: %(default)s)') | |
### Model stuff ### | |
parser.add_argument( | |
'--model', type=str, default='BigGAN', | |
help='Name of the model module (default: %(default)s)') | |
parser.add_argument( | |
'--G_param', type=str, default='SN', | |
help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)' | |
' or None (default: %(default)s)') | |
parser.add_argument( | |
'--D_param', type=str, default='SN', | |
help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)' | |
' or None (default: %(default)s)') | |
parser.add_argument( | |
'--G_ch', type=int, default=64, | |
help='Channel multiplier for G (default: %(default)s)') | |
parser.add_argument( | |
'--D_ch', type=int, default=64, | |
help='Channel multiplier for D (default: %(default)s)') | |
parser.add_argument( | |
'--G_depth', type=int, default=1, | |
help='Number of resblocks per stage in G? (default: %(default)s)') | |
parser.add_argument( | |
'--D_depth', type=int, default=1, | |
help='Number of resblocks per stage in D? (default: %(default)s)') | |
parser.add_argument( | |
'--D_thin', action='store_false', dest='D_wide', default=True, | |
help='Use the SN-GAN channel pattern for D? (default: %(default)s)') | |
parser.add_argument( | |
'--G_shared', action='store_true', default=False, | |
help='Use shared embeddings in G? (default: %(default)s)') | |
parser.add_argument( | |
'--shared_dim', type=int, default=0, | |
help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--dim_z', type=int, default=128, | |
help='Noise dimensionality: %(default)s)') | |
parser.add_argument( | |
'--z_var', type=float, default=1.0, | |
help='Noise variance: %(default)s)') | |
parser.add_argument( | |
'--hier', action='store_true', default=False, | |
help='Use hierarchical z in G? (default: %(default)s)') | |
parser.add_argument( | |
'--cross_replica', action='store_true', default=False, | |
help='Cross_replica batchnorm in G?(default: %(default)s)') | |
parser.add_argument( | |
'--mybn', action='store_true', default=False, | |
help='Use my batchnorm (which supports standing stats?) %(default)s)') | |
parser.add_argument( | |
'--G_nl', type=str, default='relu', | |
help='Activation function for G (default: %(default)s)') | |
parser.add_argument( | |
'--D_nl', type=str, default='relu', | |
help='Activation function for D (default: %(default)s)') | |
parser.add_argument( | |
'--G_attn', type=str, default='64', | |
help='What resolutions to use attention on for G (underscore separated) ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--D_attn', type=str, default='64', | |
help='What resolutions to use attention on for D (underscore separated) ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--norm_style', type=str, default='bn', | |
help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], ' | |
'ln [layernorm], gn [groupnorm] (default: %(default)s)') | |
### Model init stuff ### | |
parser.add_argument( | |
'--seed', type=int, default=0, | |
help='Random seed to use; affects both initialization and ' | |
' dataloading. (default: %(default)s)') | |
parser.add_argument( | |
'--G_init', type=str, default='ortho', | |
help='Init style to use for G (default: %(default)s)') | |
parser.add_argument( | |
'--D_init', type=str, default='ortho', | |
help='Init style to use for D(default: %(default)s)') | |
parser.add_argument( | |
'--skip_init', action='store_true', default=False, | |
help='Skip initialization, ideal for testing when ortho init was used ' | |
'(default: %(default)s)') | |
### Optimizer stuff ### | |
parser.add_argument( | |
'--G_lr', type=float, default=5e-5, | |
help='Learning rate to use for Generator (default: %(default)s)') | |
parser.add_argument( | |
'--D_lr', type=float, default=2e-4, | |
help='Learning rate to use for Discriminator (default: %(default)s)') | |
parser.add_argument( | |
'--G_B1', type=float, default=0.0, | |
help='Beta1 to use for Generator (default: %(default)s)') | |
parser.add_argument( | |
'--D_B1', type=float, default=0.0, | |
help='Beta1 to use for Discriminator (default: %(default)s)') | |
parser.add_argument( | |
'--G_B2', type=float, default=0.999, | |
help='Beta2 to use for Generator (default: %(default)s)') | |
parser.add_argument( | |
'--D_B2', type=float, default=0.999, | |
help='Beta2 to use for Discriminator (default: %(default)s)') | |
### Batch size, parallel, and precision stuff ### | |
parser.add_argument( | |
'--batch_size', type=int, default=64, | |
help='Default overall batchsize (default: %(default)s)') | |
parser.add_argument( | |
'--G_batch_size', type=int, default=0, | |
help='Batch size to use for G; if 0, same as D (default: %(default)s)') | |
parser.add_argument( | |
'--num_G_accumulations', type=int, default=1, | |
help='Number of passes to accumulate G''s gradients over ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--num_D_steps', type=int, default=2, | |
help='Number of D steps per G step (default: %(default)s)') | |
parser.add_argument( | |
'--num_D_accumulations', type=int, default=1, | |
help='Number of passes to accumulate D''s gradients over ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--split_D', action='store_true', default=False, | |
help='Run D twice rather than concatenating inputs? (default: %(default)s)') | |
parser.add_argument( | |
'--num_epochs', type=int, default=100, | |
help='Number of epochs to train for (default: %(default)s)') | |
parser.add_argument( | |
'--parallel', action='store_true', default=False, | |
help='Train with multiple GPUs (default: %(default)s)') | |
parser.add_argument( | |
'--G_fp16', action='store_true', default=False, | |
help='Train with half-precision in G? (default: %(default)s)') | |
parser.add_argument( | |
'--D_fp16', action='store_true', default=False, | |
help='Train with half-precision in D? (default: %(default)s)') | |
parser.add_argument( | |
'--D_mixed_precision', action='store_true', default=False, | |
help='Train with half-precision activations but fp32 params in D? ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--G_mixed_precision', action='store_true', default=False, | |
help='Train with half-precision activations but fp32 params in G? ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--accumulate_stats', action='store_true', default=False, | |
help='Accumulate "standing" batchnorm stats? (default: %(default)s)') | |
parser.add_argument( | |
'--num_standing_accumulations', type=int, default=16, | |
help='Number of forward passes to use in accumulating standing stats? ' | |
'(default: %(default)s)') | |
### Bookkeping stuff ### | |
parser.add_argument( | |
'--G_eval_mode', action='store_true', default=False, | |
help='Run G in eval mode (running/standing stats?) at sample/test time? ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--save_every', type=int, default=2000, | |
help='Save every X iterations (default: %(default)s)') | |
parser.add_argument( | |
'--num_save_copies', type=int, default=2, | |
help='How many copies to save (default: %(default)s)') | |
parser.add_argument( | |
'--num_best_copies', type=int, default=2, | |
help='How many previous best checkpoints to save (default: %(default)s)') | |
parser.add_argument( | |
'--which_best', type=str, default='IS', | |
help='Which metric to use to determine when to save new "best"' | |
'checkpoints, one of IS or FID (default: %(default)s)') | |
parser.add_argument( | |
'--no_fid', action='store_true', default=False, | |
help='Calculate IS only, not FID? (default: %(default)s)') | |
parser.add_argument( | |
'--test_every', type=int, default=5000, | |
help='Test every X iterations (default: %(default)s)') | |
parser.add_argument( | |
'--num_inception_images', type=int, default=50000, | |
help='Number of samples to compute inception metrics with ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--hashname', action='store_true', default=False, | |
help='Use a hash of the experiment name instead of the full config ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--base_root', type=str, default='', | |
help='Default location to store all weights, samples, data, and logs ' | |
' (default: %(default)s)') | |
parser.add_argument( | |
'--data_root', type=str, default='data', | |
help='Default location where data is stored (default: %(default)s)') | |
parser.add_argument( | |
'--weights_root', type=str, default='weights', | |
help='Default location to store weights (default: %(default)s)') | |
parser.add_argument( | |
'--logs_root', type=str, default='logs', | |
help='Default location to store logs (default: %(default)s)') | |
parser.add_argument( | |
'--samples_root', type=str, default='samples', | |
help='Default location to store samples (default: %(default)s)') | |
parser.add_argument( | |
'--pbar', type=str, default='mine', | |
help='Type of progressbar to use; one of "mine" or "tqdm" ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--name_suffix', type=str, default='', | |
help='Suffix for experiment name for loading weights for sampling ' | |
'(consider "best0") (default: %(default)s)') | |
parser.add_argument( | |
'--experiment_name', type=str, default='', | |
help='Optionally override the automatic experiment naming with this arg. ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--config_from_name', action='store_true', default=False, | |
help='Use a hash of the experiment name instead of the full config ' | |
'(default: %(default)s)') | |
### EMA Stuff ### | |
parser.add_argument( | |
'--ema', action='store_true', default=False, | |
help='Keep an ema of G''s weights? (default: %(default)s)') | |
parser.add_argument( | |
'--ema_decay', type=float, default=0.9999, | |
help='EMA decay rate (default: %(default)s)') | |
parser.add_argument( | |
'--use_ema', action='store_true', default=False, | |
help='Use the EMA parameters of G for evaluation? (default: %(default)s)') | |
parser.add_argument( | |
'--ema_start', type=int, default=0, | |
help='When to start updating the EMA weights (default: %(default)s)') | |
### Numerical precision and SV stuff ### | |
parser.add_argument( | |
'--adam_eps', type=float, default=1e-8, | |
help='epsilon value to use for Adam (default: %(default)s)') | |
parser.add_argument( | |
'--BN_eps', type=float, default=1e-5, | |
help='epsilon value to use for BatchNorm (default: %(default)s)') | |
parser.add_argument( | |
'--SN_eps', type=float, default=1e-8, | |
help='epsilon value to use for Spectral Norm(default: %(default)s)') | |
parser.add_argument( | |
'--num_G_SVs', type=int, default=1, | |
help='Number of SVs to track in G (default: %(default)s)') | |
parser.add_argument( | |
'--num_D_SVs', type=int, default=1, | |
help='Number of SVs to track in D (default: %(default)s)') | |
parser.add_argument( | |
'--num_G_SV_itrs', type=int, default=1, | |
help='Number of SV itrs in G (default: %(default)s)') | |
parser.add_argument( | |
'--num_D_SV_itrs', type=int, default=1, | |
help='Number of SV itrs in D (default: %(default)s)') | |
### Ortho reg stuff ### | |
parser.add_argument( | |
'--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN | |
help='Modified ortho reg coefficient in G(default: %(default)s)') | |
parser.add_argument( | |
'--D_ortho', type=float, default=0.0, | |
help='Modified ortho reg coefficient in D (default: %(default)s)') | |
parser.add_argument( | |
'--toggle_grads', action='store_true', default=True, | |
help='Toggle D and G''s "requires_grad" settings when not training them? ' | |
' (default: %(default)s)') | |
### Which train function ### | |
parser.add_argument( | |
'--which_train_fn', type=str, default='GAN', | |
help='How2trainyourbois (default: %(default)s)') | |
### Resume training stuff | |
parser.add_argument( | |
'--load_weights', type=str, default='', | |
help='Suffix for which weights to load (e.g. best0, copy0) ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--resume', action='store_true', default=False, | |
help='Resume training? (default: %(default)s)') | |
### Log stuff ### | |
parser.add_argument( | |
'--logstyle', type=str, default='%3.3e', | |
help='What style to use when logging training metrics?' | |
'One of: %#.#f/ %#.#e (float/exp, text),' | |
'pickle (python pickle),' | |
'npz (numpy zip),' | |
'mat (MATLAB .mat file) (default: %(default)s)') | |
parser.add_argument( | |
'--log_G_spectra', action='store_true', default=False, | |
help='Log the top 3 singular values in each SN layer in G? ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--log_D_spectra', action='store_true', default=False, | |
help='Log the top 3 singular values in each SN layer in D? ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--sv_log_interval', type=int, default=10, | |
help='Iteration interval for logging singular values ' | |
' (default: %(default)s)') | |
return parser | |
# Arguments for sample.py; not presently used in train.py | |
def add_sample_parser(parser): | |
parser.add_argument( | |
'--sample_npz', action='store_true', default=False, | |
help='Sample "sample_num_npz" images and save to npz? ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--sample_num_npz', type=int, default=50000, | |
help='Number of images to sample when sampling NPZs ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--sample_sheets', action='store_true', default=False, | |
help='Produce class-conditional sample sheets and stick them in ' | |
'the samples root? (default: %(default)s)') | |
parser.add_argument( | |
'--sample_interps', action='store_true', default=False, | |
help='Produce interpolation sheets and stick them in ' | |
'the samples root? (default: %(default)s)') | |
parser.add_argument( | |
'--sample_sheet_folder_num', type=int, default=-1, | |
help='Number to use for the folder for these sample sheets ' | |
'(default: %(default)s)') | |
parser.add_argument( | |
'--sample_random', action='store_true', default=False, | |
help='Produce a single random sheet? (default: %(default)s)') | |
parser.add_argument( | |
'--sample_trunc_curves', type=str, default='', | |
help='Get inception metrics with a range of variances?' | |
'To use this, specify a startpoint, step, and endpoint, e.g. ' | |
'--sample_trunc_curves 0.2_0.1_1.0 for a startpoint of 0.2, ' | |
'endpoint of 1.0, and stepsize of 1.0. Note that this is ' | |
'not exactly identical to using tf.truncated_normal, but should ' | |
'have approximately the same effect. (default: %(default)s)') | |
parser.add_argument( | |
'--sample_inception_metrics', action='store_true', default=False, | |
help='Calculate Inception metrics with sample.py? (default: %(default)s)') | |
return parser | |
# Convenience dicts | |
dset_dict = {'I32': dset.ImageFolder, 'I64': dset.ImageFolder, | |
'I128': dset.ImageFolder, 'I256': dset.ImageFolder, | |
'I32_hdf5': dset.ILSVRC_HDF5, 'I64_hdf5': dset.ILSVRC_HDF5, | |
'I128_hdf5': dset.ILSVRC_HDF5, 'I256_hdf5': dset.ILSVRC_HDF5, | |
'C10': dset.CIFAR10, 'C100': dset.CIFAR100} | |
imsize_dict = {'I32': 32, 'I32_hdf5': 32, | |
'I64': 64, 'I64_hdf5': 64, | |
'I128': 128, 'I128_hdf5': 128, | |
'I256': 256, 'I256_hdf5': 256, | |
'C10': 32, 'C100': 32} | |
root_dict = {'I32': 'ImageNet', 'I32_hdf5': 'ILSVRC32.hdf5', | |
'I64': 'ImageNet', 'I64_hdf5': 'ILSVRC64.hdf5', | |
'I128': 'ImageNet', 'I128_hdf5': 'ILSVRC128.hdf5', | |
'I256': 'ImageNet', 'I256_hdf5': 'ILSVRC256.hdf5', | |
'C10': 'cifar', 'C100': 'cifar'} | |
nclass_dict = {'I32': 1000, 'I32_hdf5': 1000, | |
'I64': 1000, 'I64_hdf5': 1000, | |
'I128': 1000, 'I128_hdf5': 1000, | |
'I256': 1000, 'I256_hdf5': 1000, | |
'C10': 10, 'C100': 100} | |
# Number of classes to put per sample sheet | |
classes_per_sheet_dict = {'I32': 50, 'I32_hdf5': 50, | |
'I64': 50, 'I64_hdf5': 50, | |
'I128': 20, 'I128_hdf5': 20, | |
'I256': 20, 'I256_hdf5': 20, | |
'C10': 10, 'C100': 100} | |
activation_dict = {'inplace_relu': nn.ReLU(inplace=True), | |
'relu': nn.ReLU(inplace=False), | |
'ir': nn.ReLU(inplace=True),} | |
class CenterCropLongEdge(object): | |
"""Crops the given PIL Image on the long edge. | |
Args: | |
size (sequence or int): Desired output size of the crop. If size is an | |
int instead of sequence like (h, w), a square crop (size, size) is | |
made. | |
""" | |
def __call__(self, img): | |
""" | |
Args: | |
img (PIL Image): Image to be cropped. | |
Returns: | |
PIL Image: Cropped image. | |
""" | |
return transforms.functional.center_crop(img, min(img.size)) | |
def __repr__(self): | |
return self.__class__.__name__ | |
class RandomCropLongEdge(object): | |
"""Crops the given PIL Image on the long edge with a random start point. | |
Args: | |
size (sequence or int): Desired output size of the crop. If size is an | |
int instead of sequence like (h, w), a square crop (size, size) is | |
made. | |
""" | |
def __call__(self, img): | |
""" | |
Args: | |
img (PIL Image): Image to be cropped. | |
Returns: | |
PIL Image: Cropped image. | |
""" | |
size = (min(img.size), min(img.size)) | |
# Only step forward along this edge if it's the long edge | |
i = (0 if size[0] == img.size[0] | |
else np.random.randint(low=0,high=img.size[0] - size[0])) | |
j = (0 if size[1] == img.size[1] | |
else np.random.randint(low=0,high=img.size[1] - size[1])) | |
return transforms.functional.crop(img, i, j, size[0], size[1]) | |
def __repr__(self): | |
return self.__class__.__name__ | |
# multi-epoch Dataset sampler to avoid memory leakage and enable resumption of | |
# training from the same sample regardless of if we stop mid-epoch | |
class MultiEpochSampler(torch.utils.data.Sampler): | |
r"""Samples elements randomly over multiple epochs | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
num_epochs (int) : Number of times to loop over the dataset | |
start_itr (int) : which iteration to begin from | |
""" | |
def __init__(self, data_source, num_epochs, start_itr=0, batch_size=128): | |
self.data_source = data_source | |
self.num_samples = len(self.data_source) | |
self.num_epochs = num_epochs | |
self.start_itr = start_itr | |
self.batch_size = batch_size | |
if not isinstance(self.num_samples, int) or self.num_samples <= 0: | |
raise ValueError("num_samples should be a positive integeral " | |
"value, but got num_samples={}".format(self.num_samples)) | |
def __iter__(self): | |
n = len(self.data_source) | |
# Determine number of epochs | |
num_epochs = int(np.ceil((n * self.num_epochs | |
- (self.start_itr * self.batch_size)) / float(n))) | |
# Sample all the indices, and then grab the last num_epochs index sets; | |
# This ensures if we're starting at epoch 4, we're still grabbing epoch 4's | |
# indices | |
out = [torch.randperm(n) for epoch in range(self.num_epochs)][-num_epochs:] | |
# Ignore the first start_itr % n indices of the first epoch | |
out[0] = out[0][(self.start_itr * self.batch_size % n):] | |
# if self.replacement: | |
# return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) | |
# return iter(.tolist()) | |
output = torch.cat(out).tolist() | |
print('Length dataset output is %d' % len(output)) | |
return iter(output) | |
def __len__(self): | |
return len(self.data_source) * self.num_epochs - self.start_itr * self.batch_size | |
# Convenience function to centralize all data loaders | |
def get_data_loaders(dataset, data_root=None, augment=False, batch_size=64, | |
num_workers=8, shuffle=True, load_in_mem=False, hdf5=False, | |
pin_memory=True, drop_last=True, start_itr=0, | |
num_epochs=500, use_multiepoch_sampler=False, | |
**kwargs): | |
# Append /FILENAME.hdf5 to root if using hdf5 | |
data_root += '/%s' % root_dict[dataset] | |
print('Using dataset root location %s' % data_root) | |
which_dataset = dset_dict[dataset] | |
norm_mean = [0.5,0.5,0.5] | |
norm_std = [0.5,0.5,0.5] | |
image_size = imsize_dict[dataset] | |
# For image folder datasets, name of the file where we store the precomputed | |
# image locations to avoid having to walk the dirs every time we load. | |
dataset_kwargs = {'index_filename': '%s_imgs.npz' % dataset} | |
# HDF5 datasets have their own inbuilt transform, no need to train_transform | |
if 'hdf5' in dataset: | |
train_transform = None | |
else: | |
if augment: | |
print('Data will be augmented...') | |
if dataset in ['C10', 'C100']: | |
train_transform = [transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip()] | |
else: | |
train_transform = [RandomCropLongEdge(), | |
transforms.Resize(image_size), | |
transforms.RandomHorizontalFlip()] | |
else: | |
print('Data will not be augmented...') | |
if dataset in ['C10', 'C100']: | |
train_transform = [] | |
else: | |
train_transform = [CenterCropLongEdge(), transforms.Resize(image_size)] | |
# train_transform = [transforms.Resize(image_size), transforms.CenterCrop] | |
train_transform = transforms.Compose(train_transform + [ | |
transforms.ToTensor(), | |
transforms.Normalize(norm_mean, norm_std)]) | |
train_set = which_dataset(root=data_root, transform=train_transform, | |
load_in_mem=load_in_mem, **dataset_kwargs) | |
# Prepare loader; the loaders list is for forward compatibility with | |
# using validation / test splits. | |
loaders = [] | |
if use_multiepoch_sampler: | |
print('Using multiepoch sampler from start_itr %d...' % start_itr) | |
loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} | |
sampler = MultiEpochSampler(train_set, num_epochs, start_itr, batch_size) | |
train_loader = DataLoader(train_set, batch_size=batch_size, | |
sampler=sampler, **loader_kwargs) | |
else: | |
loader_kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory, | |
'drop_last': drop_last} # Default, drop last incomplete batch | |
train_loader = DataLoader(train_set, batch_size=batch_size, | |
shuffle=shuffle, **loader_kwargs) | |
loaders.append(train_loader) | |
return loaders | |
# Utility file to seed rngs | |
def seed_rng(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
np.random.seed(seed) | |
# Utility to peg all roots to a base root | |
# If a base root folder is provided, peg all other root folders to it. | |
def update_config_roots(config): | |
if config['base_root']: | |
print('Pegging all root folders to base root %s' % config['base_root']) | |
for key in ['data', 'weights', 'logs', 'samples']: | |
config['%s_root' % key] = '%s/%s' % (config['base_root'], key) | |
return config | |
# Utility to prepare root folders if they don't exist; parent folder must exist | |
def prepare_root(config): | |
for key in ['weights_root', 'logs_root', 'samples_root']: | |
if not os.path.exists(config[key]): | |
print('Making directory %s for %s...' % (config[key], key)) | |
os.mkdir(config[key]) | |
# Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using | |
# the parameters() and buffers() module functions, but for now this works | |
# with state_dicts using .copy_ | |
class ema(object): | |
def __init__(self, source, target, decay=0.9999, start_itr=0): | |
self.source = source | |
self.target = target | |
self.decay = decay | |
# Optional parameter indicating what iteration to start the decay at | |
self.start_itr = start_itr | |
# Initialize target's params to be source's | |
self.source_dict = self.source.state_dict() | |
self.target_dict = self.target.state_dict() | |
print('Initializing EMA parameters to be source parameters...') | |
with torch.no_grad(): | |
for key in self.source_dict: | |
self.target_dict[key].data.copy_(self.source_dict[key].data) | |
# target_dict[key].data = source_dict[key].data # Doesn't work! | |
def update(self, itr=None): | |
# If an iteration counter is provided and itr is less than the start itr, | |
# peg the ema weights to the underlying weights. | |
if itr and itr < self.start_itr: | |
decay = 0.0 | |
else: | |
decay = self.decay | |
with torch.no_grad(): | |
for key in self.source_dict: | |
self.target_dict[key].data.copy_(self.target_dict[key].data * decay | |
+ self.source_dict[key].data * (1 - decay)) | |
# Apply modified ortho reg to a model | |
# This function is an optimized version that directly computes the gradient, | |
# instead of computing and then differentiating the loss. | |
def ortho(model, strength=1e-4, blacklist=[]): | |
with torch.no_grad(): | |
for param in model.parameters(): | |
# Only apply this to parameters with at least 2 axes, and not in the blacklist | |
if len(param.shape) < 2 or any([param is item for item in blacklist]): | |
continue | |
w = param.view(param.shape[0], -1) | |
grad = (2 * torch.mm(torch.mm(w, w.t()) | |
* (1. - torch.eye(w.shape[0], device=w.device)), w)) | |
param.grad.data += strength * grad.view(param.shape) | |
# Default ortho reg | |
# This function is an optimized version that directly computes the gradient, | |
# instead of computing and then differentiating the loss. | |
def default_ortho(model, strength=1e-4, blacklist=[]): | |
with torch.no_grad(): | |
for param in model.parameters(): | |
# Only apply this to parameters with at least 2 axes & not in blacklist | |
if len(param.shape) < 2 or param in blacklist: | |
continue | |
w = param.view(param.shape[0], -1) | |
grad = (2 * torch.mm(torch.mm(w, w.t()) | |
- torch.eye(w.shape[0], device=w.device), w)) | |
param.grad.data += strength * grad.view(param.shape) | |
# Convenience utility to switch off requires_grad | |
def toggle_grad(model, on_or_off): | |
for param in model.parameters(): | |
param.requires_grad = on_or_off | |
# Function to join strings or ignore them | |
# Base string is the string to link "strings," while strings | |
# is a list of strings or Nones. | |
def join_strings(base_string, strings): | |
return base_string.join([item for item in strings if item]) | |
# Save a model's weights, optimizer, and the state_dict | |
def save_weights(G, D, state_dict, weights_root, experiment_name, | |
name_suffix=None, G_ema=None): | |
root = '/'.join([weights_root, experiment_name]) | |
if not os.path.exists(root): | |
os.mkdir(root) | |
if name_suffix: | |
print('Saving weights to %s/%s...' % (root, name_suffix)) | |
else: | |
print('Saving weights to %s...' % root) | |
torch.save(G.state_dict(), | |
'%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))) | |
torch.save(G.optim.state_dict(), | |
'%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))) | |
torch.save(D.state_dict(), | |
'%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))) | |
torch.save(D.optim.state_dict(), | |
'%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))) | |
torch.save(state_dict, | |
'%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix]))) | |
if G_ema is not None: | |
torch.save(G_ema.state_dict(), | |
'%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))) | |
# Load a model's weights, optimizer, and the state_dict | |
def load_weights(G, D, state_dict, weights_root, experiment_name, | |
name_suffix=None, G_ema=None, strict=True, load_optim=True): | |
root = '/'.join([weights_root, experiment_name]) | |
if name_suffix: | |
print('Loading %s weights from %s...' % (name_suffix, root)) | |
else: | |
print('Loading weights from %s...' % root) | |
if G is not None: | |
G.load_state_dict( | |
torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))), | |
strict=strict) | |
if load_optim: | |
G.optim.load_state_dict( | |
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))) | |
if D is not None: | |
D.load_state_dict( | |
torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))), | |
strict=strict) | |
if load_optim: | |
D.optim.load_state_dict( | |
torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))) | |
# Load state dict | |
for item in state_dict: | |
state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item] | |
if G_ema is not None: | |
G_ema.load_state_dict( | |
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))), | |
strict=strict) | |
''' MetricsLogger originally stolen from VoxNet source code. | |
Used for logging inception metrics''' | |
class MetricsLogger(object): | |
def __init__(self, fname, reinitialize=False): | |
self.fname = fname | |
self.reinitialize = reinitialize | |
if os.path.exists(self.fname): | |
if self.reinitialize: | |
print('{} exists, deleting...'.format(self.fname)) | |
os.remove(self.fname) | |
def log(self, record=None, **kwargs): | |
""" | |
Assumption: no newlines in the input. | |
""" | |
if record is None: | |
record = {} | |
record.update(kwargs) | |
record['_stamp'] = time.time() | |
with open(self.fname, 'a') as f: | |
f.write(json.dumps(record, ensure_ascii=True) + '\n') | |
# Logstyle is either: | |
# '%#.#f' for floating point representation in text | |
# '%#.#e' for exponent representation in text | |
# 'npz' for output to npz # NOT YET SUPPORTED | |
# 'pickle' for output to a python pickle # NOT YET SUPPORTED | |
# 'mat' for output to a MATLAB .mat file # NOT YET SUPPORTED | |
class MyLogger(object): | |
def __init__(self, fname, reinitialize=False, logstyle='%3.3f'): | |
self.root = fname | |
if not os.path.exists(self.root): | |
os.mkdir(self.root) | |
self.reinitialize = reinitialize | |
self.metrics = [] | |
self.logstyle = logstyle # One of '%3.3f' or like '%3.3e' | |
# Delete log if re-starting and log already exists | |
def reinit(self, item): | |
if os.path.exists('%s/%s.log' % (self.root, item)): | |
if self.reinitialize: | |
# Only print the removal mess | |
if 'sv' in item : | |
if not any('sv' in item for item in self.metrics): | |
print('Deleting singular value logs...') | |
else: | |
print('{} exists, deleting...'.format('%s_%s.log' % (self.root, item))) | |
os.remove('%s/%s.log' % (self.root, item)) | |
# Log in plaintext; this is designed for being read in MATLAB(sorry not sorry) | |
def log(self, itr, **kwargs): | |
for arg in kwargs: | |
if arg not in self.metrics: | |
if self.reinitialize: | |
self.reinit(arg) | |
self.metrics += [arg] | |
if self.logstyle == 'pickle': | |
print('Pickle not currently supported...') | |
# with open('%s/%s.log' % (self.root, arg), 'a') as f: | |
# pickle.dump(kwargs[arg], f) | |
elif self.logstyle == 'mat': | |
print('.mat logstyle not currently supported...') | |
else: | |
with open('%s/%s.log' % (self.root, arg), 'a') as f: | |
f.write('%d: %s\n' % (itr, self.logstyle % kwargs[arg])) | |
# Write some metadata to the logs directory | |
def write_metadata(logs_root, experiment_name, config, state_dict): | |
with open(('%s/%s/metalog.txt' % | |
(logs_root, experiment_name)), 'w') as writefile: | |
writefile.write('datetime: %s\n' % str(datetime.datetime.now())) | |
writefile.write('config: %s\n' % str(config)) | |
writefile.write('state: %s\n' %str(state_dict)) | |
""" | |
Very basic progress indicator to wrap an iterable in. | |
Author: Jan SchlΓΌter | |
Andy's adds: time elapsed in addition to ETA, makes it possible to add | |
estimated time to 1k iters instead of estimated time to completion. | |
""" | |
def progress(items, desc='', total=None, min_delay=0.1, displaytype='s1k'): | |
""" | |
Returns a generator over `items`, printing the number and percentage of | |
items processed and the estimated remaining processing time before yielding | |
the next item. `total` gives the total number of items (required if `items` | |
has no length), and `min_delay` gives the minimum time in seconds between | |
subsequent prints. `desc` gives an optional prefix text (end with a space). | |
""" | |
total = total or len(items) | |
t_start = time.time() | |
t_last = 0 | |
for n, item in enumerate(items): | |
t_now = time.time() | |
if t_now - t_last > min_delay: | |
print("\r%s%d/%d (%6.2f%%)" % ( | |
desc, n+1, total, n / float(total) * 100), end=" ") | |
if n > 0: | |
if displaytype == 's1k': # minutes/seconds for 1000 iters | |
next_1000 = n + (1000 - n%1000) | |
t_done = t_now - t_start | |
t_1k = t_done / n * next_1000 | |
outlist = list(divmod(t_done, 60)) + list(divmod(t_1k - t_done, 60)) | |
print("(TE/ET1k: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") | |
else:# displaytype == 'eta': | |
t_done = t_now - t_start | |
t_total = t_done / n * total | |
outlist = list(divmod(t_done, 60)) + list(divmod(t_total - t_done, 60)) | |
print("(TE/ETA: %d:%02d / %d:%02d)" % tuple(outlist), end=" ") | |
sys.stdout.flush() | |
t_last = t_now | |
yield item | |
t_total = time.time() - t_start | |
print("\r%s%d/%d (100.00%%) (took %d:%02d)" % ((desc, total, total) + | |
divmod(t_total, 60))) | |
# Sample function for use with inception metrics | |
def sample(G, z_, y_, config): | |
with torch.no_grad(): | |
z_.sample_() | |
y_.sample_() | |
if config['parallel']: | |
G_z = nn.parallel.data_parallel(G, (z_, G.shared(y_))) | |
else: | |
G_z = G(z_, G.shared(y_)) | |
return G_z, y_ | |
# Sample function for sample sheets | |
def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel, | |
samples_root, experiment_name, folder_number, z_=None): | |
# Prepare sample directory | |
if not os.path.isdir('%s/%s' % (samples_root, experiment_name)): | |
os.mkdir('%s/%s' % (samples_root, experiment_name)) | |
if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)): | |
os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)) | |
# loop over total number of sheets | |
for i in range(num_classes // classes_per_sheet): | |
ims = [] | |
y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda') | |
for j in range(samples_per_class): | |
if (z_ is not None) and hasattr(z_, 'sample_') and classes_per_sheet <= z_.size(0): | |
z_.sample_() | |
else: | |
z_ = torch.randn(classes_per_sheet, G.dim_z, device='cuda') | |
with torch.no_grad(): | |
if parallel: | |
o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y))) | |
else: | |
o = G(z_[:classes_per_sheet], G.shared(y)) | |
ims += [o.data.cpu()] | |
# This line should properly unroll the images | |
out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2], | |
ims[0].shape[3]).data.float().cpu() | |
# The path for the samples | |
image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name, | |
folder_number, i) | |
torchvision.utils.save_image(out_ims, image_filename, | |
nrow=samples_per_class, normalize=True) | |
# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) | |
def interp(x0, x1, num_midpoints): | |
lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) | |
return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) | |
# interp sheet function | |
# Supports full, class-wise and intra-class interpolation | |
def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel, | |
samples_root, experiment_name, folder_number, sheet_number=0, | |
fix_z=False, fix_y=False, device='cuda'): | |
# Prepare zs and ys | |
if fix_z: # If fix Z, only sample 1 z per row | |
zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device) | |
zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z) | |
else: | |
zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device), | |
torch.randn(num_per_sheet, 1, G.dim_z, device=device), | |
num_midpoints).view(-1, G.dim_z) | |
if fix_y: # If fix y, only sample 1 z per row | |
ys = sample_1hot(num_per_sheet, num_classes) | |
ys = G.shared(ys).view(num_per_sheet, 1, -1) | |
ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1) | |
else: | |
ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1), | |
G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1), | |
num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1) | |
# Run the net--note that we've already passed y through G.shared. | |
if G.fp16: | |
zs = zs.half() | |
with torch.no_grad(): | |
if parallel: | |
out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu() | |
else: | |
out_ims = G(zs, ys).data.cpu() | |
interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '') | |
image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name, | |
folder_number, interp_style, | |
sheet_number) | |
torchvision.utils.save_image(out_ims, image_filename, | |
nrow=num_midpoints + 2, normalize=True) | |
# Convenience debugging function to print out gradnorms and shape from each layer | |
# May need to rewrite this so we can actually see which parameter is which | |
def print_grad_norms(net): | |
gradsums = [[float(torch.norm(param.grad).item()), | |
float(torch.norm(param).item()), param.shape] | |
for param in net.parameters()] | |
order = np.argsort([item[0] for item in gradsums]) | |
print(['%3.3e,%3.3e, %s' % (gradsums[item_index][0], | |
gradsums[item_index][1], | |
str(gradsums[item_index][2])) | |
for item_index in order]) | |
# Get singular values to log. This will use the state dict to find them | |
# and substitute underscores for dots. | |
def get_SVs(net, prefix): | |
d = net.state_dict() | |
return {('%s_%s' % (prefix, key)).replace('.', '_') : | |
float(d[key].item()) | |
for key in d if 'sv' in key} | |
# Name an experiment based on its config | |
def name_from_config(config): | |
name = '_'.join([ | |
item for item in [ | |
'Big%s' % config['which_train_fn'], | |
config['dataset'], | |
config['model'] if config['model'] != 'BigGAN' else None, | |
'seed%d' % config['seed'], | |
'Gch%d' % config['G_ch'], | |
'Dch%d' % config['D_ch'], | |
'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None, | |
'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None, | |
'bs%d' % config['batch_size'], | |
'Gfp16' if config['G_fp16'] else None, | |
'Dfp16' if config['D_fp16'] else None, | |
'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None, | |
'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None, | |
'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None, | |
'Glr%2.1e' % config['G_lr'], | |
'Dlr%2.1e' % config['D_lr'], | |
'GB%3.3f' % config['G_B1'] if config['G_B1'] !=0.0 else None, | |
'GBB%3.3f' % config['G_B2'] if config['G_B2'] !=0.999 else None, | |
'DB%3.3f' % config['D_B1'] if config['D_B1'] !=0.0 else None, | |
'DBB%3.3f' % config['D_B2'] if config['D_B2'] !=0.999 else None, | |
'Gnl%s' % config['G_nl'], | |
'Dnl%s' % config['D_nl'], | |
'Ginit%s' % config['G_init'], | |
'Dinit%s' % config['D_init'], | |
'G%s' % config['G_param'] if config['G_param'] != 'SN' else None, | |
'D%s' % config['D_param'] if config['D_param'] != 'SN' else None, | |
'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None, | |
'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None, | |
'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None, | |
'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None, | |
config['norm_style'] if config['norm_style'] != 'bn' else None, | |
'cr' if config['cross_replica'] else None, | |
'Gshared' if config['G_shared'] else None, | |
'hier' if config['hier'] else None, | |
'ema' if config['ema'] else None, | |
config['name_suffix'] if config['name_suffix'] else None, | |
] | |
if item is not None]) | |
# dogball | |
if config['hashname']: | |
return hashname(name) | |
else: | |
return name | |
# A simple function to produce a unique experiment name from the animal hashes. | |
def hashname(name): | |
h = hash(name) | |
a = h % len(animal_hash.a) | |
h = h // len(animal_hash.a) | |
b = h % len(animal_hash.b) | |
h = h // len(animal_hash.c) | |
c = h % len(animal_hash.c) | |
return animal_hash.a[a] + animal_hash.b[b] + animal_hash.c[c] | |
# Get GPU memory, -i is the index | |
def query_gpu(indices): | |
os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv') | |
# Convenience function to count the number of parameters in a module | |
def count_parameters(module): | |
print('Number of parameters: {}'.format( | |
sum([p.data.nelement() for p in module.parameters()]))) | |
# Convenience function to sample an index, not actually a 1-hot | |
def sample_1hot(batch_size, num_classes, device='cuda'): | |
return torch.randint(low=0, high=num_classes, size=(batch_size,), | |
device=device, dtype=torch.int64, requires_grad=False) | |
# A highly simplified convenience class for sampling from distributions | |
# One could also use PyTorch's inbuilt distributions package. | |
# Note that this class requires initialization to proceed as | |
# x = Distribution(torch.randn(size)) | |
# x.init_distribution(dist_type, **dist_kwargs) | |
# x = x.to(device,dtype) | |
# This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 | |
class Distribution(torch.Tensor): | |
# Init the params of the distribution | |
def init_distribution(self, dist_type, **kwargs): | |
self.dist_type = dist_type | |
self.dist_kwargs = kwargs | |
if self.dist_type == 'normal': | |
self.mean, self.var = kwargs['mean'], kwargs['var'] | |
elif self.dist_type == 'categorical': | |
self.num_categories = kwargs['num_categories'] | |
def sample_(self): | |
if self.dist_type == 'normal': | |
self.normal_(self.mean, self.var) | |
elif self.dist_type == 'categorical': | |
self.random_(0, self.num_categories) | |
# return self.variable | |
# Silly hack: overwrite the to() method to wrap the new object | |
# in a distribution as well | |
def to(self, *args, **kwargs): | |
new_obj = Distribution(self) | |
new_obj.init_distribution(self.dist_type, **self.dist_kwargs) | |
new_obj.data = super().to(*args, **kwargs) | |
return new_obj | |
# Convenience function to prepare a z and y vector | |
def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', | |
fp16=False,z_var=1.0): | |
z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) | |
z_.init_distribution('normal', mean=0, var=z_var) | |
z_ = z_.to(device,torch.float16 if fp16 else torch.float32) | |
if fp16: | |
z_ = z_.half() | |
y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) | |
y_.init_distribution('categorical',num_categories=nclasses) | |
y_ = y_.to(device, torch.int64) | |
return z_, y_ | |
def initiate_standing_stats(net): | |
for module in net.modules(): | |
if hasattr(module, 'accumulate_standing'): | |
module.reset_stats() | |
module.accumulate_standing = True | |
def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16): | |
initiate_standing_stats(net) | |
net.train() | |
for i in range(num_accumulations): | |
with torch.no_grad(): | |
z.normal_() | |
y.random_(0, nclasses) | |
x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn | |
# Set to eval mode | |
net.eval() | |
# This version of Adam keeps an fp32 copy of the parameters and | |
# does all of the parameter updates in fp32, while still doing the | |
# forwards and backwards passes using fp16 (i.e. fp16 copies of the | |
# parameters and fp16 activations). | |
# | |
# Note that this calls .float().cuda() on the params. | |
import math | |
from torch.optim.optimizer import Optimizer | |
class Adam16(Optimizer): | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,weight_decay=0): | |
defaults = dict(lr=lr, betas=betas, eps=eps, | |
weight_decay=weight_decay) | |
params = list(params) | |
super(Adam16, self).__init__(params, defaults) | |
# Safety modification to make sure we floatify our state | |
def load_state_dict(self, state_dict): | |
super(Adam16, self).load_state_dict(state_dict) | |
for group in self.param_groups: | |
for p in group['params']: | |
self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float() | |
self.state[p]['exp_avg_sq'] = self.state[p]['exp_avg_sq'].float() | |
self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float() | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data.float() | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
state['step'] = 0 | |
# Exponential moving average of gradient values | |
state['exp_avg'] = grad.new().resize_as_(grad).zero_() | |
# Exponential moving average of squared gradient values | |
state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() | |
# Fp32 copy of the weights | |
state['fp32_p'] = p.data.float() | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
state['step'] += 1 | |
if group['weight_decay'] != 0: | |
grad = grad.add(group['weight_decay'], state['fp32_p']) | |
# Decay the first and second moment running average coefficient | |
exp_avg.mul_(beta1).add_(1 - beta1, grad) | |
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
bias_correction1 = 1 - beta1 ** state['step'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | |
state['fp32_p'].addcdiv_(-step_size, exp_avg, denom) | |
p.data = state['fp32_p'].half() | |
return loss | |