Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Define flags are common for both train.py and eval.py scripts.""" | |
import sys | |
from tensorflow.python.platform import flags | |
import logging | |
import datasets | |
import model | |
FLAGS = flags.FLAGS | |
logging.basicConfig( | |
level=logging.DEBUG, | |
stream=sys.stderr, | |
format='%(levelname)s ' | |
'%(asctime)s.%(msecs)06d: ' | |
'%(filename)s: ' | |
'%(lineno)d ' | |
'%(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S') | |
def define(): | |
"""Define common flags.""" | |
# yapf: disable | |
flags.DEFINE_integer('batch_size', 32, | |
'Batch size.') | |
flags.DEFINE_integer('crop_width', None, | |
'Width of the central crop for images.') | |
flags.DEFINE_integer('crop_height', None, | |
'Height of the central crop for images.') | |
flags.DEFINE_string('train_log_dir', '/tmp/attention_ocr/train', | |
'Directory where to write event logs.') | |
flags.DEFINE_string('dataset_name', 'fsns', | |
'Name of the dataset. Supported: fsns') | |
flags.DEFINE_string('split_name', 'train', | |
'Dataset split name to run evaluation for: test,train.') | |
flags.DEFINE_string('dataset_dir', None, | |
'Dataset root folder.') | |
flags.DEFINE_string('checkpoint', '', | |
'Path for checkpoint to restore weights from.') | |
flags.DEFINE_string('master', | |
'', | |
'BNS name of the TensorFlow master to use.') | |
# Model hyper parameters | |
flags.DEFINE_float('learning_rate', 0.004, | |
'learning rate') | |
flags.DEFINE_string('optimizer', 'momentum', | |
'the optimizer to use') | |
flags.DEFINE_float('momentum', 0.9, | |
'momentum value for the momentum optimizer if used') | |
flags.DEFINE_bool('use_augment_input', True, | |
'If True will use image augmentation') | |
# Method hyper parameters | |
# conv_tower_fn | |
flags.DEFINE_string('final_endpoint', 'Mixed_5d', | |
'Endpoint to cut inception tower') | |
# sequence_logit_fn | |
flags.DEFINE_bool('use_attention', True, | |
'If True will use the attention mechanism') | |
flags.DEFINE_bool('use_autoregression', True, | |
'If True will use autoregression (a feedback link)') | |
flags.DEFINE_integer('num_lstm_units', 256, | |
'number of LSTM units for sequence LSTM') | |
flags.DEFINE_float('weight_decay', 0.00004, | |
'weight decay for char prediction FC layers') | |
flags.DEFINE_float('lstm_state_clip_value', 10.0, | |
'cell state is clipped by this value prior to the cell' | |
' output activation') | |
# 'sequence_loss_fn' | |
flags.DEFINE_float('label_smoothing', 0.1, | |
'weight for label smoothing') | |
flags.DEFINE_bool('ignore_nulls', True, | |
'ignore null characters for computing the loss') | |
flags.DEFINE_bool('average_across_timesteps', False, | |
'divide the returned cost by the total label weight') | |
# yapf: enable | |
def get_crop_size(): | |
if FLAGS.crop_width and FLAGS.crop_height: | |
return (FLAGS.crop_width, FLAGS.crop_height) | |
else: | |
return None | |
def create_dataset(split_name): | |
ds_module = getattr(datasets, FLAGS.dataset_name) | |
return ds_module.get_split(split_name, dataset_dir=FLAGS.dataset_dir) | |
def create_mparams(): | |
return { | |
'conv_tower_fn': | |
model.ConvTowerParams(final_endpoint=FLAGS.final_endpoint), | |
'sequence_logit_fn': | |
model.SequenceLogitsParams( | |
use_attention=FLAGS.use_attention, | |
use_autoregression=FLAGS.use_autoregression, | |
num_lstm_units=FLAGS.num_lstm_units, | |
weight_decay=FLAGS.weight_decay, | |
lstm_state_clip_value=FLAGS.lstm_state_clip_value), | |
'sequence_loss_fn': | |
model.SequenceLossParams( | |
label_smoothing=FLAGS.label_smoothing, | |
ignore_nulls=FLAGS.ignore_nulls, | |
average_across_timesteps=FLAGS.average_across_timesteps) | |
} | |
def create_model(*args, **kwargs): | |
ocr_model = model.Model(mparams=create_mparams(), *args, **kwargs) | |
return ocr_model | |