File size: 19,540 Bytes
0b8359d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
# Lint as: python3
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import os
from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging
from dataclasses import dataclass
import tensorflow as tf
import tensorflow_datasets as tfds

from official.modeling.hyperparams import base_config
from official.vision.image_classification import augment
from official.vision.image_classification import preprocessing


AUGMENTERS = {
    'autoaugment': augment.AutoAugment,
    'randaugment': augment.RandAugment,
}


@dataclass
class AugmentConfig(base_config.Config):
  """Configuration for image augmenters.

  Attributes:
    name: The name of the image augmentation to use. Possible options are
      None (default), 'autoaugment', or 'randaugment'.
    params: Any paramaters used to initialize the augmenter.
  """
  name: Optional[str] = None
  params: Optional[Mapping[str, Any]] = None

  def build(self) -> augment.ImageAugment:
    """Build the augmenter using this config."""
    params = self.params or {}
    augmenter = AUGMENTERS.get(self.name, None)
    return augmenter(**params) if augmenter is not None else None


@dataclass
class DatasetConfig(base_config.Config):
  """The base configuration for building datasets.

  Attributes:
    name: The name of the Dataset. Usually should correspond to a TFDS dataset.
    data_dir: The path where the dataset files are stored, if available.
    filenames: Optional list of strings representing the TFRecord names.
    builder: The builder type used to load the dataset. Value should be one of
      'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
      (generate dummy synthetic data without reading from files).
    split: The split of the dataset. Usually 'train', 'validation', or 'test'.
    image_size: The size of the image in the dataset. This assumes that
      `width` == `height`. Set to 'infer' to infer the image size from TFDS
      info. This requires `name` to be a registered dataset in TFDS.
    num_classes: The number of classes given by the dataset. Set to 'infer'
      to infer the image size from TFDS info. This requires `name` to be a
      registered dataset in TFDS.
    num_channels: The number of channels given by the dataset. Set to 'infer'
      to infer the image size from TFDS info. This requires `name` to be a
      registered dataset in TFDS.
    num_examples: The number of examples given by the dataset. Set to 'infer'
      to infer the image size from TFDS info. This requires `name` to be a
      registered dataset in TFDS.
    batch_size: The base batch size for the dataset.
    use_per_replica_batch_size: Whether to scale the batch size based on
      available resources. If set to `True`, the dataset builder will return
      batch_size multiplied by `num_devices`, the number of device replicas
      (e.g., the number of GPUs or TPU cores). This setting should be `True` if
      the strategy argument is passed to `build()` and `num_devices > 1`.
    num_devices: The number of replica devices to use. This should be set by
      `strategy.num_replicas_in_sync` when using a distribution strategy.
    dtype: The desired dtype of the dataset. This will be set during
      preprocessing.
    one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
      label smoothing.
    augmenter: The augmenter config to use. No augmentation is used by default.
    download: Whether to download data using TFDS.
    shuffle_buffer_size: The buffer size used for shuffling training data.
    file_shuffle_buffer_size: The buffer size used for shuffling raw training
      files.
    skip_decoding: Whether to skip image decoding when loading from TFDS.
    cache: whether to cache to dataset examples. Can be used to avoid re-reading
      from disk on the second epoch. Requires significant memory overhead.
    tf_data_service: The URI of a tf.data service to offload preprocessing onto
      during training. The URI should be in the format "protocol://address",
      e.g. "grpc://tf-data-service:5050".
    mean_subtract: whether or not to apply mean subtraction to the dataset.
    standardize: whether or not to apply standardization to the dataset.
  """
  name: Optional[str] = None
  data_dir: Optional[str] = None
  filenames: Optional[List[str]] = None
  builder: str = 'tfds'
  split: str = 'train'
  image_size: Union[int, str] = 'infer'
  num_classes: Union[int, str] = 'infer'
  num_channels: Union[int, str] = 'infer'
  num_examples: Union[int, str] = 'infer'
  batch_size: int = 128
  use_per_replica_batch_size: bool = True
  num_devices: int = 1
  dtype: str = 'float32'
  one_hot: bool = True
  augmenter: AugmentConfig = AugmentConfig()
  download: bool = False
  shuffle_buffer_size: int = 10000
  file_shuffle_buffer_size: int = 1024
  skip_decoding: bool = True
  cache: bool = False
  tf_data_service: Optional[str] = None
  mean_subtract: bool = False
  standardize: bool = False

  @property
  def has_data(self):
    """Whether this dataset is has any data associated with it."""
    return self.name or self.data_dir or self.filenames


