Spaces:
Running
Running
# Copyright 2017 Google Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Defines the various loss functions in use by the PIXELDA model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
# Dependency imports | |
import tensorflow as tf | |
slim = tf.contrib.slim | |
def add_domain_classifier_losses(end_points, hparams): | |
"""Adds losses related to the domain-classifier. | |
Args: | |
end_points: A map of network end point names to `Tensors`. | |
hparams: The hyperparameters struct. | |
Returns: | |
loss: A `Tensor` representing the total task-classifier loss. | |
""" | |
if hparams.domain_loss_weight == 0: | |
tf.logging.info( | |
'Domain classifier loss weight is 0, so not creating losses.') | |
return 0 | |
# The domain prediction loss is minimized with respect to the domain | |
# classifier features only. Its aim is to predict the domain of the images. | |
# Note: 1 = 'real image' label, 0 = 'fake image' label | |
transferred_domain_loss = tf.losses.sigmoid_cross_entropy( | |
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']), | |
logits=end_points['transferred_domain_logits']) | |
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss) | |
target_domain_loss = tf.losses.sigmoid_cross_entropy( | |
multi_class_labels=tf.ones_like(end_points['target_domain_logits']), | |
logits=end_points['target_domain_logits']) | |
tf.summary.scalar('Domain_loss_target', target_domain_loss) | |
# Compute the total domain loss: | |
total_domain_loss = transferred_domain_loss + target_domain_loss | |
total_domain_loss *= hparams.domain_loss_weight | |
tf.summary.scalar('Domain_loss_total', total_domain_loss) | |
return total_domain_loss | |
def log_quaternion_loss_batch(predictions, labels, params): | |
"""A helper function to compute the error between quaternions. | |
Args: | |
predictions: A Tensor of size [batch_size, 4]. | |
labels: A Tensor of size [batch_size, 4]. | |
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. | |
Returns: | |
A Tensor of size [batch_size], denoting the error between the quaternions. | |
""" | |
use_logging = params['use_logging'] | |
assertions = [] | |
if use_logging: | |
assertions.append( | |
tf.Assert( | |
tf.reduce_all( | |
tf.less( | |
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1), | |
1e-4)), | |
['The l2 norm of each prediction quaternion vector should be 1.'])) | |
assertions.append( | |
tf.Assert( | |
tf.reduce_all( | |
tf.less( | |
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)), | |
['The l2 norm of each label quaternion vector should be 1.'])) | |
with tf.control_dependencies(assertions): | |
product = tf.multiply(predictions, labels) | |
internal_dot_products = tf.reduce_sum(product, [1]) | |
if use_logging: | |
internal_dot_products = tf.Print(internal_dot_products, [ | |
internal_dot_products, | |
tf.shape(internal_dot_products) | |
], 'internal_dot_products:') | |
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products)) | |
return logcost | |
def log_quaternion_loss(predictions, labels, params): | |
"""A helper function to compute the mean error between batches of quaternions. | |
The caller is expected to add the loss to the graph. | |
Args: | |
predictions: A Tensor of size [batch_size, 4]. | |
labels: A Tensor of size [batch_size, 4]. | |
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'. | |
Returns: | |
A Tensor of size 1, denoting the mean error between batches of quaternions. | |
""" | |
use_logging = params['use_logging'] | |
logcost = log_quaternion_loss_batch(predictions, labels, params) | |
logcost = tf.reduce_sum(logcost, [0]) | |
batch_size = params['batch_size'] | |
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss') | |
if use_logging: | |
logcost = tf.Print( | |
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print') | |
return logcost | |
def _quaternion_loss(labels, predictions, weight, batch_size, domain, | |
add_summaries): | |
"""Creates a Quaternion Loss. | |
Args: | |
labels: The true quaternions. | |
predictions: The predicted quaternions. | |
weight: A scalar weight. | |
batch_size: The size of the batches. | |
domain: The name of the domain from which the labels were taken. | |
add_summaries: Whether or not to add summaries for the losses. | |
Returns: | |
A `Tensor` representing the loss. | |
""" | |
assert domain in ['Source', 'Transferred'] | |
params = {'use_logging': False, 'batch_size': batch_size} | |
loss = weight * log_quaternion_loss(labels, predictions, params) | |
if add_summaries: | |
assert_op = tf.Assert(tf.is_finite(loss), [loss]) | |
with tf.control_dependencies([assert_op]): | |
tf.summary.histogram( | |
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses') | |
tf.summary.scalar( | |
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses') | |
return loss | |
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams, | |
add_summaries=False): | |
"""Adds losses related to the task-classifier. | |
Args: | |
end_points: A map of network end point names to `Tensors`. | |
source_labels: A dictionary of output labels to `Tensors`. | |
num_classes: The number of classes used by the classifier. | |
hparams: The hyperparameters struct. | |
add_summaries: Whether or not to add the summaries. | |
Returns: | |
loss: A `Tensor` representing the total task-classifier loss. | |
""" | |
# TODO(ddohan): Make sure the l2 regularization is added to the loss | |
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes) | |
total_loss = 0 | |
if 'source_task_logits' in end_points: | |
loss = tf.losses.softmax_cross_entropy( | |
onehot_labels=one_hot_labels, | |
logits=end_points['source_task_logits'], | |
weights=hparams.source_task_loss_weight) | |
if add_summaries: | |
tf.summary.scalar('Task_Classifier_Loss_Source', loss) | |
total_loss += loss | |
if 'transferred_task_logits' in end_points: | |
loss = tf.losses.softmax_cross_entropy( | |
onehot_labels=one_hot_labels, | |
logits=end_points['transferred_task_logits'], | |
weights=hparams.transferred_task_loss_weight) | |
if add_summaries: | |
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss) | |
total_loss += loss | |
######################### | |
# Pose specific losses. # | |
######################### | |
if 'quaternion' in source_labels: | |
total_loss += _quaternion_loss( | |
source_labels['quaternion'], | |
end_points['source_quaternion'], | |
hparams.source_pose_weight, | |
hparams.batch_size, | |
'Source', | |
add_summaries) | |
total_loss += _quaternion_loss( | |
source_labels['quaternion'], | |
end_points['transferred_quaternion'], | |
hparams.transferred_pose_weight, | |
hparams.batch_size, | |
'Transferred', | |
add_summaries) | |
if add_summaries: | |
tf.summary.scalar('Task_Loss_Total', total_loss) | |
return total_loss | |
def _transferred_similarity_loss(reconstructions, | |
source_images, | |
weight=1.0, | |
method='mse', | |
max_diff=0.4, | |
name='similarity'): | |
"""Computes a loss encouraging similarity between source and transferred. | |
Args: | |
reconstructions: A `Tensor` of shape [batch_size, height, width, channels] | |
source_images: A `Tensor` of shape [batch_size, height, width, channels]. | |
weight: Multiple similarity loss by this weight before returning | |
method: One of: | |
mpse = Mean Pairwise Squared Error | |
mse = Mean Squared Error | |
hinged_mse = Computes the mean squared error using squared differences | |
greater than hparams.transferred_similarity_max_diff | |
hinged_mae = Computes the mean absolute error using absolute | |
differences greater than hparams.transferred_similarity_max_diff. | |
max_diff: Maximum unpenalized difference for hinged losses | |
name: Identifying name to use for creating summaries | |
Returns: | |
A `Tensor` representing the transferred similarity loss. | |
Raises: | |
ValueError: if `method` is not recognized. | |
""" | |
if weight == 0: | |
return 0 | |
source_channels = source_images.shape.as_list()[-1] | |
reconstruction_channels = reconstructions.shape.as_list()[-1] | |
# Convert grayscale source to RGB if target is RGB | |
if source_channels == 1 and reconstruction_channels != 1: | |
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels]) | |
if reconstruction_channels == 1 and source_channels != 1: | |
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels]) | |
if method == 'mpse': | |
reconstruction_similarity_loss_fn = ( | |
tf.contrib.losses.mean_pairwise_squared_error) | |
elif method == 'masked_mpse': | |
def masked_mpse(predictions, labels, weight): | |
"""Masked mpse assuming we have a depth to create a mask from.""" | |
assert labels.shape.as_list()[-1] == 4 | |
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99)) | |
mask = tf.tile(mask, [1, 1, 1, 4]) | |
predictions *= mask | |
labels *= mask | |
tf.image_summary('masked_pred', predictions) | |
tf.image_summary('masked_label', labels) | |
return tf.contrib.losses.mean_pairwise_squared_error( | |
predictions, labels, weight) | |
reconstruction_similarity_loss_fn = masked_mpse | |
elif method == 'mse': | |
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error | |
elif method == 'hinged_mse': | |
def hinged_mse(predictions, labels, weight): | |
diffs = tf.square(predictions - labels) | |
diffs = tf.maximum(0.0, diffs - max_diff) | |
return tf.reduce_mean(diffs) * weight | |
reconstruction_similarity_loss_fn = hinged_mse | |
elif method == 'hinged_mae': | |
def hinged_mae(predictions, labels, weight): | |
diffs = tf.abs(predictions - labels) | |
diffs = tf.maximum(0.0, diffs - max_diff) | |
return tf.reduce_mean(diffs) * weight | |
reconstruction_similarity_loss_fn = hinged_mae | |
else: | |
raise ValueError('Unknown reconstruction loss %s' % method) | |
reconstruction_similarity_loss = reconstruction_similarity_loss_fn( | |
reconstructions, source_images, weight) | |
name = '%s_Similarity_(%s)' % (name, method) | |
tf.summary.scalar(name, reconstruction_similarity_loss) | |
return reconstruction_similarity_loss | |
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes): | |
"""Configures the loss function which runs during the g-step. | |
Args: | |
source_images: A `Tensor` of shape [batch_size, height, width, channels]. | |
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys | |
are 'class' and 'quaternion'. | |
end_points: A map of the network end points. | |
hparams: The hyperparameters struct. | |
num_classes: Number of classes for classifier loss | |
Returns: | |
A `Tensor` representing a loss function. | |
Raises: | |
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but | |
hparams.transferred_similarity_loss is invalid. | |
""" | |
generator_loss = 0 | |
################################################################ | |
# Adds a loss which encourages the discriminator probabilities # | |
# to be high (near one). | |
################################################################ | |
# As per the GAN paper, maximize the log probs, instead of minimizing | |
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is | |
# the same thing. | |
style_transfer_loss = tf.losses.sigmoid_cross_entropy( | |
logits=end_points['transferred_domain_logits'], | |
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']), | |
weights=hparams.style_transfer_loss_weight) | |
tf.summary.scalar('Style_transfer_loss', style_transfer_loss) | |
generator_loss += style_transfer_loss | |
# Optimizes the style transfer network to produce transferred images similar | |
# to the source images. | |
generator_loss += _transferred_similarity_loss( | |
end_points['transferred_images'], | |
source_images, | |
weight=hparams.transferred_similarity_loss_weight, | |
method=hparams.transferred_similarity_loss, | |
name='transferred_similarity') | |
# Optimizes the style transfer network to maximize classification accuracy. | |
if source_labels is not None and hparams.task_tower_in_g_step: | |
generator_loss += _add_task_specific_losses( | |
end_points, source_labels, num_classes, | |
hparams) * hparams.task_loss_in_g_weight | |
return generator_loss | |
def d_step_loss(end_points, source_labels, num_classes, hparams): | |
"""Configures the losses during the D-Step. | |
Note that during the D-step, the model optimizes both the domain (binary) | |
classifier and the task classifier. | |
Args: | |
end_points: A map of the network end points. | |
source_labels: A dictionary of output labels to `Tensors`. | |
num_classes: The number of classes used by the classifier. | |
hparams: The hyperparameters struct. | |
Returns: | |
A `Tensor` representing the value of the D-step loss. | |
""" | |
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams) | |
task_classifier_loss = 0 | |
if source_labels is not None: | |
task_classifier_loss = _add_task_specific_losses( | |
end_points, source_labels, num_classes, hparams, add_summaries=True) | |
return domain_classifier_loss + task_classifier_loss | |