NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Evaluates the PIXELDA model.
-- Compiles the model for CPU.
$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Compile the model for GPU.
$ bazel build -c opt --copt=-mavx --config=cuda \
third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval
-- Runs the training.
$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \
--source_dataset=mnist \
--target_dataset=mnist_m \
--dataset_dir=/tmp/datasets/ \
--alsologtostderr
-- Visualize the results.
$ bash learning/brain/tensorboard/tensorboard.sh \
--port 2222 --logdir=/tmp/pixelda/
"""
from functools import partial
import math
# Dependency imports
import tensorflow as tf
from domain_adaptation.datasets import dataset_factory
from domain_adaptation.pixel_domain_adaptation import pixelda_model
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess
from domain_adaptation.pixel_domain_adaptation import pixelda_utils
from domain_adaptation.pixel_domain_adaptation import pixelda_losses
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.')
flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/',
'Directory where the model was written to.')
flags.DEFINE_string('eval_dir', '/tmp/pixelda/',
'Directory where the results are saved to.')
flags.DEFINE_integer('eval_interval_secs', 60,
'The frequency, in seconds, with which evaluation is run.')
flags.DEFINE_string('target_split_name', 'test',
'The name of the train/test split.')
flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.'
' Defaults to train.')
flags.DEFINE_string('source_dataset', 'mnist',
'The name of the source dataset.')
flags.DEFINE_string('target_dataset', 'mnist_m',
'The name of the target dataset.')
flags.DEFINE_string(
'dataset_dir',
'', # None,
'The directory where the datasets can be found.')
flags.DEFINE_integer(
'num_readers', 4,
'The number of parallel readers that read data from the dataset.')
flags.DEFINE_integer('num_preprocessing_threads', 4,
'The number of threads used to create the batches.')
# HParams
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values')
def run_eval(run_dir, checkpoint_dir, hparams):
"""Runs the eval loop.
Args:
run_dir: The directory where eval specific logs are placed
checkpoint_dir: The directory where the checkpoints are stored
hparams: The hyperparameters struct.
Raises:
ValueError: if hparams.arch is not recognized.
"""
for checkpoint_path in slim.evaluation.checkpoints_iterator(
checkpoint_dir, FLAGS.eval_interval_secs):
with tf.Graph().as_default():
#########################
# Preprocess the inputs #
#########################
target_dataset = dataset_factory.get_dataset(
FLAGS.target_dataset,
split_name=FLAGS.target_split_name,
dataset_dir=FLAGS.dataset_dir)
target_images, target_labels = dataset_factory.provide_batch(
FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
num_target_classes = target_dataset.num_classes
target_labels['class'] = tf.argmax(target_labels['classes'], 1)
del target_labels['classes']
if hparams.arch not in ['dcgan']:
source_dataset = dataset_factory.get_dataset(
FLAGS.source_dataset,
split_name=FLAGS.source_split_name,
dataset_dir=FLAGS.dataset_dir)
num_source_classes = source_dataset.num_classes
source_images, source_labels = dataset_factory.provide_batch(
FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir,
FLAGS.num_readers, hparams.batch_size,
FLAGS.num_preprocessing_threads)
source_labels['class'] = tf.argmax(source_labels['classes'], 1)
del source_labels['classes']
if num_source_classes != num_target_classes:
raise ValueError(
'Input and output datasets must have same number of classes')
else:
source_images = None
source_labels = None
####################
# Define the model #
####################
end_points = pixelda_model.create_model(
hparams,
target_images,
source_images=source_images,
source_labels=source_labels,
is_training=False,
num_classes=num_target_classes)
#######################
# Metrics & Summaries #
#######################
names_to_values, names_to_updates = create_metrics(end_points,
source_labels,
target_labels, hparams)
pixelda_utils.summarize_model(end_points)
pixelda_utils.summarize_transferred_grid(
end_points['transferred_images'], source_images, name='Transferred')
if 'source_images_recon' in end_points:
pixelda_utils.summarize_transferred_grid(
end_points['source_images_recon'],
source_images,
name='Source Reconstruction')
pixelda_utils.summarize_images(target_images, 'Target')
for name, value in names_to_values.iteritems():
tf.summary.scalar(name, value)
# Use the entire split by default
num_examples = target_dataset.num_samples
num_batches = math.ceil(num_examples / float(hparams.batch_size))
global_step = slim.get_or_create_global_step()
result = slim.evaluation.evaluate_once(
master=FLAGS.master,
checkpoint_path=checkpoint_path,
logdir=run_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
final_op=names_to_values)
def to_degrees(log_quaternion_loss):
"""Converts a log quaternion distance to an angle.
Args:
log_quaternion_loss: The log quaternion distance between two
unit quaternions (or a batch of pairs of quaternions).
Returns:
The angle in degrees of the implied angle-axis representation.
"""
return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
def create_metrics(end_points, source_labels, target_labels, hparams):
"""Create metrics for the model.
Args:
end_points: A dictionary of end point name to tensor
source_labels: Labels for source images. batch_size x 1
target_labels: Labels for target images. batch_size x 1
hparams: The hyperparameters struct.
Returns:
Tuple of (names_to_values, names_to_updates), dictionaries that map a metric
name to its value and update op, respectively
"""
###########################################
# Evaluate the Domain Prediction Accuracy #
###########################################
batch_size = hparams.batch_size
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
('eval/Domain_Accuracy-Transferred'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points[
'transferred_domain_logits']))),
tf.zeros(batch_size, dtype=tf.int32)),
('eval/Domain_Accuracy-Target'):
tf.contrib.metrics.streaming_accuracy(
tf.to_int32(
tf.round(tf.sigmoid(end_points['target_domain_logits']))),
tf.ones(batch_size, dtype=tf.int32))
})
################################
# Evaluate the task classifier #
################################
if 'source_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['source_task_logits'], 1),
source_labels['class'])
if 'transferred_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['transferred_task_logits'], 1),
source_labels['class'])
if 'target_task_logits' in end_points:
metric_name = 'eval/Task_Accuracy-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = tf.contrib.metrics.streaming_accuracy(
tf.argmax(end_points['target_task_logits'], 1),
target_labels['class'])
##########################################################################
# Pose data-specific losses.
##########################################################################
if 'quaternion' in source_labels.keys():
params = {}
params['use_logging'] = False
params['batch_size'] = batch_size
angle_loss_source = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'source_quaternion'], source_labels['quaternion'], params))
angle_loss_transferred = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'transferred_quaternion'], source_labels['quaternion'], params))
angle_loss_target = to_degrees(
pixelda_losses.log_quaternion_loss_batch(end_points[
'target_quaternion'], target_labels['quaternion'], params))
metric_name = 'eval/Angle_Loss-Source'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_source)
metric_name = 'eval/Angle_Loss-Transferred'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_transferred)
metric_name = 'eval/Angle_Loss-Target'
names_to_values[metric_name], names_to_updates[
metric_name] = slim.metrics.mean(angle_loss_target)
return names_to_values, names_to_updates
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
hparams = create_hparams(FLAGS.hparams)
run_eval(
run_dir=FLAGS.eval_dir,
checkpoint_dir=FLAGS.checkpoint_dir,
hparams=hparams)
if __name__ == '__main__':
tf.app.run()