# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helper functions for pretraining (rotator) as described in PTN paper.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import numpy as np from six.moves import xrange import tensorflow as tf import input_generator import losses import metrics import utils from nets import deeprotator_factory slim = tf.contrib.slim def _get_data_from_provider(inputs, batch_size, split_name): """Returns dictionary of batch input data processed by tf.train.batch.""" images, masks = tf.train.batch( [inputs['image'], inputs['mask']], batch_size=batch_size, num_threads=8, capacity=8 * batch_size, name='batching_queues/%s' % (split_name)) outputs = dict() outputs['images'] = images outputs['masks'] = masks outputs['num_samples'] = inputs['num_samples'] return outputs def get_inputs(dataset_dir, dataset_name, split_name, batch_size, image_size, is_training): """Loads the given dataset and split.""" del image_size # Unused with tf.variable_scope('data_loading_%s/%s' % (dataset_name, split_name)): common_queue_min = 50 common_queue_capacity = 256 num_readers = 4 inputs = input_generator.get( dataset_dir, dataset_name, split_name, shuffle=is_training, num_readers=num_readers, common_queue_min=common_queue_min, common_queue_capacity=common_queue_capacity) return _get_data_from_provider(inputs, batch_size, split_name) def preprocess(raw_inputs, step_size): """Selects the subset of viewpoints to train on.""" shp = raw_inputs['images'].get_shape().as_list() quantity = shp[0] num_views = shp[1] image_size = shp[2] del image_size # Unused batch_rot = np.zeros((quantity, 3), dtype=np.float32) inputs = dict() for n in xrange(step_size + 1): inputs['images_%d' % n] = [] inputs['masks_%d' % n] = [] for n in xrange(quantity): view_in = np.random.randint(0, num_views) rng_rot = np.random.randint(0, 2) if step_size == 1: rng_rot = np.random.randint(0, 3) delta = 0 if rng_rot == 0: delta = -1 batch_rot[n, 2] = 1 elif rng_rot == 1: delta = 1 batch_rot[n, 0] = 1 else: delta = 0 batch_rot[n, 1] = 1 inputs['images_0'].append(raw_inputs['images'][n, view_in, :, :, :]) inputs['masks_0'].append(raw_inputs['masks'][n, view_in, :, :, :]) view_out = view_in for k in xrange(1, step_size + 1): view_out += delta if view_out >= num_views: view_out = 0 if view_out < 0: view_out = num_views - 1 inputs['images_%d' % k].append(raw_inputs['images'][n, view_out, :, :, :]) inputs['masks_%d' % k].append(raw_inputs['masks'][n, view_out, :, :, :]) for n in xrange(step_size + 1): inputs['images_%d' % n] = tf.stack(inputs['images_%d' % n]) inputs['masks_%d' % n] = tf.stack(inputs['masks_%d' % n]) inputs['actions'] = tf.constant(batch_rot, dtype=tf.float32) return inputs def get_init_fn(scopes, params): """Initialization assignment operator function used while training.""" if not params.init_model: return None is_trainable = lambda x: x in tf.trainable_variables() var_list = [] for scope in scopes: var_list.extend( filter(is_trainable, tf.contrib.framework.get_model_variables(scope))) init_assign_op, init_feed_dict = slim.assign_from_checkpoint( params.init_model, var_list) def init_assign_function(sess): sess.run(init_assign_op, init_feed_dict) return init_assign_function def get_model_fn(params, is_training, reuse=False): return deeprotator_factory.get(params, is_training, reuse) def get_regularization_loss(scopes, params): return losses.regularization_loss(scopes, params) def get_loss(inputs, outputs, params): """Computes the rotator loss.""" g_loss = tf.zeros(dtype=tf.float32, shape=[]) if hasattr(params, 'image_weight'): g_loss += losses.add_rotator_image_loss(inputs, outputs, params.step_size, params.image_weight) if hasattr(params, 'mask_weight'): g_loss += losses.add_rotator_mask_loss(inputs, outputs, params.step_size, params.mask_weight) slim.summaries.add_scalar_summary( g_loss, 'rotator_loss', prefix='losses') return g_loss def get_train_op_for_scope(loss, optimizer, scopes, params): """Train operation function for the given scope used file training.""" is_trainable = lambda x: x in tf.trainable_variables() var_list = [] update_ops = [] for scope in scopes: var_list.extend( filter(is_trainable, tf.contrib.framework.get_model_variables(scope))) update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)) return slim.learning.create_train_op( loss, optimizer, update_ops=update_ops, variables_to_train=var_list, clip_gradient_norm=params.clip_gradient_norm) def get_metrics(inputs, outputs, params): """Aggregate the metrics for rotator model. Args: inputs: Input dictionary of the rotator model. outputs: Output dictionary returned by the rotator model. params: Hyperparameters of the rotator model. Returns: names_to_values: metrics->values (dict). names_to_updates: metrics->ops (dict). """ names_to_values = dict() names_to_updates = dict() tmp_values, tmp_updates = metrics.add_image_pred_metrics( inputs, outputs, params.num_views, 3*params.image_size**2) names_to_values.update(tmp_values) names_to_updates.update(tmp_updates) tmp_values, tmp_updates = metrics.add_mask_pred_metrics( inputs, outputs, params.num_views, params.image_size**2) names_to_values.update(tmp_values) names_to_updates.update(tmp_updates) for name, value in names_to_values.iteritems(): slim.summaries.add_scalar_summary( value, name, prefix='eval', print_summary=True) return names_to_values, names_to_updates def write_disk_grid(global_step, summary_freq, log_dir, input_images, output_images, pred_images, pred_masks): """Function called by TF to save the prediction periodically.""" def write_grid(grid, global_step): """Native python function to call for writing images to files.""" if global_step % summary_freq == 0: img_path = os.path.join(log_dir, '%s.jpg' % str(global_step)) utils.save_image(grid, img_path) return 0 grid = _build_image_grid(input_images, output_images, pred_images, pred_masks) slim.summaries.add_image_summary( tf.expand_dims(grid, axis=0), name='grid_vis') save_op = tf.py_func(write_grid, [grid, global_step], [tf.int64], 'write_grid')[0] return save_op def _build_image_grid(input_images, output_images, pred_images, pred_masks): """Builds a grid image by concatenating the input images.""" quantity = input_images.get_shape().as_list()[0] for row in xrange(int(quantity / 4)): for col in xrange(4): index = row * 4 + col input_img_ = input_images[index, :, :, :] output_img_ = output_images[index, :, :, :] pred_img_ = pred_images[index, :, :, :] pred_mask_ = tf.tile(pred_masks[index, :, :, :], [1, 1, 3]) if col == 0: tmp_ = tf.concat([input_img_, output_img_, pred_img_, pred_mask_], 1) ## to the right else: tmp_ = tf.concat([tmp_, input_img_, output_img_, pred_img_, pred_mask_], 1) if row == 0: out_grid = tmp_ else: out_grid = tf.concat([out_grid, tmp_], 0) return out_grid