# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Label map utility functions.""" import logging import tensorflow as tf from google.protobuf import text_format import string_int_label_map_pb2 def _validate_label_map(label_map): """Checks if a label map is valid. Args: label_map: StringIntLabelMap to validate. Raises: ValueError: if label map is invalid. """ for item in label_map.item: if item.id < 0: raise ValueError('Label map ids should be >= 0.') if (item.id == 0 and item.name != 'background' and item.display_name != 'background'): raise ValueError('Label map id 0 is reserved for the background label') def create_category_index(categories): """Creates dictionary of COCO compatible categories keyed by category id. Args: categories: a list of dicts, each of which has the following keys: 'id': (required) an integer id uniquely identifying this category. 'name': (required) string representing category name e.g., 'cat', 'dog', 'pizza'. Returns: category_index: a dict containing the same entries as categories, but keyed by the 'id' field of each category. """ category_index = {} for cat in categories: category_index[cat['id']] = cat return category_index def get_max_label_map_index(label_map): """Get maximum index in label map. Args: label_map: a StringIntLabelMapProto Returns: an integer """ return max([item.id for item in label_map.item]) def convert_label_map_to_categories(label_map, max_num_classes, use_display_name=True): """Loads label map proto and returns categories list compatible with eval. This function loads a label map and returns a list of dicts, each of which has the following keys: 'id': (required) an integer id uniquely identifying this category. 'name': (required) string representing category name e.g., 'cat', 'dog', 'pizza'. We only allow class into the list if its id-label_id_offset is between 0 (inclusive) and max_num_classes (exclusive). If there are several items mapping to the same id in the label map, we will only keep the first one in the categories list. Args: label_map: a StringIntLabelMapProto or None. If None, a default categories list is created with max_num_classes categories. max_num_classes: maximum number of (consecutive) label indices to include. use_display_name: (boolean) choose whether to load 'display_name' field as category name. If False or if the display_name field does not exist, uses 'name' field as category names instead. Returns: categories: a list of dictionaries representing all possible categories. """ categories = [] list_of_ids_already_added = [] if not label_map: label_id_offset = 1 for class_id in range(max_num_classes): categories.append({ 'id': class_id + label_id_offset, 'name': 'category_{}'.format(class_id + label_id_offset) }) return categories for item in label_map.item: if not 0 < item.id <= max_num_classes: logging.info('Ignore item %d since it falls outside of requested ' 'label range.', item.id) continue if use_display_name and item.HasField('display_name'): name = item.display_name else: name = item.name if item.id not in list_of_ids_already_added: list_of_ids_already_added.append(item.id) categories.append({'id': item.id, 'name': name}) return categories def load_labelmap(path): """Loads label map proto. Args: path: path to StringIntLabelMap proto text file. Returns: a StringIntLabelMapProto """ with tf.gfile.GFile(path, 'r') as fid: label_map_string = fid.read() label_map = string_int_label_map_pb2.StringIntLabelMap() try: text_format.Merge(label_map_string, label_map) except text_format.ParseError: label_map.ParseFromString(label_map_string) _validate_label_map(label_map) return label_map def get_label_map_dict(label_map_path, use_display_name=False): """Reads a label map and returns a dictionary of label names to id. Args: label_map_path: path to label_map. use_display_name: whether to use the label map items' display names as keys. Returns: A dictionary mapping label names to id. """ label_map = load_labelmap(label_map_path) label_map_dict = {} for item in label_map.item: if use_display_name: label_map_dict[item.display_name] = item.id else: label_map_dict[item.name] = item.id return label_map_dict def create_category_index_from_labelmap(label_map_path): """Reads a label map and returns a category index. Args: label_map_path: Path to `StringIntLabelMap` proto text file. Returns: A category index, which is a dictionary that maps integer ids to dicts containing categories, e.g. {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} """ label_map = load_labelmap(label_map_path) max_num_classes = max(item.id for item in label_map.item) categories = convert_label_map_to_categories(label_map, max_num_classes) return create_category_index(categories) def create_class_agnostic_category_index(): """Creates a category index with a single `object` class.""" return {1: {'id': 1, 'name': 'object'}}