Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Contains evaluation plan for the Im2vox model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import tensorflow as tf | |
from tensorflow import app | |
import model_ptn | |
flags = tf.app.flags | |
slim = tf.contrib.slim | |
flags.DEFINE_string('inp_dir', | |
'', | |
'Directory path containing the input data (tfrecords).') | |
flags.DEFINE_string( | |
'dataset_name', 'shapenet_chair', | |
'Dataset name that is to be used for training and evaluation.') | |
flags.DEFINE_integer('z_dim', 512, '') | |
flags.DEFINE_integer('f_dim', 64, '') | |
flags.DEFINE_integer('fc_dim', 1024, '') | |
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.') | |
flags.DEFINE_integer('image_size', 64, | |
'Input images dimension (pixels) - width & height.') | |
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.') | |
flags.DEFINE_integer('step_size', 24, '') | |
flags.DEFINE_integer('batch_size', 1, 'Batch size while training.') | |
flags.DEFINE_float('focal_length', 0.866, '') | |
flags.DEFINE_float('focal_range', 1.732, '') | |
flags.DEFINE_string('encoder_name', 'ptn_encoder', | |
'Name of the encoder network being used.') | |
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder', | |
'Name of the decoder network being used.') | |
flags.DEFINE_string('projector_name', 'ptn_projector', | |
'Name of the projector network being used.') | |
# Save options | |
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn/eval/', | |
'Directory path for saving trained models and other data.') | |
flags.DEFINE_string('model_name', 'ptn_proj', | |
'Name of the model used in naming the TF job. Must be different for each run.') | |
flags.DEFINE_string('eval_set', 'val', 'Data partition to form evaluation on.') | |
# Optimization | |
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.') | |
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.') | |
flags.DEFINE_float('viewpoint_weight', 1, | |
'Weighting factor for viewpoint loss.') | |
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.') | |
flags.DEFINE_float('weight_decay', 0.001, '') | |
flags.DEFINE_float('clip_gradient_norm', 0, '') | |
# Summary | |
flags.DEFINE_integer('save_summaries_secs', 15, '') | |
flags.DEFINE_integer('eval_interval_secs', 60 * 5, '') | |
# Distribution | |
flags.DEFINE_string('master', '', '') | |
FLAGS = flags.FLAGS | |
def main(argv=()): | |
del argv # Unused. | |
eval_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train') | |
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, | |
'eval_%s' % FLAGS.eval_set) | |
if not os.path.exists(eval_dir): | |
os.makedirs(eval_dir) | |
if not os.path.exists(log_dir): | |
os.makedirs(log_dir) | |
g = tf.Graph() | |
with g.as_default(): | |
eval_params = FLAGS | |
eval_params.batch_size = 1 | |
eval_params.step_size = FLAGS.num_views | |
########### | |
## model ## | |
########### | |
model = model_ptn.model_PTN(eval_params) | |
########## | |
## data ## | |
########## | |
eval_data = model.get_inputs( | |
FLAGS.inp_dir, | |
FLAGS.dataset_name, | |
eval_params.eval_set, | |
eval_params.batch_size, | |
eval_params.image_size, | |
eval_params.vox_size, | |
is_training=False) | |
inputs = model.preprocess_with_all_views(eval_data) | |
############## | |
## model_fn ## | |
############## | |
model_fn = model.get_model_fn(is_training=False, run_projection=False) | |
outputs = model_fn(inputs) | |
############# | |
## metrics ## | |
############# | |
names_to_values, names_to_updates = model.get_metrics(inputs, outputs) | |
del names_to_values | |
################ | |
## evaluation ## | |
################ | |
num_batches = eval_data['num_samples'] | |
slim.evaluation.evaluation_loop( | |
master=FLAGS.master, | |
checkpoint_dir=eval_dir, | |
logdir=log_dir, | |
num_evals=num_batches, | |
eval_op=names_to_updates.values(), | |
eval_interval_secs=FLAGS.eval_interval_secs) | |
if __name__ == '__main__': | |
app.run() | |