Spaces:
Sleeping
Sleeping
# Copyright 2018 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Library with common functions for training and eval.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import six | |
import tensorflow as tf | |
from tensorflow.contrib.slim.nets import resnet_v2 | |
def default_hparams(): | |
"""Returns default hyperparameters.""" | |
return tf.contrib.training.HParams( | |
# Batch size for training and evaluation. | |
batch_size=32, | |
eval_batch_size=50, | |
# General training parameters. | |
weight_decay=0.0001, | |
label_smoothing=0.1, | |
# Parameters of the adversarial training. | |
train_adv_method='clean', # adversarial training method | |
train_lp_weight=0.0, # Weight of adversarial logit pairing loss | |
# Parameters of the optimizer. | |
optimizer='rms', # possible values are: 'rms', 'momentum', 'adam' | |
momentum=0.9, # momentum | |
rmsprop_decay=0.9, # Decay term for RMSProp | |
rmsprop_epsilon=1.0, # Epsilon term for RMSProp | |
# Parameters of learning rate schedule. | |
lr_schedule='exp_decay', # Possible values: 'exp_decay', 'step', 'fixed' | |
learning_rate=0.045, | |
lr_decay_factor=0.94, # Learning exponential decay | |
lr_num_epochs_per_decay=2.0, # Number of epochs per lr decay | |
lr_list=[1.0 / 6, 2.0 / 6, 3.0 / 6, | |
4.0 / 6, 5.0 / 6, 1.0, 0.1, 0.01, | |
0.001, 0.0001], | |
lr_decay_epochs=[1, 2, 3, 4, 5, 30, 60, 80, | |
90]) | |
def get_lr_schedule(hparams, examples_per_epoch, replicas_to_aggregate=1): | |
"""Returns TensorFlow op which compute learning rate. | |
Args: | |
hparams: hyper parameters. | |
examples_per_epoch: number of training examples per epoch. | |
replicas_to_aggregate: number of training replicas running in parallel. | |
Raises: | |
ValueError: if learning rate schedule specified in hparams is incorrect. | |
Returns: | |
learning_rate: tensor with learning rate. | |
steps_per_epoch: number of training steps per epoch. | |
""" | |
global_step = tf.train.get_or_create_global_step() | |
steps_per_epoch = float(examples_per_epoch) / float(hparams.batch_size) | |
if replicas_to_aggregate > 0: | |
steps_per_epoch /= replicas_to_aggregate | |
if hparams.lr_schedule == 'exp_decay': | |
decay_steps = long(steps_per_epoch * hparams.lr_num_epochs_per_decay) | |
learning_rate = tf.train.exponential_decay( | |
hparams.learning_rate, | |
global_step, | |
decay_steps, | |
hparams.lr_decay_factor, | |
staircase=True) | |
elif hparams.lr_schedule == 'step': | |
lr_decay_steps = [long(epoch * steps_per_epoch) | |
for epoch in hparams.lr_decay_epochs] | |
learning_rate = tf.train.piecewise_constant( | |
global_step, lr_decay_steps, hparams.lr_list) | |
elif hparams.lr_schedule == 'fixed': | |
learning_rate = hparams.learning_rate | |
else: | |
raise ValueError('Invalid value of lr_schedule: %s' % hparams.lr_schedule) | |
if replicas_to_aggregate > 0: | |
learning_rate *= replicas_to_aggregate | |
return learning_rate, steps_per_epoch | |
def get_optimizer(hparams, learning_rate): | |
"""Returns optimizer. | |
Args: | |
hparams: hyper parameters. | |
learning_rate: learning rate tensor. | |
Raises: | |
ValueError: if type of optimizer specified in hparams is incorrect. | |
Returns: | |
Instance of optimizer class. | |
""" | |
if hparams.optimizer == 'rms': | |
optimizer = tf.train.RMSPropOptimizer(learning_rate, | |
hparams.rmsprop_decay, | |
hparams.momentum, | |
hparams.rmsprop_epsilon) | |
elif hparams.optimizer == 'momentum': | |
optimizer = tf.train.MomentumOptimizer(learning_rate, | |
hparams.momentum) | |
elif hparams.optimizer == 'adam': | |
optimizer = tf.train.AdamOptimizer(learning_rate) | |
else: | |
raise ValueError('Invalid value of optimizer: %s' % hparams.optimizer) | |
return optimizer | |
RESNET_MODELS = {'resnet_v2_50': resnet_v2.resnet_v2_50} | |
def get_model(model_name, num_classes): | |
"""Returns function which creates model. | |
Args: | |
model_name: Name of the model. | |
num_classes: Number of classes. | |
Raises: | |
ValueError: If model_name is invalid. | |
Returns: | |
Function, which creates model when called. | |
""" | |
if model_name.startswith('resnet'): | |
def resnet_model(images, is_training, reuse=tf.AUTO_REUSE): | |
with tf.contrib.framework.arg_scope(resnet_v2.resnet_arg_scope()): | |
resnet_fn = RESNET_MODELS[model_name] | |
logits, _ = resnet_fn(images, num_classes, is_training=is_training, | |
reuse=reuse) | |
logits = tf.reshape(logits, [-1, num_classes]) | |
return logits | |
return resnet_model | |
else: | |
raise ValueError('Invalid model: %s' % model_name) | |
def filter_trainable_variables(trainable_scopes): | |
"""Keep only trainable variables which are prefixed with given scopes. | |
Args: | |
trainable_scopes: either list of trainable scopes or string with comma | |
separated list of trainable scopes. | |
This function removes all variables which are not prefixed with given | |
trainable_scopes from collection of trainable variables. | |
Useful during network fine tuning, when you only need to train subset of | |
variables. | |
""" | |
if not trainable_scopes: | |
return | |
if isinstance(trainable_scopes, six.string_types): | |
trainable_scopes = [scope.strip() for scope in trainable_scopes.split(',')] | |
trainable_scopes = {scope for scope in trainable_scopes if scope} | |
if not trainable_scopes: | |
return | |
trainable_collection = tf.get_collection_ref( | |
tf.GraphKeys.TRAINABLE_VARIABLES) | |
non_trainable_vars = [ | |
v for v in trainable_collection | |
if not any([v.op.name.startswith(s) for s in trainable_scopes]) | |
] | |
for v in non_trainable_vars: | |
trainable_collection.remove(v) | |