Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Helper functions used for training AutoAugment models.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import tensorflow as tf | |
def setup_loss(logits, labels): | |
"""Returns the cross entropy for the given `logits` and `labels`.""" | |
predictions = tf.nn.softmax(logits) | |
cost = tf.losses.softmax_cross_entropy(onehot_labels=labels, | |
logits=logits) | |
return predictions, cost | |
def decay_weights(cost, weight_decay_rate): | |
"""Calculates the loss for l2 weight decay and adds it to `cost`.""" | |
costs = [] | |
for var in tf.trainable_variables(): | |
costs.append(tf.nn.l2_loss(var)) | |
cost += tf.multiply(weight_decay_rate, tf.add_n(costs)) | |
return cost | |
def eval_child_model(session, model, data_loader, mode): | |
"""Evaluates `model` on held out data depending on `mode`. | |
Args: | |
session: TensorFlow session the model will be run with. | |
model: TensorFlow model that will be evaluated. | |
data_loader: DataSet object that contains data that `model` will | |
evaluate. | |
mode: Will `model` either evaluate validation or test data. | |
Returns: | |
Accuracy of `model` when evaluated on the specified dataset. | |
Raises: | |
ValueError: if invalid dataset `mode` is specified. | |
""" | |
if mode == 'val': | |
images = data_loader.val_images | |
labels = data_loader.val_labels | |
elif mode == 'test': | |
images = data_loader.test_images | |
labels = data_loader.test_labels | |
else: | |
raise ValueError('Not valid eval mode') | |
assert len(images) == len(labels) | |
tf.logging.info('model.batch_size is {}'.format(model.batch_size)) | |
assert len(images) % model.batch_size == 0 | |
eval_batches = int(len(images) / model.batch_size) | |
for i in range(eval_batches): | |
eval_images = images[i * model.batch_size:(i + 1) * model.batch_size] | |
eval_labels = labels[i * model.batch_size:(i + 1) * model.batch_size] | |
_ = session.run( | |
model.eval_op, | |
feed_dict={ | |
model.images: eval_images, | |
model.labels: eval_labels, | |
}) | |
return session.run(model.accuracy) | |
def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs): | |
"""Cosine Learning rate. | |
Args: | |
learning_rate: Initial learning rate. | |
epoch: Current epoch we are one. This is one based. | |
iteration: Current batch in this epoch. | |
batches_per_epoch: Batches per epoch. | |
total_epochs: Total epochs you are training for. | |
Returns: | |
The learning rate to be used for this current batch. | |
""" | |
t_total = total_epochs * batches_per_epoch | |
t_cur = float(epoch * batches_per_epoch + iteration) | |
return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total)) | |
def get_lr(curr_epoch, hparams, iteration=None): | |
"""Returns the learning rate during training based on the current epoch.""" | |
assert iteration is not None | |
batches_per_epoch = int(hparams.train_size / hparams.batch_size) | |
lr = cosine_lr(hparams.lr, curr_epoch, iteration, batches_per_epoch, | |
hparams.num_epochs) | |
return lr | |
def run_epoch_training(session, model, data_loader, curr_epoch): | |
"""Runs one epoch of training for the model passed in. | |
Args: | |
session: TensorFlow session the model will be run with. | |
model: TensorFlow model that will be evaluated. | |
data_loader: DataSet object that contains data that `model` will | |
evaluate. | |
curr_epoch: How many of epochs of training have been done so far. | |
Returns: | |
The accuracy of 'model' on the training set | |
""" | |
steps_per_epoch = int(model.hparams.train_size / model.hparams.batch_size) | |
tf.logging.info('steps per epoch: {}'.format(steps_per_epoch)) | |
curr_step = session.run(model.global_step) | |
assert curr_step % steps_per_epoch == 0 | |
# Get the current learning rate for the model based on the current epoch | |
curr_lr = get_lr(curr_epoch, model.hparams, iteration=0) | |
tf.logging.info('lr of {} for epoch {}'.format(curr_lr, curr_epoch)) | |
for step in xrange(steps_per_epoch): | |
curr_lr = get_lr(curr_epoch, model.hparams, iteration=(step + 1)) | |
# Update the lr rate variable to the current LR. | |
model.lr_rate_ph.load(curr_lr, session=session) | |
if step % 20 == 0: | |
tf.logging.info('Training {}/{}'.format(step, steps_per_epoch)) | |
train_images, train_labels = data_loader.next_batch() | |
_, step, _ = session.run( | |
[model.train_op, model.global_step, model.eval_op], | |
feed_dict={ | |
model.images: train_images, | |
model.labels: train_labels, | |
}) | |
train_accuracy = session.run(model.accuracy) | |
tf.logging.info('Train accuracy: {}'.format(train_accuracy)) | |
return train_accuracy | |