NCTC / models /research /ptn /model_ptn.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
8.29 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementations for Im2Vox PTN (NIPS16) model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from six.moves import xrange
import tensorflow as tf
import losses
import metrics
import model_voxel_generation
import utils
from nets import im2vox_factory
slim = tf.contrib.slim
class model_PTN(model_voxel_generation.Im2Vox): # pylint:disable=invalid-name
"""Inherits the generic Im2Vox model class and implements the functions."""
def __init__(self, params):
super(model_PTN, self).__init__(params)
# For testing, this selects all views in input
def preprocess_with_all_views(self, raw_inputs):
(quantity, num_views) = raw_inputs['images'].get_shape().as_list()[:2]
inputs = dict()
inputs['voxels'] = []
inputs['images_1'] = []
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = []
inputs['matrix_1'] = []
for n in xrange(quantity):
for k in xrange(num_views):
inputs['images_1'].append(raw_inputs['images'][n, k, :, :, :])
inputs['voxels'].append(raw_inputs['voxels'][n, :, :, :, :])
tf_matrix = self.get_transform_matrix(k)
inputs['matrix_%d' % (k + 1)].append(tf_matrix)
inputs['images_1'] = tf.stack(inputs['images_1'])
inputs['voxels'] = tf.stack(inputs['voxels'])
for k in xrange(num_views):
inputs['matrix_%d' % (k + 1)] = tf.stack(inputs['matrix_%d' % (k + 1)])
return inputs
def get_model_fn(self, is_training=True, reuse=False, run_projection=True):
return im2vox_factory.get(self._params, is_training, reuse, run_projection)
def get_regularization_loss(self, scopes):
return losses.regularization_loss(scopes, self._params)
def get_loss(self, inputs, outputs):
"""Computes the loss used for PTN paper (projection + volume loss)."""
g_loss = tf.zeros(dtype=tf.float32, shape=[])
if self._params.proj_weight:
g_loss += losses.add_volume_proj_loss(
inputs, outputs, self._params.step_size, self._params.proj_weight)
if self._params.volume_weight:
g_loss += losses.add_volume_loss(inputs, outputs, 1,
self._params.volume_weight)
slim.summaries.add_scalar_summary(g_loss, 'im2vox_loss', prefix='losses')
return g_loss
def get_metrics(self, inputs, outputs):
"""Aggregate the metrics for voxel generation model.
Args:
inputs: Input dictionary of the voxel generation model.
outputs: Output dictionary returned by the voxel generation model.
Returns:
names_to_values: metrics->values (dict).
names_to_updates: metrics->ops (dict).
"""
names_to_values = dict()
names_to_updates = dict()
tmp_values, tmp_updates = metrics.add_volume_iou_metrics(inputs, outputs)
names_to_values.update(tmp_values)
names_to_updates.update(tmp_updates)
for name, value in names_to_values.iteritems():
slim.summaries.add_scalar_summary(
value, name, prefix='eval', print_summary=True)
return names_to_values, names_to_updates
def write_disk_grid(self,
global_step,
log_dir,
input_images,
gt_projs,
pred_projs,
input_voxels=None,
output_voxels=None):
"""Function called by TF to save the prediction periodically."""
summary_freq = self._params.save_every
def write_grid(input_images, gt_projs, pred_projs, global_step,
input_voxels, output_voxels):
"""Native python function to call for writing images to files."""
grid = _build_image_grid(
input_images,
gt_projs,
pred_projs,
input_voxels=input_voxels,
output_voxels=output_voxels)
if global_step % summary_freq == 0:
img_path = os.path.join(log_dir, '%s.jpg' % str(global_step))
utils.save_image(grid, img_path)
return grid
save_op = tf.py_func(write_grid, [
input_images, gt_projs, pred_projs, global_step, input_voxels,
output_voxels
], [tf.uint8], 'write_grid')[0]
slim.summaries.add_image_summary(
tf.expand_dims(save_op, axis=0), name='grid_vis')
return save_op
def get_transform_matrix(self, view_out):
"""Get the 4x4 Perspective Transfromation matrix used for PTN."""
num_views = self._params.num_views
focal_length = self._params.focal_length
focal_range = self._params.focal_range
phi = 30
theta_interval = 360.0 / num_views
theta = theta_interval * view_out
# pylint: disable=invalid-name
camera_matrix = np.zeros((4, 4), dtype=np.float32)
intrinsic_matrix = np.eye(4, dtype=np.float32)
extrinsic_matrix = np.eye(4, dtype=np.float32)
sin_phi = np.sin(float(phi) / 180.0 * np.pi)
cos_phi = np.cos(float(phi) / 180.0 * np.pi)
sin_theta = np.sin(float(-theta) / 180.0 * np.pi)
cos_theta = np.cos(float(-theta) / 180.0 * np.pi)
rotation_azimuth = np.zeros((3, 3), dtype=np.float32)
rotation_azimuth[0, 0] = cos_theta
rotation_azimuth[2, 2] = cos_theta
rotation_azimuth[0, 2] = -sin_theta
rotation_azimuth[2, 0] = sin_theta
rotation_azimuth[1, 1] = 1.0
## rotation axis -- x
rotation_elevation = np.zeros((3, 3), dtype=np.float32)
rotation_elevation[0, 0] = cos_phi
rotation_elevation[0, 1] = sin_phi
rotation_elevation[1, 0] = -sin_phi
rotation_elevation[1, 1] = cos_phi
rotation_elevation[2, 2] = 1.0
rotation_matrix = np.matmul(rotation_azimuth, rotation_elevation)
displacement = np.zeros((3, 1), dtype=np.float32)
displacement[0, 0] = float(focal_length) + float(focal_range) / 2.0
displacement = np.matmul(rotation_matrix, displacement)
extrinsic_matrix[0:3, 0:3] = rotation_matrix
extrinsic_matrix[0:3, 3:4] = -displacement
intrinsic_matrix[2, 2] = 1.0 / float(focal_length)
intrinsic_matrix[1, 1] = 1.0 / float(focal_length)
camera_matrix = np.matmul(extrinsic_matrix, intrinsic_matrix)
return camera_matrix
def _build_image_grid(input_images,
gt_projs,
pred_projs,
input_voxels,
output_voxels,
vis_size=128):
"""Builds a grid image by concatenating the input images."""
quantity = input_images.shape[0]
for row in xrange(int(quantity / 3)):
for col in xrange(3):
index = row * 3 + col
input_img_ = utils.resize_image(input_images[index, :, :, :], vis_size,
vis_size)
gt_proj_ = utils.resize_image(gt_projs[index, :, :, :], vis_size,
vis_size)
pred_proj_ = utils.resize_image(pred_projs[index, :, :, :], vis_size,
vis_size)
gt_voxel_vis = utils.resize_image(
utils.display_voxel(input_voxels[index, :, :, :, 0]), vis_size,
vis_size)
pred_voxel_vis = utils.resize_image(
utils.display_voxel(output_voxels[index, :, :, :, 0]), vis_size,
vis_size)
if col == 0:
tmp_ = np.concatenate(
[input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis], 1)
else:
tmp_ = np.concatenate([
tmp_, input_img_, gt_proj_, pred_proj_, gt_voxel_vis, pred_voxel_vis
], 1)
if row == 0:
out_grid = tmp_
else:
out_grid = np.concatenate([out_grid, tmp_], 0)
return out_grid