# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions to build the Attention OCR model. Usage example: ocr_model = model.Model(num_char_classes, seq_length, num_of_views) data = ... # create namedtuple InputEndpoints endpoints = model.create_base(data.images, data.labels_one_hot) # endpoints.predicted_chars is a tensor with predicted character codes. total_loss = model.create_loss(data, endpoints) """ import sys import collections import logging import tensorflow as tf from tensorflow.contrib import slim from tensorflow.contrib.slim.nets import inception import metrics import sequence_layers import utils OutputEndpoints = collections.namedtuple('OutputEndpoints', [ 'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores', 'predicted_text' ]) # TODO(gorban): replace with tf.HParams when it is released. ModelParams = collections.namedtuple('ModelParams', [ 'num_char_classes', 'seq_length', 'num_views', 'null_code' ]) ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint']) SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [ 'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay', 'lstm_state_clip_value' ]) SequenceLossParams = collections.namedtuple('SequenceLossParams', [ 'label_smoothing', 'ignore_nulls', 'average_across_timesteps' ]) EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [ 'enabled' ]) def _dict_to_array(id_to_char, default_character): num_char_classes = max(id_to_char.keys()) + 1 array = [default_character] * num_char_classes for k, v in id_to_char.items(): array[k] = v return array class CharsetMapper(object): """A simple class to map tensor ids into strings. It works only when the character set is 1:1 mapping between individual characters and individual ids. Make sure you call tf.tables_initializer().run() as part of the init op. """ def __init__(self, charset, default_character='?'): """Creates a lookup table. Args: charset: a dictionary with id-to-character mapping. """ mapping_strings = tf.constant(_dict_to_array(charset, default_character)) self.table = tf.contrib.lookup.index_to_string_table_from_tensor( mapping=mapping_strings, default_value=default_character) def get_text(self, ids): """Returns a string corresponding to a sequence of character ids. Args: ids: a tensor with shape [batch_size, max_sequence_length] """ return tf.reduce_join( self.table.lookup(tf.to_int64(ids)), reduction_indices=1) def get_softmax_loss_fn(label_smoothing): """Returns sparse or dense loss function depending on the label_smoothing. Args: label_smoothing: weight for label smoothing Returns: a function which takes labels and predictions as arguments and returns a softmax loss for the selected type of labels (sparse or dense). """ if label_smoothing > 0: def loss_fn(labels, logits): return (tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=labels)) else: def loss_fn(labels, logits): return tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=labels) return loss_fn class Model(object): """Class to create the Attention OCR Model.""" def __init__(self, num_char_classes, seq_length, num_views, null_code, mparams=None, charset=None): """Initialized model parameters. Args: num_char_classes: size of character set. seq_length: number of characters in a sequence. num_views: Number of views (conv towers) to use. null_code: A character code corresponding to a character which indicates end of a sequence. mparams: a dictionary with hyper parameters for methods, keys - function names, values - corresponding namedtuples. charset: an optional dictionary with a mapping between character ids and utf8 strings. If specified the OutputEndpoints.predicted_text will utf8 encoded strings corresponding to the character ids returned by OutputEndpoints.predicted_chars (by default the predicted_text contains an empty vector). NOTE: Make sure you call tf.tables_initializer().run() if the charset specified. """ super(Model, self).__init__() self._params = ModelParams( num_char_classes=num_char_classes, seq_length=seq_length, num_views=num_views, null_code=null_code) self._mparams = self.default_mparams() if mparams: self._mparams.update(mparams) self._charset = charset def default_mparams(self): return { 'conv_tower_fn': ConvTowerParams(final_endpoint='Mixed_5d'), 'sequence_logit_fn': SequenceLogitsParams( use_attention=True, use_autoregression=True, num_lstm_units=256, weight_decay=0.00004, lstm_state_clip_value=10.0), 'sequence_loss_fn': SequenceLossParams( label_smoothing=0.1, ignore_nulls=True, average_across_timesteps=False), 'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False) } def set_mparam(self, function, **kwargs): self._mparams[function] = self._mparams[function]._replace(**kwargs) def conv_tower_fn(self, images, is_training=True, reuse=None): """Computes convolutional features using the InceptionV3 model. Args: images: A tensor of shape [batch_size, height, width, channels]. is_training: whether is training or not. reuse: whether or not the network and its variables should be reused. To be able to reuse 'scope' must be given. Returns: A tensor of shape [batch_size, OH, OW, N], where OWxOH is resolution of output feature map and N is number of output features (depends on the network architecture). """ mparams = self._mparams['conv_tower_fn'] logging.debug('Using final_endpoint=%s', mparams.final_endpoint) with tf.variable_scope('conv_tower_fn/INCE'): if reuse: tf.get_variable_scope().reuse_variables() with slim.arg_scope(inception.inception_v3_arg_scope()): with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): net, _ = inception.inception_v3_base( images, final_endpoint=mparams.final_endpoint) return net def _create_lstm_inputs(self, net): """Splits an input tensor into a list of tensors (features). Args: net: A feature map of shape [batch_size, num_features, feature_size]. Raises: AssertionError: if num_features is less than seq_length. Returns: A list with seq_length tensors of shape [batch_size, feature_size] """ num_features = net.get_shape().dims[1].value if num_features < self._params.seq_length: raise AssertionError('Incorrect dimension #1 of input tensor' ' %d should be bigger than %d (shape=%s)' % (num_features, self._params.seq_length, net.get_shape())) elif num_features > self._params.seq_length: logging.warning('Ignoring some features: use %d of %d (shape=%s)', self._params.seq_length, num_features, net.get_shape()) net = tf.slice(net, [0, 0, 0], [-1, self._params.seq_length, -1]) return tf.unstack(net, axis=1) def sequence_logit_fn(self, net, labels_one_hot): mparams = self._mparams['sequence_logit_fn'] # TODO(gorban): remove /alias suffixes from the scopes. with tf.variable_scope('sequence_logit_fn/SQLR'): layer_class = sequence_layers.get_layer_class(mparams.use_attention, mparams.use_autoregression) layer = layer_class(net, labels_one_hot, self._params, mparams) return layer.create_logits() def max_pool_views(self, nets_list): """Max pool across all nets in spatial dimensions. Args: nets_list: A list of 4D tensors with identical size. Returns: A tensor with the same size as any input tensors. """ batch_size, height, width, num_features = [ d.value for d in nets_list[0].get_shape().dims ] xy_flat_shape = (batch_size, 1, height * width, num_features) nets_for_merge = [] with tf.variable_scope('max_pool_views', values=nets_list): for net in nets_list: nets_for_merge.append(tf.reshape(net, xy_flat_shape)) merged_net = tf.concat(nets_for_merge, 1) net = slim.max_pool2d( merged_net, kernel_size=[len(nets_list), 1], stride=1) net = tf.reshape(net, (batch_size, height, width, num_features)) return net def pool_views_fn(self, nets): """Combines output of multiple convolutional towers into a single tensor. It stacks towers one on top another (in height dim) in a 4x1 grid. The order is arbitrary design choice and shouldn't matter much. Args: nets: list of tensors of shape=[batch_size, height, width, num_features]. Returns: A tensor of shape [batch_size, seq_length, features_size]. """ with tf.variable_scope('pool_views_fn/STCK'): net = tf.concat(nets, 1) batch_size = net.get_shape().dims[0].value feature_size = net.get_shape().dims[3].value return tf.reshape(net, [batch_size, -1, feature_size]) def char_predictions(self, chars_logit): """Returns confidence scores (softmax values) for predicted characters. Args: chars_logit: chars logits, a tensor with shape [batch_size x seq_length x num_char_classes] Returns: A tuple (ids, log_prob, scores), where: ids - predicted characters, a int32 tensor with shape [batch_size x seq_length]; log_prob - a log probability of all characters, a float tensor with shape [batch_size, seq_length, num_char_classes]; scores - corresponding confidence scores for characters, a float tensor with shape [batch_size x seq_length]. """ log_prob = utils.logits_to_log_prob(chars_logit) ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars') mask = tf.cast( slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool) all_scores = tf.nn.softmax(chars_logit) selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores') scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length)) return ids, log_prob, scores def encode_coordinates_fn(self, net): """Adds one-hot encoding of coordinates to different views in the networks. For each "pixel" of a feature map it adds a onehot encoded x and y coordinates. Args: net: a tensor of shape=[batch_size, height, width, num_features] Returns: a tensor with the same height and width, but altered feature_size. """ mparams = self._mparams['encode_coordinates_fn'] if mparams.enabled: batch_size, h, w, _ = net.shape.as_list() x, y = tf.meshgrid(tf.range(w), tf.range(h)) w_loc = slim.one_hot_encoding(x, num_classes=w) h_loc = slim.one_hot_encoding(y, num_classes=h) loc = tf.concat([h_loc, w_loc], 2) loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1]) return tf.concat([net, loc], 3) else: return net def create_base(self, images, labels_one_hot, scope='AttentionOcr_v1', reuse=None): """Creates a base part of the Model (no gradients, losses or summaries). Args: images: A tensor of shape [batch_size, height, width, channels]. labels_one_hot: Optional (can be None) one-hot encoding for ground truth labels. If provided the function will create a model for training. scope: Optional variable_scope. reuse: whether or not the network and its variables should be reused. To be able to reuse 'scope' must be given. Returns: A named tuple OutputEndpoints. """ logging.debug('images: %s', images) is_training = labels_one_hot is not None with tf.variable_scope(scope, reuse=reuse): views = tf.split( value=images, num_or_size_splits=self._params.num_views, axis=2) logging.debug('Views=%d single view: %s', len(views), views[0]) nets = [ self.conv_tower_fn(v, is_training, reuse=(i != 0)) for i, v in enumerate(views) ] logging.debug('Conv tower: %s', nets[0]) nets = [self.encode_coordinates_fn(net) for net in nets] logging.debug('Conv tower w/ encoded coordinates: %s', nets[0]) net = self.pool_views_fn(nets) logging.debug('Pooled views: %s', net) chars_logit = self.sequence_logit_fn(net, labels_one_hot) logging.debug('chars_logit: %s', chars_logit) predicted_chars, chars_log_prob, predicted_scores = ( self.char_predictions(chars_logit)) if self._charset: character_mapper = CharsetMapper(self._charset) predicted_text = character_mapper.get_text(predicted_chars) else: predicted_text = tf.constant([]) return OutputEndpoints( chars_logit=chars_logit, chars_log_prob=chars_log_prob, predicted_chars=predicted_chars, predicted_scores=predicted_scores, predicted_text=predicted_text) def create_loss(self, data, endpoints): """Creates all losses required to train the model. Args: data: InputEndpoints namedtuple. endpoints: Model namedtuple. Returns: Total loss. """ # NOTE: the return value of ModelLoss is not used directly for the # gradient computation because under the hood it calls slim.losses.AddLoss, # which registers the loss in an internal collection and later returns it # as part of GetTotalLoss. We need to use total loss because model may have # multiple losses including regularization losses. self.sequence_loss_fn(endpoints.chars_logit, data.labels) total_loss = slim.losses.get_total_loss() tf.summary.scalar('TotalLoss', total_loss) return total_loss def label_smoothing_regularization(self, chars_labels, weight=0.1): """Applies a label smoothing regularization. Uses the same method as in https://arxiv.org/abs/1512.00567. Args: chars_labels: ground truth ids of charactes, shape=[batch_size, seq_length]; weight: label-smoothing regularization weight. Returns: A sensor with the same shape as the input. """ one_hot_labels = tf.one_hot( chars_labels, depth=self._params.num_char_classes, axis=-1) pos_weight = 1.0 - weight neg_weight = weight / self._params.num_char_classes return one_hot_labels * pos_weight + neg_weight def sequence_loss_fn(self, chars_logits, chars_labels): """Loss function for char sequence. Depending on values of hyper parameters it applies label smoothing and can also ignore all null chars after the first one. Args: chars_logits: logits for predicted characters, shape=[batch_size, seq_length, num_char_classes]; chars_labels: ground truth ids of characters, shape=[batch_size, seq_length]; mparams: method hyper parameters. Returns: A Tensor with shape [batch_size] - the log-perplexity for each sequence. """ mparams = self._mparams['sequence_loss_fn'] with tf.variable_scope('sequence_loss_fn/SLF'): if mparams.label_smoothing > 0: smoothed_one_hot_labels = self.label_smoothing_regularization( chars_labels, mparams.label_smoothing) labels_list = tf.unstack(smoothed_one_hot_labels, axis=1) else: # NOTE: in case of sparse softmax we are not using one-hot # encoding. labels_list = tf.unstack(chars_labels, axis=1) batch_size, seq_length, _ = chars_logits.shape.as_list() if mparams.ignore_nulls: weights = tf.ones((batch_size, seq_length), dtype=tf.float32) else: # Suppose that reject character is the last in the charset. reject_char = tf.constant( self._params.num_char_classes - 1, shape=(batch_size, seq_length), dtype=tf.int64) known_char = tf.not_equal(chars_labels, reject_char) weights = tf.to_float(known_char) logits_list = tf.unstack(chars_logits, axis=1) weights_list = tf.unstack(weights, axis=1) loss = tf.contrib.legacy_seq2seq.sequence_loss( logits_list, labels_list, weights_list, softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing), average_across_timesteps=mparams.average_across_timesteps) tf.losses.add_loss(loss) return loss def create_summaries(self, data, endpoints, charset, is_training): """Creates all summaries for the model. Args: data: InputEndpoints namedtuple. endpoints: OutputEndpoints namedtuple. charset: A dictionary with mapping between character codes and unicode characters. Use the one provided by a dataset.charset. is_training: If True will create summary prefixes for training job, otherwise - for evaluation. Returns: A list of evaluation ops """ def sname(label): prefix = 'train' if is_training else 'eval' return '%s/%s' % (prefix, label) max_outputs = 4 # TODO(gorban): uncomment, when tf.summary.text released. # charset_mapper = CharsetMapper(charset) # pr_text = charset_mapper.get_text( # endpoints.predicted_chars[:max_outputs,:]) # tf.summary.text(sname('text/pr'), pr_text) # gt_text = charset_mapper.get_text(data.labels[:max_outputs,:]) # tf.summary.text(sname('text/gt'), gt_text) tf.summary.image(sname('image'), data.images, max_outputs=max_outputs) if is_training: tf.summary.image( sname('image/orig'), data.images_orig, max_outputs=max_outputs) for var in tf.trainable_variables(): tf.summary.histogram(var.op.name, var) return None else: names_to_values = {} names_to_updates = {} def use_metric(name, value_update_tuple): names_to_values[name] = value_update_tuple[0] names_to_updates[name] = value_update_tuple[1] use_metric('CharacterAccuracy', metrics.char_accuracy( endpoints.predicted_chars, data.labels, streaming=True, rej_char=self._params.null_code)) # Sequence accuracy computed by cutting sequence at the first null char use_metric('SequenceAccuracy', metrics.sequence_accuracy( endpoints.predicted_chars, data.labels, streaming=True, rej_char=self._params.null_code)) for name, value in names_to_values.items(): summary_name = 'eval/' + name tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name)) return list(names_to_updates.values()) def create_init_fn_to_restore(self, master_checkpoint, inception_checkpoint=None): """Creates an init operations to restore weights from various checkpoints. Args: master_checkpoint: path to a checkpoint which contains all weights for the whole model. inception_checkpoint: path to a checkpoint which contains weights for the inception part only. Returns: a function to run initialization ops. """ all_assign_ops = [] all_feed_dict = {} def assign_from_checkpoint(variables, checkpoint): logging.info('Request to re-store %d weights from %s', len(variables), checkpoint) if not variables: logging.error('Can\'t find any variables to restore.') sys.exit(1) assign_op, feed_dict = slim.assign_from_checkpoint(checkpoint, variables) all_assign_ops.append(assign_op) all_feed_dict.update(feed_dict) logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys()) logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()]) logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()]) if master_checkpoint: assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint) if inception_checkpoint: variables = utils.variables_to_restore( 'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True) assign_from_checkpoint(variables, inception_checkpoint) def init_assign_fn(sess): logging.info('Restoring checkpoint(s)') sess.run(all_assign_ops, all_feed_dict) return init_assign_fn