Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Detection model trainer. | |
This file provides a generic training method that can be used to train a | |
DetectionModel. | |
""" | |
import functools | |
import tensorflow.compat.v1 as tf | |
import tf_slim as slim | |
from object_detection.builders import optimizer_builder | |
from object_detection.builders import preprocessor_builder | |
from object_detection.core import batcher | |
from object_detection.core import preprocessor | |
from object_detection.core import standard_fields as fields | |
from object_detection.utils import ops as util_ops | |
from object_detection.utils import variables_helper | |
from deployment import model_deploy | |
def create_input_queue(batch_size_per_clone, create_tensor_dict_fn, | |
batch_queue_capacity, num_batch_queue_threads, | |
prefetch_queue_capacity, data_augmentation_options): | |
"""Sets up reader, prefetcher and returns input queue. | |
Args: | |
batch_size_per_clone: batch size to use per clone. | |
create_tensor_dict_fn: function to create tensor dictionary. | |
batch_queue_capacity: maximum number of elements to store within a queue. | |
num_batch_queue_threads: number of threads to use for batching. | |
prefetch_queue_capacity: maximum capacity of the queue used to prefetch | |
assembled batches. | |
data_augmentation_options: a list of tuples, where each tuple contains a | |
data augmentation function and a dictionary containing arguments and their | |
values (see preprocessor.py). | |
Returns: | |
input queue: a batcher.BatchQueue object holding enqueued tensor_dicts | |
(which hold images, boxes and targets). To get a batch of tensor_dicts, | |
call input_queue.Dequeue(). | |
""" | |
tensor_dict = create_tensor_dict_fn() | |
tensor_dict[fields.InputDataFields.image] = tf.expand_dims( | |
tensor_dict[fields.InputDataFields.image], 0) | |
images = tensor_dict[fields.InputDataFields.image] | |
float_images = tf.cast(images, dtype=tf.float32) | |
tensor_dict[fields.InputDataFields.image] = float_images | |
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks | |
in tensor_dict) | |
include_keypoints = (fields.InputDataFields.groundtruth_keypoints | |
in tensor_dict) | |
include_multiclass_scores = (fields.InputDataFields.multiclass_scores | |
in tensor_dict) | |
if data_augmentation_options: | |
tensor_dict = preprocessor.preprocess( | |
tensor_dict, data_augmentation_options, | |
func_arg_map=preprocessor.get_default_func_arg_map( | |
include_label_weights=True, | |
include_multiclass_scores=include_multiclass_scores, | |
include_instance_masks=include_instance_masks, | |
include_keypoints=include_keypoints)) | |
input_queue = batcher.BatchQueue( | |
tensor_dict, | |
batch_size=batch_size_per_clone, | |
batch_queue_capacity=batch_queue_capacity, | |
num_batch_queue_threads=num_batch_queue_threads, | |
prefetch_queue_capacity=prefetch_queue_capacity) | |
return input_queue | |
def get_inputs(input_queue, | |
num_classes, | |
merge_multiple_label_boxes=False, | |
use_multiclass_scores=False): | |
"""Dequeues batch and constructs inputs to object detection model. | |
Args: | |
input_queue: BatchQueue object holding enqueued tensor_dicts. | |
num_classes: Number of classes. | |
merge_multiple_label_boxes: Whether to merge boxes with multiple labels | |
or not. Defaults to false. Merged boxes are represented with a single | |
box and a k-hot encoding of the multiple labels associated with the | |
boxes. | |
use_multiclass_scores: Whether to use multiclass scores instead of | |
groundtruth_classes. | |
Returns: | |
images: a list of 3-D float tensor of images. | |
image_keys: a list of string keys for the images. | |
locations_list: a list of tensors of shape [num_boxes, 4] | |
containing the corners of the groundtruth boxes. | |
classes_list: a list of padded one-hot (or K-hot) float32 tensors containing | |
target classes. | |
masks_list: a list of 3-D float tensors of shape [num_boxes, image_height, | |
image_width] containing instance masks for objects if present in the | |
input_queue. Else returns None. | |
keypoints_list: a list of 3-D float tensors of shape [num_boxes, | |
num_keypoints, 2] containing keypoints for objects if present in the | |
input queue. Else returns None. | |
weights_lists: a list of 1-D float32 tensors of shape [num_boxes] | |
containing groundtruth weight for each box. | |
""" | |
read_data_list = input_queue.dequeue() | |
label_id_offset = 1 | |
def extract_images_and_targets(read_data): | |
"""Extract images and targets from the input dict.""" | |
image = read_data[fields.InputDataFields.image] | |
key = '' | |
if fields.InputDataFields.source_id in read_data: | |
key = read_data[fields.InputDataFields.source_id] | |
location_gt = read_data[fields.InputDataFields.groundtruth_boxes] | |
classes_gt = tf.cast(read_data[fields.InputDataFields.groundtruth_classes], | |
tf.int32) | |
classes_gt -= label_id_offset | |
if merge_multiple_label_boxes and use_multiclass_scores: | |
raise ValueError( | |
'Using both merge_multiple_label_boxes and use_multiclass_scores is' | |
'not supported' | |
) | |
if merge_multiple_label_boxes: | |
location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels( | |
location_gt, classes_gt, num_classes) | |
classes_gt = tf.cast(classes_gt, tf.float32) | |
elif use_multiclass_scores: | |
classes_gt = tf.cast(read_data[fields.InputDataFields.multiclass_scores], | |
tf.float32) | |
else: | |
classes_gt = util_ops.padded_one_hot_encoding( | |
indices=classes_gt, depth=num_classes, left_pad=0) | |
masks_gt = read_data.get(fields.InputDataFields.groundtruth_instance_masks) | |
keypoints_gt = read_data.get(fields.InputDataFields.groundtruth_keypoints) | |
if (merge_multiple_label_boxes and ( | |
masks_gt is not None or keypoints_gt is not None)): | |
raise NotImplementedError('Multi-label support is only for boxes.') | |
weights_gt = read_data.get( | |
fields.InputDataFields.groundtruth_weights) | |
return (image, key, location_gt, classes_gt, masks_gt, keypoints_gt, | |
weights_gt) | |
return zip(*map(extract_images_and_targets, read_data_list)) | |
def _create_losses(input_queue, create_model_fn, train_config): | |
"""Creates loss function for a DetectionModel. | |
Args: | |
input_queue: BatchQueue object holding enqueued tensor_dicts. | |
create_model_fn: A function to create the DetectionModel. | |
train_config: a train_pb2.TrainConfig protobuf. | |
""" | |
detection_model = create_model_fn() | |
(images, _, groundtruth_boxes_list, groundtruth_classes_list, | |
groundtruth_masks_list, groundtruth_keypoints_list, | |
groundtruth_weights_list) = get_inputs( | |
input_queue, | |
detection_model.num_classes, | |
train_config.merge_multiple_label_boxes, | |
train_config.use_multiclass_scores) | |
preprocessed_images = [] | |
true_image_shapes = [] | |
for image in images: | |
resized_image, true_image_shape = detection_model.preprocess(image) | |
preprocessed_images.append(resized_image) | |
true_image_shapes.append(true_image_shape) | |
images = tf.concat(preprocessed_images, 0) | |
true_image_shapes = tf.concat(true_image_shapes, 0) | |
if any(mask is None for mask in groundtruth_masks_list): | |
groundtruth_masks_list = None | |
if any(keypoints is None for keypoints in groundtruth_keypoints_list): | |
groundtruth_keypoints_list = None | |
detection_model.provide_groundtruth( | |
groundtruth_boxes_list, | |
groundtruth_classes_list, | |
groundtruth_masks_list, | |
groundtruth_keypoints_list, | |
groundtruth_weights_list=groundtruth_weights_list) | |
prediction_dict = detection_model.predict(images, true_image_shapes) | |
losses_dict = detection_model.loss(prediction_dict, true_image_shapes) | |
for loss_tensor in losses_dict.values(): | |
tf.losses.add_loss(loss_tensor) | |
def train(create_tensor_dict_fn, | |
create_model_fn, | |
train_config, | |
master, | |
task, | |
num_clones, | |
worker_replicas, | |
clone_on_cpu, | |
ps_tasks, | |
worker_job_name, | |
is_chief, | |
train_dir, | |
graph_hook_fn=None): | |
"""Training function for detection models. | |
Args: | |
create_tensor_dict_fn: a function to create a tensor input dictionary. | |
create_model_fn: a function that creates a DetectionModel and generates | |
losses. | |
train_config: a train_pb2.TrainConfig protobuf. | |
master: BNS name of the TensorFlow master to use. | |
task: The task id of this training instance. | |
num_clones: The number of clones to run per machine. | |
worker_replicas: The number of work replicas to train with. | |
clone_on_cpu: True if clones should be forced to run on CPU. | |
ps_tasks: Number of parameter server tasks. | |
worker_job_name: Name of the worker job. | |
is_chief: Whether this replica is the chief replica. | |
train_dir: Directory to write checkpoints and training summaries to. | |
graph_hook_fn: Optional function that is called after the inference graph is | |
built (before optimization). This is helpful to perform additional changes | |
to the training graph such as adding FakeQuant ops. The function should | |
modify the default graph. | |
Raises: | |
ValueError: If both num_clones > 1 and train_config.sync_replicas is true. | |
""" | |
detection_model = create_model_fn() | |
data_augmentation_options = [ | |
preprocessor_builder.build(step) | |
for step in train_config.data_augmentation_options] | |
with tf.Graph().as_default(): | |
# Build a configuration specifying multi-GPU and multi-replicas. | |
deploy_config = model_deploy.DeploymentConfig( | |
num_clones=num_clones, | |
clone_on_cpu=clone_on_cpu, | |
replica_id=task, | |
num_replicas=worker_replicas, | |
num_ps_tasks=ps_tasks, | |
worker_job_name=worker_job_name) | |
# Place the global step on the device storing the variables. | |
with tf.device(deploy_config.variables_device()): | |
global_step = slim.create_global_step() | |
if num_clones != 1 and train_config.sync_replicas: | |
raise ValueError('In Synchronous SGD mode num_clones must ', | |
'be 1. Found num_clones: {}'.format(num_clones)) | |
batch_size = train_config.batch_size // num_clones | |
if train_config.sync_replicas: | |
batch_size //= train_config.replicas_to_aggregate | |
with tf.device(deploy_config.inputs_device()): | |
input_queue = create_input_queue( | |
batch_size, create_tensor_dict_fn, | |
train_config.batch_queue_capacity, | |
train_config.num_batch_queue_threads, | |
train_config.prefetch_queue_capacity, data_augmentation_options) | |
# Gather initial summaries. | |
# TODO(rathodv): See if summaries can be added/extracted from global tf | |
# collections so that they don't have to be passed around. | |
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) | |
global_summaries = set([]) | |
model_fn = functools.partial(_create_losses, | |
create_model_fn=create_model_fn, | |
train_config=train_config) | |
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) | |
first_clone_scope = clones[0].scope | |
if graph_hook_fn: | |
with tf.device(deploy_config.variables_device()): | |
graph_hook_fn() | |
# Gather update_ops from the first clone. These contain, for example, | |
# the updates for the batch_norm variables created by model_fn. | |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) | |
with tf.device(deploy_config.optimizer_device()): | |
training_optimizer, optimizer_summary_vars = optimizer_builder.build( | |
train_config.optimizer) | |
for var in optimizer_summary_vars: | |
tf.summary.scalar(var.op.name, var, family='LearningRate') | |
sync_optimizer = None | |
if train_config.sync_replicas: | |
training_optimizer = tf.train.SyncReplicasOptimizer( | |
training_optimizer, | |
replicas_to_aggregate=train_config.replicas_to_aggregate, | |
total_num_replicas=worker_replicas) | |
sync_optimizer = training_optimizer | |
with tf.device(deploy_config.optimizer_device()): | |
regularization_losses = (None if train_config.add_regularization_loss | |
else []) | |
total_loss, grads_and_vars = model_deploy.optimize_clones( | |
clones, training_optimizer, | |
regularization_losses=regularization_losses) | |
total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') | |
# Optionally multiply bias gradients by train_config.bias_grad_multiplier. | |
if train_config.bias_grad_multiplier: | |
biases_regex_list = ['.*/biases'] | |
grads_and_vars = variables_helper.multiply_gradients_matching_regex( | |
grads_and_vars, | |
biases_regex_list, | |
multiplier=train_config.bias_grad_multiplier) | |
# Optionally freeze some layers by setting their gradients to be zero. | |
if train_config.freeze_variables: | |
grads_and_vars = variables_helper.freeze_gradients_matching_regex( | |
grads_and_vars, train_config.freeze_variables) | |
# Optionally clip gradients | |
if train_config.gradient_clipping_by_norm > 0: | |
with tf.name_scope('clip_grads'): | |
grads_and_vars = slim.learning.clip_gradient_norms( | |
grads_and_vars, train_config.gradient_clipping_by_norm) | |
# Create gradient updates. | |
grad_updates = training_optimizer.apply_gradients(grads_and_vars, | |
global_step=global_step) | |
update_ops.append(grad_updates) | |
update_op = tf.group(*update_ops, name='update_barrier') | |
with tf.control_dependencies([update_op]): | |
train_tensor = tf.identity(total_loss, name='train_op') | |
# Add summaries. | |
for model_var in slim.get_model_variables(): | |
global_summaries.add(tf.summary.histogram('ModelVars/' + | |
model_var.op.name, model_var)) | |
for loss_tensor in tf.losses.get_losses(): | |
global_summaries.add(tf.summary.scalar('Losses/' + loss_tensor.op.name, | |
loss_tensor)) | |
global_summaries.add( | |
tf.summary.scalar('Losses/TotalLoss', tf.losses.get_total_loss())) | |
# Add the summaries from the first clone. These contain the summaries | |
# created by model_fn and either optimize_clones() or _gather_clone_loss(). | |
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, | |
first_clone_scope)) | |
summaries |= global_summaries | |
# Merge all summaries together. | |
summary_op = tf.summary.merge(list(summaries), name='summary_op') | |
# Soft placement allows placing on CPU ops without GPU implementation. | |
session_config = tf.ConfigProto(allow_soft_placement=True, | |
log_device_placement=False) | |
# Save checkpoints regularly. | |
keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours | |
saver = tf.train.Saver( | |
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) | |
# Create ops required to initialize the model from a given checkpoint. | |
init_fn = None | |
if train_config.fine_tune_checkpoint: | |
if not train_config.fine_tune_checkpoint_type: | |
# train_config.from_detection_checkpoint field is deprecated. For | |
# backward compatibility, fine_tune_checkpoint_type is set based on | |
# from_detection_checkpoint. | |
if train_config.from_detection_checkpoint: | |
train_config.fine_tune_checkpoint_type = 'detection' | |
else: | |
train_config.fine_tune_checkpoint_type = 'classification' | |
var_map = detection_model.restore_map( | |
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type, | |
load_all_detection_checkpoint_vars=( | |
train_config.load_all_detection_checkpoint_vars)) | |
available_var_map = (variables_helper. | |
get_variables_available_in_checkpoint( | |
var_map, train_config.fine_tune_checkpoint, | |
include_global_step=False)) | |
init_saver = tf.train.Saver(available_var_map) | |
def initializer_fn(sess): | |
init_saver.restore(sess, train_config.fine_tune_checkpoint) | |
init_fn = initializer_fn | |
slim.learning.train( | |
train_tensor, | |
logdir=train_dir, | |
master=master, | |
is_chief=is_chief, | |
session_config=session_config, | |
startup_delay_steps=train_config.startup_delay_steps, | |
init_fn=init_fn, | |
summary_op=summary_op, | |
number_of_steps=( | |
train_config.num_steps if train_config.num_steps else None), | |
save_summaries_secs=120, | |
sync_optimizer=sync_optimizer, | |
saver=saver) | |