# Lint as: python3 # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Common configuration settings.""" from typing import Optional, Union import dataclasses from official.modeling.hyperparams import base_config from official.modeling.optimization.configs import optimization_config from official.utils import registry OptimizationConfig = optimization_config.OptimizationConfig @dataclasses.dataclass class DataConfig(base_config.Config): """The base configuration for building datasets. Attributes: input_path: The path to the input. It can be either (1) a file pattern, or (2) multiple file patterns separated by comma. It should not be specified when the following `tfds_name` is specified. tfds_name: The name of the tensorflow dataset (TFDS). It should not be specified when the above `input_path` is specified. tfds_split: A str indicating which split of the data to load from TFDS. It is required when above `tfds_name` is specified. global_batch_size: The global batch size across all replicas. is_training: Whether this data is used for training or not. drop_remainder: Whether the last batch should be dropped in the case it has fewer than `global_batch_size` elements. shuffle_buffer_size: The buffer size used for shuffling training data. cache: Whether to cache dataset examples. Can be used to avoid re-reading from disk on the second epoch. Requires significant memory overhead. cycle_length: The number of files that will be processed concurrently when interleaving files. block_length: The number of consecutive elements to produce from each input element before cycling to another input element when interleaving files. sharding: Whether sharding is used in the input pipeline. examples_consume: An `integer` specifying the number of examples it will produce. If positive, it only takes this number of examples and raises tf.error.OutOfRangeError after that. Default is -1, meaning it will exhaust all the examples in the dataset. tfds_data_dir: A str specifying the directory to read/write TFDS data. tfds_download: A bool to indicate whether to download data using TFDS. tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the returned tf.data.Dataset will have a 2-tuple structure (input, label) according to builder.info.supervised_keys; if False, the default, the returned tf.data.Dataset will have a dictionary with all the features. tfds_skip_decoding_feature: A str to indicate which features are skipped for decoding when loading dataset from TFDS. Use comma to separate multiple features. The main use case is to skip the image/video decoding for better performance. """ input_path: str = "" tfds_name: str = "" tfds_split: str = "" global_batch_size: int = 0 is_training: bool = None drop_remainder: bool = True shuffle_buffer_size: int = 100 cache: bool = False cycle_length: int = 8 block_length: int = 1 sharding: bool = True examples_consume: int = -1 tfds_data_dir: str = "" tfds_download: bool = False tfds_as_supervised: bool = False tfds_skip_decoding_feature: str = "" @dataclasses.dataclass class RuntimeConfig(base_config.Config): """High-level configurations for Runtime. These include parameters that are not directly related to the experiment, e.g. directories, accelerator type, etc. Attributes: distribution_strategy: e.g. 'mirrored', 'tpu', etc. enable_xla: Whether or not to enable XLA. per_gpu_thread_count: thread count per GPU. gpu_thread_mode: Whether and how the GPU device uses its own threadpool. dataset_num_private_threads: Number of threads for a private threadpool created for all datasets computation. tpu: The address of the TPU to use, if any. num_gpus: The number of GPUs to use, if any. worker_hosts: comma-separated list of worker ip:port pairs for running multi-worker models with DistributionStrategy. task_index: If multi-worker training, the task index of this worker. all_reduce_alg: Defines the algorithm for performing all-reduce. num_packs: Sets `num_packs` in the cross device ops used in MirroredStrategy. For details, see tf.distribute.NcclAllReduce. mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32', 'float16', or 'bfloat16'. loss_scale: The type of loss scale, or 'float' value. This is used when setting the mixed precision policy. run_eagerly: Whether or not to run the experiment eagerly. batchnorm_spatial_persistent: Whether or not to enable the spatial persistent mode for CuDNN batch norm kernel for improved GPU performance. """ distribution_strategy: str = "mirrored" enable_xla: bool = False gpu_thread_mode: Optional[str] = None dataset_num_private_threads: Optional[int] = None per_gpu_thread_count: int = 0 tpu: Optional[str] = None num_gpus: int = 0 worker_hosts: Optional[str] = None task_index: int = -1 all_reduce_alg: Optional[str] = None num_packs: int = 1 loss_scale: Optional[Union[str, float]] = None mixed_precision_dtype: Optional[str] = None run_eagerly: bool = False batchnorm_spatial_persistent: bool = False @dataclasses.dataclass class TensorboardConfig(base_config.Config): """Configuration for Tensorboard. Attributes: track_lr: Whether or not to track the learning rate in Tensorboard. Defaults to True. write_model_weights: Whether or not to write the model weights as images in Tensorboard. Defaults to False. """ track_lr: bool = True write_model_weights: bool = False @dataclasses.dataclass class CallbacksConfig(base_config.Config): """Configuration for Callbacks. Attributes: enable_checkpoint_and_export: Whether or not to enable checkpoints as a Callback. Defaults to True. enable_tensorboard: Whether or not to enable Tensorboard as a Callback. Defaults to True. enable_time_history: Whether or not to enable TimeHistory Callbacks. Defaults to True. """ enable_checkpoint_and_export: bool = True enable_tensorboard: bool = True enable_time_history: bool = True @dataclasses.dataclass class TrainerConfig(base_config.Config): """Configuration for trainer. Attributes: optimizer_config: optimizer config, it includes optimizer, learning rate, and warmup schedule configs. train_tf_while_loop: whether or not to use tf while loop. train_tf_function: whether or not to use tf_function for training loop. eval_tf_function: whether or not to use tf_function for eval. steps_per_loop: number of steps per loop. summary_interval: number of steps between each summary. checkpoint_intervals: number of steps between checkpoints. max_to_keep: max checkpoints to keep. continuous_eval_timeout: maximum number of seconds to wait between checkpoints, if set to None, continuous eval will wait indefinetely. """ optimizer_config: OptimizationConfig = OptimizationConfig() train_tf_while_loop: bool = True train_tf_function: bool = True eval_tf_function: bool = True steps_per_loop: int = 1000 summary_interval: int = 1000 checkpoint_interval: int = 1000 max_to_keep: int = 5 continuous_eval_timeout: Optional[int] = None @dataclasses.dataclass class TaskConfig(base_config.Config): network: base_config.Config = None train_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig() @dataclasses.dataclass class ExperimentConfig(base_config.Config): """Top-level configuration.""" task: TaskConfig = TaskConfig() trainer: TrainerConfig = TrainerConfig() runtime: RuntimeConfig = RuntimeConfig() train_steps: int = 0 validation_steps: Optional[int] = None validation_interval: int = 100 _REGISTERED_CONFIGS = {} def register_config_factory(name): """Register ExperimentConfig factory method.""" return registry.register(_REGISTERED_CONFIGS, name) def get_exp_config_creater(exp_name: str): """Looks up ExperimentConfig factory methods.""" exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name) return exp_creater