Spaces:
Running
Running
# Copyright 2016 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Script for training, evaluation and sampling for Real NVP. | |
$ python real_nvp_multiscale_dataset.py \ | |
--alsologtostderr \ | |
--image_size 64 \ | |
--hpconfig=n_scale=5,base_dim=8 \ | |
--dataset imnet \ | |
--data_path [DATA_PATH] | |
""" | |
from __future__ import print_function | |
import time | |
from datetime import datetime | |
import os | |
import numpy | |
from six.moves import xrange | |
import tensorflow as tf | |
from tensorflow import gfile | |
from real_nvp_utils import ( | |
batch_norm, batch_norm_log_diff, conv_layer, | |
squeeze_2x2, squeeze_2x2_ordered, standard_normal_ll, | |
standard_normal_sample, unsqueeze_2x2, variable_on_cpu) | |
tf.flags.DEFINE_string("master", "local", | |
"BNS name of the TensorFlow master, or local.") | |
tf.flags.DEFINE_string("logdir", "/tmp/real_nvp_multiscale", | |
"Directory to which writes logs.") | |
tf.flags.DEFINE_string("traindir", "/tmp/real_nvp_multiscale", | |
"Directory to which writes logs.") | |
tf.flags.DEFINE_integer("train_steps", 1000000000000000000, | |
"Number of steps to train for.") | |
tf.flags.DEFINE_string("data_path", "", "Path to the data.") | |
tf.flags.DEFINE_string("mode", "train", | |
"Mode of execution. Must be 'train', " | |
"'sample' or 'eval'.") | |
tf.flags.DEFINE_string("dataset", "imnet", | |
"Dataset used. Must be 'imnet', " | |
"'celeba' or 'lsun'.") | |
tf.flags.DEFINE_integer("recursion_type", 2, | |
"Type of the recursion.") | |
tf.flags.DEFINE_integer("image_size", 64, | |
"Size of the input image.") | |
tf.flags.DEFINE_integer("eval_set_size", 0, | |
"Size of evaluation dataset.") | |
tf.flags.DEFINE_string( | |
"hpconfig", "", | |
"A comma separated list of hyperparameters for the model. Format is " | |
"hp1=value1,hp2=value2,etc. If this FLAG is set, the model will be trained " | |
"with the specified hyperparameters, filling in missing hyperparameters " | |
"from the default_values in |hyper_params|.") | |
FLAGS = tf.flags.FLAGS | |
class HParams(object): | |
"""Dictionary of hyperparameters.""" | |
def __init__(self, **kwargs): | |
self.dict_ = kwargs | |
self.__dict__.update(self.dict_) | |
def update_config(self, in_string): | |
"""Update the dictionary with a comma separated list.""" | |
pairs = in_string.split(",") | |
pairs = [pair.split("=") for pair in pairs] | |
for key, val in pairs: | |
self.dict_[key] = type(self.dict_[key])(val) | |
self.__dict__.update(self.dict_) | |
return self | |
def __getitem__(self, key): | |
return self.dict_[key] | |
def __setitem__(self, key, val): | |
self.dict_[key] = val | |
self.__dict__.update(self.dict_) | |
def get_default_hparams(): | |
"""Get the default hyperparameters.""" | |
return HParams( | |
batch_size=64, | |
residual_blocks=2, | |
n_couplings=2, | |
n_scale=4, | |
learning_rate=0.001, | |
momentum=1e-1, | |
decay=1e-3, | |
l2_coeff=0.00005, | |
clip_gradient=100., | |
optimizer="adam", | |
dropout_mask=0, | |
base_dim=32, | |
bottleneck=0, | |
use_batch_norm=1, | |
alternate=1, | |
use_aff=1, | |
skip=1, | |
data_constraint=.9, | |
n_opt=0) | |
# RESNET UTILS | |
def residual_block(input_, dim, name, use_batch_norm=True, | |
train=True, weight_norm=True, bottleneck=False): | |
"""Residual convolutional block.""" | |
with tf.variable_scope(name): | |
res = input_ | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_in", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
if bottleneck: | |
res = conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, | |
name="h_0", stddev=numpy.sqrt(2. / (dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=(not use_batch_norm), | |
weight_norm=weight_norm, scale=False) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, | |
name="bn_0", scale=False, train=train, | |
epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim, | |
dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, | |
bias=(not use_batch_norm), | |
weight_norm=weight_norm, scale=False) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_1", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, | |
name="out", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, | |
bias=True, weight_norm=weight_norm, scale=True) | |
else: | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim, | |
name="h_0", stddev=numpy.sqrt(2. / (dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=(not use_batch_norm), | |
weight_norm=weight_norm, scale=False) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_0", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim, | |
name="out", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", nonlinearity=None, | |
bias=True, weight_norm=weight_norm, scale=True) | |
res += input_ | |
return res | |
def resnet(input_, dim_in, dim, dim_out, name, use_batch_norm=True, | |
train=True, weight_norm=True, residual_blocks=5, | |
bottleneck=False, skip=True): | |
"""Residual convolutional network.""" | |
with tf.variable_scope(name): | |
res = input_ | |
if residual_blocks != 0: | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim, | |
name="h_in", stddev=numpy.sqrt(2. / (dim_in)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=True, | |
weight_norm=weight_norm, scale=False) | |
if skip: | |
out = conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, | |
name="skip_in", stddev=numpy.sqrt(2. / (dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=True, | |
weight_norm=weight_norm, scale=True) | |
# residual blocks | |
for idx_block in xrange(residual_blocks): | |
res = residual_block(res, dim, "block_%d" % idx_block, | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
bottleneck=bottleneck) | |
if skip: | |
out += conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim, | |
name="skip_%d" % idx_block, stddev=numpy.sqrt(2. / (dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=True, | |
weight_norm=weight_norm, scale=True) | |
# outputs | |
if skip: | |
res = out | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_pre_out", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim, | |
dim_out=dim_out, | |
name="out", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=True, | |
weight_norm=weight_norm, scale=True) | |
else: | |
if bottleneck: | |
res = conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim_in, dim_out=dim, | |
name="h_0", stddev=numpy.sqrt(2. / (dim_in)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=(not use_batch_norm), | |
weight_norm=weight_norm, scale=False) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_0", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim, | |
dim_out=dim, name="h_1", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, | |
bias=(not use_batch_norm), | |
weight_norm=weight_norm, scale=False) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_1", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[1, 1], dim_in=dim, dim_out=dim_out, | |
name="out", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=True, | |
weight_norm=weight_norm, scale=True) | |
else: | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim_in, dim_out=dim, | |
name="h_0", stddev=numpy.sqrt(2. / (dim_in)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=(not use_batch_norm), | |
weight_norm=weight_norm, scale=False) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=dim, name="bn_0", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.nn.relu(res) | |
res = conv_layer( | |
input_=res, filter_size=[3, 3], dim_in=dim, dim_out=dim_out, | |
name="out", stddev=numpy.sqrt(2. / (1. * dim)), | |
strides=[1, 1, 1, 1], padding="SAME", | |
nonlinearity=None, bias=True, | |
weight_norm=weight_norm, scale=True) | |
return res | |
# COUPLING LAYERS | |
# masked convolution implementations | |
def masked_conv_aff_coupling(input_, mask_in, dim, name, | |
use_batch_norm=True, train=True, weight_norm=True, | |
reverse=False, residual_blocks=5, | |
bottleneck=False, use_width=1., use_height=1., | |
mask_channel=0., skip=True): | |
"""Affine coupling with masked convolution.""" | |
with tf.variable_scope(name) as scope: | |
if reverse or (not train): | |
scope.reuse_variables() | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
# build mask | |
mask = use_width * numpy.arange(width) | |
mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask | |
mask = mask.astype("float32") | |
mask = tf.mod(mask_in + mask, 2) | |
mask = tf.reshape(mask, [-1, height, width, 1]) | |
if mask.get_shape().as_list()[0] == 1: | |
mask = tf.tile(mask, [batch_size, 1, 1, 1]) | |
res = input_ * tf.mod(mask_channel + mask, 2) | |
# initial input | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=channels, name="bn_in", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res *= 2. | |
res = tf.concat([res, -res], 3) | |
res = tf.concat([res, mask], 3) | |
dim_in = 2. * channels + 1 | |
res = tf.nn.relu(res) | |
res = resnet(input_=res, dim_in=dim_in, dim=dim, | |
dim_out=2 * channels, | |
name="resnet", use_batch_norm=use_batch_norm, | |
train=train, weight_norm=weight_norm, | |
residual_blocks=residual_blocks, | |
bottleneck=bottleneck, skip=skip) | |
mask = tf.mod(mask_channel + mask, 2) | |
res = tf.split(axis=3, num_or_size_splits=2, value=res) | |
shift, log_rescaling = res[-2], res[-1] | |
scale = variable_on_cpu( | |
"rescaling_scale", [], | |
tf.constant_initializer(0.)) | |
shift = tf.reshape( | |
shift, [batch_size, height, width, channels]) | |
log_rescaling = tf.reshape( | |
log_rescaling, [batch_size, height, width, channels]) | |
log_rescaling = scale * tf.tanh(log_rescaling) | |
if not use_batch_norm: | |
scale_shift = variable_on_cpu( | |
"scale_shift", [], | |
tf.constant_initializer(0.)) | |
log_rescaling += scale_shift | |
shift *= (1. - mask) | |
log_rescaling *= (1. - mask) | |
if reverse: | |
res = input_ | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res * (1. - mask), dim=channels, name="bn_out", | |
train=False, epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res *= tf.exp(.5 * log_var * (1. - mask)) | |
res += mean * (1. - mask) | |
res *= tf.exp(-log_rescaling) | |
res -= shift | |
log_diff = -log_rescaling | |
if use_batch_norm: | |
log_diff += .5 * log_var * (1. - mask) | |
else: | |
res = input_ | |
res += shift | |
res *= tf.exp(log_rescaling) | |
log_diff = log_rescaling | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res * (1. - mask), dim=channels, name="bn_out", | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res -= mean * (1. - mask) | |
res *= tf.exp(-.5 * log_var * (1. - mask)) | |
log_diff -= .5 * log_var * (1. - mask) | |
return res, log_diff | |
def masked_conv_add_coupling(input_, mask_in, dim, name, | |
use_batch_norm=True, train=True, weight_norm=True, | |
reverse=False, residual_blocks=5, | |
bottleneck=False, use_width=1., use_height=1., | |
mask_channel=0., skip=True): | |
"""Additive coupling with masked convolution.""" | |
with tf.variable_scope(name) as scope: | |
if reverse or (not train): | |
scope.reuse_variables() | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
# build mask | |
mask = use_width * numpy.arange(width) | |
mask = use_height * numpy.arange(height).reshape((-1, 1)) + mask | |
mask = mask.astype("float32") | |
mask = tf.mod(mask_in + mask, 2) | |
mask = tf.reshape(mask, [-1, height, width, 1]) | |
if mask.get_shape().as_list()[0] == 1: | |
mask = tf.tile(mask, [batch_size, 1, 1, 1]) | |
res = input_ * tf.mod(mask_channel + mask, 2) | |
# initial input | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=channels, name="bn_in", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res *= 2. | |
res = tf.concat([res, -res], 3) | |
res = tf.concat([res, mask], 3) | |
dim_in = 2. * channels + 1 | |
res = tf.nn.relu(res) | |
shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels, | |
name="resnet", use_batch_norm=use_batch_norm, | |
train=train, weight_norm=weight_norm, | |
residual_blocks=residual_blocks, | |
bottleneck=bottleneck, skip=skip) | |
mask = tf.mod(mask_channel + mask, 2) | |
shift *= (1. - mask) | |
# use_batch_norm = False | |
if reverse: | |
res = input_ | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res * (1. - mask), | |
dim=channels, name="bn_out", train=False, epsilon=1e-4) | |
log_var = tf.log(var) | |
res *= tf.exp(.5 * log_var * (1. - mask)) | |
res += mean * (1. - mask) | |
res -= shift | |
log_diff = tf.zeros_like(res) | |
if use_batch_norm: | |
log_diff += .5 * log_var * (1. - mask) | |
else: | |
res = input_ | |
res += shift | |
log_diff = tf.zeros_like(res) | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res * (1. - mask), dim=channels, | |
name="bn_out", train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res -= mean * (1. - mask) | |
res *= tf.exp(-.5 * log_var * (1. - mask)) | |
log_diff -= .5 * log_var * (1. - mask) | |
return res, log_diff | |
def masked_conv_coupling(input_, mask_in, dim, name, | |
use_batch_norm=True, train=True, weight_norm=True, | |
reverse=False, residual_blocks=5, | |
bottleneck=False, use_aff=True, | |
use_width=1., use_height=1., | |
mask_channel=0., skip=True): | |
"""Coupling with masked convolution.""" | |
if use_aff: | |
return masked_conv_aff_coupling( | |
input_=input_, mask_in=mask_in, dim=dim, name=name, | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=reverse, residual_blocks=residual_blocks, | |
bottleneck=bottleneck, use_width=use_width, use_height=use_height, | |
mask_channel=mask_channel, skip=skip) | |
else: | |
return masked_conv_add_coupling( | |
input_=input_, mask_in=mask_in, dim=dim, name=name, | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=reverse, residual_blocks=residual_blocks, | |
bottleneck=bottleneck, use_width=use_width, use_height=use_height, | |
mask_channel=mask_channel, skip=skip) | |
# channel-axis splitting implementations | |
def conv_ch_aff_coupling(input_, dim, name, | |
use_batch_norm=True, train=True, weight_norm=True, | |
reverse=False, residual_blocks=5, | |
bottleneck=False, change_bottom=True, skip=True): | |
"""Affine coupling with channel-wise splitting.""" | |
with tf.variable_scope(name) as scope: | |
if reverse or (not train): | |
scope.reuse_variables() | |
if change_bottom: | |
input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_) | |
else: | |
canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_) | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
res = input_ | |
# initial input | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=channels, name="bn_in", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.concat([res, -res], 3) | |
dim_in = 2. * channels | |
res = tf.nn.relu(res) | |
res = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=2 * channels, | |
name="resnet", use_batch_norm=use_batch_norm, | |
train=train, weight_norm=weight_norm, | |
residual_blocks=residual_blocks, | |
bottleneck=bottleneck, skip=skip) | |
shift, log_rescaling = tf.split(axis=3, num_or_size_splits=2, value=res) | |
scale = variable_on_cpu( | |
"scale", [], | |
tf.constant_initializer(1.)) | |
shift = tf.reshape( | |
shift, [batch_size, height, width, channels]) | |
log_rescaling = tf.reshape( | |
log_rescaling, [batch_size, height, width, channels]) | |
log_rescaling = scale * tf.tanh(log_rescaling) | |
if not use_batch_norm: | |
scale_shift = variable_on_cpu( | |
"scale_shift", [], | |
tf.constant_initializer(0.)) | |
log_rescaling += scale_shift | |
if reverse: | |
res = canvas | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res, dim=channels, name="bn_out", train=False, | |
epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res *= tf.exp(.5 * log_var) | |
res += mean | |
res *= tf.exp(-log_rescaling) | |
res -= shift | |
log_diff = -log_rescaling | |
if use_batch_norm: | |
log_diff += .5 * log_var | |
else: | |
res = canvas | |
res += shift | |
res *= tf.exp(log_rescaling) | |
log_diff = log_rescaling | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res, dim=channels, name="bn_out", train=train, | |
epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res -= mean | |
res *= tf.exp(-.5 * log_var) | |
log_diff -= .5 * log_var | |
if change_bottom: | |
res = tf.concat([input_, res], 3) | |
log_diff = tf.concat([tf.zeros_like(log_diff), log_diff], 3) | |
else: | |
res = tf.concat([res, input_], 3) | |
log_diff = tf.concat([log_diff, tf.zeros_like(log_diff)], 3) | |
return res, log_diff | |
def conv_ch_add_coupling(input_, dim, name, | |
use_batch_norm=True, train=True, weight_norm=True, | |
reverse=False, residual_blocks=5, | |
bottleneck=False, change_bottom=True, skip=True): | |
"""Additive coupling with channel-wise splitting.""" | |
with tf.variable_scope(name) as scope: | |
if reverse or (not train): | |
scope.reuse_variables() | |
if change_bottom: | |
input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_) | |
else: | |
canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_) | |
shape = input_.get_shape().as_list() | |
channels = shape[3] | |
res = input_ | |
# initial input | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, dim=channels, name="bn_in", scale=False, | |
train=train, epsilon=1e-4, axes=[0, 1, 2]) | |
res = tf.concat([res, -res], 3) | |
dim_in = 2. * channels | |
res = tf.nn.relu(res) | |
shift = resnet(input_=res, dim_in=dim_in, dim=dim, dim_out=channels, | |
name="resnet", use_batch_norm=use_batch_norm, | |
train=train, weight_norm=weight_norm, | |
residual_blocks=residual_blocks, | |
bottleneck=bottleneck, skip=skip) | |
if reverse: | |
res = canvas | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res, dim=channels, name="bn_out", train=False, | |
epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res *= tf.exp(.5 * log_var) | |
res += mean | |
res -= shift | |
log_diff = tf.zeros_like(res) | |
if use_batch_norm: | |
log_diff += .5 * log_var | |
else: | |
res = canvas | |
res += shift | |
log_diff = tf.zeros_like(res) | |
if use_batch_norm: | |
mean, var = batch_norm_log_diff( | |
input_=res, dim=channels, name="bn_out", train=train, | |
epsilon=1e-4, axes=[0, 1, 2]) | |
log_var = tf.log(var) | |
res -= mean | |
res *= tf.exp(-.5 * log_var) | |
log_diff -= .5 * log_var | |
if change_bottom: | |
res = tf.concat([input_, res], 3) | |
log_diff = tf.concat([tf.zeros_like(log_diff), log_diff], 3) | |
else: | |
res = tf.concat([res, input_], 3) | |
log_diff = tf.concat([log_diff, tf.zeros_like(log_diff)], 3) | |
return res, log_diff | |
def conv_ch_coupling(input_, dim, name, | |
use_batch_norm=True, train=True, weight_norm=True, | |
reverse=False, residual_blocks=5, | |
bottleneck=False, use_aff=True, change_bottom=True, | |
skip=True): | |
"""Coupling with channel-wise splitting.""" | |
if use_aff: | |
return conv_ch_aff_coupling( | |
input_=input_, dim=dim, name=name, | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=reverse, residual_blocks=residual_blocks, | |
bottleneck=bottleneck, change_bottom=change_bottom, skip=skip) | |
else: | |
return conv_ch_add_coupling( | |
input_=input_, dim=dim, name=name, | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=reverse, residual_blocks=residual_blocks, | |
bottleneck=bottleneck, change_bottom=change_bottom, skip=skip) | |
# RECURSIVE USE OF COUPLING LAYERS | |
def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale, | |
use_batch_norm=True, weight_norm=True, | |
train=True): | |
"""Recursion on coupling layers.""" | |
shape = input_.get_shape().as_list() | |
channels = shape[3] | |
residual_blocks = hps.residual_blocks | |
base_dim = hps.base_dim | |
mask = 1. | |
use_aff = hps.use_aff | |
res = input_ | |
skip = hps.skip | |
log_diff = tf.zeros_like(input_) | |
dim = base_dim | |
if FLAGS.recursion_type < 4: | |
dim *= 2 ** scale_idx | |
with tf.variable_scope("scale_%d" % scale_idx): | |
# initial coupling layers | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=mask, dim=dim, | |
name="coupling_0", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=1. - mask, dim=dim, | |
name="coupling_1", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=mask, dim=dim, | |
name="coupling_2", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=True, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
if scale_idx < (n_scale - 1): | |
with tf.variable_scope("scale_%d" % scale_idx): | |
res = squeeze_2x2(res) | |
log_diff = squeeze_2x2(log_diff) | |
res, inc_log_diff = conv_ch_coupling( | |
input_=res, | |
change_bottom=True, dim=2 * dim, | |
name="coupling_4", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = conv_ch_coupling( | |
input_=res, | |
change_bottom=False, dim=2 * dim, | |
name="coupling_5", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = conv_ch_coupling( | |
input_=res, | |
change_bottom=True, dim=2 * dim, | |
name="coupling_6", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=True, skip=skip) | |
log_diff += inc_log_diff | |
res = unsqueeze_2x2(res) | |
log_diff = unsqueeze_2x2(log_diff) | |
if FLAGS.recursion_type > 1: | |
res = squeeze_2x2_ordered(res) | |
log_diff = squeeze_2x2_ordered(log_diff) | |
if FLAGS.recursion_type > 2: | |
res_1 = res[:, :, :, :channels] | |
res_2 = res[:, :, :, channels:] | |
log_diff_1 = log_diff[:, :, :, :channels] | |
log_diff_2 = log_diff[:, :, :, channels:] | |
else: | |
res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res) | |
log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff) | |
res_1, inc_log_diff = rec_masked_conv_coupling( | |
input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, | |
use_batch_norm=use_batch_norm, weight_norm=weight_norm, | |
train=train) | |
res = tf.concat([res_1, res_2], 3) | |
log_diff_1 += inc_log_diff | |
log_diff = tf.concat([log_diff_1, log_diff_2], 3) | |
res = squeeze_2x2_ordered(res, reverse=True) | |
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) | |
else: | |
res = squeeze_2x2_ordered(res) | |
log_diff = squeeze_2x2_ordered(log_diff) | |
res, inc_log_diff = rec_masked_conv_coupling( | |
input_=res, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, | |
use_batch_norm=use_batch_norm, weight_norm=weight_norm, | |
train=train) | |
log_diff += inc_log_diff | |
res = squeeze_2x2_ordered(res, reverse=True) | |
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) | |
else: | |
with tf.variable_scope("scale_%d" % scale_idx): | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=1. - mask, dim=dim, | |
name="coupling_3", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=False, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=True, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
return res, log_diff | |
def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale, | |
use_batch_norm=True, weight_norm=True, | |
train=True): | |
"""Recursion on inverting coupling layers.""" | |
shape = input_.get_shape().as_list() | |
channels = shape[3] | |
residual_blocks = hps.residual_blocks | |
base_dim = hps.base_dim | |
mask = 1. | |
use_aff = hps.use_aff | |
res = input_ | |
log_diff = tf.zeros_like(input_) | |
skip = hps.skip | |
dim = base_dim | |
if FLAGS.recursion_type < 4: | |
dim *= 2 ** scale_idx | |
if scale_idx < (n_scale - 1): | |
if FLAGS.recursion_type > 1: | |
res = squeeze_2x2_ordered(res) | |
log_diff = squeeze_2x2_ordered(log_diff) | |
if FLAGS.recursion_type > 2: | |
res_1 = res[:, :, :, :channels] | |
res_2 = res[:, :, :, channels:] | |
log_diff_1 = log_diff[:, :, :, :channels] | |
log_diff_2 = log_diff[:, :, :, channels:] | |
else: | |
res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res) | |
log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff) | |
res_1, log_diff_1 = rec_masked_deconv_coupling( | |
input_=res_1, hps=hps, | |
scale_idx=scale_idx + 1, n_scale=n_scale, | |
use_batch_norm=use_batch_norm, weight_norm=weight_norm, | |
train=train) | |
res = tf.concat([res_1, res_2], 3) | |
log_diff = tf.concat([log_diff_1, log_diff_2], 3) | |
res = squeeze_2x2_ordered(res, reverse=True) | |
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) | |
else: | |
res = squeeze_2x2_ordered(res) | |
log_diff = squeeze_2x2_ordered(log_diff) | |
res, log_diff = rec_masked_deconv_coupling( | |
input_=res, hps=hps, | |
scale_idx=scale_idx + 1, n_scale=n_scale, | |
use_batch_norm=use_batch_norm, weight_norm=weight_norm, | |
train=train) | |
res = squeeze_2x2_ordered(res, reverse=True) | |
log_diff = squeeze_2x2_ordered(log_diff, reverse=True) | |
with tf.variable_scope("scale_%d" % scale_idx): | |
res = squeeze_2x2(res) | |
log_diff = squeeze_2x2(log_diff) | |
res, inc_log_diff = conv_ch_coupling( | |
input_=res, | |
change_bottom=True, dim=2 * dim, | |
name="coupling_6", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=True, skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = conv_ch_coupling( | |
input_=res, | |
change_bottom=False, dim=2 * dim, | |
name="coupling_5", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = conv_ch_coupling( | |
input_=res, | |
change_bottom=True, dim=2 * dim, | |
name="coupling_4", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, skip=skip) | |
log_diff += inc_log_diff | |
res = unsqueeze_2x2(res) | |
log_diff = unsqueeze_2x2(log_diff) | |
else: | |
with tf.variable_scope("scale_%d" % scale_idx): | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=1. - mask, dim=dim, | |
name="coupling_3", | |
use_batch_norm=use_batch_norm, train=train, | |
weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=True, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
with tf.variable_scope("scale_%d" % scale_idx): | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=mask, dim=dim, | |
name="coupling_2", | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=True, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=1. - mask, dim=dim, | |
name="coupling_1", | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
res, inc_log_diff = masked_conv_coupling( | |
input_=res, | |
mask_in=mask, dim=dim, | |
name="coupling_0", | |
use_batch_norm=use_batch_norm, train=train, weight_norm=weight_norm, | |
reverse=True, residual_blocks=residual_blocks, | |
bottleneck=hps.bottleneck, use_aff=use_aff, | |
use_width=1., use_height=1., skip=skip) | |
log_diff += inc_log_diff | |
return res, log_diff | |
# ENCODER AND DECODER IMPLEMENTATIONS | |
# start the recursions | |
def encoder(input_, hps, n_scale, use_batch_norm=True, | |
weight_norm=True, train=True): | |
"""Encoding/gaussianization function.""" | |
res = input_ | |
log_diff = tf.zeros_like(input_) | |
res, inc_log_diff = rec_masked_conv_coupling( | |
input_=res, hps=hps, scale_idx=0, n_scale=n_scale, | |
use_batch_norm=use_batch_norm, weight_norm=weight_norm, | |
train=train) | |
log_diff += inc_log_diff | |
return res, log_diff | |
def decoder(input_, hps, n_scale, use_batch_norm=True, | |
weight_norm=True, train=True): | |
"""Decoding/generator function.""" | |
res, log_diff = rec_masked_deconv_coupling( | |
input_=input_, hps=hps, scale_idx=0, n_scale=n_scale, | |
use_batch_norm=use_batch_norm, weight_norm=weight_norm, | |
train=train) | |
return res, log_diff | |
class RealNVP(object): | |
"""Real NVP model.""" | |
def __init__(self, hps, sampling=False): | |
# DATA TENSOR INSTANTIATION | |
device = "/cpu:0" | |
if FLAGS.dataset == "imnet": | |
with tf.device( | |
tf.train.replica_device_setter(0, worker_device=device)): | |
filename_queue = tf.train.string_input_producer( | |
gfile.Glob(FLAGS.data_path), num_epochs=None) | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example( | |
serialized_example, | |
features={ | |
"image_raw": tf.FixedLenFeature([], tf.string), | |
}) | |
image = tf.decode_raw(features["image_raw"], tf.uint8) | |
image.set_shape([FLAGS.image_size * FLAGS.image_size * 3]) | |
image = tf.cast(image, tf.float32) | |
if FLAGS.mode == "train": | |
images = tf.train.shuffle_batch( | |
[image], batch_size=hps.batch_size, num_threads=1, | |
capacity=1000 + 3 * hps.batch_size, | |
# Ensures a minimum amount of shuffling of examples. | |
min_after_dequeue=1000) | |
else: | |
images = tf.train.batch( | |
[image], batch_size=hps.batch_size, num_threads=1, | |
capacity=1000 + 3 * hps.batch_size) | |
self.x_orig = x_orig = images | |
image_size = FLAGS.image_size | |
x_in = tf.reshape( | |
x_orig, | |
[hps.batch_size, FLAGS.image_size, FLAGS.image_size, 3]) | |
x_in = tf.clip_by_value(x_in, 0, 255) | |
x_in = (tf.cast(x_in, tf.float32) | |
+ tf.random_uniform(tf.shape(x_in))) / 256. | |
elif FLAGS.dataset == "celeba": | |
with tf.device( | |
tf.train.replica_device_setter(0, worker_device=device)): | |
filename_queue = tf.train.string_input_producer( | |
gfile.Glob(FLAGS.data_path), num_epochs=None) | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example( | |
serialized_example, | |
features={ | |
"image_raw": tf.FixedLenFeature([], tf.string), | |
}) | |
image = tf.decode_raw(features["image_raw"], tf.uint8) | |
image.set_shape([218 * 178 * 3]) # 218, 178 | |
image = tf.cast(image, tf.float32) | |
image = tf.reshape(image, [218, 178, 3]) | |
image = image[40:188, 15:163, :] | |
if FLAGS.mode == "train": | |
image = tf.image.random_flip_left_right(image) | |
images = tf.train.shuffle_batch( | |
[image], batch_size=hps.batch_size, num_threads=1, | |
capacity=1000 + 3 * hps.batch_size, | |
min_after_dequeue=1000) | |
else: | |
images = tf.train.batch( | |
[image], batch_size=hps.batch_size, num_threads=1, | |
capacity=1000 + 3 * hps.batch_size) | |
self.x_orig = x_orig = images | |
image_size = 64 | |
x_in = tf.reshape(x_orig, [hps.batch_size, 148, 148, 3]) | |
x_in = tf.image.resize_images( | |
x_in, [64, 64], method=0, align_corners=False) | |
x_in = (tf.cast(x_in, tf.float32) | |
+ tf.random_uniform(tf.shape(x_in))) / 256. | |
elif FLAGS.dataset == "lsun": | |
with tf.device( | |
tf.train.replica_device_setter(0, worker_device=device)): | |
filename_queue = tf.train.string_input_producer( | |
gfile.Glob(FLAGS.data_path), num_epochs=None) | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example( | |
serialized_example, | |
features={ | |
"image_raw": tf.FixedLenFeature([], tf.string), | |
"height": tf.FixedLenFeature([], tf.int64), | |
"width": tf.FixedLenFeature([], tf.int64), | |
"depth": tf.FixedLenFeature([], tf.int64) | |
}) | |
image = tf.decode_raw(features["image_raw"], tf.uint8) | |
height = tf.reshape((features["height"], tf.int64)[0], [1]) | |
height = tf.cast(height, tf.int32) | |
width = tf.reshape((features["width"], tf.int64)[0], [1]) | |
width = tf.cast(width, tf.int32) | |
depth = tf.reshape((features["depth"], tf.int64)[0], [1]) | |
depth = tf.cast(depth, tf.int32) | |
image = tf.reshape(image, tf.concat([height, width, depth], 0)) | |
image = tf.random_crop(image, [64, 64, 3]) | |
if FLAGS.mode == "train": | |
image = tf.image.random_flip_left_right(image) | |
images = tf.train.shuffle_batch( | |
[image], batch_size=hps.batch_size, num_threads=1, | |
capacity=1000 + 3 * hps.batch_size, | |
# Ensures a minimum amount of shuffling of examples. | |
min_after_dequeue=1000) | |
else: | |
images = tf.train.batch( | |
[image], batch_size=hps.batch_size, num_threads=1, | |
capacity=1000 + 3 * hps.batch_size) | |
self.x_orig = x_orig = images | |
image_size = 64 | |
x_in = tf.reshape(x_orig, [hps.batch_size, 64, 64, 3]) | |
x_in = (tf.cast(x_in, tf.float32) | |
+ tf.random_uniform(tf.shape(x_in))) / 256. | |
else: | |
raise ValueError("Unknown dataset.") | |
x_in = tf.reshape(x_in, [hps.batch_size, image_size, image_size, 3]) | |
side_shown = int(numpy.sqrt(hps.batch_size)) | |
shown_x = tf.transpose( | |
tf.reshape( | |
x_in[:(side_shown * side_shown), :, :, :], | |
[side_shown, image_size * side_shown, image_size, 3]), | |
[0, 2, 1, 3]) | |
shown_x = tf.transpose( | |
tf.reshape( | |
shown_x, | |
[1, image_size * side_shown, image_size * side_shown, 3]), | |
[0, 2, 1, 3]) * 255. | |
tf.summary.image( | |
"inputs", | |
tf.cast(shown_x, tf.uint8), | |
max_outputs=1) | |
# restrict the data | |
FLAGS.image_size = image_size | |
data_constraint = hps.data_constraint | |
pre_logit_scale = numpy.log(data_constraint) | |
pre_logit_scale -= numpy.log(1. - data_constraint) | |
pre_logit_scale = tf.cast(pre_logit_scale, tf.float32) | |
logit_x_in = 2. * x_in # [0, 2] | |
logit_x_in -= 1. # [-1, 1] | |
logit_x_in *= data_constraint # [-.9, .9] | |
logit_x_in += 1. # [.1, 1.9] | |
logit_x_in /= 2. # [.05, .95] | |
# logit the data | |
logit_x_in = tf.log(logit_x_in) - tf.log(1. - logit_x_in) | |
transform_cost = tf.reduce_sum( | |
tf.nn.softplus(logit_x_in) + tf.nn.softplus(-logit_x_in) | |
- tf.nn.softplus(-pre_logit_scale), | |
[1, 2, 3]) | |
# INFERENCE AND COSTS | |
z_out, log_diff = encoder( | |
input_=logit_x_in, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=True) | |
if FLAGS.mode != "train": | |
z_out, log_diff = encoder( | |
input_=logit_x_in, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
final_shape = [image_size, image_size, 3] | |
prior_ll = standard_normal_ll(z_out) | |
prior_ll = tf.reduce_sum(prior_ll, [1, 2, 3]) | |
log_diff = tf.reduce_sum(log_diff, [1, 2, 3]) | |
log_diff += transform_cost | |
cost = -(prior_ll + log_diff) | |
self.x_in = x_in | |
self.z_out = z_out | |
self.cost = cost = tf.reduce_mean(cost) | |
l2_reg = sum( | |
[tf.reduce_sum(tf.square(v)) for v in tf.trainable_variables() | |
if ("magnitude" in v.name) or ("rescaling_scale" in v.name)]) | |
bit_per_dim = ((cost + numpy.log(256.) * image_size * image_size * 3.) | |
/ (image_size * image_size * 3. * numpy.log(2.))) | |
self.bit_per_dim = bit_per_dim | |
# OPTIMIZATION | |
momentum = 1. - hps.momentum | |
decay = 1. - hps.decay | |
if hps.optimizer == "adam": | |
optimizer = tf.train.AdamOptimizer( | |
learning_rate=hps.learning_rate, | |
beta1=momentum, beta2=decay, epsilon=1e-08, | |
use_locking=False, name="Adam") | |
elif hps.optimizer == "rmsprop": | |
optimizer = tf.train.RMSPropOptimizer( | |
learning_rate=hps.learning_rate, decay=decay, | |
momentum=momentum, epsilon=1e-04, | |
use_locking=False, name="RMSProp") | |
else: | |
optimizer = tf.train.MomentumOptimizer(hps.learning_rate, | |
momentum=momentum) | |
step = tf.get_variable( | |
"global_step", [], tf.int64, | |
tf.zeros_initializer(), | |
trainable=False) | |
self.step = step | |
grads_and_vars = optimizer.compute_gradients( | |
cost + hps.l2_coeff * l2_reg, | |
tf.trainable_variables()) | |
grads, vars_ = zip(*grads_and_vars) | |
capped_grads, gradient_norm = tf.clip_by_global_norm( | |
grads, clip_norm=hps.clip_gradient) | |
gradient_norm = tf.check_numerics(gradient_norm, | |
"Gradient norm is NaN or Inf.") | |
l2_z = tf.reduce_sum(tf.square(z_out), [1, 2, 3]) | |
if not sampling: | |
tf.summary.scalar("negative_log_likelihood", tf.reshape(cost, [])) | |
tf.summary.scalar("gradient_norm", tf.reshape(gradient_norm, [])) | |
tf.summary.scalar("bit_per_dim", tf.reshape(bit_per_dim, [])) | |
tf.summary.scalar("log_diff", tf.reshape(tf.reduce_mean(log_diff), [])) | |
tf.summary.scalar("prior_ll", tf.reshape(tf.reduce_mean(prior_ll), [])) | |
tf.summary.scalar( | |
"log_diff_var", | |
tf.reshape(tf.reduce_mean(tf.square(log_diff)) | |
- tf.square(tf.reduce_mean(log_diff)), [])) | |
tf.summary.scalar( | |
"prior_ll_var", | |
tf.reshape(tf.reduce_mean(tf.square(prior_ll)) | |
- tf.square(tf.reduce_mean(prior_ll)), [])) | |
tf.summary.scalar("l2_z_mean", tf.reshape(tf.reduce_mean(l2_z), [])) | |
tf.summary.scalar( | |
"l2_z_var", | |
tf.reshape(tf.reduce_mean(tf.square(l2_z)) | |
- tf.square(tf.reduce_mean(l2_z)), [])) | |
capped_grads_and_vars = zip(capped_grads, vars_) | |
self.train_step = optimizer.apply_gradients( | |
capped_grads_and_vars, global_step=step) | |
# SAMPLING AND VISUALIZATION | |
if sampling: | |
# SAMPLES | |
sample = standard_normal_sample([100] + final_shape) | |
sample, _ = decoder( | |
input_=sample, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=True) | |
sample = tf.nn.sigmoid(sample) | |
sample = tf.clip_by_value(sample, 0, 1) * 255. | |
sample = tf.reshape(sample, [100, image_size, image_size, 3]) | |
sample = tf.transpose( | |
tf.reshape(sample, [10, image_size * 10, image_size, 3]), | |
[0, 2, 1, 3]) | |
sample = tf.transpose( | |
tf.reshape(sample, [1, image_size * 10, image_size * 10, 3]), | |
[0, 2, 1, 3]) | |
tf.summary.image( | |
"samples", | |
tf.cast(sample, tf.uint8), | |
max_outputs=1) | |
# CONCATENATION | |
concatenation, _ = encoder( | |
input_=logit_x_in, hps=hps, | |
n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
concatenation = tf.reshape( | |
concatenation, | |
[(side_shown * side_shown), image_size, image_size, 3]) | |
concatenation = tf.transpose( | |
tf.reshape( | |
concatenation, | |
[side_shown, image_size * side_shown, image_size, 3]), | |
[0, 2, 1, 3]) | |
concatenation = tf.transpose( | |
tf.reshape( | |
concatenation, | |
[1, image_size * side_shown, image_size * side_shown, 3]), | |
[0, 2, 1, 3]) | |
concatenation, _ = decoder( | |
input_=concatenation, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
concatenation = tf.nn.sigmoid(concatenation) * 255. | |
tf.summary.image( | |
"concatenation", | |
tf.cast(concatenation, tf.uint8), | |
max_outputs=1) | |
# MANIFOLD | |
# Data basis | |
z_u, _ = encoder( | |
input_=logit_x_in[:8, :, :, :], hps=hps, | |
n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
u_1 = tf.reshape(z_u[0, :, :, :], [-1]) | |
u_2 = tf.reshape(z_u[1, :, :, :], [-1]) | |
u_3 = tf.reshape(z_u[2, :, :, :], [-1]) | |
u_4 = tf.reshape(z_u[3, :, :, :], [-1]) | |
u_5 = tf.reshape(z_u[4, :, :, :], [-1]) | |
u_6 = tf.reshape(z_u[5, :, :, :], [-1]) | |
u_7 = tf.reshape(z_u[6, :, :, :], [-1]) | |
u_8 = tf.reshape(z_u[7, :, :, :], [-1]) | |
# 3D dome | |
manifold_side = 8 | |
angle_1 = numpy.arange(manifold_side) * 1. / manifold_side | |
angle_2 = numpy.arange(manifold_side) * 1. / manifold_side | |
angle_1 *= 2. * numpy.pi | |
angle_2 *= 2. * numpy.pi | |
angle_1 = angle_1.astype("float32") | |
angle_2 = angle_2.astype("float32") | |
angle_1 = tf.reshape(angle_1, [1, -1, 1]) | |
angle_1 += tf.zeros([manifold_side, manifold_side, 1]) | |
angle_2 = tf.reshape(angle_2, [-1, 1, 1]) | |
angle_2 += tf.zeros([manifold_side, manifold_side, 1]) | |
n_angle_3 = 40 | |
angle_3 = numpy.arange(n_angle_3) * 1. / n_angle_3 | |
angle_3 *= 2 * numpy.pi | |
angle_3 = angle_3.astype("float32") | |
angle_3 = tf.reshape(angle_3, [-1, 1, 1, 1]) | |
angle_3 += tf.zeros([n_angle_3, manifold_side, manifold_side, 1]) | |
manifold = tf.cos(angle_1) * ( | |
tf.cos(angle_2) * ( | |
tf.cos(angle_3) * u_1 + tf.sin(angle_3) * u_2) | |
+ tf.sin(angle_2) * ( | |
tf.cos(angle_3) * u_3 + tf.sin(angle_3) * u_4)) | |
manifold += tf.sin(angle_1) * ( | |
tf.cos(angle_2) * ( | |
tf.cos(angle_3) * u_5 + tf.sin(angle_3) * u_6) | |
+ tf.sin(angle_2) * ( | |
tf.cos(angle_3) * u_7 + tf.sin(angle_3) * u_8)) | |
manifold = tf.reshape( | |
manifold, | |
[n_angle_3 * manifold_side * manifold_side] + final_shape) | |
manifold, _ = decoder( | |
input_=manifold, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
manifold = tf.nn.sigmoid(manifold) | |
manifold = tf.clip_by_value(manifold, 0, 1) * 255. | |
manifold = tf.reshape( | |
manifold, | |
[n_angle_3, | |
manifold_side * manifold_side, | |
image_size, | |
image_size, | |
3]) | |
manifold = tf.transpose( | |
tf.reshape( | |
manifold, | |
[n_angle_3, manifold_side, | |
image_size * manifold_side, image_size, 3]), [0, 1, 3, 2, 4]) | |
manifold = tf.transpose( | |
tf.reshape( | |
manifold, | |
[n_angle_3, image_size * manifold_side, | |
image_size * manifold_side, 3]), | |
[0, 2, 1, 3]) | |
manifold = tf.transpose(manifold, [1, 2, 0, 3]) | |
manifold = tf.reshape( | |
manifold, | |
[1, image_size * manifold_side, | |
image_size * manifold_side, 3 * n_angle_3]) | |
tf.summary.image( | |
"manifold", | |
tf.cast(manifold[:, :, :, :3], tf.uint8), | |
max_outputs=1) | |
# COMPRESSION | |
z_complete, _ = encoder( | |
input_=logit_x_in[:hps.n_scale, :, :, :], hps=hps, | |
n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
z_compressed_list = [z_complete] | |
z_noisy_list = [z_complete] | |
z_lost = z_complete | |
for scale_idx in xrange(hps.n_scale - 1): | |
z_lost = squeeze_2x2_ordered(z_lost) | |
z_lost, _ = tf.split(axis=3, num_or_size_splits=2, value=z_lost) | |
z_compressed = z_lost | |
z_noisy = z_lost | |
for _ in xrange(scale_idx + 1): | |
z_compressed = tf.concat( | |
[z_compressed, tf.zeros_like(z_compressed)], 3) | |
z_compressed = squeeze_2x2_ordered( | |
z_compressed, reverse=True) | |
z_noisy = tf.concat( | |
[z_noisy, tf.random_normal( | |
z_noisy.get_shape().as_list())], 3) | |
z_noisy = squeeze_2x2_ordered(z_noisy, reverse=True) | |
z_compressed_list.append(z_compressed) | |
z_noisy_list.append(z_noisy) | |
self.z_reduced = z_lost | |
z_compressed = tf.concat(z_compressed_list, 0) | |
z_noisy = tf.concat(z_noisy_list, 0) | |
noisy_images, _ = decoder( | |
input_=z_noisy, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
compressed_images, _ = decoder( | |
input_=z_compressed, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=False) | |
noisy_images = tf.nn.sigmoid(noisy_images) | |
compressed_images = tf.nn.sigmoid(compressed_images) | |
noisy_images = tf.clip_by_value(noisy_images, 0, 1) * 255. | |
noisy_images = tf.reshape( | |
noisy_images, | |
[(hps.n_scale * hps.n_scale), image_size, image_size, 3]) | |
noisy_images = tf.transpose( | |
tf.reshape( | |
noisy_images, | |
[hps.n_scale, image_size * hps.n_scale, image_size, 3]), | |
[0, 2, 1, 3]) | |
noisy_images = tf.transpose( | |
tf.reshape( | |
noisy_images, | |
[1, image_size * hps.n_scale, image_size * hps.n_scale, 3]), | |
[0, 2, 1, 3]) | |
tf.summary.image( | |
"noise", | |
tf.cast(noisy_images, tf.uint8), | |
max_outputs=1) | |
compressed_images = tf.clip_by_value(compressed_images, 0, 1) * 255. | |
compressed_images = tf.reshape( | |
compressed_images, | |
[(hps.n_scale * hps.n_scale), image_size, image_size, 3]) | |
compressed_images = tf.transpose( | |
tf.reshape( | |
compressed_images, | |
[hps.n_scale, image_size * hps.n_scale, image_size, 3]), | |
[0, 2, 1, 3]) | |
compressed_images = tf.transpose( | |
tf.reshape( | |
compressed_images, | |
[1, image_size * hps.n_scale, image_size * hps.n_scale, 3]), | |
[0, 2, 1, 3]) | |
tf.summary.image( | |
"compression", | |
tf.cast(compressed_images, tf.uint8), | |
max_outputs=1) | |
# SAMPLES x2 | |
final_shape[0] *= 2 | |
final_shape[1] *= 2 | |
big_sample = standard_normal_sample([25] + final_shape) | |
big_sample, _ = decoder( | |
input_=big_sample, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=True) | |
big_sample = tf.nn.sigmoid(big_sample) | |
big_sample = tf.clip_by_value(big_sample, 0, 1) * 255. | |
big_sample = tf.reshape( | |
big_sample, | |
[25, image_size * 2, image_size * 2, 3]) | |
big_sample = tf.transpose( | |
tf.reshape( | |
big_sample, | |
[5, image_size * 10, image_size * 2, 3]), [0, 2, 1, 3]) | |
big_sample = tf.transpose( | |
tf.reshape( | |
big_sample, | |
[1, image_size * 10, image_size * 10, 3]), | |
[0, 2, 1, 3]) | |
tf.summary.image( | |
"big_sample", | |
tf.cast(big_sample, tf.uint8), | |
max_outputs=1) | |
# SAMPLES x10 | |
final_shape[0] *= 5 | |
final_shape[1] *= 5 | |
extra_large = standard_normal_sample([1] + final_shape) | |
extra_large, _ = decoder( | |
input_=extra_large, hps=hps, n_scale=hps.n_scale, | |
use_batch_norm=hps.use_batch_norm, weight_norm=True, | |
train=True) | |
extra_large = tf.nn.sigmoid(extra_large) | |
extra_large = tf.clip_by_value(extra_large, 0, 1) * 255. | |
tf.summary.image( | |
"extra_large", | |
tf.cast(extra_large, tf.uint8), | |
max_outputs=1) | |
def eval_epoch(self, hps): | |
"""Evaluate bits/dim.""" | |
n_eval_dict = { | |
"imnet": 50000, | |
"lsun": 300, | |
"celeba": 19962, | |
"svhn": 26032, | |
} | |
if FLAGS.eval_set_size == 0: | |
num_examples_eval = n_eval_dict[FLAGS.dataset] | |
else: | |
num_examples_eval = FLAGS.eval_set_size | |
n_epoch = num_examples_eval / hps.batch_size | |
eval_costs = [] | |
bar_len = 70 | |
for epoch_idx in xrange(n_epoch): | |
n_equal = epoch_idx * bar_len * 1. / n_epoch | |
n_equal = numpy.ceil(n_equal) | |
n_equal = int(n_equal) | |
n_dash = bar_len - n_equal | |
progress_bar = "[" + "=" * n_equal + "-" * n_dash + "]\r" | |
print(progress_bar, end=' ') | |
cost = self.bit_per_dim.eval() | |
eval_costs.append(cost) | |
print("") | |
return float(numpy.mean(eval_costs)) | |
def train_model(hps, logdir): | |
"""Training.""" | |
with tf.Graph().as_default(): | |
with tf.device(tf.train.replica_device_setter(0)): | |
with tf.variable_scope("model"): | |
model = RealNVP(hps) | |
saver = tf.train.Saver(tf.global_variables()) | |
# Build the summary operation from the last tower summaries. | |
summary_op = tf.summary.merge_all() | |
# Build an initialization operation to run below. | |
init = tf.global_variables_initializer() | |
# Start running operations on the Graph. allow_soft_placement must be set to | |
# True to build towers on GPU, as some of the ops do not have GPU | |
# implementations. | |
sess = tf.Session(config=tf.ConfigProto( | |
allow_soft_placement=True, | |
log_device_placement=True)) | |
sess.run(init) | |
ckpt_state = tf.train.get_checkpoint_state(logdir) | |
if ckpt_state and ckpt_state.model_checkpoint_path: | |
print("Loading file %s" % ckpt_state.model_checkpoint_path) | |
saver.restore(sess, ckpt_state.model_checkpoint_path) | |
# Start the queue runners. | |
tf.train.start_queue_runners(sess=sess) | |
summary_writer = tf.summary.FileWriter( | |
logdir, | |
graph=sess.graph) | |
local_step = 0 | |
while True: | |
fetches = [model.step, model.bit_per_dim, model.train_step] | |
# The chief worker evaluates the summaries every 10 steps. | |
should_eval_summaries = local_step % 100 == 0 | |
if should_eval_summaries: | |
fetches += [summary_op] | |
start_time = time.time() | |
outputs = sess.run(fetches) | |
global_step_val = outputs[0] | |
loss = outputs[1] | |
duration = time.time() - start_time | |
assert not numpy.isnan( | |
loss), 'Model diverged with loss = NaN' | |
if local_step % 10 == 0: | |
examples_per_sec = hps.batch_size / float(duration) | |
format_str = ('%s: step %d, loss = %.2f ' | |
'(%.1f examples/sec; %.3f ' | |
'sec/batch)') | |
print(format_str % (datetime.now(), global_step_val, loss, | |
examples_per_sec, duration)) | |
if should_eval_summaries: | |
summary_str = outputs[-1] | |
summary_writer.add_summary(summary_str, global_step_val) | |
# Save the model checkpoint periodically. | |
if local_step % 1000 == 0 or (local_step + 1) == FLAGS.train_steps: | |
checkpoint_path = os.path.join(logdir, 'model.ckpt') | |
saver.save( | |
sess, | |
checkpoint_path, | |
global_step=global_step_val) | |
if outputs[0] >= FLAGS.train_steps: | |
break | |
local_step += 1 | |
def evaluate(hps, logdir, traindir, subset="valid", return_val=False): | |
"""Evaluation.""" | |
hps.batch_size = 100 | |
with tf.Graph().as_default(): | |
with tf.device("/cpu:0"): | |
with tf.variable_scope("model") as var_scope: | |
eval_model = RealNVP(hps) | |
summary_writer = tf.summary.FileWriter(logdir) | |
var_scope.reuse_variables() | |
saver = tf.train.Saver() | |
sess = tf.Session(config=tf.ConfigProto( | |
allow_soft_placement=True, | |
log_device_placement=True)) | |
tf.train.start_queue_runners(sess) | |
previous_global_step = 0 # don"t run eval for step = 0 | |
with sess.as_default(): | |
while True: | |
ckpt_state = tf.train.get_checkpoint_state(traindir) | |
if not (ckpt_state and ckpt_state.model_checkpoint_path): | |
print("No model to eval yet at %s" % traindir) | |
time.sleep(30) | |
continue | |
print("Loading file %s" % ckpt_state.model_checkpoint_path) | |
saver.restore(sess, ckpt_state.model_checkpoint_path) | |
current_step = tf.train.global_step(sess, eval_model.step) | |
if current_step == previous_global_step: | |
print("Waiting for the checkpoint to be updated.") | |
time.sleep(30) | |
continue | |
previous_global_step = current_step | |
print("Evaluating...") | |
bit_per_dim = eval_model.eval_epoch(hps) | |
print("Epoch: %d, %s -> %.3f bits/dim" | |
% (current_step, subset, bit_per_dim)) | |
print("Writing summary...") | |
summary = tf.Summary() | |
summary.value.extend( | |
[tf.Summary.Value( | |
tag="bit_per_dim", | |
simple_value=bit_per_dim)]) | |
summary_writer.add_summary(summary, current_step) | |
if return_val: | |
return current_step, bit_per_dim | |
def sample_from_model(hps, logdir, traindir): | |
"""Sampling.""" | |
hps.batch_size = 100 | |
with tf.Graph().as_default(): | |
with tf.device("/cpu:0"): | |
with tf.variable_scope("model") as var_scope: | |
eval_model = RealNVP(hps, sampling=True) | |
summary_writer = tf.summary.FileWriter(logdir) | |
var_scope.reuse_variables() | |
summary_op = tf.summary.merge_all() | |
saver = tf.train.Saver() | |
sess = tf.Session(config=tf.ConfigProto( | |
allow_soft_placement=True, | |
log_device_placement=True)) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
previous_global_step = 0 # don"t run eval for step = 0 | |
initialized = False | |
with sess.as_default(): | |
while True: | |
ckpt_state = tf.train.get_checkpoint_state(traindir) | |
if not (ckpt_state and ckpt_state.model_checkpoint_path): | |
if not initialized: | |
print("No model to eval yet at %s" % traindir) | |
time.sleep(30) | |
continue | |
else: | |
print ("Loading file %s" | |
% ckpt_state.model_checkpoint_path) | |
saver.restore(sess, ckpt_state.model_checkpoint_path) | |
current_step = tf.train.global_step(sess, eval_model.step) | |
if current_step == previous_global_step: | |
print("Waiting for the checkpoint to be updated.") | |
time.sleep(30) | |
continue | |
previous_global_step = current_step | |
fetches = [summary_op] | |
outputs = sess.run(fetches) | |
summary_writer.add_summary(outputs[0], current_step) | |
coord.request_stop() | |
coord.join(threads) | |
def main(unused_argv): | |
hps = get_default_hparams().update_config(FLAGS.hpconfig) | |
if FLAGS.mode == "train": | |
train_model(hps=hps, logdir=FLAGS.logdir) | |
elif FLAGS.mode == "sample": | |
sample_from_model(hps=hps, logdir=FLAGS.logdir, | |
traindir=FLAGS.traindir) | |
else: | |
hps.batch_size = 100 | |
evaluate(hps=hps, logdir=FLAGS.logdir, | |
traindir=FLAGS.traindir, subset=FLAGS.mode) | |
if __name__ == "__main__": | |
tf.app.run() | |