@dataclass
class ImageNetConfig(DatasetConfig):
  """The base ImageNet dataset config."""
  name: str = 'imagenet2012'
  # Note: for large datasets like ImageNet, using records is faster than tfds
  builder: str = 'records'
  image_size: int = 224
  batch_size: int = 128


@dataclass
class Cifar10Config(DatasetConfig):
  """The base CIFAR-10 dataset config."""
  name: str = 'cifar10'
  image_size: int = 224
  batch_size: int = 128
  download: bool = True
  cache: bool = True


class DatasetBuilder:
  """An object for building datasets.

  Allows building various pipelines fetching examples, preprocessing, etc.
  Maintains additional state information calculated from the dataset, i.e.,
  training set split, batch size, and number of steps (batches).
  """

  def __init__(self, config: DatasetConfig, **overrides: Any):
    """Initialize the builder from the config."""
    self.config = config.replace(**overrides)
    self.builder_info = None

    if self.config.augmenter is not None:
      logging.info('Using augmentation: %s', self.config.augmenter.name)
      self.augmenter = self.config.augmenter.build()
    else:
      self.augmenter = None

  @property
  def is_training(self) -> bool:
    """Whether this is the training set."""
    return self.config.split == 'train'

  @property
  def batch_size(self) -> int:
    """The batch size, multiplied by the number of replicas (if configured)."""
    if self.config.use_per_replica_batch_size:
      return self.config.batch_size * self.config.num_devices
    else:
      return self.config.batch_size

  @property
  def global_batch_size(self):
    """The global batch size across all replicas."""
    return self.batch_size

  @property
  def local_batch_size(self):
    """The base unscaled batch size."""
    if self.config.use_per_replica_batch_size:
      return self.config.batch_size
    else:
      return self.config.batch_size // self.config.num_devices

  @property
  def num_steps(self) -> int:
    """The number of steps (batches) to exhaust this dataset."""
    # Always divide by the global batch size to get the correct # of steps
    return self.num_examples // self.global_batch_size

  @property
  def dtype(self) -> tf.dtypes.DType:
    """Converts the config's dtype string to a tf dtype.

    Returns:
      A mapping from string representation of a dtype to the `tf.dtypes.DType`.

    Raises:
      ValueError if the config's dtype is not supported.

    """
    dtype_map = {
        'float32': tf.float32,
        'bfloat16': tf.bfloat16,
        'float16': tf.float16,
        'fp32': tf.float32,
        'bf16': tf.bfloat16,
    }
    try:
      return dtype_map[self.config.dtype]
    except:
      raise ValueError('Invalid DType provided. Supported types: {}'.format(
          dtype_map.keys()))

  @property
  def image_size(self) -> int:
    """The size of each image (can be inferred from the dataset)."""

    if self.config.image_size == 'infer':
      return self.info.features['image'].shape[0]
    else:
      return int(self.config.image_size)

  @property
  def num_channels(self) -> int:
    """The number of image channels (can be inferred from the dataset)."""
    if self.config.num_channels == 'infer':
      return self.info.features['image'].shape[-1]
    else:
      return int(self.config.num_channels)

  @property
  def num_examples(self) -> int:
    """The number of examples (can be inferred from the dataset)."""
    if self.config.num_examples == 'infer':
      return self.info.splits[self.config.split].num_examples
    else:
      return int(self.config.num_examples)

  @property
  def num_classes(self) -> int:
    """The number of classes (can be inferred from the dataset)."""
    if self.config.num_classes == 'infer':
      return self.info.features['label'].num_classes
    else:
      return int(self.config.num_classes)

  @property
  def info(self) -> tfds.core.DatasetInfo:
    """The TFDS dataset info, if available."""
    if self.builder_info is None:
      self.builder_info = tfds.builder(self.config.name).info
    return self.builder_info

  def build(self, strategy: tf.distribute.Strategy = None) -> tf.data.Dataset:
    """Construct a dataset end-to-end and return it using an optional strategy.

    Args:
      strategy: a strategy that, if passed, will distribute the dataset
        according to that strategy. If passed and `num_devices > 1`,
        `use_per_replica_batch_size` must be set to `True`.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    if strategy:
      if strategy.num_replicas_in_sync != self.config.num_devices:
        logging.warn('Passed a strategy with %d devices, but expected'
                     '%d devices.',
                     strategy.num_replicas_in_sync,
                     self.config.num_devices)
      dataset = strategy.experimental_distribute_datasets_from_function(
          self._build)
    else:
      dataset = self._build()

    return dataset

  def _build(self, input_context: tf.distribute.InputContext = None
             ) -> tf.data.Dataset:
    """Construct a dataset end-to-end and return it.

    Args:
      input_context: An optional context provided by `tf.distribute` for
        cross-replica training.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    builders = {
        'tfds': self.load_tfds,
        'records': self.load_records,
        'synthetic': self.load_synthetic,
    }

    builder = builders.get(self.config.builder, None)

    if builder is None:
      raise ValueError('Unknown builder type {}'.format(self.config.builder))

    self.input_context = input_context
    dataset = builder()
    dataset = self.pipeline(dataset)

    return dataset

  def load_tfds(self) -> tf.data.Dataset:
    """Return a dataset loading files from TFDS."""

    logging.info('Using TFDS to load data.')

    builder = tfds.builder(self.config.name,
                           data_dir=self.config.data_dir)

    if self.config.download:
      builder.download_and_prepare()

    decoders = {}

    if self.config.skip_decoding:
      decoders['image'] = tfds.decode.SkipDecoding()

    read_config = tfds.ReadConfig(
        interleave_cycle_length=10,
        interleave_block_length=1,
        input_context=self.input_context)

    dataset = builder.as_dataset(
        split=self.config.split,
        as_supervised=True,
        shuffle_files=True,
        decoders=decoders,
        read_config=read_config)

    return dataset

  def load_records(self) -> tf.data.Dataset:
    """Return a dataset loading files with TFRecords."""
    logging.info('Using TFRecords to load data.')
    if self.config.filenames is None:
      if self.config.data_dir is None:
        raise ValueError('Dataset must specify a path for the data files.')

      file_pattern = os.path.join(self.config.data_dir,
                                  '{}*'.format(self.config.split))
      dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
    else:
      dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames)

    return dataset

  def load_synthetic(self) -> tf.data.Dataset:
    """Return a dataset generating dummy synthetic data."""
    logging.info('Generating a synthetic dataset.')

    def generate_data(_):
      image = tf.zeros([self.image_size, self.image_size, self.num_channels],
                       dtype=self.dtype)
      label = tf.zeros([1], dtype=tf.int32)
      return image, label

    dataset = tf.data.Dataset.range(1)
    dataset = dataset.repeat()
    dataset = dataset.map(generate_data,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)
    return dataset

  def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Build a pipeline fetching, shuffling, and preprocessing the dataset.

    Args:
      dataset: A `tf.data.Dataset` that loads raw files.

    Returns:
      A TensorFlow dataset outputting batched images and labels.
    """
    if (self.config.builder != 'tfds' and self.input_context
        and self.input_context.num_input_pipelines > 1):
      dataset = dataset.shard(self.input_context.num_input_pipelines,
                              self.input_context.input_pipeline_id)
      logging.info('Sharding the dataset: input_pipeline_id=%d '
                   'num_input_pipelines=%d',
                   self.input_context.num_input_pipelines,
                   self.input_context.input_pipeline_id)

    if self.is_training and self.config.builder == 'records':
      # Shuffle the input files.
      dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)

    if self.is_training and not self.config.cache:
      dataset = dataset.repeat()

    if self.config.builder == 'records':
      # Read the data from disk in parallel
      dataset = dataset.interleave(
          tf.data.TFRecordDataset,
          cycle_length=10,
          block_length=1,
          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self.config.cache:
      dataset = dataset.cache()

    if self.is_training:
      dataset = dataset.shuffle(self.config.shuffle_buffer_size)
      dataset = dataset.repeat()

    # Parse, pre-process, and batch the data in parallel
    if self.config.builder == 'records':
      preprocess = self.parse_record
    else:
      preprocess = self.preprocess
    dataset = dataset.map(preprocess,
                          num_parallel_calls=tf.data.experimental.AUTOTUNE)

    if self.input_context and self.config.num_devices > 1:
      if not self.config.use_per_replica_batch_size:
        raise ValueError(
            'The builder does not support a global batch size with more than '
            'one replica. Got {} replicas. Please set a '
            '`per_replica_batch_size` and enable '
            '`use_per_replica_batch_size=True`.'.format(
                self.config.num_devices))

      # The batch size of the dataset will be multiplied by the number of
      # replicas automatically when strategy.distribute_datasets_from_function
      # is called, so we use local batch size here.
      dataset = dataset.batch(self.local_batch_size,
                              drop_remainder=self.is_training)
    else:
      dataset = dataset.batch(self.global_batch_size,
                              drop_remainder=self.is_training)

    # Prefetch overlaps in-feed with training
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    if self.config.tf_data_service:
      if not hasattr(tf.data.experimental, 'service'):
        raise ValueError('The tf_data_service flag requires Tensorflow version '
                         '>= 2.3.0, but the version is {}'.format(
                             tf.__version__))
      dataset = dataset.apply(
          tf.data.experimental.service.distribute(
              processing_mode='parallel_epochs',
              service=self.config.tf_data_service,
              job_name='resnet_train'))
      dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

  def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parse an ImageNet record from a serialized string Tensor."""
    keys_to_features = {
        'image/encoded':
            tf.io.FixedLenFeature((), tf.string, ''),
        'image/format':
            tf.io.FixedLenFeature((), tf.string, 'jpeg'),
        'image/class/label':
            tf.io.FixedLenFeature([], tf.int64, -1),
        'image/class/text':
            tf.io.FixedLenFeature([], tf.string, ''),
        'image/object/bbox/xmin':
            tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymin':
            tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/xmax':
            tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/bbox/ymax':
            tf.io.VarLenFeature(dtype=tf.float32),
        'image/object/class/label':
            tf.io.VarLenFeature(dtype=tf.int64),
    }

    parsed = tf.io.parse_single_example(record, keys_to_features)

    label = tf.reshape(parsed['image/class/label'], shape=[1])

    # Subtract one so that labels are in [0, 1000)
    label -= 1

    image_bytes = tf.reshape(parsed['image/encoded'], shape=[])
    image, label = self.preprocess(image_bytes, label)

    return image, label

  def preprocess(self, image: tf.Tensor, label: tf.Tensor
                ) -> Tuple[tf.Tensor, tf.Tensor]:
    """Apply image preprocessing and augmentation to the image and label."""
    if self.is_training:
      image = preprocessing.preprocess_for_train(
          image,
          image_size=self.image_size,
          mean_subtract=self.config.mean_subtract,
          standardize=self.config.standardize,
          dtype=self.dtype,
          augmenter=self.augmenter)
    else:
      image = preprocessing.preprocess_for_eval(
          image,
          image_size=self.image_size,
          num_channels=self.num_channels,
          mean_subtract=self.config.mean_subtract,
          standardize=self.config.standardize,
          dtype=self.dtype)

    label = tf.cast(label, tf.int32)
    if self.config.one_hot:
      label = tf.one_hot(label, self.num_classes)
      label = tf.reshape(label, [self.num_classes])

    return image, label

  @classmethod
  def from_params(cls, *args, **kwargs):
    """Construct a dataset builder from a default config and any overrides."""
    config = DatasetConfig.from_args(*args, **kwargs)
    return cls(config)