NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
9.36 kB
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cross Convolutional Model.
https://arxiv.org/pdf/1607.02586v1.pdf
"""
import math
import sys
from six.moves import xrange
import tensorflow as tf
slim = tf.contrib.slim
class CrossConvModel(object):
def __init__(self, image_diff_list, params):
"""Constructor.
Args:
image_diff_list: A list of (image, diff) tuples, with shape
[batch_size, image_size, image_size, 3] and image_sizes as
[32, 64, 128, 256].
params: Dict of parameters.
"""
self.images = [i for (i, _) in image_diff_list]
# Move the diff to the positive realm.
self.diffs = [(d + params['scale']) / 2 for (i, d) in image_diff_list]
self.params = params
def Build(self):
with tf.device('/gpu:0'):
with slim.arg_scope([slim.conv2d],
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params={'is_training':
self.params['is_training']}):
self._BuildMotionKernel()
encoded_images = self._BuildImageEncoder()
cross_conved_images = self._CrossConv(encoded_images)
self._BuildImageDecoder(cross_conved_images)
self._BuildLoss()
image = self.images[1]
diff = self.diffs[1]
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self.params['is_training']:
self._BuildTrainOp()
diff = diff * 2.0 - self.params['scale']
diff_output = self.diff_output * 2.0 - self.params['scale']
concat_image = tf.concat(
axis=1, values=[image, image + diff_output, image + diff, diff_output])
tf.summary.image('origin_predict_expect_predictdiff', concat_image)
self.summary_op = tf.summary.merge_all()
return self.loss
def _BuildTrainOp(self):
lrn_rate = tf.maximum(
0.01, # min_lr_rate.
tf.train.exponential_decay(
self.params['learning_rate'], self.global_step, 10000, 0.5))
tf.summary.scalar('learning rate', lrn_rate)
optimizer = tf.train.GradientDescentOptimizer(lrn_rate)
self.train_op = slim.learning.create_train_op(
self.loss, optimizer, global_step=self.global_step)
def _BuildLoss(self):
# 1. reconstr_loss seems doesn't do better than l2 loss.
# 2. Only works when using reduce_mean. reduce_sum doesn't work.
# 3. It seems kl loss doesn't play an important role.
self.loss = 0
with tf.variable_scope('loss'):
if self.params['l2_loss']:
l2_loss = tf.reduce_mean(tf.square(self.diff_output - self.diffs[1]))
tf.summary.scalar('l2_loss', l2_loss)
self.loss += l2_loss
if self.params['reconstr_loss']:
reconstr_loss = (-tf.reduce_mean(
self.diffs[1] * (1e-10 + self.diff_output) +
(1-self.diffs[1]) * tf.log(1e-10 + 1 - self.diff_output)))
reconstr_loss = tf.check_numerics(reconstr_loss, 'reconstr_loss')
tf.summary.scalar('reconstr_loss', reconstr_loss)
self.loss += reconstr_loss
if self.params['kl_loss']:
kl_loss = (0.5 * tf.reduce_mean(
tf.square(self.z_mean) + tf.square(self.z_stddev) -
2 * self.z_stddev_log - 1))
tf.summary.scalar('kl_loss', kl_loss)
self.loss += kl_loss
tf.summary.scalar('loss', self.loss)
def _BuildMotionKernel(self):
image = self.images[-2]
diff = self.diffs[-2]
shape = image.get_shape().as_list()
assert shape[1] == shape[2] and shape[1] == 128
batch_size = shape[0]
net = tf.concat(axis=3, values=[image, diff])
with tf.variable_scope('motion_encoder'):
with slim.arg_scope([slim.conv2d], padding='VALID'):
net = slim.conv2d(net, 96, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 96, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 128, [5, 5], stride=1)
net = slim.conv2d(net, 128, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
net = slim.conv2d(net, 256, [4, 4], stride=1)
net = slim.conv2d(net, 256, [3, 3], stride=1)
z = tf.reshape(net, shape=[batch_size, -1])
self.z_mean, self.z_stddev_log = tf.split(
axis=1, num_or_size_splits=2, value=z)
self.z_stddev = tf.exp(self.z_stddev_log)
epsilon = tf.random_normal(
self.z_mean.get_shape().as_list(), 0, 1, dtype=tf.float32)
kernel = self.z_mean + tf.multiply(self.z_stddev, epsilon)
width = int(math.sqrt(kernel.get_shape().as_list()[1] // 128))
kernel = tf.reshape(kernel, [batch_size, width, width, 128])
with tf.variable_scope('kernel_decoder'):
with slim.arg_scope([slim.conv2d], padding='SAME'):
kernel = slim.conv2d(kernel, 128, [5, 5], stride=1)
self.kernel = slim.conv2d(kernel, 128, [5, 5], stride=1)
sys.stderr.write('kernel shape: %s\n' % kernel.get_shape())
def _BuildImageEncoder(self):
feature_maps = []
for (i, image) in enumerate(self.images):
with tf.variable_scope('image_encoder_%d' % i):
with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'):
net = slim.conv2d(image, 64, [5, 5], stride=1)
net = slim.conv2d(net, 64, [5, 5], stride=1)
net = slim.max_pool2d(net, [5, 5])
net = slim.conv2d(net, 64, [5, 5], stride=1)
net = slim.conv2d(net, 32, [5, 5], stride=1)
net = slim.max_pool2d(net, [2, 2])
sys.stderr.write('image_conv shape: %s\n' % net.get_shape())
feature_maps.append(net)
return feature_maps
def _CrossConvHelper(self, encoded_image, kernel):
"""Cross Convolution.
The encoded image and kernel are of the same shape. Namely
[batch_size, image_size, image_size, channels]. They are split
into [image_size, image_size] image squares [kernel_size, kernel_size]
kernel squares. kernel squares are used to convolute image squares.
"""
images = tf.expand_dims(encoded_image, 0)
kernels = tf.expand_dims(kernel, 3)
return tf.nn.depthwise_conv2d(images, kernels, [1, 1, 1, 1], 'SAME')
def _CrossConv(self, encoded_images):
"""Apply the motion kernel on the encoded_images."""
cross_conved_images = []
kernels = tf.split(axis=3, num_or_size_splits=4, value=self.kernel)
for (i, encoded_image) in enumerate(encoded_images):
with tf.variable_scope('cross_conv_%d' % i):
kernel = kernels[i]
encoded_image = tf.unstack(encoded_image, axis=0)
kernel = tf.unstack(kernel, axis=0)
assert len(encoded_image) == len(kernel)
assert len(encoded_image) == self.params['batch_size']
conved_image = []
for j in xrange(len(encoded_image)):
conved_image.append(self._CrossConvHelper(
encoded_image[j], kernel[j]))
cross_conved_images.append(tf.concat(axis=0, values=conved_image))
sys.stderr.write('cross_conved shape: %s\n' %
cross_conved_images[-1].get_shape())
return cross_conved_images
def _Deconv(self, net, out_filters, kernel_size, stride):
shape = net.get_shape().as_list()
in_filters = shape[3]
kernel_shape = [kernel_size, kernel_size, out_filters, in_filters]
weights = tf.get_variable(
name='weights',
shape=kernel_shape,
dtype=tf.float32,
initializer=tf.truncated_normal_initializer(stddev=0.01))
out_height = shape[1] * stride
out_width = shape[2] * stride
batch_size = shape[0]
output_shape = [batch_size, out_height, out_width, out_filters]
net = tf.nn.conv2d_transpose(net, weights, output_shape,
[1, stride, stride, 1], padding='SAME')
slim.batch_norm(net)
return net
def _BuildImageDecoder(self, cross_conved_images):
"""Decode the cross_conved feature maps into the predicted images."""
nets = []
for i, cross_conved_image in enumerate(cross_conved_images):
with tf.variable_scope('image_decoder_%d' % i):
stride = 64 / cross_conved_image.get_shape().as_list()[1]
# TODO(xpan): Alternative solution for upsampling?
nets.append(self._Deconv(
cross_conved_image, 64, kernel_size=3, stride=stride))
net = tf.concat(axis=3, values=nets)
net = slim.conv2d(net, 128, [9, 9], padding='SAME', stride=1)
net = slim.conv2d(net, 128, [1, 1], padding='SAME', stride=1)
net = slim.conv2d(net, 3, [1, 1], padding='SAME', stride=1)
self.diff_output = net
sys.stderr.write('diff_output shape: %s\n' % self.diff_output.get_shape())