Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Provides flags that are common to scripts. | |
Common flags from train/eval/vis/export_model.py are collected in this script. | |
""" | |
import collections | |
import copy | |
import json | |
import tensorflow as tf | |
flags = tf.app.flags | |
# Flags for input preprocessing. | |
flags.DEFINE_integer('min_resize_value', None, | |
'Desired size of the smaller image side.') | |
flags.DEFINE_integer('max_resize_value', None, | |
'Maximum allowed size of the larger image side.') | |
flags.DEFINE_integer('resize_factor', None, | |
'Resized dimensions are multiple of factor plus one.') | |
flags.DEFINE_boolean('keep_aspect_ratio', True, | |
'Keep aspect ratio after resizing or not.') | |
# Model dependent flags. | |
flags.DEFINE_integer('logits_kernel_size', 1, | |
'The kernel size for the convolutional kernel that ' | |
'generates logits.') | |
# When using 'mobilent_v2', we set atrous_rates = decoder_output_stride = None. | |
# When using 'xception_65' or 'resnet_v1' model variants, we set | |
# atrous_rates = [6, 12, 18] (output stride 16) and decoder_output_stride = 4. | |
# See core/feature_extractor.py for supported model variants. | |
flags.DEFINE_string('model_variant', 'mobilenet_v2', 'DeepLab model variant.') | |
flags.DEFINE_multi_float('image_pyramid', None, | |
'Input scales for multi-scale feature extraction.') | |
flags.DEFINE_boolean('add_image_level_feature', True, | |
'Add image level feature.') | |
flags.DEFINE_list( | |
'image_pooling_crop_size', None, | |
'Image pooling crop size [height, width] used in the ASPP module. When ' | |
'value is None, the model performs image pooling with "crop_size". This' | |
'flag is useful when one likes to use different image pooling sizes.') | |
flags.DEFINE_list( | |
'image_pooling_stride', '1,1', | |
'Image pooling stride [height, width] used in the ASPP image pooling. ') | |
flags.DEFINE_boolean('aspp_with_batch_norm', True, | |
'Use batch norm parameters for ASPP or not.') | |
flags.DEFINE_boolean('aspp_with_separable_conv', True, | |
'Use separable convolution for ASPP or not.') | |
# Defaults to None. Set multi_grid = [1, 2, 4] when using provided | |
# 'resnet_v1_{50,101}_beta' checkpoints. | |
flags.DEFINE_multi_integer('multi_grid', None, | |
'Employ a hierarchy of atrous rates for ResNet.') | |
flags.DEFINE_float('depth_multiplier', 1.0, | |
'Multiplier for the depth (number of channels) for all ' | |
'convolution ops used in MobileNet.') | |
flags.DEFINE_integer('divisible_by', None, | |
'An integer that ensures the layer # channels are ' | |
'divisible by this value. Used in MobileNet.') | |
# For `xception_65`, use decoder_output_stride = 4. For `mobilenet_v2`, use | |
# decoder_output_stride = None. | |
flags.DEFINE_list('decoder_output_stride', None, | |
'Comma-separated list of strings with the number specifying ' | |
'output stride of low-level features at each network level.' | |
'Current semantic segmentation implementation assumes at ' | |
'most one output stride (i.e., either None or a list with ' | |
'only one element.') | |
flags.DEFINE_boolean('decoder_use_separable_conv', True, | |
'Employ separable convolution for decoder or not.') | |
flags.DEFINE_enum('merge_method', 'max', ['max', 'avg'], | |
'Scheme to merge multi scale features.') | |
flags.DEFINE_boolean( | |
'prediction_with_upsampled_logits', True, | |
'When performing prediction, there are two options: (1) bilinear ' | |
'upsampling the logits followed by softmax, or (2) softmax followed by ' | |
'bilinear upsampling.') | |
flags.DEFINE_string( | |
'dense_prediction_cell_json', | |
'', | |
'A JSON file that specifies the dense prediction cell.') | |
flags.DEFINE_integer( | |
'nas_stem_output_num_conv_filters', 20, | |
'Number of filters of the stem output tensor in NAS models.') | |
flags.DEFINE_bool('nas_use_classification_head', False, | |
'Use image classification head for NAS model variants.') | |
flags.DEFINE_bool('nas_remove_os32_stride', False, | |
'Remove the stride in the output stride 32 branch.') | |
flags.DEFINE_bool('use_bounded_activation', False, | |
'Whether or not to use bounded activations. Bounded ' | |
'activations better lend themselves to quantized inference.') | |
flags.DEFINE_boolean('aspp_with_concat_projection', True, | |
'ASPP with concat projection.') | |
flags.DEFINE_boolean('aspp_with_squeeze_and_excitation', False, | |
'ASPP with squeeze and excitation.') | |
flags.DEFINE_integer('aspp_convs_filters', 256, 'ASPP convolution filters.') | |
flags.DEFINE_boolean('decoder_use_sum_merge', False, | |
'Decoder uses simply sum merge.') | |
flags.DEFINE_integer('decoder_filters', 256, 'Decoder filters.') | |
flags.DEFINE_boolean('decoder_output_is_logits', False, | |
'Use decoder output as logits or not.') | |
flags.DEFINE_boolean('image_se_uses_qsigmoid', False, 'Use q-sigmoid.') | |
flags.DEFINE_multi_float( | |
'label_weights', None, | |
'A list of label weights, each element represents the weight for the label ' | |
'of its index, for example, label_weights = [0.1, 0.5] means the weight ' | |
'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all ' | |
'the labels have the same weight 1.0.') | |
flags.DEFINE_float('batch_norm_decay', 0.9997, 'Batchnorm decay.') | |
FLAGS = flags.FLAGS | |
# Constants | |
# Perform semantic segmentation predictions. | |
OUTPUT_TYPE = 'semantic' | |
# Semantic segmentation item names. | |
LABELS_CLASS = 'labels_class' | |
IMAGE = 'image' | |
HEIGHT = 'height' | |
WIDTH = 'width' | |
IMAGE_NAME = 'image_name' | |
LABEL = 'label' | |
ORIGINAL_IMAGE = 'original_image' | |
# Test set name. | |
TEST_SET = 'test' | |
class ModelOptions( | |
collections.namedtuple('ModelOptions', [ | |
'outputs_to_num_classes', | |
'crop_size', | |
'atrous_rates', | |
'output_stride', | |
'preprocessed_images_dtype', | |
'merge_method', | |
'add_image_level_feature', | |
'image_pooling_crop_size', | |
'image_pooling_stride', | |
'aspp_with_batch_norm', | |
'aspp_with_separable_conv', | |
'multi_grid', | |
'decoder_output_stride', | |
'decoder_use_separable_conv', | |
'logits_kernel_size', | |
'model_variant', | |
'depth_multiplier', | |
'divisible_by', | |
'prediction_with_upsampled_logits', | |
'dense_prediction_cell_config', | |
'nas_architecture_options', | |
'use_bounded_activation', | |
'aspp_with_concat_projection', | |
'aspp_with_squeeze_and_excitation', | |
'aspp_convs_filters', | |
'decoder_use_sum_merge', | |
'decoder_filters', | |
'decoder_output_is_logits', | |
'image_se_uses_qsigmoid', | |
'label_weights', | |
'sync_batch_norm_method', | |
'batch_norm_decay', | |
])): | |
"""Immutable class to hold model options.""" | |
__slots__ = () | |
def __new__(cls, | |
outputs_to_num_classes, | |
crop_size=None, | |
atrous_rates=None, | |
output_stride=8, | |
preprocessed_images_dtype=tf.float32): | |
"""Constructor to set default values. | |
Args: | |
outputs_to_num_classes: A dictionary from output type to the number of | |
classes. For example, for the task of semantic segmentation with 21 | |
semantic classes, we would have outputs_to_num_classes['semantic'] = 21. | |
crop_size: A tuple [crop_height, crop_width]. | |
atrous_rates: A list of atrous convolution rates for ASPP. | |
output_stride: The ratio of input to output spatial resolution. | |
preprocessed_images_dtype: The type after the preprocessing function. | |
Returns: | |
A new ModelOptions instance. | |
""" | |
dense_prediction_cell_config = None | |
if FLAGS.dense_prediction_cell_json: | |
with tf.gfile.Open(FLAGS.dense_prediction_cell_json, 'r') as f: | |
dense_prediction_cell_config = json.load(f) | |
decoder_output_stride = None | |
if FLAGS.decoder_output_stride: | |
decoder_output_stride = [ | |
int(x) for x in FLAGS.decoder_output_stride] | |
if sorted(decoder_output_stride, reverse=True) != decoder_output_stride: | |
raise ValueError('Decoder output stride need to be sorted in the ' | |
'descending order.') | |
image_pooling_crop_size = None | |
if FLAGS.image_pooling_crop_size: | |
image_pooling_crop_size = [int(x) for x in FLAGS.image_pooling_crop_size] | |
image_pooling_stride = [1, 1] | |
if FLAGS.image_pooling_stride: | |
image_pooling_stride = [int(x) for x in FLAGS.image_pooling_stride] | |
label_weights = FLAGS.label_weights | |
if label_weights is None: | |
label_weights = 1.0 | |
nas_architecture_options = { | |
'nas_stem_output_num_conv_filters': ( | |
FLAGS.nas_stem_output_num_conv_filters), | |
'nas_use_classification_head': FLAGS.nas_use_classification_head, | |
'nas_remove_os32_stride': FLAGS.nas_remove_os32_stride, | |
} | |
return super(ModelOptions, cls).__new__( | |
cls, outputs_to_num_classes, crop_size, atrous_rates, output_stride, | |
preprocessed_images_dtype, | |
FLAGS.merge_method, | |
FLAGS.add_image_level_feature, | |
image_pooling_crop_size, | |
image_pooling_stride, | |
FLAGS.aspp_with_batch_norm, | |
FLAGS.aspp_with_separable_conv, | |
FLAGS.multi_grid, | |
decoder_output_stride, | |
FLAGS.decoder_use_separable_conv, | |
FLAGS.logits_kernel_size, | |
FLAGS.model_variant, | |
FLAGS.depth_multiplier, | |
FLAGS.divisible_by, | |
FLAGS.prediction_with_upsampled_logits, | |
dense_prediction_cell_config, | |
nas_architecture_options, | |
FLAGS.use_bounded_activation, | |
FLAGS.aspp_with_concat_projection, | |
FLAGS.aspp_with_squeeze_and_excitation, | |
FLAGS.aspp_convs_filters, | |
FLAGS.decoder_use_sum_merge, | |
FLAGS.decoder_filters, | |
FLAGS.decoder_output_is_logits, | |
FLAGS.image_se_uses_qsigmoid, | |
label_weights, | |
'None', | |
FLAGS.batch_norm_decay) | |
def __deepcopy__(self, memo): | |
return ModelOptions(copy.deepcopy(self.outputs_to_num_classes), | |
self.crop_size, | |
self.atrous_rates, | |
self.output_stride, | |
self.preprocessed_images_dtype) | |