NCTC / models /official /vision /image_classification /optimizer_factory_test.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.02 kB
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizer_factory."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from absl.testing import parameterized
from official.vision.image_classification import optimizer_factory
from official.vision.image_classification.configs import base_configs
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
('sgd', 'sgd', 0., False),
('momentum', 'momentum', 0., False),
('rmsprop', 'rmsprop', 0., False),
('adam', 'adam', 0., False),
('adamw', 'adamw', 0., False),
('momentum_lookahead', 'momentum', 0., True),
('sgd_ema', 'sgd', 0.999, False),
('momentum_ema', 'momentum', 0.999, False),
('rmsprop_ema', 'rmsprop', 0.999, False))
def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
"""Smoke test to be sure no syntax errors."""
params = {
'learning_rate': 0.001,
'rho': 0.09,
'momentum': 0.,
'epsilon': 1e-07,
'moving_average_decay': moving_average_decay,
'lookahead': lookahead,
}
optimizer = optimizer_factory.build_optimizer(
optimizer_name=optimizer_name,
base_learning_rate=params['learning_rate'],
params=params)
self.assertTrue(issubclass(type(optimizer), tf.keras.optimizers.Optimizer))
def test_unknown_optimizer(self):
with self.assertRaises(ValueError):
optimizer_factory.build_optimizer(
optimizer_name='this_optimizer_does_not_exist',
base_learning_rate=None,
params=None)
def test_learning_rate_without_decay_or_warmups(self):
params = base_configs.LearningRateConfig(
name='exponential',
initial_lr=0.01,
decay_rate=0.01,
decay_epochs=None,
warmup_epochs=None,
scale_by_batch_size=0.01,
examples_per_epoch=1,
boundaries=[0],
multipliers=[0, 1])
batch_size = 1
train_steps = 1
lr = optimizer_factory.build_learning_rate(
params=params,
batch_size=batch_size,
train_steps=train_steps)
self.assertTrue(
issubclass(
type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
@parameterized.named_parameters(
('exponential', 'exponential'),
('piecewise_constant_with_warmup', 'piecewise_constant_with_warmup'),
('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax."""
params = base_configs.LearningRateConfig(
name=lr_decay_type,
initial_lr=0.01,
decay_rate=0.01,
decay_epochs=1,
warmup_epochs=1,
scale_by_batch_size=0.01,
examples_per_epoch=1,
boundaries=[0],
multipliers=[0, 1])
batch_size = 1
train_epochs = 1
train_steps = 1
lr = optimizer_factory.build_learning_rate(
params=params,
batch_size=batch_size,
train_epochs=train_epochs,
train_steps=train_steps)
self.assertTrue(
issubclass(
type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
if __name__ == '__main__':
tf.test.main()