NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
8.57 kB
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.modeling import performance
from official.staging.training import grad_utils
from official.staging.training import standard_runnable
from official.staging.training import utils
from official.utils.flags import core as flags_core
from official.vision.image_classification.resnet import common
from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
class ResnetRunnable(standard_runnable.StandardTrainable,
standard_runnable.StandardEvaluable):
"""Implements the training and evaluation APIs for Resnet model."""
def __init__(self, flags_obj, time_callback, epoch_steps):
standard_runnable.StandardTrainable.__init__(self,
flags_obj.use_tf_while_loop,
flags_obj.use_tf_function)
standard_runnable.StandardEvaluable.__init__(self,
flags_obj.use_tf_function)
self.strategy = tf.distribute.get_strategy()
self.flags_obj = flags_obj
self.dtype = flags_core.get_tf_dtype(flags_obj)
self.time_callback = time_callback
# Input pipeline related
batch_size = flags_obj.batch_size
if batch_size % self.strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
self.strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)
if self.flags_obj.use_synthetic_data:
self.input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=self.dtype,
drop_remainder=True)
else:
self.input_fn = imagenet_preprocessing.input_fn
self.model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
use_l2_regularizer=not flags_obj.single_l2_loss_op)
lr_schedule = common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size,
epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=common.LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
self.optimizer = common.get_optimizer(lr_schedule)
# Make sure iterations variable is created inside scope.
self.global_step = self.optimizer.iterations
use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite'
if use_graph_rewrite and not flags_obj.use_tf_function:
raise ValueError('--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true')
self.optimizer = performance.configure_optimizer(
self.optimizer,
use_float16=self.dtype == tf.float16,
use_graph_rewrite=use_graph_rewrite,
loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'train_accuracy', dtype=tf.float32)
self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'test_accuracy', dtype=tf.float32)
self.checkpoint = tf.train.Checkpoint(
model=self.model, optimizer=self.optimizer)
# Handling epochs.
self.epoch_steps = epoch_steps
self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step)
def build_train_dataset(self):
"""See base class."""
return utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=True,
data_dir=self.flags_obj.data_dir,
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=self.flags_obj
.datasets_num_private_threads,
dtype=self.dtype,
drop_remainder=True)
def build_eval_dataset(self):
"""See base class."""
return utils.make_distributed_dataset(
self.strategy,
self.input_fn,
is_training=False,
data_dir=self.flags_obj.data_dir,
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype)
def train_loop_begin(self):
"""See base class."""
# Reset all metrics
self.train_loss.reset_states()
self.train_accuracy.reset_states()
self._epoch_begin()
self.time_callback.on_batch_begin(self.epoch_helper.batch_index)
def train_step(self, iterator):
"""See base class."""
def step_fn(inputs):
"""Function to run on the device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = self.model(images, training=True)
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
loss = tf.reduce_sum(prediction_loss) * (1.0 /
self.flags_obj.batch_size)
num_replicas = self.strategy.num_replicas_in_sync
l2_weight_decay = 1e-4
if self.flags_obj.single_l2_loss_op:
l2_loss = l2_weight_decay * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in self.model.trainable_variables
if 'bn' not in v.name
])
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(self.model.losses) / num_replicas)
grad_utils.minimize_using_explicit_allreduce(
tape, self.optimizer, loss, self.model.trainable_variables)
self.train_loss.update_state(loss)
self.train_accuracy.update_state(labels, logits)
self.strategy.run(step_fn, args=(next(iterator),))
def train_loop_end(self):
"""See base class."""
metrics = {
'train_loss': self.train_loss.result(),
'train_accuracy': self.train_accuracy.result(),
}
self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1)
self._epoch_end()
return metrics
def eval_begin(self):
"""See base class."""
self.test_loss.reset_states()
self.test_accuracy.reset_states()
def eval_step(self, iterator):
"""See base class."""
def step_fn(inputs):
"""Function to run on the device."""
images, labels = inputs
logits = self.model(images, training=False)
loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits)
loss = tf.reduce_sum(loss) * (1.0 / self.flags_obj.batch_size)
self.test_loss.update_state(loss)
self.test_accuracy.update_state(labels, logits)
self.strategy.run(step_fn, args=(next(iterator),))
def eval_end(self):
"""See base class."""
return {
'test_loss': self.test_loss.result(),
'test_accuracy': self.test_accuracy.result()
}
def _epoch_begin(self):
if self.epoch_helper.epoch_begin():
self.time_callback.on_epoch_begin(self.epoch_helper.current_epoch)
def _epoch_end(self):
if self.epoch_helper.epoch_end():
self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)