Spaces:
Running
Running
# Copyright 2017 Google Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
r"""Evaluates the PIXELDA model. | |
-- Compiles the model for CPU. | |
$ bazel build -c opt third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval | |
-- Compile the model for GPU. | |
$ bazel build -c opt --copt=-mavx --config=cuda \ | |
third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation:pixelda_eval | |
-- Runs the training. | |
$ ./bazel-bin/third_party/tensorflow_models/domain_adaptation/pixel_domain_adaptation/pixelda_eval \ | |
--source_dataset=mnist \ | |
--target_dataset=mnist_m \ | |
--dataset_dir=/tmp/datasets/ \ | |
--alsologtostderr | |
-- Visualize the results. | |
$ bash learning/brain/tensorboard/tensorboard.sh \ | |
--port 2222 --logdir=/tmp/pixelda/ | |
""" | |
from functools import partial | |
import math | |
# Dependency imports | |
import tensorflow as tf | |
from domain_adaptation.datasets import dataset_factory | |
from domain_adaptation.pixel_domain_adaptation import pixelda_model | |
from domain_adaptation.pixel_domain_adaptation import pixelda_preprocess | |
from domain_adaptation.pixel_domain_adaptation import pixelda_utils | |
from domain_adaptation.pixel_domain_adaptation import pixelda_losses | |
from domain_adaptation.pixel_domain_adaptation.hparams import create_hparams | |
slim = tf.contrib.slim | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('master', '', 'BNS name of the TensorFlow master to use.') | |
flags.DEFINE_string('checkpoint_dir', '/tmp/pixelda/', | |
'Directory where the model was written to.') | |
flags.DEFINE_string('eval_dir', '/tmp/pixelda/', | |
'Directory where the results are saved to.') | |
flags.DEFINE_integer('eval_interval_secs', 60, | |
'The frequency, in seconds, with which evaluation is run.') | |
flags.DEFINE_string('target_split_name', 'test', | |
'The name of the train/test split.') | |
flags.DEFINE_string('source_split_name', 'train', 'Split for source dataset.' | |
' Defaults to train.') | |
flags.DEFINE_string('source_dataset', 'mnist', | |
'The name of the source dataset.') | |
flags.DEFINE_string('target_dataset', 'mnist_m', | |
'The name of the target dataset.') | |
flags.DEFINE_string( | |
'dataset_dir', | |
'', # None, | |
'The directory where the datasets can be found.') | |
flags.DEFINE_integer( | |
'num_readers', 4, | |
'The number of parallel readers that read data from the dataset.') | |
flags.DEFINE_integer('num_preprocessing_threads', 4, | |
'The number of threads used to create the batches.') | |
# HParams | |
flags.DEFINE_string('hparams', '', 'Comma separated hyperparameter values') | |
def run_eval(run_dir, checkpoint_dir, hparams): | |
"""Runs the eval loop. | |
Args: | |
run_dir: The directory where eval specific logs are placed | |
checkpoint_dir: The directory where the checkpoints are stored | |
hparams: The hyperparameters struct. | |
Raises: | |
ValueError: if hparams.arch is not recognized. | |
""" | |
for checkpoint_path in slim.evaluation.checkpoints_iterator( | |
checkpoint_dir, FLAGS.eval_interval_secs): | |
with tf.Graph().as_default(): | |
######################### | |
# Preprocess the inputs # | |
######################### | |
target_dataset = dataset_factory.get_dataset( | |
FLAGS.target_dataset, | |
split_name=FLAGS.target_split_name, | |
dataset_dir=FLAGS.dataset_dir) | |
target_images, target_labels = dataset_factory.provide_batch( | |
FLAGS.target_dataset, FLAGS.target_split_name, FLAGS.dataset_dir, | |
FLAGS.num_readers, hparams.batch_size, | |
FLAGS.num_preprocessing_threads) | |
num_target_classes = target_dataset.num_classes | |
target_labels['class'] = tf.argmax(target_labels['classes'], 1) | |
del target_labels['classes'] | |
if hparams.arch not in ['dcgan']: | |
source_dataset = dataset_factory.get_dataset( | |
FLAGS.source_dataset, | |
split_name=FLAGS.source_split_name, | |
dataset_dir=FLAGS.dataset_dir) | |
num_source_classes = source_dataset.num_classes | |
source_images, source_labels = dataset_factory.provide_batch( | |
FLAGS.source_dataset, FLAGS.source_split_name, FLAGS.dataset_dir, | |
FLAGS.num_readers, hparams.batch_size, | |
FLAGS.num_preprocessing_threads) | |
source_labels['class'] = tf.argmax(source_labels['classes'], 1) | |
del source_labels['classes'] | |
if num_source_classes != num_target_classes: | |
raise ValueError( | |
'Input and output datasets must have same number of classes') | |
else: | |
source_images = None | |
source_labels = None | |
#################### | |
# Define the model # | |
#################### | |
end_points = pixelda_model.create_model( | |
hparams, | |
target_images, | |
source_images=source_images, | |
source_labels=source_labels, | |
is_training=False, | |
num_classes=num_target_classes) | |
####################### | |
# Metrics & Summaries # | |
####################### | |
names_to_values, names_to_updates = create_metrics(end_points, | |
source_labels, | |
target_labels, hparams) | |
pixelda_utils.summarize_model(end_points) | |
pixelda_utils.summarize_transferred_grid( | |
end_points['transferred_images'], source_images, name='Transferred') | |
if 'source_images_recon' in end_points: | |
pixelda_utils.summarize_transferred_grid( | |
end_points['source_images_recon'], | |
source_images, | |
name='Source Reconstruction') | |
pixelda_utils.summarize_images(target_images, 'Target') | |
for name, value in names_to_values.iteritems(): | |
tf.summary.scalar(name, value) | |
# Use the entire split by default | |
num_examples = target_dataset.num_samples | |
num_batches = math.ceil(num_examples / float(hparams.batch_size)) | |
global_step = slim.get_or_create_global_step() | |
result = slim.evaluation.evaluate_once( | |
master=FLAGS.master, | |
checkpoint_path=checkpoint_path, | |
logdir=run_dir, | |
num_evals=num_batches, | |
eval_op=names_to_updates.values(), | |
final_op=names_to_values) | |
def to_degrees(log_quaternion_loss): | |
"""Converts a log quaternion distance to an angle. | |
Args: | |
log_quaternion_loss: The log quaternion distance between two | |
unit quaternions (or a batch of pairs of quaternions). | |
Returns: | |
The angle in degrees of the implied angle-axis representation. | |
""" | |
return tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi | |
def create_metrics(end_points, source_labels, target_labels, hparams): | |
"""Create metrics for the model. | |
Args: | |
end_points: A dictionary of end point name to tensor | |
source_labels: Labels for source images. batch_size x 1 | |
target_labels: Labels for target images. batch_size x 1 | |
hparams: The hyperparameters struct. | |
Returns: | |
Tuple of (names_to_values, names_to_updates), dictionaries that map a metric | |
name to its value and update op, respectively | |
""" | |
########################################### | |
# Evaluate the Domain Prediction Accuracy # | |
########################################### | |
batch_size = hparams.batch_size | |
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ | |
('eval/Domain_Accuracy-Transferred'): | |
tf.contrib.metrics.streaming_accuracy( | |
tf.to_int32( | |
tf.round(tf.sigmoid(end_points[ | |
'transferred_domain_logits']))), | |
tf.zeros(batch_size, dtype=tf.int32)), | |
('eval/Domain_Accuracy-Target'): | |
tf.contrib.metrics.streaming_accuracy( | |
tf.to_int32( | |
tf.round(tf.sigmoid(end_points['target_domain_logits']))), | |
tf.ones(batch_size, dtype=tf.int32)) | |
}) | |
################################ | |
# Evaluate the task classifier # | |
################################ | |
if 'source_task_logits' in end_points: | |
metric_name = 'eval/Task_Accuracy-Source' | |
names_to_values[metric_name], names_to_updates[ | |
metric_name] = tf.contrib.metrics.streaming_accuracy( | |
tf.argmax(end_points['source_task_logits'], 1), | |
source_labels['class']) | |
if 'transferred_task_logits' in end_points: | |
metric_name = 'eval/Task_Accuracy-Transferred' | |
names_to_values[metric_name], names_to_updates[ | |
metric_name] = tf.contrib.metrics.streaming_accuracy( | |
tf.argmax(end_points['transferred_task_logits'], 1), | |
source_labels['class']) | |
if 'target_task_logits' in end_points: | |
metric_name = 'eval/Task_Accuracy-Target' | |
names_to_values[metric_name], names_to_updates[ | |
metric_name] = tf.contrib.metrics.streaming_accuracy( | |
tf.argmax(end_points['target_task_logits'], 1), | |
target_labels['class']) | |
########################################################################## | |
# Pose data-specific losses. | |
########################################################################## | |
if 'quaternion' in source_labels.keys(): | |
params = {} | |
params['use_logging'] = False | |
params['batch_size'] = batch_size | |
angle_loss_source = to_degrees( | |
pixelda_losses.log_quaternion_loss_batch(end_points[ | |
'source_quaternion'], source_labels['quaternion'], params)) | |
angle_loss_transferred = to_degrees( | |
pixelda_losses.log_quaternion_loss_batch(end_points[ | |
'transferred_quaternion'], source_labels['quaternion'], params)) | |
angle_loss_target = to_degrees( | |
pixelda_losses.log_quaternion_loss_batch(end_points[ | |
'target_quaternion'], target_labels['quaternion'], params)) | |
metric_name = 'eval/Angle_Loss-Source' | |
names_to_values[metric_name], names_to_updates[ | |
metric_name] = slim.metrics.mean(angle_loss_source) | |
metric_name = 'eval/Angle_Loss-Transferred' | |
names_to_values[metric_name], names_to_updates[ | |
metric_name] = slim.metrics.mean(angle_loss_transferred) | |
metric_name = 'eval/Angle_Loss-Target' | |
names_to_values[metric_name], names_to_updates[ | |
metric_name] = slim.metrics.mean(angle_loss_target) | |
return names_to_values, names_to_updates | |
def main(_): | |
tf.logging.set_verbosity(tf.logging.INFO) | |
hparams = create_hparams(FLAGS.hparams) | |
run_eval( | |
run_dir=FLAGS.eval_dir, | |
checkpoint_dir=FLAGS.checkpoint_dir, | |
hparams=hparams) | |
if __name__ == '__main__': | |
tf.app.run() | |