Spaces:
Running
Running
#!/usr/bin/python | |
# -*- coding: UTF-8 -*- | |
import os | |
import demo_inference | |
import tensorflow as tf | |
from tensorflow.python.training import monitored_session | |
_CHECKPOINT = 'model.ckpt-399731' | |
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz' | |
class DemoInferenceTest(tf.test.TestCase): | |
def setUp(self): | |
super(DemoInferenceTest, self).setUp() | |
for suffix in ['.meta', '.index', '.data-00000-of-00001']: | |
filename = _CHECKPOINT + suffix | |
self.assertTrue(tf.gfile.Exists(filename), | |
msg='Missing checkpoint file %s. ' | |
'Please download and extract it from %s' % | |
(filename, _CHECKPOINT_URL)) | |
self._batch_size = 32 | |
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns') | |
def test_moving_variables_properly_loaded_from_a_checkpoint(self): | |
batch_size = 32 | |
dataset_name = 'fsns' | |
images_placeholder, endpoints = demo_inference.create_model(batch_size, | |
dataset_name) | |
image_path_pattern = 'testdata/fsns_train_%02d.png' | |
images_data = demo_inference.load_images(image_path_pattern, batch_size, | |
dataset_name) | |
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean' | |
moving_mean_tf = tf.get_default_graph().get_tensor_by_name( | |
tensor_name + ':0') | |
reader = tf.train.NewCheckpointReader(_CHECKPOINT) | |
moving_mean_expected = reader.get_tensor(tensor_name) | |
session_creator = monitored_session.ChiefSessionCreator( | |
checkpoint_filename_with_path=_CHECKPOINT) | |
with monitored_session.MonitoredSession( | |
session_creator=session_creator) as sess: | |
moving_mean_np = sess.run(moving_mean_tf, | |
feed_dict={images_placeholder: images_data}) | |
self.assertAllEqual(moving_mean_expected, moving_mean_np) | |
def test_correct_results_on_test_data(self): | |
image_path_pattern = 'testdata/fsns_train_%02d.png' | |
predictions = demo_inference.run(_CHECKPOINT, self._batch_size, | |
'fsns', | |
image_path_pattern) | |
self.assertEqual([ | |
u'Boulevard de Lunelβββββββββββββββββββ', | |
'Rue de Provenceββββββββββββββββββββββ', | |
'Rue de Port Mariaββββββββββββββββββββ', | |
'Avenue Charles Gounodββββββββββββββββ', | |
'Rue de lβAuroreββββββββββββββββββββββ', | |
'Rue de Beuzevilleββββββββββββββββββββ', | |
'Rue dβOrbeyββββββββββββββββββββββββββ', | |
'Rue Victor Schoulcherββββββββββββββββ', | |
'Rue de la Gareβββββββββββββββββββββββ', | |
'Rue des Tulipesββββββββββββββββββββββ', | |
'Rue AndrΓ© Maginotββββββββββββββββββββ', | |
'Route de Pringyββββββββββββββββββββββ', | |
'Rue des Landellesββββββββββββββββββββ', | |
'Rue des Ilettesββββββββββββββββββββββ', | |
'Avenue de Maurinβββββββββββββββββββββ', | |
'Rue ThΓ©resaββββββββββββββββββββββββββ', # GT='Rue ThΓ©rΓ©sa' | |
'Route de la Balmeββββββββββββββββββββ', | |
'Rue HΓ©lΓ¨ne Roedererββββββββββββββββββ', | |
'Rue Emile Bernardββββββββββββββββββββ', | |
'Place de la Mairieβββββββββββββββββββ', | |
'Rue des Perrotsββββββββββββββββββββββ', | |
'Rue de la LibΓ©rationβββββββββββββββββ', | |
'Impasse du Capcirββββββββββββββββββββ', | |
'Avenue de la Grand Mareββββββββββββββ', | |
'Rue Pierre Brossoletteβββββββββββββββ', | |
'Rue de Provenceββββββββββββββββββββββ', | |
'Rue du Docteur Mourreββββββββββββββββ', | |
'Rue dβOrtheuilβββββββββββββββββββββββ', | |
'Rue des Sarmentsβββββββββββββββββββββ', | |
'Rue du Centreββββββββββββββββββββββββ', | |
'Impasse Pierre Mourguesββββββββββββββ', | |
'Rue Marcel Dassaultββββββββββββββββββ' | |
], predictions) | |
if __name__ == '__main__': | |
tf.test.main() | |