# Lint as: python2, python3 # Copyright 2019 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implementation of the Panoptic Quality metric. Panoptic Quality is an instance-based metric for evaluating the task of image parsing, aka panoptic segmentation. Please see the paper for details: "Panoptic Segmentation", Alexander Kirillov, Kaiming He, Ross Girshick, Carsten Rother and Piotr Dollar. arXiv:1801.00868, 2018. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import numpy as np import prettytable import six from deeplab.evaluation import base_metric def _ids_to_counts(id_array): """Given a numpy array, a mapping from each unique entry to its count.""" ids, counts = np.unique(id_array, return_counts=True) return dict(six.moves.zip(ids, counts)) class PanopticQuality(base_metric.SegmentationMetric): """Metric class for Panoptic Quality. "Panoptic Segmentation" by Alexander Kirillov, Kaiming He, Ross Girshick, Carsten Rother, Piotr Dollar. https://arxiv.org/abs/1801.00868 """ def compare_and_accumulate( self, groundtruth_category_array, groundtruth_instance_array, predicted_category_array, predicted_instance_array): """See base class.""" # First, combine the category and instance labels so that every unique # value for (category, instance) is assigned a unique integer label. pred_segment_id = self._naively_combine_labels(predicted_category_array, predicted_instance_array) gt_segment_id = self._naively_combine_labels(groundtruth_category_array, groundtruth_instance_array) # Pre-calculate areas for all groundtruth and predicted segments. gt_segment_areas = _ids_to_counts(gt_segment_id) pred_segment_areas = _ids_to_counts(pred_segment_id) # We assume there is only one void segment and it has instance id = 0. void_segment_id = self.ignored_label * self.max_instances_per_category # There may be other ignored groundtruth segments with instance id > 0, find # those ids using the unique segment ids extracted with the area computation # above. ignored_segment_ids = { gt_segment_id for gt_segment_id in six.iterkeys(gt_segment_areas) if (gt_segment_id // self.max_instances_per_category) == self.ignored_label } # Next, combine the groundtruth and predicted labels. Dividing up the pixels # based on which groundtruth segment and which predicted segment they belong # to, this will assign a different 32-bit integer label to each choice # of (groundtruth segment, predicted segment), encoded as # gt_segment_id * offset + pred_segment_id. intersection_id_array = ( gt_segment_id.astype(np.uint32) * self.offset + pred_segment_id.astype(np.uint32)) # For every combination of (groundtruth segment, predicted segment) with a # non-empty intersection, this counts the number of pixels in that # intersection. intersection_areas = _ids_to_counts(intersection_id_array) # Helper function that computes the area of the overlap between a predicted # segment and the ground-truth void/ignored segment. def prediction_void_overlap(pred_segment_id): void_intersection_id = void_segment_id * self.offset + pred_segment_id return intersection_areas.get(void_intersection_id, 0) # Compute overall ignored overlap. def prediction_ignored_overlap(pred_segment_id): total_ignored_overlap = 0 for ignored_segment_id in ignored_segment_ids: intersection_id = ignored_segment_id * self.offset + pred_segment_id total_ignored_overlap += intersection_areas.get(intersection_id, 0) return total_ignored_overlap # Sets that are populated with which segments groundtruth/predicted segments # have been matched with overlapping predicted/groundtruth segments # respectively. gt_matched = set() pred_matched = set() # Calculate IoU per pair of intersecting segments of the same category. for intersection_id, intersection_area in six.iteritems(intersection_areas): gt_segment_id = intersection_id // self.offset pred_segment_id = intersection_id % self.offset gt_category = gt_segment_id // self.max_instances_per_category pred_category = pred_segment_id // self.max_instances_per_category if gt_category != pred_category: continue # Union between the groundtruth and predicted segments being compared does # not include the portion of the predicted segment that consists of # groundtruth "void" pixels. union = ( gt_segment_areas[gt_segment_id] + pred_segment_areas[pred_segment_id] - intersection_area - prediction_void_overlap(pred_segment_id)) iou = intersection_area / union if iou > 0.5: self.tp_per_class[gt_category] += 1 self.iou_per_class[gt_category] += iou gt_matched.add(gt_segment_id) pred_matched.add(pred_segment_id) # Count false negatives for each category. for gt_segment_id in six.iterkeys(gt_segment_areas): if gt_segment_id in gt_matched: continue category = gt_segment_id // self.max_instances_per_category # Failing to detect a void segment is not a false negative. if category == self.ignored_label: continue self.fn_per_class[category] += 1 # Count false positives for each category. for pred_segment_id in six.iterkeys(pred_segment_areas): if pred_segment_id in pred_matched: continue # A false positive is not penalized if is mostly ignored in the # groundtruth. if (prediction_ignored_overlap(pred_segment_id) / pred_segment_areas[pred_segment_id]) > 0.5: continue category = pred_segment_id // self.max_instances_per_category self.fp_per_class[category] += 1 return self.result() def _valid_categories(self): """Categories with a "valid" value for the metric, have > 0 instances. We will ignore the `ignore_label` class and other classes which have `tp + fn + fp = 0`. Returns: Boolean array of shape `[num_categories]`. """ valid_categories = np.not_equal( self.tp_per_class + self.fn_per_class + self.fp_per_class, 0) if self.ignored_label >= 0 and self.ignored_label < self.num_categories: valid_categories[self.ignored_label] = False return valid_categories def detailed_results(self, is_thing=None): """See base class.""" valid_categories = self._valid_categories() # If known, break down which categories are valid _and_ things/stuff. category_sets = collections.OrderedDict() category_sets['All'] = valid_categories if is_thing is not None: category_sets['Things'] = np.logical_and(valid_categories, is_thing) category_sets['Stuff'] = np.logical_and(valid_categories, np.logical_not(is_thing)) # Compute individual per-class metrics that constitute factors of PQ. sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class) rq = base_metric.realdiv_maybe_zero( self.tp_per_class, self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class) pq = np.multiply(sq, rq) # Assemble detailed results dictionary. results = {} for category_set_name, in_category_set in six.iteritems(category_sets): if np.any(in_category_set): results[category_set_name] = { 'pq': np.mean(pq[in_category_set]), 'sq': np.mean(sq[in_category_set]), 'rq': np.mean(rq[in_category_set]), # The number of categories in this subset. 'n': np.sum(in_category_set.astype(np.int32)), } else: results[category_set_name] = {'pq': 0, 'sq': 0, 'rq': 0, 'n': 0} return results def result_per_category(self): """See base class.""" sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class) rq = base_metric.realdiv_maybe_zero( self.tp_per_class, self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class) return np.multiply(sq, rq) def print_detailed_results(self, is_thing=None, print_digits=3): """See base class.""" results = self.detailed_results(is_thing=is_thing) tab = prettytable.PrettyTable() tab.add_column('', [], align='l') for fieldname in ['PQ', 'SQ', 'RQ', 'N']: tab.add_column(fieldname, [], align='r') for category_set, subset_results in six.iteritems(results): data_cols = [ round(subset_results[col_key], print_digits) * 100 for col_key in ['pq', 'sq', 'rq'] ] data_cols += [subset_results['n']] tab.add_row([category_set] + data_cols) print(tab) def result(self): """See base class.""" pq_per_class = self.result_per_category() valid_categories = self._valid_categories() if not np.any(valid_categories): return 0. return np.mean(pq_per_class[valid_categories]) def merge(self, other_instance): """See base class.""" self.iou_per_class += other_instance.iou_per_class self.tp_per_class += other_instance.tp_per_class self.fn_per_class += other_instance.fn_per_class self.fp_per_class += other_instance.fp_per_class def reset(self): """See base class.""" self.iou_per_class = np.zeros(self.num_categories, dtype=np.float64) self.tp_per_class = np.zeros(self.num_categories, dtype=np.float64) self.fn_per_class = np.zeros(self.num_categories, dtype=np.float64) self.fp_per_class = np.zeros(self.num_categories, dtype=np.float64)