Spaces:
Running
Running
# Copyright 2016 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Cross Convolutional Model. | |
https://arxiv.org/pdf/1607.02586v1.pdf | |
""" | |
import math | |
import sys | |
from six.moves import xrange | |
import tensorflow as tf | |
slim = tf.contrib.slim | |
class CrossConvModel(object): | |
def __init__(self, image_diff_list, params): | |
"""Constructor. | |
Args: | |
image_diff_list: A list of (image, diff) tuples, with shape | |
[batch_size, image_size, image_size, 3] and image_sizes as | |
[32, 64, 128, 256]. | |
params: Dict of parameters. | |
""" | |
self.images = [i for (i, _) in image_diff_list] | |
# Move the diff to the positive realm. | |
self.diffs = [(d + params['scale']) / 2 for (i, d) in image_diff_list] | |
self.params = params | |
def Build(self): | |
with tf.device('/gpu:0'): | |
with slim.arg_scope([slim.conv2d], | |
activation_fn=tf.nn.relu, | |
normalizer_fn=slim.batch_norm, | |
normalizer_params={'is_training': | |
self.params['is_training']}): | |
self._BuildMotionKernel() | |
encoded_images = self._BuildImageEncoder() | |
cross_conved_images = self._CrossConv(encoded_images) | |
self._BuildImageDecoder(cross_conved_images) | |
self._BuildLoss() | |
image = self.images[1] | |
diff = self.diffs[1] | |
self.global_step = tf.Variable(0, name='global_step', trainable=False) | |
if self.params['is_training']: | |
self._BuildTrainOp() | |
diff = diff * 2.0 - self.params['scale'] | |
diff_output = self.diff_output * 2.0 - self.params['scale'] | |
concat_image = tf.concat( | |
axis=1, values=[image, image + diff_output, image + diff, diff_output]) | |
tf.summary.image('origin_predict_expect_predictdiff', concat_image) | |
self.summary_op = tf.summary.merge_all() | |
return self.loss | |
def _BuildTrainOp(self): | |
lrn_rate = tf.maximum( | |
0.01, # min_lr_rate. | |
tf.train.exponential_decay( | |
self.params['learning_rate'], self.global_step, 10000, 0.5)) | |
tf.summary.scalar('learning rate', lrn_rate) | |
optimizer = tf.train.GradientDescentOptimizer(lrn_rate) | |
self.train_op = slim.learning.create_train_op( | |
self.loss, optimizer, global_step=self.global_step) | |
def _BuildLoss(self): | |
# 1. reconstr_loss seems doesn't do better than l2 loss. | |
# 2. Only works when using reduce_mean. reduce_sum doesn't work. | |
# 3. It seems kl loss doesn't play an important role. | |
self.loss = 0 | |
with tf.variable_scope('loss'): | |
if self.params['l2_loss']: | |
l2_loss = tf.reduce_mean(tf.square(self.diff_output - self.diffs[1])) | |
tf.summary.scalar('l2_loss', l2_loss) | |
self.loss += l2_loss | |
if self.params['reconstr_loss']: | |
reconstr_loss = (-tf.reduce_mean( | |
self.diffs[1] * (1e-10 + self.diff_output) + | |
(1-self.diffs[1]) * tf.log(1e-10 + 1 - self.diff_output))) | |
reconstr_loss = tf.check_numerics(reconstr_loss, 'reconstr_loss') | |
tf.summary.scalar('reconstr_loss', reconstr_loss) | |
self.loss += reconstr_loss | |
if self.params['kl_loss']: | |
kl_loss = (0.5 * tf.reduce_mean( | |
tf.square(self.z_mean) + tf.square(self.z_stddev) - | |
2 * self.z_stddev_log - 1)) | |
tf.summary.scalar('kl_loss', kl_loss) | |
self.loss += kl_loss | |
tf.summary.scalar('loss', self.loss) | |
def _BuildMotionKernel(self): | |
image = self.images[-2] | |
diff = self.diffs[-2] | |
shape = image.get_shape().as_list() | |
assert shape[1] == shape[2] and shape[1] == 128 | |
batch_size = shape[0] | |
net = tf.concat(axis=3, values=[image, diff]) | |
with tf.variable_scope('motion_encoder'): | |
with slim.arg_scope([slim.conv2d], padding='VALID'): | |
net = slim.conv2d(net, 96, [5, 5], stride=1) | |
net = slim.max_pool2d(net, [2, 2]) | |
net = slim.conv2d(net, 96, [5, 5], stride=1) | |
net = slim.max_pool2d(net, [2, 2]) | |
net = slim.conv2d(net, 128, [5, 5], stride=1) | |
net = slim.conv2d(net, 128, [5, 5], stride=1) | |
net = slim.max_pool2d(net, [2, 2]) | |
net = slim.conv2d(net, 256, [4, 4], stride=1) | |
net = slim.conv2d(net, 256, [3, 3], stride=1) | |
z = tf.reshape(net, shape=[batch_size, -1]) | |
self.z_mean, self.z_stddev_log = tf.split( | |
axis=1, num_or_size_splits=2, value=z) | |
self.z_stddev = tf.exp(self.z_stddev_log) | |
epsilon = tf.random_normal( | |
self.z_mean.get_shape().as_list(), 0, 1, dtype=tf.float32) | |
kernel = self.z_mean + tf.multiply(self.z_stddev, epsilon) | |
width = int(math.sqrt(kernel.get_shape().as_list()[1] // 128)) | |
kernel = tf.reshape(kernel, [batch_size, width, width, 128]) | |
with tf.variable_scope('kernel_decoder'): | |
with slim.arg_scope([slim.conv2d], padding='SAME'): | |
kernel = slim.conv2d(kernel, 128, [5, 5], stride=1) | |
self.kernel = slim.conv2d(kernel, 128, [5, 5], stride=1) | |
sys.stderr.write('kernel shape: %s\n' % kernel.get_shape()) | |
def _BuildImageEncoder(self): | |
feature_maps = [] | |
for (i, image) in enumerate(self.images): | |
with tf.variable_scope('image_encoder_%d' % i): | |
with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'): | |
net = slim.conv2d(image, 64, [5, 5], stride=1) | |
net = slim.conv2d(net, 64, [5, 5], stride=1) | |
net = slim.max_pool2d(net, [5, 5]) | |
net = slim.conv2d(net, 64, [5, 5], stride=1) | |
net = slim.conv2d(net, 32, [5, 5], stride=1) | |
net = slim.max_pool2d(net, [2, 2]) | |
sys.stderr.write('image_conv shape: %s\n' % net.get_shape()) | |
feature_maps.append(net) | |
return feature_maps | |
def _CrossConvHelper(self, encoded_image, kernel): | |
"""Cross Convolution. | |
The encoded image and kernel are of the same shape. Namely | |
[batch_size, image_size, image_size, channels]. They are split | |
into [image_size, image_size] image squares [kernel_size, kernel_size] | |
kernel squares. kernel squares are used to convolute image squares. | |
""" | |
images = tf.expand_dims(encoded_image, 0) | |
kernels = tf.expand_dims(kernel, 3) | |
return tf.nn.depthwise_conv2d(images, kernels, [1, 1, 1, 1], 'SAME') | |
def _CrossConv(self, encoded_images): | |
"""Apply the motion kernel on the encoded_images.""" | |
cross_conved_images = [] | |
kernels = tf.split(axis=3, num_or_size_splits=4, value=self.kernel) | |
for (i, encoded_image) in enumerate(encoded_images): | |
with tf.variable_scope('cross_conv_%d' % i): | |
kernel = kernels[i] | |
encoded_image = tf.unstack(encoded_image, axis=0) | |
kernel = tf.unstack(kernel, axis=0) | |
assert len(encoded_image) == len(kernel) | |
assert len(encoded_image) == self.params['batch_size'] | |
conved_image = [] | |
for j in xrange(len(encoded_image)): | |
conved_image.append(self._CrossConvHelper( | |
encoded_image[j], kernel[j])) | |
cross_conved_images.append(tf.concat(axis=0, values=conved_image)) | |
sys.stderr.write('cross_conved shape: %s\n' % | |
cross_conved_images[-1].get_shape()) | |
return cross_conved_images | |
def _Deconv(self, net, out_filters, kernel_size, stride): | |
shape = net.get_shape().as_list() | |
in_filters = shape[3] | |
kernel_shape = [kernel_size, kernel_size, out_filters, in_filters] | |
weights = tf.get_variable( | |
name='weights', | |
shape=kernel_shape, | |
dtype=tf.float32, | |
initializer=tf.truncated_normal_initializer(stddev=0.01)) | |
out_height = shape[1] * stride | |
out_width = shape[2] * stride | |
batch_size = shape[0] | |
output_shape = [batch_size, out_height, out_width, out_filters] | |
net = tf.nn.conv2d_transpose(net, weights, output_shape, | |
[1, stride, stride, 1], padding='SAME') | |
slim.batch_norm(net) | |
return net | |
def _BuildImageDecoder(self, cross_conved_images): | |
"""Decode the cross_conved feature maps into the predicted images.""" | |
nets = [] | |
for i, cross_conved_image in enumerate(cross_conved_images): | |
with tf.variable_scope('image_decoder_%d' % i): | |
stride = 64 / cross_conved_image.get_shape().as_list()[1] | |
# TODO(xpan): Alternative solution for upsampling? | |
nets.append(self._Deconv( | |
cross_conved_image, 64, kernel_size=3, stride=stride)) | |
net = tf.concat(axis=3, values=nets) | |
net = slim.conv2d(net, 128, [9, 9], padding='SAME', stride=1) | |
net = slim.conv2d(net, 128, [1, 1], padding='SAME', stride=1) | |
net = slim.conv2d(net, 3, [1, 1], padding='SAME', stride=1) | |
self.diff_output = net | |
sys.stderr.write('diff_output shape: %s\n' % self.diff_output.get_shape()) | |