NCTCMumbai's picture
Upload 2571 files
0b8359d
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the various loss functions in use by the PIXELDA model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
slim = tf.contrib.slim
def add_domain_classifier_losses(end_points, hparams):
"""Adds losses related to the domain-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
hparams: The hyperparameters struct.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
if hparams.domain_loss_weight == 0:
tf.logging.info(
'Domain classifier loss weight is 0, so not creating losses.')
return 0
# The domain prediction loss is minimized with respect to the domain
# classifier features only. Its aim is to predict the domain of the images.
# Note: 1 = 'real image' label, 0 = 'fake image' label
transferred_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.zeros_like(end_points['transferred_domain_logits']),
logits=end_points['transferred_domain_logits'])
tf.summary.scalar('Domain_loss_transferred', transferred_domain_loss)
target_domain_loss = tf.losses.sigmoid_cross_entropy(
multi_class_labels=tf.ones_like(end_points['target_domain_logits']),
logits=end_points['target_domain_logits'])
tf.summary.scalar('Domain_loss_target', target_domain_loss)
# Compute the total domain loss:
total_domain_loss = transferred_domain_loss + target_domain_loss
total_domain_loss *= hparams.domain_loss_weight
tf.summary.scalar('Domain_loss_total', total_domain_loss)
return total_domain_loss
def log_quaternion_loss_batch(predictions, labels, params):
"""A helper function to compute the error between quaternions.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size [batch_size], denoting the error between the quaternions.
"""
use_logging = params['use_logging']
assertions = []
if use_logging:
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
1e-4)),
['The l2 norm of each prediction quaternion vector should be 1.']))
assertions.append(
tf.Assert(
tf.reduce_all(
tf.less(
tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
['The l2 norm of each label quaternion vector should be 1.']))
with tf.control_dependencies(assertions):
product = tf.multiply(predictions, labels)
internal_dot_products = tf.reduce_sum(product, [1])
if use_logging:
internal_dot_products = tf.Print(internal_dot_products, [
internal_dot_products,
tf.shape(internal_dot_products)
], 'internal_dot_products:')
logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
return logcost
def log_quaternion_loss(predictions, labels, params):
"""A helper function to compute the mean error between batches of quaternions.
The caller is expected to add the loss to the graph.
Args:
predictions: A Tensor of size [batch_size, 4].
labels: A Tensor of size [batch_size, 4].
params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
Returns:
A Tensor of size 1, denoting the mean error between batches of quaternions.
"""
use_logging = params['use_logging']
logcost = log_quaternion_loss_batch(predictions, labels, params)
logcost = tf.reduce_sum(logcost, [0])
batch_size = params['batch_size']
logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
if use_logging:
logcost = tf.Print(
logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
return logcost
def _quaternion_loss(labels, predictions, weight, batch_size, domain,
add_summaries):
"""Creates a Quaternion Loss.
Args:
labels: The true quaternions.
predictions: The predicted quaternions.
weight: A scalar weight.
batch_size: The size of the batches.
domain: The name of the domain from which the labels were taken.
add_summaries: Whether or not to add summaries for the losses.
Returns:
A `Tensor` representing the loss.
"""
assert domain in ['Source', 'Transferred']
params = {'use_logging': False, 'batch_size': batch_size}
loss = weight * log_quaternion_loss(labels, predictions, params)
if add_summaries:
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.histogram(
'Log_Quaternion_Loss_%s' % domain, loss, collections='losses')
tf.summary.scalar(
'Task_Quaternion_Loss_%s' % domain, loss, collections='losses')
return loss
def _add_task_specific_losses(end_points, source_labels, num_classes, hparams,
add_summaries=False):
"""Adds losses related to the task-classifier.
Args:
end_points: A map of network end point names to `Tensors`.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
add_summaries: Whether or not to add the summaries.
Returns:
loss: A `Tensor` representing the total task-classifier loss.
"""
# TODO(ddohan): Make sure the l2 regularization is added to the loss
one_hot_labels = slim.one_hot_encoding(source_labels['class'], num_classes)
total_loss = 0
if 'source_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['source_task_logits'],
weights=hparams.source_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Source', loss)
total_loss += loss
if 'transferred_task_logits' in end_points:
loss = tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels,
logits=end_points['transferred_task_logits'],
weights=hparams.transferred_task_loss_weight)
if add_summaries:
tf.summary.scalar('Task_Classifier_Loss_Transferred', loss)
total_loss += loss
#########################
# Pose specific losses. #
#########################
if 'quaternion' in source_labels:
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['source_quaternion'],
hparams.source_pose_weight,
hparams.batch_size,
'Source',
add_summaries)
total_loss += _quaternion_loss(
source_labels['quaternion'],
end_points['transferred_quaternion'],
hparams.transferred_pose_weight,
hparams.batch_size,
'Transferred',
add_summaries)
if add_summaries:
tf.summary.scalar('Task_Loss_Total', total_loss)
return total_loss
def _transferred_similarity_loss(reconstructions,
source_images,
weight=1.0,
method='mse',
max_diff=0.4,
name='similarity'):
"""Computes a loss encouraging similarity between source and transferred.
Args:
reconstructions: A `Tensor` of shape [batch_size, height, width, channels]
source_images: A `Tensor` of shape [batch_size, height, width, channels].
weight: Multiple similarity loss by this weight before returning
method: One of:
mpse = Mean Pairwise Squared Error
mse = Mean Squared Error
hinged_mse = Computes the mean squared error using squared differences
greater than hparams.transferred_similarity_max_diff
hinged_mae = Computes the mean absolute error using absolute
differences greater than hparams.transferred_similarity_max_diff.
max_diff: Maximum unpenalized difference for hinged losses
name: Identifying name to use for creating summaries
Returns:
A `Tensor` representing the transferred similarity loss.
Raises:
ValueError: if `method` is not recognized.
"""
if weight == 0:
return 0
source_channels = source_images.shape.as_list()[-1]
reconstruction_channels = reconstructions.shape.as_list()[-1]
# Convert grayscale source to RGB if target is RGB
if source_channels == 1 and reconstruction_channels != 1:
source_images = tf.tile(source_images, [1, 1, 1, reconstruction_channels])
if reconstruction_channels == 1 and source_channels != 1:
reconstructions = tf.tile(reconstructions, [1, 1, 1, source_channels])
if method == 'mpse':
reconstruction_similarity_loss_fn = (
tf.contrib.losses.mean_pairwise_squared_error)
elif method == 'masked_mpse':
def masked_mpse(predictions, labels, weight):
"""Masked mpse assuming we have a depth to create a mask from."""
assert labels.shape.as_list()[-1] == 4
mask = tf.to_float(tf.less(labels[:, :, :, 3:4], 0.99))
mask = tf.tile(mask, [1, 1, 1, 4])
predictions *= mask
labels *= mask
tf.image_summary('masked_pred', predictions)
tf.image_summary('masked_label', labels)
return tf.contrib.losses.mean_pairwise_squared_error(
predictions, labels, weight)
reconstruction_similarity_loss_fn = masked_mpse
elif method == 'mse':
reconstruction_similarity_loss_fn = tf.contrib.losses.mean_squared_error
elif method == 'hinged_mse':
def hinged_mse(predictions, labels, weight):
diffs = tf.square(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mse
elif method == 'hinged_mae':
def hinged_mae(predictions, labels, weight):
diffs = tf.abs(predictions - labels)
diffs = tf.maximum(0.0, diffs - max_diff)
return tf.reduce_mean(diffs) * weight
reconstruction_similarity_loss_fn = hinged_mae
else:
raise ValueError('Unknown reconstruction loss %s' % method)
reconstruction_similarity_loss = reconstruction_similarity_loss_fn(
reconstructions, source_images, weight)
name = '%s_Similarity_(%s)' % (name, method)
tf.summary.scalar(name, reconstruction_similarity_loss)
return reconstruction_similarity_loss
def g_step_loss(source_images, source_labels, end_points, hparams, num_classes):
"""Configures the loss function which runs during the g-step.
Args:
source_images: A `Tensor` of shape [batch_size, height, width, channels].
source_labels: A dictionary of `Tensors` of shape [batch_size]. Valid keys
are 'class' and 'quaternion'.
end_points: A map of the network end points.
hparams: The hyperparameters struct.
num_classes: Number of classes for classifier loss
Returns:
A `Tensor` representing a loss function.
Raises:
ValueError: if hparams.transferred_similarity_loss_weight is non-zero but
hparams.transferred_similarity_loss is invalid.
"""
generator_loss = 0
################################################################
# Adds a loss which encourages the discriminator probabilities #
# to be high (near one).
################################################################
# As per the GAN paper, maximize the log probs, instead of minimizing
# log(1-probs). Since we're minimizing, we'll minimize -log(probs) which is
# the same thing.
style_transfer_loss = tf.losses.sigmoid_cross_entropy(
logits=end_points['transferred_domain_logits'],
multi_class_labels=tf.ones_like(end_points['transferred_domain_logits']),
weights=hparams.style_transfer_loss_weight)
tf.summary.scalar('Style_transfer_loss', style_transfer_loss)
generator_loss += style_transfer_loss
# Optimizes the style transfer network to produce transferred images similar
# to the source images.
generator_loss += _transferred_similarity_loss(
end_points['transferred_images'],
source_images,
weight=hparams.transferred_similarity_loss_weight,
method=hparams.transferred_similarity_loss,
name='transferred_similarity')
# Optimizes the style transfer network to maximize classification accuracy.
if source_labels is not None and hparams.task_tower_in_g_step:
generator_loss += _add_task_specific_losses(
end_points, source_labels, num_classes,
hparams) * hparams.task_loss_in_g_weight
return generator_loss
def d_step_loss(end_points, source_labels, num_classes, hparams):
"""Configures the losses during the D-Step.
Note that during the D-step, the model optimizes both the domain (binary)
classifier and the task classifier.
Args:
end_points: A map of the network end points.
source_labels: A dictionary of output labels to `Tensors`.
num_classes: The number of classes used by the classifier.
hparams: The hyperparameters struct.
Returns:
A `Tensor` representing the value of the D-step loss.
"""
domain_classifier_loss = add_domain_classifier_losses(end_points, hparams)
task_classifier_loss = 0
if source_labels is not None:
task_classifier_loss = _add_task_specific_losses(
end_points, source_labels, num_classes, hparams, add_summaries=True)
return domain_classifier_loss + task_classifier_loss