Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Functions for reading and updating configuration files.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import tensorflow.compat.v1 as tf | |
from google.protobuf import text_format | |
from tensorflow.python.lib.io import file_io | |
from object_detection.protos import eval_pb2 | |
from object_detection.protos import graph_rewriter_pb2 | |
from object_detection.protos import input_reader_pb2 | |
from object_detection.protos import model_pb2 | |
from object_detection.protos import pipeline_pb2 | |
from object_detection.protos import train_pb2 | |
def get_image_resizer_config(model_config): | |
"""Returns the image resizer config from a model config. | |
Args: | |
model_config: A model_pb2.DetectionModel. | |
Returns: | |
An image_resizer_pb2.ImageResizer. | |
Raises: | |
ValueError: If the model type is not recognized. | |
""" | |
meta_architecture = model_config.WhichOneof("model") | |
meta_architecture_config = getattr(model_config, meta_architecture) | |
if hasattr(meta_architecture_config, "image_resizer"): | |
return getattr(meta_architecture_config, "image_resizer") | |
else: | |
raise ValueError("{} has no image_reszier_config".format( | |
meta_architecture)) | |
def get_spatial_image_size(image_resizer_config): | |
"""Returns expected spatial size of the output image from a given config. | |
Args: | |
image_resizer_config: An image_resizer_pb2.ImageResizer. | |
Returns: | |
A list of two integers of the form [height, width]. `height` and `width` are | |
set -1 if they cannot be determined during graph construction. | |
Raises: | |
ValueError: If the model type is not recognized. | |
""" | |
if image_resizer_config.HasField("fixed_shape_resizer"): | |
return [ | |
image_resizer_config.fixed_shape_resizer.height, | |
image_resizer_config.fixed_shape_resizer.width | |
] | |
if image_resizer_config.HasField("keep_aspect_ratio_resizer"): | |
if image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension: | |
return [image_resizer_config.keep_aspect_ratio_resizer.max_dimension] * 2 | |
else: | |
return [-1, -1] | |
if image_resizer_config.HasField( | |
"identity_resizer") or image_resizer_config.HasField( | |
"conditional_shape_resizer"): | |
return [-1, -1] | |
raise ValueError("Unknown image resizer type.") | |
def get_max_num_context_features(model_config): | |
"""Returns maximum number of context features from a given config. | |
Args: | |
model_config: A model config file. | |
Returns: | |
An integer specifying the max number of context features if the model | |
config contains context_config, None otherwise | |
""" | |
meta_architecture = model_config.WhichOneof("model") | |
meta_architecture_config = getattr(model_config, meta_architecture) | |
if hasattr(meta_architecture_config, "context_config"): | |
return meta_architecture_config.context_config.max_num_context_features | |
def get_context_feature_length(model_config): | |
"""Returns context feature length from a given config. | |
Args: | |
model_config: A model config file. | |
Returns: | |
An integer specifying the fixed length of each feature in context_features. | |
""" | |
meta_architecture = model_config.WhichOneof("model") | |
meta_architecture_config = getattr(model_config, meta_architecture) | |
if hasattr(meta_architecture_config, "context_config"): | |
return meta_architecture_config.context_config.context_feature_length | |
def get_configs_from_pipeline_file(pipeline_config_path, config_override=None): | |
"""Reads config from a file containing pipeline_pb2.TrainEvalPipelineConfig. | |
Args: | |
pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text | |
proto. | |
config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to | |
override pipeline_config_path. | |
Returns: | |
Dictionary of configuration objects. Keys are `model`, `train_config`, | |
`train_input_config`, `eval_config`, `eval_input_config`. Value are the | |
corresponding config objects. | |
""" | |
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | |
with tf.gfile.GFile(pipeline_config_path, "r") as f: | |
proto_str = f.read() | |
text_format.Merge(proto_str, pipeline_config) | |
if config_override: | |
text_format.Merge(config_override, pipeline_config) | |
return create_configs_from_pipeline_proto(pipeline_config) | |
def create_configs_from_pipeline_proto(pipeline_config): | |
"""Creates a configs dictionary from pipeline_pb2.TrainEvalPipelineConfig. | |
Args: | |
pipeline_config: pipeline_pb2.TrainEvalPipelineConfig proto object. | |
Returns: | |
Dictionary of configuration objects. Keys are `model`, `train_config`, | |
`train_input_config`, `eval_config`, `eval_input_configs`. Value are | |
the corresponding config objects or list of config objects (only for | |
eval_input_configs). | |
""" | |
configs = {} | |
configs["model"] = pipeline_config.model | |
configs["train_config"] = pipeline_config.train_config | |
configs["train_input_config"] = pipeline_config.train_input_reader | |
configs["eval_config"] = pipeline_config.eval_config | |
configs["eval_input_configs"] = pipeline_config.eval_input_reader | |
# Keeps eval_input_config only for backwards compatibility. All clients should | |
# read eval_input_configs instead. | |
if configs["eval_input_configs"]: | |
configs["eval_input_config"] = configs["eval_input_configs"][0] | |
if pipeline_config.HasField("graph_rewriter"): | |
configs["graph_rewriter_config"] = pipeline_config.graph_rewriter | |
return configs | |
def get_graph_rewriter_config_from_file(graph_rewriter_config_file): | |
"""Parses config for graph rewriter. | |
Args: | |
graph_rewriter_config_file: file path to the graph rewriter config. | |
Returns: | |
graph_rewriter_pb2.GraphRewriter proto | |
""" | |
graph_rewriter_config = graph_rewriter_pb2.GraphRewriter() | |
with tf.gfile.GFile(graph_rewriter_config_file, "r") as f: | |
text_format.Merge(f.read(), graph_rewriter_config) | |
return graph_rewriter_config | |
def create_pipeline_proto_from_configs(configs): | |
"""Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary. | |
This function performs the inverse operation of | |
create_configs_from_pipeline_proto(). | |
Args: | |
configs: Dictionary of configs. See get_configs_from_pipeline_file(). | |
Returns: | |
A fully populated pipeline_pb2.TrainEvalPipelineConfig. | |
""" | |
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() | |
pipeline_config.model.CopyFrom(configs["model"]) | |
pipeline_config.train_config.CopyFrom(configs["train_config"]) | |
pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"]) | |
pipeline_config.eval_config.CopyFrom(configs["eval_config"]) | |
pipeline_config.eval_input_reader.extend(configs["eval_input_configs"]) | |
if "graph_rewriter_config" in configs: | |
pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"]) | |
return pipeline_config | |
def save_pipeline_config(pipeline_config, directory): | |
"""Saves a pipeline config text file to disk. | |
Args: | |
pipeline_config: A pipeline_pb2.TrainEvalPipelineConfig. | |
directory: The model directory into which the pipeline config file will be | |
saved. | |
""" | |
if not file_io.file_exists(directory): | |
file_io.recursive_create_dir(directory) | |
pipeline_config_path = os.path.join(directory, "pipeline.config") | |
config_text = text_format.MessageToString(pipeline_config) | |
with tf.gfile.Open(pipeline_config_path, "wb") as f: | |
tf.logging.info("Writing pipeline config file to %s", | |
pipeline_config_path) | |
f.write(config_text) | |
def get_configs_from_multiple_files(model_config_path="", | |
train_config_path="", | |
train_input_config_path="", | |
eval_config_path="", | |
eval_input_config_path="", | |
graph_rewriter_config_path=""): | |
"""Reads training configuration from multiple config files. | |
Args: | |
model_config_path: Path to model_pb2.DetectionModel. | |
train_config_path: Path to train_pb2.TrainConfig. | |
train_input_config_path: Path to input_reader_pb2.InputReader. | |
eval_config_path: Path to eval_pb2.EvalConfig. | |
eval_input_config_path: Path to input_reader_pb2.InputReader. | |
graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter. | |
Returns: | |
Dictionary of configuration objects. Keys are `model`, `train_config`, | |
`train_input_config`, `eval_config`, `eval_input_config`. Key/Values are | |
returned only for valid (non-empty) strings. | |
""" | |
configs = {} | |
if model_config_path: | |
model_config = model_pb2.DetectionModel() | |
with tf.gfile.GFile(model_config_path, "r") as f: | |
text_format.Merge(f.read(), model_config) | |
configs["model"] = model_config | |
if train_config_path: | |
train_config = train_pb2.TrainConfig() | |
with tf.gfile.GFile(train_config_path, "r") as f: | |
text_format.Merge(f.read(), train_config) | |
configs["train_config"] = train_config | |
if train_input_config_path: | |
train_input_config = input_reader_pb2.InputReader() | |
with tf.gfile.GFile(train_input_config_path, "r") as f: | |
text_format.Merge(f.read(), train_input_config) | |
configs["train_input_config"] = train_input_config | |
if eval_config_path: | |
eval_config = eval_pb2.EvalConfig() | |
with tf.gfile.GFile(eval_config_path, "r") as f: | |
text_format.Merge(f.read(), eval_config) | |
configs["eval_config"] = eval_config | |
if eval_input_config_path: | |
eval_input_config = input_reader_pb2.InputReader() | |
with tf.gfile.GFile(eval_input_config_path, "r") as f: | |
text_format.Merge(f.read(), eval_input_config) | |
configs["eval_input_configs"] = [eval_input_config] | |
if graph_rewriter_config_path: | |
configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file( | |
graph_rewriter_config_path) | |
return configs | |
def get_number_of_classes(model_config): | |
"""Returns the number of classes for a detection model. | |
Args: | |
model_config: A model_pb2.DetectionModel. | |
Returns: | |
Number of classes. | |
Raises: | |
ValueError: If the model type is not recognized. | |
""" | |
meta_architecture = model_config.WhichOneof("model") | |
meta_architecture_config = getattr(model_config, meta_architecture) | |
if hasattr(meta_architecture_config, "num_classes"): | |
return meta_architecture_config.num_classes | |
else: | |
raise ValueError("{} does not have num_classes.".format(meta_architecture)) | |
def get_optimizer_type(train_config): | |
"""Returns the optimizer type for training. | |
Args: | |
train_config: A train_pb2.TrainConfig. | |
Returns: | |
The type of the optimizer | |
""" | |
return train_config.optimizer.WhichOneof("optimizer") | |
def get_learning_rate_type(optimizer_config): | |
"""Returns the learning rate type for training. | |
Args: | |
optimizer_config: An optimizer_pb2.Optimizer. | |
Returns: | |
The type of the learning rate. | |
""" | |
return optimizer_config.learning_rate.WhichOneof("learning_rate") | |
def _is_generic_key(key): | |
"""Determines whether the key starts with a generic config dictionary key.""" | |
for prefix in [ | |
"graph_rewriter_config", | |
"model", | |
"train_input_config", | |
"train_config", | |
"eval_config"]: | |
if key.startswith(prefix + "."): | |
return True | |
return False | |
def _check_and_convert_legacy_input_config_key(key): | |
"""Checks key and converts legacy input config update to specific update. | |
Args: | |
key: string indicates the target of update operation. | |
Returns: | |
is_valid_input_config_key: A boolean indicating whether the input key is to | |
update input config(s). | |
key_name: 'eval_input_configs' or 'train_input_config' string if | |
is_valid_input_config_key is true. None if is_valid_input_config_key is | |
false. | |
input_name: always returns None since legacy input config key never | |
specifies the target input config. Keeping this output only to match the | |
output form defined for input config update. | |
field_name: the field name in input config. `key` itself if | |
is_valid_input_config_key is false. | |
""" | |
key_name = None | |
input_name = None | |
field_name = key | |
is_valid_input_config_key = True | |
if field_name == "train_shuffle": | |
key_name = "train_input_config" | |
field_name = "shuffle" | |
elif field_name == "eval_shuffle": | |
key_name = "eval_input_configs" | |
field_name = "shuffle" | |
elif field_name == "train_input_path": | |
key_name = "train_input_config" | |
field_name = "input_path" | |
elif field_name == "eval_input_path": | |
key_name = "eval_input_configs" | |
field_name = "input_path" | |
elif field_name == "append_train_input_path": | |
key_name = "train_input_config" | |
field_name = "input_path" | |
elif field_name == "append_eval_input_path": | |
key_name = "eval_input_configs" | |
field_name = "input_path" | |
else: | |
is_valid_input_config_key = False | |
return is_valid_input_config_key, key_name, input_name, field_name | |
def check_and_parse_input_config_key(configs, key): | |
"""Checks key and returns specific fields if key is valid input config update. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
key: string indicates the target of update operation. | |
Returns: | |
is_valid_input_config_key: A boolean indicate whether the input key is to | |
update input config(s). | |
key_name: 'eval_input_configs' or 'train_input_config' string if | |
is_valid_input_config_key is true. None if is_valid_input_config_key is | |
false. | |
input_name: the name of the input config to be updated. None if | |
is_valid_input_config_key is false. | |
field_name: the field name in input config. `key` itself if | |
is_valid_input_config_key is false. | |
Raises: | |
ValueError: when the input key format doesn't match any known formats. | |
ValueError: if key_name doesn't match 'eval_input_configs' or | |
'train_input_config'. | |
ValueError: if input_name doesn't match any name in train or eval input | |
configs. | |
ValueError: if field_name doesn't match any supported fields. | |
""" | |
key_name = None | |
input_name = None | |
field_name = None | |
fields = key.split(":") | |
if len(fields) == 1: | |
field_name = key | |
return _check_and_convert_legacy_input_config_key(key) | |
elif len(fields) == 3: | |
key_name = fields[0] | |
input_name = fields[1] | |
field_name = fields[2] | |
else: | |
raise ValueError("Invalid key format when overriding configs.") | |
# Checks if key_name is valid for specific update. | |
if key_name not in ["eval_input_configs", "train_input_config"]: | |
raise ValueError("Invalid key_name when overriding input config.") | |
# Checks if input_name is valid for specific update. For train input config it | |
# should match configs[key_name].name, for eval input configs it should match | |
# the name field of one of the eval_input_configs. | |
if isinstance(configs[key_name], input_reader_pb2.InputReader): | |
is_valid_input_name = configs[key_name].name == input_name | |
else: | |
is_valid_input_name = input_name in [ | |
eval_input_config.name for eval_input_config in configs[key_name] | |
] | |
if not is_valid_input_name: | |
raise ValueError("Invalid input_name when overriding input config.") | |
# Checks if field_name is valid for specific update. | |
if field_name not in [ | |
"input_path", "label_map_path", "shuffle", "mask_type", | |
"sample_1_of_n_examples" | |
]: | |
raise ValueError("Invalid field_name when overriding input config.") | |
return True, key_name, input_name, field_name | |
def merge_external_params_with_configs(configs, hparams=None, kwargs_dict=None): | |
"""Updates `configs` dictionary based on supplied parameters. | |
This utility is for modifying specific fields in the object detection configs. | |
Say that one would like to experiment with different learning rates, momentum | |
values, or batch sizes. Rather than creating a new config text file for each | |
experiment, one can use a single base config file, and update particular | |
values. | |
There are two types of field overrides: | |
1. Strategy-based overrides, which update multiple relevant configuration | |
options. For example, updating `learning_rate` will update both the warmup and | |
final learning rates. | |
In this case key can be one of the following formats: | |
1. legacy update: single string that indicates the attribute to be | |
updated. E.g. 'label_map_path', 'eval_input_path', 'shuffle'. | |
Note that when updating fields (e.g. eval_input_path, eval_shuffle) in | |
eval_input_configs, the override will only be applied when | |
eval_input_configs has exactly 1 element. | |
2. specific update: colon separated string that indicates which field in | |
which input_config to update. It should have 3 fields: | |
- key_name: Name of the input config we should update, either | |
'train_input_config' or 'eval_input_configs' | |
- input_name: a 'name' that can be used to identify elements, especially | |
when configs[key_name] is a repeated field. | |
- field_name: name of the field that you want to override. | |
For example, given configs dict as below: | |
configs = { | |
'model': {...} | |
'train_config': {...} | |
'train_input_config': {...} | |
'eval_config': {...} | |
'eval_input_configs': [{ name:"eval_coco", ...}, | |
{ name:"eval_voc", ... }] | |
} | |
Assume we want to update the input_path of the eval_input_config | |
whose name is 'eval_coco'. The `key` would then be: | |
'eval_input_configs:eval_coco:input_path' | |
2. Generic key/value, which update a specific parameter based on namespaced | |
configuration keys. For example, | |
`model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the | |
hard example miner configuration for an SSD model config. Generic overrides | |
are automatically detected based on the namespaced keys. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
hparams: A `HParams`. | |
kwargs_dict: Extra keyword arguments that are treated the same way as | |
attribute/value pairs in `hparams`. Note that hyperparameters with the | |
same names will override keyword arguments. | |
Returns: | |
`configs` dictionary. | |
Raises: | |
ValueError: when the key string doesn't match any of its allowed formats. | |
""" | |
if kwargs_dict is None: | |
kwargs_dict = {} | |
if hparams: | |
kwargs_dict.update(hparams.values()) | |
for key, value in kwargs_dict.items(): | |
tf.logging.info("Maybe overwriting %s: %s", key, value) | |
# pylint: disable=g-explicit-bool-comparison | |
if value == "" or value is None: | |
continue | |
# pylint: enable=g-explicit-bool-comparison | |
elif _maybe_update_config_with_key_value(configs, key, value): | |
continue | |
elif _is_generic_key(key): | |
_update_generic(configs, key, value) | |
else: | |
tf.logging.info("Ignoring config override key: %s", key) | |
return configs | |
def _maybe_update_config_with_key_value(configs, key, value): | |
"""Checks key type and updates `configs` with the key value pair accordingly. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
key: String indicates the field(s) to be updated. | |
value: Value used to override existing field value. | |
Returns: | |
A boolean value that indicates whether the override succeeds. | |
Raises: | |
ValueError: when the key string doesn't match any of the formats above. | |
""" | |
is_valid_input_config_key, key_name, input_name, field_name = ( | |
check_and_parse_input_config_key(configs, key)) | |
if is_valid_input_config_key: | |
update_input_reader_config( | |
configs, | |
key_name=key_name, | |
input_name=input_name, | |
field_name=field_name, | |
value=value) | |
elif field_name == "learning_rate": | |
_update_initial_learning_rate(configs, value) | |
elif field_name == "batch_size": | |
_update_batch_size(configs, value) | |
elif field_name == "momentum_optimizer_value": | |
_update_momentum_optimizer_value(configs, value) | |
elif field_name == "classification_localization_weight_ratio": | |
# Localization weight is fixed to 1.0. | |
_update_classification_localization_weight_ratio(configs, value) | |
elif field_name == "focal_loss_gamma": | |
_update_focal_loss_gamma(configs, value) | |
elif field_name == "focal_loss_alpha": | |
_update_focal_loss_alpha(configs, value) | |
elif field_name == "train_steps": | |
_update_train_steps(configs, value) | |
elif field_name == "label_map_path": | |
_update_label_map_path(configs, value) | |
elif field_name == "mask_type": | |
_update_mask_type(configs, value) | |
elif field_name == "sample_1_of_n_eval_examples": | |
_update_all_eval_input_configs(configs, "sample_1_of_n_examples", value) | |
elif field_name == "eval_num_epochs": | |
_update_all_eval_input_configs(configs, "num_epochs", value) | |
elif field_name == "eval_with_moving_averages": | |
_update_use_moving_averages(configs, value) | |
elif field_name == "retain_original_images_in_eval": | |
_update_retain_original_images(configs["eval_config"], value) | |
elif field_name == "use_bfloat16": | |
_update_use_bfloat16(configs, value) | |
elif field_name == "retain_original_image_additional_channels_in_eval": | |
_update_retain_original_image_additional_channels(configs["eval_config"], | |
value) | |
elif field_name == "num_classes": | |
_update_num_classes(configs["model"], value) | |
else: | |
return False | |
return True | |
def _update_tf_record_input_path(input_config, input_path): | |
"""Updates input configuration to reflect a new input path. | |
The input_config object is updated in place, and hence not returned. | |
Args: | |
input_config: A input_reader_pb2.InputReader. | |
input_path: A path to data or list of paths. | |
Raises: | |
TypeError: if input reader type is not `tf_record_input_reader`. | |
""" | |
input_reader_type = input_config.WhichOneof("input_reader") | |
if input_reader_type == "tf_record_input_reader": | |
input_config.tf_record_input_reader.ClearField("input_path") | |
if isinstance(input_path, list): | |
input_config.tf_record_input_reader.input_path.extend(input_path) | |
else: | |
input_config.tf_record_input_reader.input_path.append(input_path) | |
else: | |
raise TypeError("Input reader type must be `tf_record_input_reader`.") | |
def update_input_reader_config(configs, | |
key_name=None, | |
input_name=None, | |
field_name=None, | |
value=None, | |
path_updater=_update_tf_record_input_path): | |
"""Updates specified input reader config field. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
key_name: Name of the input config we should update, either | |
'train_input_config' or 'eval_input_configs' | |
input_name: String name used to identify input config to update with. Should | |
be either None or value of the 'name' field in one of the input reader | |
configs. | |
field_name: Field name in input_reader_pb2.InputReader. | |
value: Value used to override existing field value. | |
path_updater: helper function used to update the input path. Only used when | |
field_name is "input_path". | |
Raises: | |
ValueError: when input field_name is None. | |
ValueError: when input_name is None and number of eval_input_readers does | |
not equal to 1. | |
""" | |
if isinstance(configs[key_name], input_reader_pb2.InputReader): | |
# Updates singular input_config object. | |
target_input_config = configs[key_name] | |
if field_name == "input_path": | |
path_updater(input_config=target_input_config, input_path=value) | |
else: | |
setattr(target_input_config, field_name, value) | |
elif input_name is None and len(configs[key_name]) == 1: | |
# Updates first (and the only) object of input_config list. | |
target_input_config = configs[key_name][0] | |
if field_name == "input_path": | |
path_updater(input_config=target_input_config, input_path=value) | |
else: | |
setattr(target_input_config, field_name, value) | |
elif input_name is not None and len(configs[key_name]): | |
# Updates input_config whose name matches input_name. | |
update_count = 0 | |
for input_config in configs[key_name]: | |
if input_config.name == input_name: | |
setattr(input_config, field_name, value) | |
update_count = update_count + 1 | |
if not update_count: | |
raise ValueError( | |
"Input name {} not found when overriding.".format(input_name)) | |
elif update_count > 1: | |
raise ValueError("Duplicate input name found when overriding.") | |
else: | |
key_name = "None" if key_name is None else key_name | |
input_name = "None" if input_name is None else input_name | |
field_name = "None" if field_name is None else field_name | |
raise ValueError("Unknown input config overriding: " | |
"key_name:{}, input_name:{}, field_name:{}.".format( | |
key_name, input_name, field_name)) | |
def _update_initial_learning_rate(configs, learning_rate): | |
"""Updates `configs` to reflect the new initial learning rate. | |
This function updates the initial learning rate. For learning rate schedules, | |
all other defined learning rates in the pipeline config are scaled to maintain | |
their same ratio with the initial learning rate. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
learning_rate: Initial learning rate for optimizer. | |
Raises: | |
TypeError: if optimizer type is not supported, or if learning rate type is | |
not supported. | |
""" | |
optimizer_type = get_optimizer_type(configs["train_config"]) | |
if optimizer_type == "rms_prop_optimizer": | |
optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer | |
elif optimizer_type == "momentum_optimizer": | |
optimizer_config = configs["train_config"].optimizer.momentum_optimizer | |
elif optimizer_type == "adam_optimizer": | |
optimizer_config = configs["train_config"].optimizer.adam_optimizer | |
else: | |
raise TypeError("Optimizer %s is not supported." % optimizer_type) | |
learning_rate_type = get_learning_rate_type(optimizer_config) | |
if learning_rate_type == "constant_learning_rate": | |
constant_lr = optimizer_config.learning_rate.constant_learning_rate | |
constant_lr.learning_rate = learning_rate | |
elif learning_rate_type == "exponential_decay_learning_rate": | |
exponential_lr = ( | |
optimizer_config.learning_rate.exponential_decay_learning_rate) | |
exponential_lr.initial_learning_rate = learning_rate | |
elif learning_rate_type == "manual_step_learning_rate": | |
manual_lr = optimizer_config.learning_rate.manual_step_learning_rate | |
original_learning_rate = manual_lr.initial_learning_rate | |
learning_rate_scaling = float(learning_rate) / original_learning_rate | |
manual_lr.initial_learning_rate = learning_rate | |
for schedule in manual_lr.schedule: | |
schedule.learning_rate *= learning_rate_scaling | |
elif learning_rate_type == "cosine_decay_learning_rate": | |
cosine_lr = optimizer_config.learning_rate.cosine_decay_learning_rate | |
learning_rate_base = cosine_lr.learning_rate_base | |
warmup_learning_rate = cosine_lr.warmup_learning_rate | |
warmup_scale_factor = warmup_learning_rate / learning_rate_base | |
cosine_lr.learning_rate_base = learning_rate | |
cosine_lr.warmup_learning_rate = warmup_scale_factor * learning_rate | |
else: | |
raise TypeError("Learning rate %s is not supported." % learning_rate_type) | |
def _update_batch_size(configs, batch_size): | |
"""Updates `configs` to reflect the new training batch size. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
batch_size: Batch size to use for training (Ideally a power of 2). Inputs | |
are rounded, and capped to be 1 or greater. | |
""" | |
configs["train_config"].batch_size = max(1, int(round(batch_size))) | |
def _validate_message_has_field(message, field): | |
if not message.HasField(field): | |
raise ValueError("Expecting message to have field %s" % field) | |
def _update_generic(configs, key, value): | |
"""Update a pipeline configuration parameter based on a generic key/value. | |
Args: | |
configs: Dictionary of pipeline configuration protos. | |
key: A string key, dot-delimited to represent the argument key. | |
e.g. "model.ssd.train_config.batch_size" | |
value: A value to set the argument to. The type of the value must match the | |
type for the protocol buffer. Note that setting the wrong type will | |
result in a TypeError. | |
e.g. 42 | |
Raises: | |
ValueError if the message key does not match the existing proto fields. | |
TypeError the value type doesn't match the protobuf field type. | |
""" | |
fields = key.split(".") | |
first_field = fields.pop(0) | |
last_field = fields.pop() | |
message = configs[first_field] | |
for field in fields: | |
_validate_message_has_field(message, field) | |
message = getattr(message, field) | |
_validate_message_has_field(message, last_field) | |
setattr(message, last_field, value) | |
def _update_momentum_optimizer_value(configs, momentum): | |
"""Updates `configs` to reflect the new momentum value. | |
Momentum is only supported for RMSPropOptimizer and MomentumOptimizer. For any | |
other optimizer, no changes take place. The configs dictionary is updated in | |
place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
momentum: New momentum value. Values are clipped at 0.0 and 1.0. | |
Raises: | |
TypeError: If the optimizer type is not `rms_prop_optimizer` or | |
`momentum_optimizer`. | |
""" | |
optimizer_type = get_optimizer_type(configs["train_config"]) | |
if optimizer_type == "rms_prop_optimizer": | |
optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer | |
elif optimizer_type == "momentum_optimizer": | |
optimizer_config = configs["train_config"].optimizer.momentum_optimizer | |
else: | |
raise TypeError("Optimizer type must be one of `rms_prop_optimizer` or " | |
"`momentum_optimizer`.") | |
optimizer_config.momentum_optimizer_value = min(max(0.0, momentum), 1.0) | |
def _update_classification_localization_weight_ratio(configs, ratio): | |
"""Updates the classification/localization weight loss ratio. | |
Detection models usually define a loss weight for both classification and | |
objectness. This function updates the weights such that the ratio between | |
classification weight to localization weight is the ratio provided. | |
Arbitrarily, localization weight is set to 1.0. | |
Note that in the case of Faster R-CNN, this same ratio is applied to the first | |
stage objectness loss weight relative to localization loss weight. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
ratio: Desired ratio of classification (and/or objectness) loss weight to | |
localization loss weight. | |
""" | |
meta_architecture = configs["model"].WhichOneof("model") | |
if meta_architecture == "faster_rcnn": | |
model = configs["model"].faster_rcnn | |
model.first_stage_localization_loss_weight = 1.0 | |
model.first_stage_objectness_loss_weight = ratio | |
model.second_stage_localization_loss_weight = 1.0 | |
model.second_stage_classification_loss_weight = ratio | |
if meta_architecture == "ssd": | |
model = configs["model"].ssd | |
model.loss.localization_weight = 1.0 | |
model.loss.classification_weight = ratio | |
def _get_classification_loss(model_config): | |
"""Returns the classification loss for a model.""" | |
meta_architecture = model_config.WhichOneof("model") | |
if meta_architecture == "faster_rcnn": | |
model = model_config.faster_rcnn | |
classification_loss = model.second_stage_classification_loss | |
elif meta_architecture == "ssd": | |
model = model_config.ssd | |
classification_loss = model.loss.classification_loss | |
else: | |
raise TypeError("Did not recognize the model architecture.") | |
return classification_loss | |
def _update_focal_loss_gamma(configs, gamma): | |
"""Updates the gamma value for a sigmoid focal loss. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
gamma: Exponent term in focal loss. | |
Raises: | |
TypeError: If the classification loss is not `weighted_sigmoid_focal`. | |
""" | |
classification_loss = _get_classification_loss(configs["model"]) | |
classification_loss_type = classification_loss.WhichOneof( | |
"classification_loss") | |
if classification_loss_type != "weighted_sigmoid_focal": | |
raise TypeError("Classification loss must be `weighted_sigmoid_focal`.") | |
classification_loss.weighted_sigmoid_focal.gamma = gamma | |
def _update_focal_loss_alpha(configs, alpha): | |
"""Updates the alpha value for a sigmoid focal loss. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
alpha: Class weight multiplier for sigmoid loss. | |
Raises: | |
TypeError: If the classification loss is not `weighted_sigmoid_focal`. | |
""" | |
classification_loss = _get_classification_loss(configs["model"]) | |
classification_loss_type = classification_loss.WhichOneof( | |
"classification_loss") | |
if classification_loss_type != "weighted_sigmoid_focal": | |
raise TypeError("Classification loss must be `weighted_sigmoid_focal`.") | |
classification_loss.weighted_sigmoid_focal.alpha = alpha | |
def _update_train_steps(configs, train_steps): | |
"""Updates `configs` to reflect new number of training steps.""" | |
configs["train_config"].num_steps = int(train_steps) | |
def _update_all_eval_input_configs(configs, field, value): | |
"""Updates the content of `field` with `value` for all eval input configs.""" | |
for eval_input_config in configs["eval_input_configs"]: | |
setattr(eval_input_config, field, value) | |
def _update_label_map_path(configs, label_map_path): | |
"""Updates the label map path for both train and eval input readers. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
label_map_path: New path to `StringIntLabelMap` pbtxt file. | |
""" | |
configs["train_input_config"].label_map_path = label_map_path | |
_update_all_eval_input_configs(configs, "label_map_path", label_map_path) | |
def _update_mask_type(configs, mask_type): | |
"""Updates the mask type for both train and eval input readers. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
mask_type: A string name representing a value of | |
input_reader_pb2.InstanceMaskType | |
""" | |
configs["train_input_config"].mask_type = mask_type | |
_update_all_eval_input_configs(configs, "mask_type", mask_type) | |
def _update_use_moving_averages(configs, use_moving_averages): | |
"""Updates the eval config option to use or not use moving averages. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
use_moving_averages: Boolean indicating whether moving average variables | |
should be loaded during evaluation. | |
""" | |
configs["eval_config"].use_moving_averages = use_moving_averages | |
def _update_retain_original_images(eval_config, retain_original_images): | |
"""Updates eval config with option to retain original images. | |
The eval_config object is updated in place, and hence not returned. | |
Args: | |
eval_config: A eval_pb2.EvalConfig. | |
retain_original_images: Boolean indicating whether to retain original images | |
in eval mode. | |
""" | |
eval_config.retain_original_images = retain_original_images | |
def _update_use_bfloat16(configs, use_bfloat16): | |
"""Updates `configs` to reflect the new setup on whether to use bfloat16. | |
The configs dictionary is updated in place, and hence not returned. | |
Args: | |
configs: Dictionary of configuration objects. See outputs from | |
get_configs_from_pipeline_file() or get_configs_from_multiple_files(). | |
use_bfloat16: A bool, indicating whether to use bfloat16 for training. | |
""" | |
configs["train_config"].use_bfloat16 = use_bfloat16 | |
def _update_retain_original_image_additional_channels( | |
eval_config, | |
retain_original_image_additional_channels): | |
"""Updates eval config to retain original image additional channels or not. | |
The eval_config object is updated in place, and hence not returned. | |
Args: | |
eval_config: A eval_pb2.EvalConfig. | |
retain_original_image_additional_channels: Boolean indicating whether to | |
retain original image additional channels in eval mode. | |
""" | |
eval_config.retain_original_image_additional_channels = ( | |
retain_original_image_additional_channels) | |
def remove_unecessary_ema(variables_to_restore, no_ema_collection=None): | |
"""Remap and Remove EMA variable that are not created during training. | |
ExponentialMovingAverage.variables_to_restore() returns a map of EMA names | |
to tf variables to restore. E.g.: | |
{ | |
conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma, | |
conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params, | |
global_step: global_step | |
} | |
This function takes care of the extra ExponentialMovingAverage variables | |
that get created during eval but aren't available in the checkpoint, by | |
remapping the key to the shallow copy of the variable itself, and remove | |
the entry of its EMA from the variables to restore. An example resulting | |
dictionary would look like: | |
{ | |
conv/batchnorm/gamma: conv/batchnorm/gamma, | |
conv_4/conv2d_params: conv_4/conv2d_params, | |
global_step: global_step | |
} | |
Args: | |
variables_to_restore: A dictionary created by ExponentialMovingAverage. | |
variables_to_restore(). | |
no_ema_collection: A list of namescope substrings to match the variables | |
to eliminate EMA. | |
Returns: | |
A variables_to_restore dictionary excluding the collection of unwanted | |
EMA mapping. | |
""" | |
if no_ema_collection is None: | |
return variables_to_restore | |
for key in variables_to_restore: | |
if "ExponentialMovingAverage" in key: | |
for name in no_ema_collection: | |
if name in key: | |
variables_to_restore[key.replace("/ExponentialMovingAverage", | |
"")] = variables_to_restore[key] | |
del variables_to_restore[key] | |
return variables_to_restore | |
def _update_num_classes(model_config, num_classes): | |
meta_architecture = model_config.WhichOneof("model") | |
if meta_architecture == "faster_rcnn": | |
model_config.faster_rcnn.num_classes = num_classes | |
if meta_architecture == "ssd": | |
model_config.ssd.num_classes = num_classes | |