Spaces:
Running
Running
File size: 5,185 Bytes
0b8359d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import os
import demo_inference
import tensorflow as tf
from tensorflow.python.training import monitored_session
_CHECKPOINT = 'model.ckpt-399731'
_CHECKPOINT_URL = 'http://download.tensorflow.org/models/attention_ocr_2017_08_09.tar.gz'
class DemoInferenceTest(tf.test.TestCase):
def setUp(self):
super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename),
msg='Missing checkpoint file %s. '
'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL))
self._batch_size = 32
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns')
def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32
dataset_name = 'fsns'
images_placeholder, endpoints = demo_inference.create_model(batch_size,
dataset_name)
image_path_pattern = 'testdata/fsns_train_%02d.png'
images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name(
tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession(
session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data})
self.assertAllEqual(moving_mean_expected, moving_mean_np)
def test_correct_results_on_test_data(self):
image_path_pattern = 'testdata/fsns_train_%02d.png'
predictions = demo_inference.run(_CHECKPOINT, self._batch_size,
'fsns',
image_path_pattern)
self.assertEqual([
u'Boulevard de Lunelβββββββββββββββββββ',
'Rue de Provenceββββββββββββββββββββββ',
'Rue de Port Mariaββββββββββββββββββββ',
'Avenue Charles Gounodββββββββββββββββ',
'Rue de lβAuroreββββββββββββββββββββββ',
'Rue de Beuzevilleββββββββββββββββββββ',
'Rue dβOrbeyββββββββββββββββββββββββββ',
'Rue Victor Schoulcherββββββββββββββββ',
'Rue de la Gareβββββββββββββββββββββββ',
'Rue des Tulipesββββββββββββββββββββββ',
'Rue AndrΓ© Maginotββββββββββββββββββββ',
'Route de Pringyββββββββββββββββββββββ',
'Rue des Landellesββββββββββββββββββββ',
'Rue des Ilettesββββββββββββββββββββββ',
'Avenue de Maurinβββββββββββββββββββββ',
'Rue ThΓ©resaββββββββββββββββββββββββββ', # GT='Rue ThΓ©rΓ©sa'
'Route de la Balmeββββββββββββββββββββ',
'Rue HΓ©lΓ¨ne Roedererββββββββββββββββββ',
'Rue Emile Bernardββββββββββββββββββββ',
'Place de la Mairieβββββββββββββββββββ',
'Rue des Perrotsββββββββββββββββββββββ',
'Rue de la LibΓ©rationβββββββββββββββββ',
'Impasse du Capcirββββββββββββββββββββ',
'Avenue de la Grand Mareββββββββββββββ',
'Rue Pierre Brossoletteβββββββββββββββ',
'Rue de Provenceββββββββββββββββββββββ',
'Rue du Docteur Mourreββββββββββββββββ',
'Rue dβOrtheuilβββββββββββββββββββββββ',
'Rue des Sarmentsβββββββββββββββββββββ',
'Rue du Centreββββββββββββββββββββββββ',
'Impasse Pierre Mourguesββββββββββββββ',
'Rue Marcel Dassaultββββββββββββββββββ'
], predictions)
if __name__ == '__main__':
tf.test.main()
|