NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.85 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define flags are common for both train.py and eval.py scripts."""
import sys
from tensorflow.python.platform import flags
import logging
import datasets
import model
FLAGS = flags.FLAGS
logging.basicConfig(
level=logging.DEBUG,
stream=sys.stderr,
format='%(levelname)s '
'%(asctime)s.%(msecs)06d: '
'%(filename)s: '
'%(lineno)d '
'%(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
def define():
"""Define common flags."""
# yapf: disable
flags.DEFINE_integer('batch_size', 32,
'Batch size.')
flags.DEFINE_integer('crop_width', None,
'Width of the central crop for images.')
flags.DEFINE_integer('crop_height', None,
'Height of the central crop for images.')
flags.DEFINE_string('train_log_dir', '/tmp/attention_ocr/train',
'Directory where to write event logs.')
flags.DEFINE_string('dataset_name', 'fsns',
'Name of the dataset. Supported: fsns')
flags.DEFINE_string('split_name', 'train',
'Dataset split name to run evaluation for: test,train.')
flags.DEFINE_string('dataset_dir', None,
'Dataset root folder.')
flags.DEFINE_string('checkpoint', '',
'Path for checkpoint to restore weights from.')
flags.DEFINE_string('master',
'',
'BNS name of the TensorFlow master to use.')
# Model hyper parameters
flags.DEFINE_float('learning_rate', 0.004,
'learning rate')
flags.DEFINE_string('optimizer', 'momentum',
'the optimizer to use')
flags.DEFINE_float('momentum', 0.9,
'momentum value for the momentum optimizer if used')
flags.DEFINE_bool('use_augment_input', True,
'If True will use image augmentation')
# Method hyper parameters
# conv_tower_fn
flags.DEFINE_string('final_endpoint', 'Mixed_5d',
'Endpoint to cut inception tower')
# sequence_logit_fn
flags.DEFINE_bool('use_attention', True,
'If True will use the attention mechanism')
flags.DEFINE_bool('use_autoregression', True,
'If True will use autoregression (a feedback link)')
flags.DEFINE_integer('num_lstm_units', 256,
'number of LSTM units for sequence LSTM')
flags.DEFINE_float('weight_decay', 0.00004,
'weight decay for char prediction FC layers')
flags.DEFINE_float('lstm_state_clip_value', 10.0,
'cell state is clipped by this value prior to the cell'
' output activation')
# 'sequence_loss_fn'
flags.DEFINE_float('label_smoothing', 0.1,
'weight for label smoothing')
flags.DEFINE_bool('ignore_nulls', True,
'ignore null characters for computing the loss')
flags.DEFINE_bool('average_across_timesteps', False,
'divide the returned cost by the total label weight')
# yapf: enable
def get_crop_size():
if FLAGS.crop_width and FLAGS.crop_height:
return (FLAGS.crop_width, FLAGS.crop_height)
else:
return None
def create_dataset(split_name):
ds_module = getattr(datasets, FLAGS.dataset_name)
return ds_module.get_split(split_name, dataset_dir=FLAGS.dataset_dir)
def create_mparams():
return {
'conv_tower_fn':
model.ConvTowerParams(final_endpoint=FLAGS.final_endpoint),
'sequence_logit_fn':
model.SequenceLogitsParams(
use_attention=FLAGS.use_attention,
use_autoregression=FLAGS.use_autoregression,
num_lstm_units=FLAGS.num_lstm_units,
weight_decay=FLAGS.weight_decay,
lstm_state_clip_value=FLAGS.lstm_state_clip_value),
'sequence_loss_fn':
model.SequenceLossParams(
label_smoothing=FLAGS.label_smoothing,
ignore_nulls=FLAGS.ignore_nulls,
average_across_timesteps=FLAGS.average_across_timesteps)
}
def create_model(*args, **kwargs):
ocr_model = model.Model(mparams=create_mparams(), *args, **kwargs)
return ocr_model