Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Script to train the Attention OCR model. | |
A simple usage example: | |
python train.py | |
""" | |
import collections | |
import logging | |
import tensorflow as tf | |
from tensorflow.contrib import slim | |
from tensorflow import app | |
from tensorflow.python.platform import flags | |
from tensorflow.contrib.tfprof import model_analyzer | |
import data_provider | |
import common_flags | |
FLAGS = flags.FLAGS | |
common_flags.define() | |
# yapf: disable | |
flags.DEFINE_integer('task', 0, | |
'The Task ID. This value is used when training with ' | |
'multiple workers to identify each worker.') | |
flags.DEFINE_integer('ps_tasks', 0, | |
'The number of parameter servers. If the value is 0, then' | |
' the parameters are handled locally by the worker.') | |
flags.DEFINE_integer('save_summaries_secs', 60, | |
'The frequency with which summaries are saved, in ' | |
'seconds.') | |
flags.DEFINE_integer('save_interval_secs', 600, | |
'Frequency in seconds of saving the model.') | |
flags.DEFINE_integer('max_number_of_steps', int(1e10), | |
'The maximum number of gradient steps.') | |
flags.DEFINE_string('checkpoint_inception', '', | |
'Checkpoint to recover inception weights from.') | |
flags.DEFINE_float('clip_gradient_norm', 2.0, | |
'If greater than 0 then the gradients would be clipped by ' | |
'it.') | |
flags.DEFINE_bool('sync_replicas', False, | |
'If True will synchronize replicas during training.') | |
flags.DEFINE_integer('replicas_to_aggregate', 1, | |
'The number of gradients updates before updating params.') | |
flags.DEFINE_integer('total_num_replicas', 1, | |
'Total number of worker replicas.') | |
flags.DEFINE_integer('startup_delay_steps', 15, | |
'Number of training steps between replicas startup.') | |
flags.DEFINE_boolean('reset_train_dir', False, | |
'If true will delete all files in the train_log_dir') | |
flags.DEFINE_boolean('show_graph_stats', False, | |
'Output model size stats to stderr.') | |
# yapf: enable | |
TrainingHParams = collections.namedtuple('TrainingHParams', [ | |
'learning_rate', | |
'optimizer', | |
'momentum', | |
'use_augment_input', | |
]) | |
def get_training_hparams(): | |
return TrainingHParams( | |
learning_rate=FLAGS.learning_rate, | |
optimizer=FLAGS.optimizer, | |
momentum=FLAGS.momentum, | |
use_augment_input=FLAGS.use_augment_input) | |
def create_optimizer(hparams): | |
"""Creates optimized based on the specified flags.""" | |
if hparams.optimizer == 'momentum': | |
optimizer = tf.train.MomentumOptimizer( | |
hparams.learning_rate, momentum=hparams.momentum) | |
elif hparams.optimizer == 'adam': | |
optimizer = tf.train.AdamOptimizer(hparams.learning_rate) | |
elif hparams.optimizer == 'adadelta': | |
optimizer = tf.train.AdadeltaOptimizer(hparams.learning_rate) | |
elif hparams.optimizer == 'adagrad': | |
optimizer = tf.train.AdagradOptimizer(hparams.learning_rate) | |
elif hparams.optimizer == 'rmsprop': | |
optimizer = tf.train.RMSPropOptimizer( | |
hparams.learning_rate, momentum=hparams.momentum) | |
return optimizer | |
def train(loss, init_fn, hparams): | |
"""Wraps slim.learning.train to run a training loop. | |
Args: | |
loss: a loss tensor | |
init_fn: A callable to be executed after all other initialization is done. | |
hparams: a model hyper parameters | |
""" | |
optimizer = create_optimizer(hparams) | |
if FLAGS.sync_replicas: | |
replica_id = tf.constant(FLAGS.task, tf.int32, shape=()) | |
optimizer = tf.LegacySyncReplicasOptimizer( | |
opt=optimizer, | |
replicas_to_aggregate=FLAGS.replicas_to_aggregate, | |
replica_id=replica_id, | |
total_num_replicas=FLAGS.total_num_replicas) | |
sync_optimizer = optimizer | |
startup_delay_steps = 0 | |
else: | |
startup_delay_steps = 0 | |
sync_optimizer = None | |
train_op = slim.learning.create_train_op( | |
loss, | |
optimizer, | |
summarize_gradients=True, | |
clip_gradient_norm=FLAGS.clip_gradient_norm) | |
slim.learning.train( | |
train_op=train_op, | |
logdir=FLAGS.train_log_dir, | |
graph=loss.graph, | |
master=FLAGS.master, | |
is_chief=(FLAGS.task == 0), | |
number_of_steps=FLAGS.max_number_of_steps, | |
save_summaries_secs=FLAGS.save_summaries_secs, | |
save_interval_secs=FLAGS.save_interval_secs, | |
startup_delay_steps=startup_delay_steps, | |
sync_optimizer=sync_optimizer, | |
init_fn=init_fn) | |
def prepare_training_dir(): | |
if not tf.gfile.Exists(FLAGS.train_log_dir): | |
logging.info('Create a new training directory %s', FLAGS.train_log_dir) | |
tf.gfile.MakeDirs(FLAGS.train_log_dir) | |
else: | |
if FLAGS.reset_train_dir: | |
logging.info('Reset the training directory %s', FLAGS.train_log_dir) | |
tf.gfile.DeleteRecursively(FLAGS.train_log_dir) | |
tf.gfile.MakeDirs(FLAGS.train_log_dir) | |
else: | |
logging.info('Use already existing training directory %s', | |
FLAGS.train_log_dir) | |
def calculate_graph_metrics(): | |
param_stats = model_analyzer.print_model_analysis( | |
tf.get_default_graph(), | |
tfprof_options=model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS) | |
return param_stats.total_parameters | |
def main(_): | |
prepare_training_dir() | |
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) | |
model = common_flags.create_model(dataset.num_char_classes, | |
dataset.max_sequence_length, | |
dataset.num_of_views, dataset.null_code) | |
hparams = get_training_hparams() | |
# If ps_tasks is zero, the local device is used. When using multiple | |
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables | |
# across the different devices. | |
device_setter = tf.train.replica_device_setter( | |
FLAGS.ps_tasks, merge_devices=True) | |
with tf.device(device_setter): | |
data = data_provider.get_data( | |
dataset, | |
FLAGS.batch_size, | |
augment=hparams.use_augment_input, | |
central_crop_size=common_flags.get_crop_size()) | |
endpoints = model.create_base(data.images, data.labels_one_hot) | |
total_loss = model.create_loss(data, endpoints) | |
model.create_summaries(data, endpoints, dataset.charset, is_training=True) | |
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint, | |
FLAGS.checkpoint_inception) | |
if FLAGS.show_graph_stats: | |
logging.info('Total number of weights in the graph: %s', | |
calculate_graph_metrics()) | |
train(total_loss, init_fn, hparams) | |
if __name__ == '__main__': | |
app.run() | |