NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
6.5 kB
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library with common functions for training and eval."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf
from tensorflow.contrib.slim.nets import resnet_v2
def default_hparams():
"""Returns default hyperparameters."""
return tf.contrib.training.HParams(
# Batch size for training and evaluation.
batch_size=32,
eval_batch_size=50,
# General training parameters.
weight_decay=0.0001,
label_smoothing=0.1,
# Parameters of the adversarial training.
train_adv_method='clean', # adversarial training method
train_lp_weight=0.0, # Weight of adversarial logit pairing loss
# Parameters of the optimizer.
optimizer='rms', # possible values are: 'rms', 'momentum', 'adam'
momentum=0.9, # momentum
rmsprop_decay=0.9, # Decay term for RMSProp
rmsprop_epsilon=1.0, # Epsilon term for RMSProp
# Parameters of learning rate schedule.
lr_schedule='exp_decay', # Possible values: 'exp_decay', 'step', 'fixed'
learning_rate=0.045,
lr_decay_factor=0.94, # Learning exponential decay
lr_num_epochs_per_decay=2.0, # Number of epochs per lr decay
lr_list=[1.0 / 6, 2.0 / 6, 3.0 / 6,
4.0 / 6, 5.0 / 6, 1.0, 0.1, 0.01,
0.001, 0.0001],
lr_decay_epochs=[1, 2, 3, 4, 5, 30, 60, 80,
90])
def get_lr_schedule(hparams, examples_per_epoch, replicas_to_aggregate=1):
"""Returns TensorFlow op which compute learning rate.
Args:
hparams: hyper parameters.
examples_per_epoch: number of training examples per epoch.
replicas_to_aggregate: number of training replicas running in parallel.
Raises:
ValueError: if learning rate schedule specified in hparams is incorrect.
Returns:
learning_rate: tensor with learning rate.
steps_per_epoch: number of training steps per epoch.
"""
global_step = tf.train.get_or_create_global_step()
steps_per_epoch = float(examples_per_epoch) / float(hparams.batch_size)
if replicas_to_aggregate > 0:
steps_per_epoch /= replicas_to_aggregate
if hparams.lr_schedule == 'exp_decay':
decay_steps = long(steps_per_epoch * hparams.lr_num_epochs_per_decay)
learning_rate = tf.train.exponential_decay(
hparams.learning_rate,
global_step,
decay_steps,
hparams.lr_decay_factor,
staircase=True)
elif hparams.lr_schedule == 'step':
lr_decay_steps = [long(epoch * steps_per_epoch)
for epoch in hparams.lr_decay_epochs]
learning_rate = tf.train.piecewise_constant(
global_step, lr_decay_steps, hparams.lr_list)
elif hparams.lr_schedule == 'fixed':
learning_rate = hparams.learning_rate
else:
raise ValueError('Invalid value of lr_schedule: %s' % hparams.lr_schedule)
if replicas_to_aggregate > 0:
learning_rate *= replicas_to_aggregate
return learning_rate, steps_per_epoch
def get_optimizer(hparams, learning_rate):
"""Returns optimizer.
Args:
hparams: hyper parameters.
learning_rate: learning rate tensor.
Raises:
ValueError: if type of optimizer specified in hparams is incorrect.
Returns:
Instance of optimizer class.
"""
if hparams.optimizer == 'rms':
optimizer = tf.train.RMSPropOptimizer(learning_rate,
hparams.rmsprop_decay,
hparams.momentum,
hparams.rmsprop_epsilon)
elif hparams.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(learning_rate,
hparams.momentum)
elif hparams.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(learning_rate)
else:
raise ValueError('Invalid value of optimizer: %s' % hparams.optimizer)
return optimizer
RESNET_MODELS = {'resnet_v2_50': resnet_v2.resnet_v2_50}
def get_model(model_name, num_classes):
"""Returns function which creates model.
Args:
model_name: Name of the model.
num_classes: Number of classes.
Raises:
ValueError: If model_name is invalid.
Returns:
Function, which creates model when called.
"""
if model_name.startswith('resnet'):
def resnet_model(images, is_training, reuse=tf.AUTO_REUSE):
with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()):
resnet_fn = RESNET_MODELS[model_name]
logits, _ = resnet_fn(images, num_classes, is_training=is_training,
reuse=reuse)
logits = tf.reshape(logits, [-1, num_classes])
return logits
return resnet_model
else:
raise ValueError('Invalid model: %s' % model_name)
def filter_trainable_variables(trainable_scopes):
"""Keep only trainable variables which are prefixed with given scopes.
Args:
trainable_scopes: either list of trainable scopes or string with comma
separated list of trainable scopes.
This function removes all variables which are not prefixed with given
trainable_scopes from collection of trainable variables.
Useful during network fine tuning, when you only need to train subset of
variables.
"""
if not trainable_scopes:
return
if isinstance(trainable_scopes, six.string_types):
trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')]
trainable_scopes = {scope for scope in trainable_scopes if scope}
if not trainable_scopes:
return
trainable_collection = tf.get_collection_ref(
tf.GraphKeys.TRAINABLE_VARIABLES)
non_trainable_vars = [
v for v in trainable_collection
if not any([v.op.name.startswith(s) for s in trainable_scopes])
]
for v in non_trainable_vars:
trainable_collection.remove(v)