Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for object_detection.utils.label_map_util.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import numpy as np | |
from six.moves import range | |
import tensorflow.compat.v1 as tf | |
from google.protobuf import text_format | |
from object_detection.protos import string_int_label_map_pb2 | |
from object_detection.utils import label_map_util | |
class LabelMapUtilTest(tf.test.TestCase): | |
def _generate_label_map(self, num_classes): | |
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() | |
for i in range(1, num_classes + 1): | |
item = label_map_proto.item.add() | |
item.id = i | |
item.name = 'label_' + str(i) | |
item.display_name = str(i) | |
return label_map_proto | |
def _generate_label_map_with_hierarchy(self, num_classes, ancestors_dict, | |
descendants_dict): | |
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() | |
for i in range(1, num_classes + 1): | |
item = label_map_proto.item.add() | |
item.id = i | |
item.name = 'label_' + str(i) | |
item.display_name = str(i) | |
if i in ancestors_dict: | |
for anc_i in ancestors_dict[i]: | |
item.ancestor_ids.append(anc_i) | |
if i in descendants_dict: | |
for desc_i in descendants_dict[i]: | |
item.descendant_ids.append(desc_i) | |
return label_map_proto | |
def test_get_label_map_dict(self): | |
label_map_string = """ | |
item { | |
id:2 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'dog' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
label_map_dict = label_map_util.get_label_map_dict(label_map_path) | |
self.assertEqual(label_map_dict['dog'], 1) | |
self.assertEqual(label_map_dict['cat'], 2) | |
def test_get_label_map_dict_from_proto(self): | |
label_map_string = """ | |
item { | |
id:2 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'dog' | |
} | |
""" | |
label_map_proto = text_format.Parse( | |
label_map_string, string_int_label_map_pb2.StringIntLabelMap()) | |
label_map_dict = label_map_util.get_label_map_dict(label_map_proto) | |
self.assertEqual(label_map_dict['dog'], 1) | |
self.assertEqual(label_map_dict['cat'], 2) | |
def test_get_label_map_dict_display(self): | |
label_map_string = """ | |
item { | |
id:2 | |
display_name:'cat' | |
} | |
item { | |
id:1 | |
display_name:'dog' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
label_map_dict = label_map_util.get_label_map_dict( | |
label_map_path, use_display_name=True) | |
self.assertEqual(label_map_dict['dog'], 1) | |
self.assertEqual(label_map_dict['cat'], 2) | |
def test_load_bad_label_map(self): | |
label_map_string = """ | |
item { | |
id:0 | |
name:'class that should not be indexed at zero' | |
} | |
item { | |
id:2 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'dog' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
with self.assertRaises(ValueError): | |
label_map_util.load_labelmap(label_map_path) | |
def test_load_label_map_with_background(self): | |
label_map_string = """ | |
item { | |
id:0 | |
name:'background' | |
} | |
item { | |
id:2 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'dog' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
label_map_dict = label_map_util.get_label_map_dict(label_map_path) | |
self.assertEqual(label_map_dict['background'], 0) | |
self.assertEqual(label_map_dict['dog'], 1) | |
self.assertEqual(label_map_dict['cat'], 2) | |
def test_get_label_map_dict_with_fill_in_gaps_and_background(self): | |
label_map_string = """ | |
item { | |
id:3 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'dog' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
label_map_dict = label_map_util.get_label_map_dict( | |
label_map_path, fill_in_gaps_and_background=True) | |
self.assertEqual(label_map_dict['background'], 0) | |
self.assertEqual(label_map_dict['dog'], 1) | |
self.assertEqual(label_map_dict['2'], 2) | |
self.assertEqual(label_map_dict['cat'], 3) | |
self.assertEqual(len(label_map_dict), max(label_map_dict.values()) + 1) | |
def test_keep_categories_with_unique_id(self): | |
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() | |
label_map_string = """ | |
item { | |
id:2 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'child' | |
} | |
item { | |
id:1 | |
name:'person' | |
} | |
item { | |
id:1 | |
name:'n00007846' | |
} | |
""" | |
text_format.Merge(label_map_string, label_map_proto) | |
categories = label_map_util.convert_label_map_to_categories( | |
label_map_proto, max_num_classes=3) | |
self.assertListEqual([{ | |
'id': 2, | |
'name': u'cat' | |
}, { | |
'id': 1, | |
'name': u'child' | |
}], categories) | |
def test_convert_label_map_to_categories_no_label_map(self): | |
categories = label_map_util.convert_label_map_to_categories( | |
None, max_num_classes=3) | |
expected_categories_list = [{ | |
'name': u'category_1', | |
'id': 1 | |
}, { | |
'name': u'category_2', | |
'id': 2 | |
}, { | |
'name': u'category_3', | |
'id': 3 | |
}] | |
self.assertListEqual(expected_categories_list, categories) | |
def test_convert_label_map_to_categories(self): | |
label_map_proto = self._generate_label_map(num_classes=4) | |
categories = label_map_util.convert_label_map_to_categories( | |
label_map_proto, max_num_classes=3) | |
expected_categories_list = [{ | |
'name': u'1', | |
'id': 1 | |
}, { | |
'name': u'2', | |
'id': 2 | |
}, { | |
'name': u'3', | |
'id': 3 | |
}] | |
self.assertListEqual(expected_categories_list, categories) | |
def test_convert_label_map_with_keypoints_to_categories(self): | |
label_map_str = """ | |
item { | |
id: 1 | |
name: 'person' | |
keypoints: { | |
id: 1 | |
label: 'nose' | |
} | |
keypoints: { | |
id: 2 | |
label: 'ear' | |
} | |
} | |
""" | |
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() | |
text_format.Merge(label_map_str, label_map_proto) | |
categories = label_map_util.convert_label_map_to_categories( | |
label_map_proto, max_num_classes=1) | |
self.assertEqual('person', categories[0]['name']) | |
self.assertEqual(1, categories[0]['id']) | |
self.assertEqual(1, categories[0]['keypoints']['nose']) | |
self.assertEqual(2, categories[0]['keypoints']['ear']) | |
def test_disallow_duplicate_keypoint_ids(self): | |
label_map_str = """ | |
item { | |
id: 1 | |
name: 'person' | |
keypoints: { | |
id: 1 | |
label: 'right_elbow' | |
} | |
keypoints: { | |
id: 1 | |
label: 'left_elbow' | |
} | |
} | |
item { | |
id: 2 | |
name: 'face' | |
keypoints: { | |
id: 3 | |
label: 'ear' | |
} | |
} | |
""" | |
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() | |
text_format.Merge(label_map_str, label_map_proto) | |
with self.assertRaises(ValueError): | |
label_map_util.convert_label_map_to_categories( | |
label_map_proto, max_num_classes=2) | |
def test_convert_label_map_to_categories_with_few_classes(self): | |
label_map_proto = self._generate_label_map(num_classes=4) | |
cat_no_offset = label_map_util.convert_label_map_to_categories( | |
label_map_proto, max_num_classes=2) | |
expected_categories_list = [{ | |
'name': u'1', | |
'id': 1 | |
}, { | |
'name': u'2', | |
'id': 2 | |
}] | |
self.assertListEqual(expected_categories_list, cat_no_offset) | |
def test_get_max_label_map_index(self): | |
num_classes = 4 | |
label_map_proto = self._generate_label_map(num_classes=num_classes) | |
max_index = label_map_util.get_max_label_map_index(label_map_proto) | |
self.assertEqual(num_classes, max_index) | |
def test_create_category_index(self): | |
categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}] | |
category_index = label_map_util.create_category_index(categories) | |
self.assertDictEqual({ | |
1: { | |
'name': u'1', | |
'id': 1 | |
}, | |
2: { | |
'name': u'2', | |
'id': 2 | |
} | |
}, category_index) | |
def test_create_categories_from_labelmap(self): | |
label_map_string = """ | |
item { | |
id:1 | |
name:'dog' | |
} | |
item { | |
id:2 | |
name:'cat' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
categories = label_map_util.create_categories_from_labelmap(label_map_path) | |
self.assertListEqual([{ | |
'name': u'dog', | |
'id': 1 | |
}, { | |
'name': u'cat', | |
'id': 2 | |
}], categories) | |
def test_create_category_index_from_labelmap(self): | |
label_map_string = """ | |
item { | |
id:2 | |
name:'cat' | |
} | |
item { | |
id:1 | |
name:'dog' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
category_index = label_map_util.create_category_index_from_labelmap( | |
label_map_path) | |
self.assertDictEqual({ | |
1: { | |
'name': u'dog', | |
'id': 1 | |
}, | |
2: { | |
'name': u'cat', | |
'id': 2 | |
} | |
}, category_index) | |
def test_create_category_index_from_labelmap_display(self): | |
label_map_string = """ | |
item { | |
id:2 | |
name:'cat' | |
display_name:'meow' | |
} | |
item { | |
id:1 | |
name:'dog' | |
display_name:'woof' | |
} | |
""" | |
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') | |
with tf.gfile.Open(label_map_path, 'wb') as f: | |
f.write(label_map_string) | |
self.assertDictEqual({ | |
1: { | |
'name': u'dog', | |
'id': 1 | |
}, | |
2: { | |
'name': u'cat', | |
'id': 2 | |
} | |
}, label_map_util.create_category_index_from_labelmap( | |
label_map_path, False)) | |
self.assertDictEqual({ | |
1: { | |
'name': u'woof', | |
'id': 1 | |
}, | |
2: { | |
'name': u'meow', | |
'id': 2 | |
} | |
}, label_map_util.create_category_index_from_labelmap(label_map_path)) | |
def test_get_label_map_hierarchy_lut(self): | |
num_classes = 5 | |
ancestors = {2: [1, 3], 5: [1]} | |
descendants = {1: [2], 5: [1, 2]} | |
label_map = self._generate_label_map_with_hierarchy(num_classes, ancestors, | |
descendants) | |
gt_hierarchy_dict_lut = { | |
'ancestors': | |
np.array([ | |
[1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0], | |
[0, 0, 1, 0, 0], | |
[0, 0, 0, 1, 0], | |
[1, 0, 0, 0, 1], | |
]), | |
'descendants': | |
np.array([ | |
[1, 1, 0, 0, 0], | |
[0, 1, 0, 0, 0], | |
[0, 0, 1, 0, 0], | |
[0, 0, 0, 1, 0], | |
[1, 1, 0, 0, 1], | |
]), | |
} | |
ancestors_lut, descendants_lut = ( | |
label_map_util.get_label_map_hierarchy_lut(label_map, True)) | |
np.testing.assert_array_equal(gt_hierarchy_dict_lut['ancestors'], | |
ancestors_lut) | |
np.testing.assert_array_equal(gt_hierarchy_dict_lut['descendants'], | |
descendants_lut) | |
if __name__ == '__main__': | |
tf.test.main() | |