# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Perspective Transformer Layer Implementation. Transform the volume based on 4 x 4 perspective projection matrix. Reference: (1) "Perspective Transformer Nets: Perspective Transformer Nets: Learning Single-View 3D Object Reconstruction without 3D Supervision." Xinchen Yan, Jimei Yang, Ersin Yumer, Yijie Guo, Honglak Lee. In NIPS 2016 https://papers.nips.cc/paper/6206-perspective-transformer-nets-learning-single-view-3d-object-reconstruction-without-3d-supervision.pdf (2) Official implementation in Torch: https://github.com/xcyan/ptnbhwd (3) 2D Transformer implementation in TF: github.com/tensorflow/models/tree/master/research/transformer """ import tensorflow as tf def transformer(voxels, theta, out_size, z_near, z_far, name='PerspectiveTransformer'): """Perspective Transformer Layer. Args: voxels: A tensor of size [num_batch, depth, height, width, num_channels]. It is the output of a deconv/upsampling conv network (tf.float32). theta: A tensor of size [num_batch, 16]. It is the inverse camera transformation matrix (tf.float32). out_size: A tuple representing the size of output of transformer layer (float). z_near: A number representing the near clipping plane (float). z_far: A number representing the far clipping plane (float). Returns: A transformed tensor (tf.float32). """ def _repeat(x, n_repeats): with tf.variable_scope('_repeat'): rep = tf.transpose( tf.expand_dims(tf.ones(shape=tf.stack([ n_repeats, ])), 1), [1, 0]) rep = tf.to_int32(rep) x = tf.matmul(tf.reshape(x, (-1, 1)), rep) return tf.reshape(x, [-1]) def _interpolate(im, x, y, z, out_size): """Bilinear interploation layer. Args: im: A 5D tensor of size [num_batch, depth, height, width, num_channels]. It is the input volume for the transformation layer (tf.float32). x: A tensor of size [num_batch, out_depth, out_height, out_width] representing the inverse coordinate mapping for x (tf.float32). y: A tensor of size [num_batch, out_depth, out_height, out_width] representing the inverse coordinate mapping for y (tf.float32). z: A tensor of size [num_batch, out_depth, out_height, out_width] representing the inverse coordinate mapping for z (tf.float32). out_size: A tuple representing the output size of transformation layer (float). Returns: A transformed tensor (tf.float32). """ with tf.variable_scope('_interpolate'): num_batch = im.get_shape().as_list()[0] depth = im.get_shape().as_list()[1] height = im.get_shape().as_list()[2] width = im.get_shape().as_list()[3] channels = im.get_shape().as_list()[4] x = tf.to_float(x) y = tf.to_float(y) z = tf.to_float(z) depth_f = tf.to_float(depth) height_f = tf.to_float(height) width_f = tf.to_float(width) # Number of disparity interpolated. out_depth = out_size[0] out_height = out_size[1] out_width = out_size[2] zero = tf.zeros([], dtype='int32') # 0 <= z < depth, 0 <= y < height & 0 <= x < width. max_z = tf.to_int32(tf.shape(im)[1] - 1) max_y = tf.to_int32(tf.shape(im)[2] - 1) max_x = tf.to_int32(tf.shape(im)[3] - 1) # Converts scale indices from [-1, 1] to [0, width/height/depth]. x = (x + 1.0) * (width_f) / 2.0 y = (y + 1.0) * (height_f) / 2.0 z = (z + 1.0) * (depth_f) / 2.0 x0 = tf.to_int32(tf.floor(x)) x1 = x0 + 1 y0 = tf.to_int32(tf.floor(y)) y1 = y0 + 1 z0 = tf.to_int32(tf.floor(z)) z1 = z0 + 1 x0_clip = tf.clip_by_value(x0, zero, max_x) x1_clip = tf.clip_by_value(x1, zero, max_x) y0_clip = tf.clip_by_value(y0, zero, max_y) y1_clip = tf.clip_by_value(y1, zero, max_y) z0_clip = tf.clip_by_value(z0, zero, max_z) z1_clip = tf.clip_by_value(z1, zero, max_z) dim3 = width dim2 = width * height dim1 = width * height * depth base = _repeat( tf.range(num_batch) * dim1, out_depth * out_height * out_width) base_z0_y0 = base + z0_clip * dim2 + y0_clip * dim3 base_z0_y1 = base + z0_clip * dim2 + y1_clip * dim3 base_z1_y0 = base + z1_clip * dim2 + y0_clip * dim3 base_z1_y1 = base + z1_clip * dim2 + y1_clip * dim3 idx_z0_y0_x0 = base_z0_y0 + x0_clip idx_z0_y0_x1 = base_z0_y0 + x1_clip idx_z0_y1_x0 = base_z0_y1 + x0_clip idx_z0_y1_x1 = base_z0_y1 + x1_clip idx_z1_y0_x0 = base_z1_y0 + x0_clip idx_z1_y0_x1 = base_z1_y0 + x1_clip idx_z1_y1_x0 = base_z1_y1 + x0_clip idx_z1_y1_x1 = base_z1_y1 + x1_clip # Use indices to lookup pixels in the flat image and restore # channels dim im_flat = tf.reshape(im, tf.stack([-1, channels])) im_flat = tf.to_float(im_flat) i_z0_y0_x0 = tf.gather(im_flat, idx_z0_y0_x0) i_z0_y0_x1 = tf.gather(im_flat, idx_z0_y0_x1) i_z0_y1_x0 = tf.gather(im_flat, idx_z0_y1_x0) i_z0_y1_x1 = tf.gather(im_flat, idx_z0_y1_x1) i_z1_y0_x0 = tf.gather(im_flat, idx_z1_y0_x0) i_z1_y0_x1 = tf.gather(im_flat, idx_z1_y0_x1) i_z1_y1_x0 = tf.gather(im_flat, idx_z1_y1_x0) i_z1_y1_x1 = tf.gather(im_flat, idx_z1_y1_x1) # Finally calculate interpolated values. x0_f = tf.to_float(x0) x1_f = tf.to_float(x1) y0_f = tf.to_float(y0) y1_f = tf.to_float(y1) z0_f = tf.to_float(z0) z1_f = tf.to_float(z1) # Check the out-of-boundary case. x0_valid = tf.to_float( tf.less_equal(x0, max_x) & tf.greater_equal(x0, 0)) x1_valid = tf.to_float( tf.less_equal(x1, max_x) & tf.greater_equal(x1, 0)) y0_valid = tf.to_float( tf.less_equal(y0, max_y) & tf.greater_equal(y0, 0)) y1_valid = tf.to_float( tf.less_equal(y1, max_y) & tf.greater_equal(y1, 0)) z0_valid = tf.to_float( tf.less_equal(z0, max_z) & tf.greater_equal(z0, 0)) z1_valid = tf.to_float( tf.less_equal(z1, max_z) & tf.greater_equal(z1, 0)) w_z0_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) * (z1_f - z) * x1_valid * y1_valid * z1_valid), 1) w_z0_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) * (z1_f - z) * x0_valid * y1_valid * z1_valid), 1) w_z0_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) * (z1_f - z) * x1_valid * y0_valid * z1_valid), 1) w_z0_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) * (z1_f - z) * x0_valid * y0_valid * z1_valid), 1) w_z1_y0_x0 = tf.expand_dims(((x1_f - x) * (y1_f - y) * (z - z0_f) * x1_valid * y1_valid * z0_valid), 1) w_z1_y0_x1 = tf.expand_dims(((x - x0_f) * (y1_f - y) * (z - z0_f) * x0_valid * y1_valid * z0_valid), 1) w_z1_y1_x0 = tf.expand_dims(((x1_f - x) * (y - y0_f) * (z - z0_f) * x1_valid * y0_valid * z0_valid), 1) w_z1_y1_x1 = tf.expand_dims(((x - x0_f) * (y - y0_f) * (z - z0_f) * x0_valid * y0_valid * z0_valid), 1) output = tf.add_n([ w_z0_y0_x0 * i_z0_y0_x0, w_z0_y0_x1 * i_z0_y0_x1, w_z0_y1_x0 * i_z0_y1_x0, w_z0_y1_x1 * i_z0_y1_x1, w_z1_y0_x0 * i_z1_y0_x0, w_z1_y0_x1 * i_z1_y0_x1, w_z1_y1_x0 * i_z1_y1_x0, w_z1_y1_x1 * i_z1_y1_x1 ]) return output def _meshgrid(depth, height, width, z_near, z_far): with tf.variable_scope('_meshgrid'): x_t = tf.reshape( tf.tile(tf.linspace(-1.0, 1.0, width), [height * depth]), [depth, height, width]) y_t = tf.reshape( tf.tile(tf.linspace(-1.0, 1.0, height), [width * depth]), [depth, width, height]) y_t = tf.transpose(y_t, [0, 2, 1]) sample_grid = tf.tile( tf.linspace(float(z_near), float(z_far), depth), [width * height]) z_t = tf.reshape(sample_grid, [height, width, depth]) z_t = tf.transpose(z_t, [2, 0, 1]) z_t = 1 / z_t d_t = 1 / z_t x_t /= z_t y_t /= z_t x_t_flat = tf.reshape(x_t, (1, -1)) y_t_flat = tf.reshape(y_t, (1, -1)) d_t_flat = tf.reshape(d_t, (1, -1)) ones = tf.ones_like(x_t_flat) grid = tf.concat([d_t_flat, y_t_flat, x_t_flat, ones], 0) return grid def _transform(theta, input_dim, out_size, z_near, z_far): with tf.variable_scope('_transform'): num_batch = input_dim.get_shape().as_list()[0] num_channels = input_dim.get_shape().as_list()[4] theta = tf.reshape(theta, (-1, 4, 4)) theta = tf.cast(theta, 'float32') out_depth = out_size[0] out_height = out_size[1] out_width = out_size[2] grid = _meshgrid(out_depth, out_height, out_width, z_near, z_far) grid = tf.expand_dims(grid, 0) grid = tf.reshape(grid, [-1]) grid = tf.tile(grid, tf.stack([num_batch])) grid = tf.reshape(grid, tf.stack([num_batch, 4, -1])) # Transform A x (x_t', y_t', 1, d_t)^T -> (x_s, y_s, z_s, 1). t_g = tf.matmul(theta, grid) z_s = tf.slice(t_g, [0, 0, 0], [-1, 1, -1]) y_s = tf.slice(t_g, [0, 1, 0], [-1, 1, -1]) x_s = tf.slice(t_g, [0, 2, 0], [-1, 1, -1]) z_s_flat = tf.reshape(z_s, [-1]) y_s_flat = tf.reshape(y_s, [-1]) x_s_flat = tf.reshape(x_s, [-1]) input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, z_s_flat, out_size) output = tf.reshape( input_transformed, tf.stack([num_batch, out_depth, out_height, out_width, num_channels])) return output with tf.variable_scope(name): output = _transform(theta, voxels, out_size, z_near, z_far) return output