"""A script to run inference on a set of image files. NOTE #1: The Attention OCR model was trained only using FSNS train dataset and it will work only for images which look more or less similar to french street names. In order to apply it to images from a different distribution you need to retrain (or at least fine-tune) it using images from that distribution. NOTE #2: This script exists for demo purposes only. It is highly recommended to use tools and mechanisms provided by the TensorFlow Serving system to run inference on TensorFlow models in production: https://www.tensorflow.org/serving/serving_basic Usage: python demo_inference.py --batch_size=32 \ --checkpoint=model.ckpt-399731\ --image_path_pattern=./datasets/data/fsns/temp/fsns_train_%02d.png """ import numpy as np import PIL.Image import tensorflow as tf from tensorflow.python.platform import flags from tensorflow.python.training import monitored_session import common_flags import datasets import data_provider FLAGS = flags.FLAGS common_flags.define() # e.g. ./datasets/data/fsns/temp/fsns_train_%02d.png flags.DEFINE_string('image_path_pattern', '', 'A file pattern with a placeholder for the image index.') def get_dataset_image_size(dataset_name): # Ideally this info should be exposed through the dataset interface itself. # But currently it is not available by other means. ds_module = getattr(datasets, dataset_name) height, width, _ = ds_module.DEFAULT_CONFIG['image_shape'] return width, height def load_images(file_pattern, batch_size, dataset_name): width, height = get_dataset_image_size(dataset_name) images_actual_data = np.ndarray(shape=(batch_size, height, width, 3), dtype='uint8') for i in range(batch_size): path = file_pattern % i print("Reading %s" % path) pil_image = PIL.Image.open(tf.gfile.GFile(path, 'rb')) images_actual_data[i, ...] = np.asarray(pil_image) return images_actual_data def create_model(batch_size, dataset_name): width, height = get_dataset_image_size(dataset_name) dataset = common_flags.create_dataset(split_name=FLAGS.split_name) model = common_flags.create_model( num_char_classes=dataset.num_char_classes, seq_length=dataset.max_sequence_length, num_views=dataset.num_of_views, null_code=dataset.null_code, charset=dataset.charset) raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3]) images = tf.map_fn(data_provider.preprocess_image, raw_images, dtype=tf.float32) endpoints = model.create_base(images, labels_one_hot=None) return raw_images, endpoints def run(checkpoint, batch_size, dataset_name, image_path_pattern): images_placeholder, endpoints = create_model(batch_size, dataset_name) images_data = load_images(image_path_pattern, batch_size, dataset_name) session_creator = monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint) with monitored_session.MonitoredSession( session_creator=session_creator) as sess: predictions = sess.run(endpoints.predicted_text, feed_dict={images_placeholder: images_data}) return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()] def main(_): print("Predicted strings:") predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name, FLAGS.image_path_pattern) for line in predictions: print(line) if __name__ == '__main__': tf.app.run()