Spaces:
Running
Running
# Copyright 2016 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Utility functions for Real NVP. | |
""" | |
# pylint: disable=dangerous-default-value | |
import numpy | |
from six.moves import xrange | |
import tensorflow as tf | |
from tensorflow.python.framework import ops | |
DEFAULT_BN_LAG = .0 | |
def stable_var(input_, mean=None, axes=[0]): | |
"""Numerically more stable variance computation.""" | |
if mean is None: | |
mean = tf.reduce_mean(input_, axes) | |
res = tf.square(input_ - mean) | |
max_sqr = tf.reduce_max(res, axes) | |
res /= max_sqr | |
res = tf.reduce_mean(res, axes) | |
res *= max_sqr | |
return res | |
def variable_on_cpu(name, shape, initializer, trainable=True): | |
"""Helper to create a Variable stored on CPU memory. | |
Args: | |
name: name of the variable | |
shape: list of ints | |
initializer: initializer for Variable | |
trainable: boolean defining if the variable is for training | |
Returns: | |
Variable Tensor | |
""" | |
var = tf.get_variable( | |
name, shape, initializer=initializer, trainable=trainable) | |
return var | |
# layers | |
def conv_layer(input_, | |
filter_size, | |
dim_in, | |
dim_out, | |
name, | |
stddev=1e-2, | |
strides=[1, 1, 1, 1], | |
padding="SAME", | |
nonlinearity=None, | |
bias=False, | |
weight_norm=False, | |
scale=False): | |
"""Convolutional layer.""" | |
with tf.variable_scope(name) as scope: | |
weights = variable_on_cpu( | |
"weights", | |
filter_size + [dim_in, dim_out], | |
tf.random_uniform_initializer( | |
minval=-stddev, maxval=stddev)) | |
# weight normalization | |
if weight_norm: | |
weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2])) | |
if scale: | |
magnitude = variable_on_cpu( | |
"magnitude", [dim_out], | |
tf.constant_initializer( | |
stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.))) | |
weights *= magnitude | |
res = input_ | |
# handling filter size bigger than image size | |
if hasattr(input_, "shape"): | |
if input_.get_shape().as_list()[1] < filter_size[0]: | |
pad_1 = tf.zeros([ | |
input_.get_shape().as_list()[0], | |
filter_size[0] - input_.get_shape().as_list()[1], | |
input_.get_shape().as_list()[2], | |
input_.get_shape().as_list()[3] | |
]) | |
pad_2 = tf.zeros([ | |
input_.get_shape().as_list[0], | |
filter_size[0], | |
filter_size[1] - input_.get_shape().as_list()[2], | |
input_.get_shape().as_list()[3] | |
]) | |
res = tf.concat(axis=1, values=[pad_1, res]) | |
res = tf.concat(axis=2, values=[pad_2, res]) | |
res = tf.nn.conv2d( | |
input=res, | |
filter=weights, | |
strides=strides, | |
padding=padding, | |
name=scope.name) | |
if hasattr(input_, "shape"): | |
if input_.get_shape().as_list()[1] < filter_size[0]: | |
res = tf.slice(res, [ | |
0, filter_size[0] - input_.get_shape().as_list()[1], | |
filter_size[1] - input_.get_shape().as_list()[2], 0 | |
], [-1, -1, -1, -1]) | |
if bias: | |
biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.)) | |
res = tf.nn.bias_add(res, biases) | |
if nonlinearity is not None: | |
res = nonlinearity(res) | |
return res | |
def max_pool_2x2(input_): | |
"""Max pooling.""" | |
return tf.nn.max_pool( | |
input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") | |
def depool_2x2(input_, stride=2): | |
"""Depooling.""" | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels]) | |
res = tf.concat( | |
axis=2, values=[res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])]) | |
res = tf.concat(axis=4, values=[ | |
res, tf.zeros([batch_size, height, stride, width, stride - 1, channels]) | |
]) | |
res = tf.reshape(res, [batch_size, stride * height, stride * width, channels]) | |
return res | |
# random flip on a batch of images | |
def batch_random_flip(input_): | |
"""Simultaneous horizontal random flip.""" | |
if isinstance(input_, (float, int)): | |
return input_ | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
res = tf.split(axis=0, num_or_size_splits=batch_size, value=input_) | |
res = [elem[0, :, :, :] for elem in res] | |
res = [tf.image.random_flip_left_right(elem) for elem in res] | |
res = [tf.reshape(elem, [1, height, width, channels]) for elem in res] | |
res = tf.concat(axis=0, values=res) | |
return res | |
# build a one hot representation corresponding to the integer tensor | |
# the one-hot dimension is appended to the integer tensor shape | |
def as_one_hot(input_, n_indices): | |
"""Convert indices to one-hot.""" | |
shape = input_.get_shape().as_list() | |
n_elem = numpy.prod(shape) | |
indices = tf.range(n_elem) | |
indices = tf.cast(indices, tf.int64) | |
indices_input = tf.concat(axis=0, values=[indices, tf.reshape(input_, [-1])]) | |
indices_input = tf.reshape(indices_input, [2, -1]) | |
indices_input = tf.transpose(indices_input) | |
res = tf.sparse_to_dense( | |
indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot") | |
res = tf.reshape(res, [elem for elem in shape] + [n_indices]) | |
return res | |
def squeeze_2x2(input_): | |
"""Squeezing operation: reshape to convert space to channels.""" | |
return squeeze_nxn(input_, n_factor=2) | |
def squeeze_nxn(input_, n_factor=2): | |
"""Squeezing operation: reshape to convert space to channels.""" | |
if isinstance(input_, (float, int)): | |
return input_ | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
if height % n_factor != 0: | |
raise ValueError("Height not divisible by %d." % n_factor) | |
if width % n_factor != 0: | |
raise ValueError("Width not divisible by %d." % n_factor) | |
res = tf.reshape( | |
input_, | |
[batch_size, | |
height // n_factor, | |
n_factor, width // n_factor, | |
n_factor, channels]) | |
res = tf.transpose(res, [0, 1, 3, 5, 2, 4]) | |
res = tf.reshape( | |
res, | |
[batch_size, | |
height // n_factor, | |
width // n_factor, | |
channels * n_factor * n_factor]) | |
return res | |
def unsqueeze_2x2(input_): | |
"""Unsqueezing operation: reshape to convert channels into space.""" | |
if isinstance(input_, (float, int)): | |
return input_ | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
if channels % 4 != 0: | |
raise ValueError("Number of channels not divisible by 4.") | |
res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2]) | |
res = tf.transpose(res, [0, 1, 4, 2, 5, 3]) | |
res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4]) | |
return res | |
# batch norm | |
def batch_norm(input_, | |
dim, | |
name, | |
scale=True, | |
train=True, | |
epsilon=1e-8, | |
decay=.1, | |
axes=[0], | |
bn_lag=DEFAULT_BN_LAG): | |
"""Batch normalization.""" | |
# create variables | |
with tf.variable_scope(name): | |
var = variable_on_cpu( | |
"var", [dim], tf.constant_initializer(1.), trainable=False) | |
mean = variable_on_cpu( | |
"mean", [dim], tf.constant_initializer(0.), trainable=False) | |
step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False) | |
if scale: | |
gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.)) | |
beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.)) | |
# choose the appropriate moments | |
if train: | |
used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm") | |
cur_mean, cur_var = used_mean, used_var | |
if bn_lag > 0.: | |
used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean)) | |
used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var)) | |
used_mean /= (1. - bn_lag**(step + 1)) | |
used_var /= (1. - bn_lag**(step + 1)) | |
else: | |
used_mean, used_var = mean, var | |
cur_mean, cur_var = used_mean, used_var | |
# normalize | |
res = (input_ - used_mean) / tf.sqrt(used_var + epsilon) | |
# de-normalize | |
if scale: | |
res *= gamma | |
res += beta | |
# update variables | |
if train: | |
with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]): | |
with ops.colocate_with(mean): | |
new_mean = tf.assign_sub( | |
mean, | |
tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean.")) | |
with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]): | |
with ops.colocate_with(var): | |
new_var = tf.assign_sub( | |
var, | |
tf.check_numerics(decay * (var - cur_var), | |
"NaN in moving variance.")) | |
with tf.name_scope(name, "IncrementTime", [step]): | |
with ops.colocate_with(step): | |
new_step = tf.assign_add(step, 1.) | |
res += 0. * new_mean * new_var * new_step | |
return res | |
# batch normalization taking into account the volume transformation | |
def batch_norm_log_diff(input_, | |
dim, | |
name, | |
train=True, | |
epsilon=1e-8, | |
decay=.1, | |
axes=[0], | |
reuse=None, | |
bn_lag=DEFAULT_BN_LAG): | |
"""Batch normalization with corresponding log determinant Jacobian.""" | |
if reuse is None: | |
reuse = not train | |
# create variables | |
with tf.variable_scope(name) as scope: | |
if reuse: | |
scope.reuse_variables() | |
var = variable_on_cpu( | |
"var", [dim], tf.constant_initializer(1.), trainable=False) | |
mean = variable_on_cpu( | |
"mean", [dim], tf.constant_initializer(0.), trainable=False) | |
step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False) | |
# choose the appropriate moments | |
if train: | |
used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm") | |
cur_mean, cur_var = used_mean, used_var | |
if bn_lag > 0.: | |
used_var = stable_var(input_=input_, mean=used_mean, axes=axes) | |
cur_var = used_var | |
used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean)) | |
used_mean /= (1. - bn_lag**(step + 1)) | |
used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var)) | |
used_var /= (1. - bn_lag**(step + 1)) | |
else: | |
used_mean, used_var = mean, var | |
cur_mean, cur_var = used_mean, used_var | |
# update variables | |
if train: | |
with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]): | |
with ops.colocate_with(mean): | |
new_mean = tf.assign_sub( | |
mean, | |
tf.check_numerics( | |
decay * (mean - cur_mean), "NaN in moving mean.")) | |
with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]): | |
with ops.colocate_with(var): | |
new_var = tf.assign_sub( | |
var, | |
tf.check_numerics(decay * (var - cur_var), | |
"NaN in moving variance.")) | |
with tf.name_scope(name, "IncrementTime", [step]): | |
with ops.colocate_with(step): | |
new_step = tf.assign_add(step, 1.) | |
used_var += 0. * new_mean * new_var * new_step | |
used_var += epsilon | |
return used_mean, used_var | |
def convnet(input_, | |
dim_in, | |
dim_hid, | |
filter_sizes, | |
dim_out, | |
name, | |
use_batch_norm=True, | |
train=True, | |
nonlinearity=tf.nn.relu): | |
"""Chaining of convolutional layers.""" | |
dims_in = [dim_in] + dim_hid[:-1] | |
dims_out = dim_hid | |
res = input_ | |
bias = (not use_batch_norm) | |
with tf.variable_scope(name): | |
for layer_idx in xrange(len(dim_hid)): | |
res = conv_layer( | |
input_=res, | |
filter_size=filter_sizes[layer_idx], | |
dim_in=dims_in[layer_idx], | |
dim_out=dims_out[layer_idx], | |
name="h_%d" % layer_idx, | |
stddev=1e-2, | |
nonlinearity=None, | |
bias=bias) | |
if use_batch_norm: | |
res = batch_norm( | |
input_=res, | |
dim=dims_out[layer_idx], | |
name="bn_%d" % layer_idx, | |
scale=(nonlinearity == tf.nn.relu), | |
train=train, | |
epsilon=1e-8, | |
axes=[0, 1, 2]) | |
if nonlinearity is not None: | |
res = nonlinearity(res) | |
res = conv_layer( | |
input_=res, | |
filter_size=filter_sizes[-1], | |
dim_in=dims_out[-1], | |
dim_out=dim_out, | |
name="out", | |
stddev=1e-2, | |
nonlinearity=None) | |
return res | |
# distributions | |
# log-likelihood estimation | |
def standard_normal_ll(input_): | |
"""Log-likelihood of standard Gaussian distribution.""" | |
res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi)) | |
return res | |
def standard_normal_sample(shape): | |
"""Samples from standard Gaussian distribution.""" | |
return tf.random_normal(shape) | |
SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]], | |
[[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]]) | |
def squeeze_2x2_ordered(input_, reverse=False): | |
"""Squeezing operation with a controlled ordering.""" | |
shape = input_.get_shape().as_list() | |
batch_size = shape[0] | |
height = shape[1] | |
width = shape[2] | |
channels = shape[3] | |
if reverse: | |
if channels % 4 != 0: | |
raise ValueError("Number of channels not divisible by 4.") | |
channels /= 4 | |
else: | |
if height % 2 != 0: | |
raise ValueError("Height not divisible by 2.") | |
if width % 2 != 0: | |
raise ValueError("Width not divisible by 2.") | |
weights = numpy.zeros((2, 2, channels, 4 * channels)) | |
for idx_ch in xrange(channels): | |
slice_2 = slice(idx_ch, (idx_ch + 1)) | |
slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4)) | |
weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX | |
shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)] | |
shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)] | |
shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)] | |
shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)] | |
shuffle_channels = numpy.array(shuffle_channels) | |
weights = weights[:, :, :, shuffle_channels].astype("float32") | |
if reverse: | |
res = tf.nn.conv2d_transpose( | |
value=input_, | |
filter=weights, | |
output_shape=[batch_size, height * 2, width * 2, channels], | |
strides=[1, 2, 2, 1], | |
padding="SAME", | |
name="unsqueeze_2x2") | |
else: | |
res = tf.nn.conv2d( | |
input=input_, | |
filter=weights, | |
strides=[1, 2, 2, 1], | |
padding="SAME", | |
name="squeeze_2x2") | |
return res | |