Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Detection model evaluator. | |
This file provides a generic evaluation method that can be used to evaluate a | |
DetectionModel. | |
""" | |
import logging | |
import tensorflow.compat.v1 as tf | |
from object_detection import eval_util | |
from object_detection.core import prefetcher | |
from object_detection.core import standard_fields as fields | |
from object_detection.metrics import coco_evaluation | |
from object_detection.utils import object_detection_evaluation | |
# A dictionary of metric names to classes that implement the metric. The classes | |
# in the dictionary must implement | |
# utils.object_detection_evaluation.DetectionEvaluator interface. | |
EVAL_METRICS_CLASS_DICT = { | |
'pascal_voc_detection_metrics': | |
object_detection_evaluation.PascalDetectionEvaluator, | |
'weighted_pascal_voc_detection_metrics': | |
object_detection_evaluation.WeightedPascalDetectionEvaluator, | |
'pascal_voc_instance_segmentation_metrics': | |
object_detection_evaluation.PascalInstanceSegmentationEvaluator, | |
'weighted_pascal_voc_instance_segmentation_metrics': | |
object_detection_evaluation.WeightedPascalInstanceSegmentationEvaluator, | |
'oid_V2_detection_metrics': | |
object_detection_evaluation.OpenImagesDetectionEvaluator, | |
# DEPRECATED: please use oid_V2_detection_metrics instead | |
'open_images_V2_detection_metrics': | |
object_detection_evaluation.OpenImagesDetectionEvaluator, | |
'coco_detection_metrics': | |
coco_evaluation.CocoDetectionEvaluator, | |
'coco_mask_metrics': | |
coco_evaluation.CocoMaskEvaluator, | |
'oid_challenge_detection_metrics': | |
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator, | |
# DEPRECATED: please use oid_challenge_detection_metrics instead | |
'oid_challenge_object_detection_metrics': | |
object_detection_evaluation.OpenImagesDetectionChallengeEvaluator, | |
'oid_challenge_segmentation_metrics': | |
object_detection_evaluation | |
.OpenImagesInstanceSegmentationChallengeEvaluator, | |
} | |
EVAL_DEFAULT_METRIC = 'pascal_voc_detection_metrics' | |
def _extract_predictions_and_losses(model, | |
create_input_dict_fn, | |
ignore_groundtruth=False): | |
"""Constructs tensorflow detection graph and returns output tensors. | |
Args: | |
model: model to perform predictions with. | |
create_input_dict_fn: function to create input tensor dictionaries. | |
ignore_groundtruth: whether groundtruth should be ignored. | |
Returns: | |
prediction_groundtruth_dict: A dictionary with postprocessed tensors (keyed | |
by standard_fields.DetectionResultsFields) and optional groundtruth | |
tensors (keyed by standard_fields.InputDataFields). | |
losses_dict: A dictionary containing detection losses. This is empty when | |
ignore_groundtruth is true. | |
""" | |
input_dict = create_input_dict_fn() | |
prefetch_queue = prefetcher.prefetch(input_dict, capacity=500) | |
input_dict = prefetch_queue.dequeue() | |
original_image = tf.expand_dims(input_dict[fields.InputDataFields.image], 0) | |
preprocessed_image, true_image_shapes = model.preprocess( | |
tf.cast(original_image, dtype=tf.float32)) | |
prediction_dict = model.predict(preprocessed_image, true_image_shapes) | |
detections = model.postprocess(prediction_dict, true_image_shapes) | |
groundtruth = None | |
losses_dict = {} | |
if not ignore_groundtruth: | |
groundtruth = { | |
fields.InputDataFields.groundtruth_boxes: | |
input_dict[fields.InputDataFields.groundtruth_boxes], | |
fields.InputDataFields.groundtruth_classes: | |
input_dict[fields.InputDataFields.groundtruth_classes], | |
fields.InputDataFields.groundtruth_area: | |
input_dict[fields.InputDataFields.groundtruth_area], | |
fields.InputDataFields.groundtruth_is_crowd: | |
input_dict[fields.InputDataFields.groundtruth_is_crowd], | |
fields.InputDataFields.groundtruth_difficult: | |
input_dict[fields.InputDataFields.groundtruth_difficult] | |
} | |
if fields.InputDataFields.groundtruth_group_of in input_dict: | |
groundtruth[fields.InputDataFields.groundtruth_group_of] = ( | |
input_dict[fields.InputDataFields.groundtruth_group_of]) | |
groundtruth_masks_list = None | |
if fields.DetectionResultFields.detection_masks in detections: | |
groundtruth[fields.InputDataFields.groundtruth_instance_masks] = ( | |
input_dict[fields.InputDataFields.groundtruth_instance_masks]) | |
groundtruth_masks_list = [ | |
input_dict[fields.InputDataFields.groundtruth_instance_masks]] | |
groundtruth_keypoints_list = None | |
if fields.DetectionResultFields.detection_keypoints in detections: | |
groundtruth[fields.InputDataFields.groundtruth_keypoints] = ( | |
input_dict[fields.InputDataFields.groundtruth_keypoints]) | |
groundtruth_keypoints_list = [ | |
input_dict[fields.InputDataFields.groundtruth_keypoints]] | |
label_id_offset = 1 | |
model.provide_groundtruth( | |
[input_dict[fields.InputDataFields.groundtruth_boxes]], | |
[tf.one_hot(input_dict[fields.InputDataFields.groundtruth_classes] | |
- label_id_offset, depth=model.num_classes)], | |
groundtruth_masks_list, groundtruth_keypoints_list) | |
losses_dict.update(model.loss(prediction_dict, true_image_shapes)) | |
result_dict = eval_util.result_dict_for_single_example( | |
original_image, | |
input_dict[fields.InputDataFields.source_id], | |
detections, | |
groundtruth, | |
class_agnostic=( | |
fields.DetectionResultFields.detection_classes not in detections), | |
scale_to_absolute=True) | |
return result_dict, losses_dict | |
def get_evaluators(eval_config, categories): | |
"""Returns the evaluator class according to eval_config, valid for categories. | |
Args: | |
eval_config: evaluation configurations. | |
categories: a list of categories to evaluate. | |
Returns: | |
An list of instances of DetectionEvaluator. | |
Raises: | |
ValueError: if metric is not in the metric class dictionary. | |
""" | |
eval_metric_fn_keys = eval_config.metrics_set | |
if not eval_metric_fn_keys: | |
eval_metric_fn_keys = [EVAL_DEFAULT_METRIC] | |
evaluators_list = [] | |
for eval_metric_fn_key in eval_metric_fn_keys: | |
if eval_metric_fn_key not in EVAL_METRICS_CLASS_DICT: | |
raise ValueError('Metric not found: {}'.format(eval_metric_fn_key)) | |
if eval_metric_fn_key == 'oid_challenge_object_detection_metrics': | |
logging.warning( | |
'oid_challenge_object_detection_metrics is deprecated; ' | |
'use oid_challenge_detection_metrics instead' | |
) | |
if eval_metric_fn_key == 'oid_V2_detection_metrics': | |
logging.warning( | |
'open_images_V2_detection_metrics is deprecated; ' | |
'use oid_V2_detection_metrics instead' | |
) | |
evaluators_list.append( | |
EVAL_METRICS_CLASS_DICT[eval_metric_fn_key](categories=categories)) | |
return evaluators_list | |
def evaluate(create_input_dict_fn, create_model_fn, eval_config, categories, | |
checkpoint_dir, eval_dir, graph_hook_fn=None, evaluator_list=None): | |
"""Evaluation function for detection models. | |
Args: | |
create_input_dict_fn: a function to create a tensor input dictionary. | |
create_model_fn: a function that creates a DetectionModel. | |
eval_config: a eval_pb2.EvalConfig protobuf. | |
categories: a list of category dictionaries. Each dict in the list should | |
have an integer 'id' field and string 'name' field. | |
checkpoint_dir: directory to load the checkpoints to evaluate from. | |
eval_dir: directory to write evaluation metrics summary to. | |
graph_hook_fn: Optional function that is called after the training graph is | |
completely built. This is helpful to perform additional changes to the | |
training graph such as optimizing batchnorm. The function should modify | |
the default graph. | |
evaluator_list: Optional list of instances of DetectionEvaluator. If not | |
given, this list of metrics is created according to the eval_config. | |
Returns: | |
metrics: A dictionary containing metric names and values from the latest | |
run. | |
""" | |
model = create_model_fn() | |
if eval_config.ignore_groundtruth and not eval_config.export_path: | |
logging.fatal('If ignore_groundtruth=True then an export_path is ' | |
'required. Aborting!!!') | |
tensor_dict, losses_dict = _extract_predictions_and_losses( | |
model=model, | |
create_input_dict_fn=create_input_dict_fn, | |
ignore_groundtruth=eval_config.ignore_groundtruth) | |
def _process_batch(tensor_dict, sess, batch_index, counters, | |
losses_dict=None): | |
"""Evaluates tensors in tensor_dict, losses_dict and visualizes examples. | |
This function calls sess.run on tensor_dict, evaluating the original_image | |
tensor only on the first K examples and visualizing detections overlaid | |
on this original_image. | |
Args: | |
tensor_dict: a dictionary of tensors | |
sess: tensorflow session | |
batch_index: the index of the batch amongst all batches in the run. | |
counters: a dictionary holding 'success' and 'skipped' fields which can | |
be updated to keep track of number of successful and failed runs, | |
respectively. If these fields are not updated, then the success/skipped | |
counter values shown at the end of evaluation will be incorrect. | |
losses_dict: Optional dictonary of scalar loss tensors. | |
Returns: | |
result_dict: a dictionary of numpy arrays | |
result_losses_dict: a dictionary of scalar losses. This is empty if input | |
losses_dict is None. | |
""" | |
try: | |
if not losses_dict: | |
losses_dict = {} | |
result_dict, result_losses_dict = sess.run([tensor_dict, losses_dict]) | |
counters['success'] += 1 | |
except tf.errors.InvalidArgumentError: | |
logging.info('Skipping image') | |
counters['skipped'] += 1 | |
return {}, {} | |
global_step = tf.train.global_step(sess, tf.train.get_global_step()) | |
if batch_index < eval_config.num_visualizations: | |
tag = 'image-{}'.format(batch_index) | |
eval_util.visualize_detection_results( | |
result_dict, | |
tag, | |
global_step, | |
categories=categories, | |
summary_dir=eval_dir, | |
export_dir=eval_config.visualization_export_dir, | |
show_groundtruth=eval_config.visualize_groundtruth_boxes, | |
groundtruth_box_visualization_color=eval_config. | |
groundtruth_box_visualization_color, | |
min_score_thresh=eval_config.min_score_threshold, | |
max_num_predictions=eval_config.max_num_boxes_to_visualize, | |
skip_scores=eval_config.skip_scores, | |
skip_labels=eval_config.skip_labels, | |
keep_image_id_for_visualization_export=eval_config. | |
keep_image_id_for_visualization_export) | |
return result_dict, result_losses_dict | |
if graph_hook_fn: graph_hook_fn() | |
variables_to_restore = tf.global_variables() | |
global_step = tf.train.get_or_create_global_step() | |
variables_to_restore.append(global_step) | |
if eval_config.use_moving_averages: | |
variable_averages = tf.train.ExponentialMovingAverage(0.0) | |
variables_to_restore = variable_averages.variables_to_restore() | |
saver = tf.train.Saver(variables_to_restore) | |
def _restore_latest_checkpoint(sess): | |
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) | |
saver.restore(sess, latest_checkpoint) | |
if not evaluator_list: | |
evaluator_list = get_evaluators(eval_config, categories) | |
metrics = eval_util.repeated_checkpoint_run( | |
tensor_dict=tensor_dict, | |
summary_dir=eval_dir, | |
evaluators=evaluator_list, | |
batch_processor=_process_batch, | |
checkpoint_dirs=[checkpoint_dir], | |
variables_to_restore=None, | |
restore_fn=_restore_latest_checkpoint, | |
num_batches=eval_config.num_examples, | |
eval_interval_secs=eval_config.eval_interval_secs, | |
max_number_of_evaluations=(1 if eval_config.ignore_groundtruth else | |
eval_config.max_evals | |
if eval_config.max_evals else None), | |
master=eval_config.eval_master, | |
save_graph=eval_config.save_graph, | |
save_graph_dir=(eval_dir if eval_config.save_graph else ''), | |
losses_dict=losses_dict, | |
eval_export_path=eval_config.export_path) | |
return metrics | |