NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
1.98 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for unittest_utils."""
import numpy as np
import io
from PIL import Image as PILImage
import tensorflow as tf
from datasets import unittest_utils
class UnittestUtilsTest(tf.test.TestCase):
def test_creates_an_image_of_specified_shape(self):
image, _ = unittest_utils.create_random_image('PNG', (10, 20, 3))
self.assertEqual(image.shape, (10, 20, 3))
def test_encoded_image_corresponds_to_numpy_array(self):
image, encoded = unittest_utils.create_random_image('PNG', (20, 10, 3))
pil_image = PILImage.open(io.BytesIO(encoded))
self.assertAllEqual(image, np.array(pil_image))
def test_created_example_has_correct_values(self):
example_serialized = unittest_utils.create_serialized_example({
'labels': [1, 2, 3],
'data': [b'FAKE']
})
example = tf.train.Example()
example.ParseFromString(example_serialized)
self.assertProtoEquals("""
features {
feature {
key: "labels"
value { int64_list {
value: 1
value: 2
value: 3
}}
}
feature {
key: "data"
value { bytes_list {
value: "FAKE"
}}
}
}
""", example)
if __name__ == '__main__':
tf.test.main()