Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Tests for detection_inference.py.""" | |
import os | |
import unittest | |
import numpy as np | |
from PIL import Image | |
import six | |
import tensorflow.compat.v1 as tf | |
from google.protobuf import text_format | |
from object_detection.core import standard_fields | |
from object_detection.inference import detection_inference | |
from object_detection.utils import dataset_util | |
from object_detection.utils import tf_version | |
def get_mock_tfrecord_path(): | |
return os.path.join(tf.test.get_temp_dir(), 'mock.tfrec') | |
def create_mock_tfrecord(): | |
pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB') | |
image_output_stream = six.BytesIO() | |
pil_image.save(image_output_stream, format='png') | |
encoded_image = image_output_stream.getvalue() | |
feature_map = { | |
'test_field': | |
dataset_util.float_list_feature([1, 2, 3, 4]), | |
standard_fields.TfExampleFields.image_encoded: | |
dataset_util.bytes_feature(encoded_image), | |
} | |
tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map)) | |
with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer: | |
writer.write(tf_example.SerializeToString()) | |
return encoded_image | |
def get_mock_graph_path(): | |
return os.path.join(tf.test.get_temp_dir(), 'mock_graph.pb') | |
def create_mock_graph(): | |
g = tf.Graph() | |
with g.as_default(): | |
in_image_tensor = tf.placeholder( | |
tf.uint8, shape=[1, None, None, 3], name='image_tensor') | |
tf.constant([2.0], name='num_detections') | |
tf.constant( | |
[[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]], | |
name='detection_boxes') | |
tf.constant([[0.1, 0.2, 0.3]], name='detection_scores') | |
tf.identity( | |
tf.constant([[1.0, 2.0, 3.0]]) * | |
tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)), | |
name='detection_classes') | |
graph_def = g.as_graph_def() | |
with tf.gfile.Open(get_mock_graph_path(), 'w') as fl: | |
fl.write(graph_def.SerializeToString()) | |
class InferDetectionsTests(tf.test.TestCase): | |
def test_simple(self): | |
create_mock_graph() | |
encoded_image = create_mock_tfrecord() | |
serialized_example_tensor, image_tensor = detection_inference.build_input( | |
[get_mock_tfrecord_path()]) | |
self.assertAllEqual(image_tensor.get_shape().as_list(), [1, None, None, 3]) | |
(detected_boxes_tensor, detected_scores_tensor, | |
detected_labels_tensor) = detection_inference.build_inference_graph( | |
image_tensor, get_mock_graph_path()) | |
with self.test_session(use_gpu=False) as sess: | |
sess.run(tf.global_variables_initializer()) | |
sess.run(tf.local_variables_initializer()) | |
tf.train.start_queue_runners() | |
tf_example = detection_inference.infer_detections_and_add_to_example( | |
serialized_example_tensor, detected_boxes_tensor, | |
detected_scores_tensor, detected_labels_tensor, False) | |
expected_example = tf.train.Example() | |
text_format.Merge(r""" | |
features { | |
feature { | |
key: "image/detection/bbox/ymin" | |
value { float_list { value: [0.0, 0.1] } } } | |
feature { | |
key: "image/detection/bbox/xmin" | |
value { float_list { value: [0.8, 0.2] } } } | |
feature { | |
key: "image/detection/bbox/ymax" | |
value { float_list { value: [0.7, 0.8] } } } | |
feature { | |
key: "image/detection/bbox/xmax" | |
value { float_list { value: [1.0, 0.9] } } } | |
feature { | |
key: "image/detection/label" | |
value { int64_list { value: [123, 246] } } } | |
feature { | |
key: "image/detection/score" | |
value { float_list { value: [0.1, 0.2] } } } | |
feature { | |
key: "test_field" | |
value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }""", | |
expected_example) | |
expected_example.features.feature[ | |
standard_fields.TfExampleFields | |
.image_encoded].CopyFrom(dataset_util.bytes_feature(encoded_image)) | |
self.assertProtoEquals(expected_example, tf_example) | |
def test_discard_image(self): | |
create_mock_graph() | |
create_mock_tfrecord() | |
serialized_example_tensor, image_tensor = detection_inference.build_input( | |
[get_mock_tfrecord_path()]) | |
(detected_boxes_tensor, detected_scores_tensor, | |
detected_labels_tensor) = detection_inference.build_inference_graph( | |
image_tensor, get_mock_graph_path()) | |
with self.test_session(use_gpu=False) as sess: | |
sess.run(tf.global_variables_initializer()) | |
sess.run(tf.local_variables_initializer()) | |
tf.train.start_queue_runners() | |
tf_example = detection_inference.infer_detections_and_add_to_example( | |
serialized_example_tensor, detected_boxes_tensor, | |
detected_scores_tensor, detected_labels_tensor, True) | |
self.assertProtoEquals(r""" | |
features { | |
feature { | |
key: "image/detection/bbox/ymin" | |
value { float_list { value: [0.0, 0.1] } } } | |
feature { | |
key: "image/detection/bbox/xmin" | |
value { float_list { value: [0.8, 0.2] } } } | |
feature { | |
key: "image/detection/bbox/ymax" | |
value { float_list { value: [0.7, 0.8] } } } | |
feature { | |
key: "image/detection/bbox/xmax" | |
value { float_list { value: [1.0, 0.9] } } } | |
feature { | |
key: "image/detection/label" | |
value { int64_list { value: [123, 246] } } } | |
feature { | |
key: "image/detection/score" | |
value { float_list { value: [0.1, 0.2] } } } | |
feature { | |
key: "test_field" | |
value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } } | |
""", tf_example) | |
if __name__ == '__main__': | |
tf.test.main() | |