# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Runs a ResNet model on the ImageNet dataset using custom training loops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from official.modeling import performance from official.staging.training import grad_utils from official.staging.training import standard_runnable from official.staging.training import utils from official.utils.flags import core as flags_core from official.vision.image_classification.resnet import common from official.vision.image_classification.resnet import imagenet_preprocessing from official.vision.image_classification.resnet import resnet_model class ResnetRunnable(standard_runnable.StandardTrainable, standard_runnable.StandardEvaluable): """Implements the training and evaluation APIs for Resnet model.""" def __init__(self, flags_obj, time_callback, epoch_steps): standard_runnable.StandardTrainable.__init__(self, flags_obj.use_tf_while_loop, flags_obj.use_tf_function) standard_runnable.StandardEvaluable.__init__(self, flags_obj.use_tf_function) self.strategy = tf.distribute.get_strategy() self.flags_obj = flags_obj self.dtype = flags_core.get_tf_dtype(flags_obj) self.time_callback = time_callback # Input pipeline related batch_size = flags_obj.batch_size if batch_size % self.strategy.num_replicas_in_sync != 0: raise ValueError( 'Batch size must be divisible by number of replicas : {}'.format( self.strategy.num_replicas_in_sync)) # As auto rebatching is not supported in # `experimental_distribute_datasets_from_function()` API, which is # required when cloning dataset to multiple workers in eager mode, # we use per-replica batch size. self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync) if self.flags_obj.use_synthetic_data: self.input_fn = common.get_synth_input_fn( height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, num_channels=imagenet_preprocessing.NUM_CHANNELS, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=self.dtype, drop_remainder=True) else: self.input_fn = imagenet_preprocessing.input_fn self.model = resnet_model.resnet50( num_classes=imagenet_preprocessing.NUM_CLASSES, use_l2_regularizer=not flags_obj.single_l2_loss_op) lr_schedule = common.PiecewiseConstantDecayWithWarmup( batch_size=flags_obj.batch_size, epoch_size=imagenet_preprocessing.NUM_IMAGES['train'], warmup_epochs=common.LR_SCHEDULE[0][1], boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), multipliers=list(p[0] for p in common.LR_SCHEDULE), compute_lr_on_cpu=True) self.optimizer = common.get_optimizer(lr_schedule) # Make sure iterations variable is created inside scope. self.global_step = self.optimizer.iterations use_graph_rewrite = flags_obj.fp16_implementation == 'graph_rewrite' if use_graph_rewrite and not flags_obj.use_tf_function: raise ValueError('--fp16_implementation=graph_rewrite requires ' '--use_tf_function to be true') self.optimizer = performance.configure_optimizer( self.optimizer, use_float16=self.dtype == tf.float16, use_graph_rewrite=use_graph_rewrite, loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128)) self.train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'train_accuracy', dtype=tf.float32) self.test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( 'test_accuracy', dtype=tf.float32) self.checkpoint = tf.train.Checkpoint( model=self.model, optimizer=self.optimizer) # Handling epochs. self.epoch_steps = epoch_steps self.epoch_helper = utils.EpochHelper(epoch_steps, self.global_step) def build_train_dataset(self): """See base class.""" return utils.make_distributed_dataset( self.strategy, self.input_fn, is_training=True, data_dir=self.flags_obj.data_dir, batch_size=self.batch_size, parse_record_fn=imagenet_preprocessing.parse_record, datasets_num_private_threads=self.flags_obj .datasets_num_private_threads, dtype=self.dtype, drop_remainder=True) def build_eval_dataset(self): """See base class.""" return utils.make_distributed_dataset( self.strategy, self.input_fn, is_training=False, data_dir=self.flags_obj.data_dir, batch_size=self.batch_size, parse_record_fn=imagenet_preprocessing.parse_record, dtype=self.dtype) def train_loop_begin(self): """See base class.""" # Reset all metrics self.train_loss.reset_states() self.train_accuracy.reset_states() self._epoch_begin() self.time_callback.on_batch_begin(self.epoch_helper.batch_index) def train_step(self, iterator): """See base class.""" def step_fn(inputs): """Function to run on the device.""" images, labels = inputs with tf.GradientTape() as tape: logits = self.model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, logits) loss = tf.reduce_sum(prediction_loss) * (1.0 / self.flags_obj.batch_size) num_replicas = self.strategy.num_replicas_in_sync l2_weight_decay = 1e-4 if self.flags_obj.single_l2_loss_op: l2_loss = l2_weight_decay * 2 * tf.add_n([ tf.nn.l2_loss(v) for v in self.model.trainable_variables if 'bn' not in v.name ]) loss += (l2_loss / num_replicas) else: loss += (tf.reduce_sum(self.model.losses) / num_replicas) grad_utils.minimize_using_explicit_allreduce( tape, self.optimizer, loss, self.model.trainable_variables) self.train_loss.update_state(loss) self.train_accuracy.update_state(labels, logits) self.strategy.run(step_fn, args=(next(iterator),)) def train_loop_end(self): """See base class.""" metrics = { 'train_loss': self.train_loss.result(), 'train_accuracy': self.train_accuracy.result(), } self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1) self._epoch_end() return metrics def eval_begin(self): """See base class.""" self.test_loss.reset_states() self.test_accuracy.reset_states() def eval_step(self, iterator): """See base class.""" def step_fn(inputs): """Function to run on the device.""" images, labels = inputs logits = self.model(images, training=False) loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits) loss = tf.reduce_sum(loss) * (1.0 / self.flags_obj.batch_size) self.test_loss.update_state(loss) self.test_accuracy.update_state(labels, logits) self.strategy.run(step_fn, args=(next(iterator),)) def eval_end(self): """See base class.""" return { 'test_loss': self.test_loss.result(), 'test_accuracy': self.test_accuracy.result() } def _epoch_begin(self): if self.epoch_helper.epoch_begin(): self.time_callback.on_epoch_begin(self.epoch_helper.current_epoch) def _epoch_end(self): if self.epoch_helper.epoch_end(): self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)