File size: 29,845 Bytes
a5ccd04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generic JAX training loop for experiments."""

import functools
import os
from typing import (Any, Callable, Dict, Optional, Sequence, Tuple)

from absl import logging
from clu import metric_writers
import flax
from flax import jax_utils
from flax import linen as nn
from flax import struct
from flax.training import checkpoints
import gin
import jax
import jax.numpy as jnp
import  metrics_summary
import  optimizer_config as opt_config
import  training_task
import numpy as np
import tensorflow.compat.v2 as tf


PRNGKeys = training_task.PRNGKeys
TrainState = training_task.TrainState
TrainingTask = training_task.TrainingTask
StepFunction = training_task.StepFunction
Metrics = training_task.Metrics
MetricWriter = metric_writers.MetricWriter
MetricsSummary = metrics_summary.MetricsSummary


gfile = tf.io.gfile
unfreeze = flax.core.unfreeze
flatten_dict = flax.traverse_util.flatten_dict
should_run = training_task.should_run


# TODO(cstaats): Use a Protocol to specify that it must be possible to call
# the function with parameters (step: int, mode: str). This won't be feasible
# until we start using Python 3.8 or later.
StepModeCallable = Callable[..., None]


# This variable should *only* be set from register_interstep_callbacks.
_interstep_callbacks: Optional[Tuple[StepModeCallable, ...]] = None


@gin.configurable
def register_interstep_callbacks(**kwargs: StepModeCallable) -> None:
  """Populates _interstep_callbacks from gin.

  This function should be called exactly ONCE and that call should happen AFTER
  flag initialization (and more specifically, after gin parsing). And the caller
  should NOT specify any arguments.

  In gin configurations, a callback can be specified with an arbitrary name
  like so:

      register_interstep_callbacks.my_callback_name = @my_callback_function

  Multiple callbacks can be registered without overriding each other as long as
  they all have different names. Conversely, if you *want* to override a
  callback, you need to give that callback the same name.

  Args:
    **kwargs: Specified by gin. Each argument should be a function (callable)
      that can be called as my_function(step, mode), where step is an int and
      mode is a str.

  Raises:
    ValueError: Raised on the second (and any subsequent) function call.
  """
  global _interstep_callbacks
  logging.info("registering functions: %s", kwargs.keys())
  if _interstep_callbacks is not None:
    raise ValueError("register_interstep_callbacks may only be called once.")
  _interstep_callbacks = tuple(kwargs.values())


def clear_interstep_callbacks():
  """Clear all registered callbacks, so that new ones can be registered."""
  global _interstep_callbacks
  _interstep_callbacks = None


def run_interstep_callbacks(mode: str, step: int, sub_step: int = 0):
  """Run the registered callbacks.

  Args:
    mode: mode of the task to execute callbacks for.
    step: training step number.
    sub_step: For tasks that execute multiple iterations within a step.
      E.g. a test cycle that runs multiple testing steps.
  """
  for func in _interstep_callbacks:
    func(sub_step or step, mode)


@gin.configurable
@struct.dataclass
class Trainer:
  """Implements a JAX training loop."""

  # Returns a Flax module for the model.
  # Takes a single argument mode, which can be "test", "train", or "generate".
  model_definition: Any = gin.REQUIRED

  # Iterator over trainining data.
  get_training_dataset_iterator: Callable[[], Any] = gin.REQUIRED

  # Iterator over test data.
  get_test_dataset_iterator: Optional[Callable[[], Any]] = None

  workdir: str = ""                    # Working directory for checkpoints.
  load_dir: str = ""                   # Optional directory to load model.
  num_steps: int = 100000              # Number of steps to train.
  status_every_steps: int = 10         # Log step number every N steps.
  log_every_steps: int = 100           # Log scalar data every N steps.
  test_every_steps: int = 10           # Test model every N steps.
  num_test_steps: int = 1              # Number of iterations to test.
  generate_every_steps: int = 1000     # Generate examples every N steps.
  print_input_every_steps: int = 1000  # Print example data every N steps.

  save_checkpoints: bool = True        # Save training checkpoints
  checkpoint_every_steps: int = 5000   # Save checkpoints every N steps.
  restore_checkpoints: bool = True     # Restore from previous checkpoint.
  restore_state_variables: bool = True  # Restore TrainState.state from chkpt.

  # Record metrics for "train", "test", etc. in separate directories.
  # Otherwise they will be saved with separate prefixes.
  use_separate_metric_directories: bool = True

  # Optimizer options.
  optimizer_factory: opt_config.OptimizerConfig = gin.REQUIRED
  learning_rate_schedule: Callable[[jnp.ndarray, int], jnp.ndarray] = (
      opt_config.lr_cosine_decay)

  # Maximum steps for the LR schedule.  Zero means use num_steps.
  max_scheduled_steps: int = 0
  warmup_steps: int = 1000               # Number of warmup steps.
  learning_rate_multiplier: float = 1.0  # Used to scale the learning rate.

  random_seed: int = 42                  # Initial random seed.

  # Names of random number generators used by the model.
  rng_key_names: Optional[Sequence[str]] = ("dropout",)

  # Debug options.
  replicate_mode: bool = True     # pmap over multiple replicas.
  trace_debug_mode: bool = False  # Run in eager mode to trace results.
  print_variables: bool = False   # Dump parameters/variables to stdout.

  # Function to compute additional summary information.
  # Takes a MetricsSummary object and a mode string (e.g. "test") as arguments,
  # returns a MetricsSummary object.
  process_summaries_function: Optional[Callable[[Any, str], Any]] = None

  # Function to pretty print the input for each training step.
  pretty_print_input_function: Optional[Callable[[Any], Any]] = None

  # Classes to use for summarizing metrics.
  metrics_summary_factory: Any = metrics_summary.MetricsSummary
  extra_summaries_fn: training_task.ExtraSummariesFunction = (
      lambda mode, step: dict())

  post_save_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None
  post_load_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None

  def learning_rate_schedule_fn(self, step):
    """Returns the learning rate for the given step."""

    # There are four components to the learning rate.
    #
    # The base_lrate is defined by the optimizer, and different optimizers have
    # different relative rates, e.g. Adafactor requires a higher LR than Adam.
    # By default, the base_lrate is 1.0 for Adafactor.
    #
    # The base_lrate is then multiplied by the learning rate decay schedule,
    # which typically starts at a maximum value and decays over time.
    # Each schedule can be individually configured, e.g. from 0.01 to 0.001.
    # The max_scheduled_steps parameter controls the decay rate of the schedule.
    #
    # Finally, the LR is scaled by the learning_rate_multiplier, which provides
    # an easy way to scale the LR for hyperparameter tuning in a way that is
    # independent of the choice of schedule or optimizer.  The default is 1.0.
    #
    # During the warmp period, the learning rate ramps up linearly from zero.

    step = jnp.asarray(step, dtype=jnp.float32)
    if self.max_scheduled_steps == 0:
      max_steps = self.num_steps
    else:
      max_steps = self.max_scheduled_steps

    base_lrate = float(self.optimizer_factory.learning_rate)
    lr_multiplier = float(self.learning_rate_multiplier)

    # Linear increase in learning rate up to warmup_steps.
    warmup_steps = float(self.warmup_steps)
    lr_warmup_ramp = jnp.minimum(step, warmup_steps) / warmup_steps

    # Hold step at a constant value during the warmup period.
    # Required for some schedules, like rsqrt_decay.
    step = jnp.maximum(step, warmup_steps)

    # Get the scheduled learning rate.
    lrate = self.learning_rate_schedule(step, max_steps)

    # Multiply lrate by the base, warmup and multiplier factors.
    lrate = lrate * base_lrate * lr_warmup_ramp * lr_multiplier
    return jnp.asarray(lrate, dtype=jnp.float32)

  def _init_rngs(self, rngs: PRNGKeys, step: int) -> PRNGKeys:
    # Get a new random number generator for each step
    rngs = jax.random.fold_in(rngs, step)
    rngs = jax.random.split(rngs, len(self.rng_key_names))
    rngs = {key: rngs[i] for i, key in enumerate(self.rng_key_names)}
    return rngs

  def train_step(self, model: nn.Module, tstate: TrainState, x: Any,
                 rngs: PRNGKeys) -> Tuple[TrainState, Metrics]:
    """Perform a training step, pmapped over multiple devices.

    Args:
      model:  The model to use for the step function.
      tstate: Values for state variables, and the optimizer.
      x:      A batch of inputs to train on.
      rngs:   PRNGKey (possibly replicated).

    Returns:
      Tuple of (new_tstate, metrics: dictionary of scalar values)
    """

    mutable_keys = [k for (k, _) in tstate.state.items()]
    step = tstate.optimizer.state.step
    rngs = self._init_rngs(rngs, step)

    # Refactor the model as a loss function from trainable params to loss, so
    # that we can differentiate with jax and get {d}loss/{d}params.
    # Inputs and non-trainable params are bound within the closure.
    # model:: x, { state_params } -> (loss, metrics), { new_state_params }
    # loss_fn:: params -> (loss, (metrics, new_state))
    def loss_fn(params):
      """Loss function."""
      (loss, mets), nstate = model.apply({"params": params, **tstate.state},
                                         x,
                                         rngs=rngs,
                                         mutable=mutable_keys)
      return loss, (mets, nstate)

    # grad_fn:: params -> ((loss, (aux, nstate)), param_gradients)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    # Run forward and backward pass.
    (loss, (metrics, new_state)), param_grads = grad_fn(tstate.optimizer.target)
    del loss  # loss is only recorded if it is part of the metrics
    if self.replicate_mode:
      param_grads = jax.lax.pmean(param_grads, axis_name="batch")
    lrate = self.learning_rate_schedule_fn(step)
    new_optimizer = tstate.optimizer.apply_gradient(
        param_grads, learning_rate=lrate)

    # Metrics are summary values that will be logged.
    if self.replicate_mode:
      # Merge metrics (take mean/sum etc.) over replicas on-device.
      summary_class = self.metrics_summary_factory
      metrics = summary_class.merge_replicated_metrics(
          metrics, model.metrics_summary_operations(aggregate_over="devices"))

    metrics["learning_rate"] = lrate
    return (TrainState(new_optimizer, new_state), metrics)

  def other_step(self, model: nn.Module, tstate: TrainState, x: Any,
                 rngs: PRNGKeys) -> Tuple[TrainState, Metrics]:
    """Perform a test or generate step, pmapped over multiple devices.

    Args:
      model:  The model to use for the step function.
      tstate: Values for state variables, and the optimizer.
      x:      A batch of inputs to train on.
      rngs:   PRNGKey (possibly replicated).

    Returns:
      Tuple of (new_tstate, metrics: dictionary of scalar values)
    """

    mutable_keys = [k for (k, _) in tstate.state.items()]
    step = tstate.optimizer.state.step
    rngs = self._init_rngs(rngs, step)

    params = tstate.optimizer.target
    (loss, metrics), new_state = model.apply({"params": params, **tstate.state},
                                             x,
                                             rngs=rngs,
                                             mutable=mutable_keys)
    del loss  # loss is only recorded if it is part of the metrics

    # Metrics are summary values that will be logged.
    if self.replicate_mode:
      # Merge metrics (take mean/sum etc.) over replicas on-device.
      summary_class = self.metrics_summary_factory
      metrics = summary_class.merge_replicated_metrics(
          metrics, model.metrics_summary_operations(aggregate_over="devices"))

    return (TrainState(tstate.optimizer, new_state), metrics)

  def initialize_model(self) -> Tuple[TrainState, int, nn.Module, PRNGKeys]:
    """Initialize the model and/or load it from a checkpoint.

    Returns:
      (tstate: TrainState,  -- The parameters and state for the the model.
       start_step: int,     -- The step number, when restoring from checkpoint.
       imodel: nn.Module,   -- A model object (created with mode "init").
       rngs: PRNGkeys)      -- Initial random numbers.
    """

    # Set up random number generators.
    # ---------------------------------
    logging.info("==== Training loop: initializing model ====")
    logging.info("Process %d of %d", jax.process_index(), jax.process_count())
    logging.info("Local device count = %d", jax.local_device_count())
    logging.info("Number of replicas = %d",
                 jax.process_count() * jax.local_device_count())
    logging.info("Using random number seed %d", self.random_seed)

    prng = jax.random.PRNGKey(self.random_seed)
    prng, init_rng = jax.random.split(prng)

    # Grab rngs, which provide different random numbers for each replica.
    if self.replicate_mode:
      prngs = jax.random.split(prng, jax.local_device_count())
    else:
      prngs = prng
    del prng

    # Create a dictionary of prng keys for initialization.
    rng_key_names_init = list(self.rng_key_names) + ["params"]
    init_rngs = jax.random.split(init_rng, len(rng_key_names_init))
    init_rngs = {key: init_rngs[i] for i, key in enumerate(rng_key_names_init)}
    del init_rng

    # Build Model
    # -------------------------------------------------------------------------
    logging.info("Initializing the model.")

    # Create a model, which will be used to initialize trainable parameters.
    imodel = self.model_definition(mode="init")

    # The init function will lazily initialize the model, given a fake input.
    # It returns initialized variables, without doing a fwd pass.
    model_init_fn = jax.jit(imodel.init)
    variables = model_init_fn(init_rngs, imodel.get_fake_input())

    # Split variables into trainable and non-trainable sets.
    mstate, params = variables.pop("params")
    del variables  # Delete to avoid wasting resources.

    # Create an optimizer for params.
    optimizer_def = self.optimizer_factory.create_optimizer_def()
    optimizer = optimizer_def.create(params)

    # tstate holds the full training state of the model.
    tstate = TrainState(optimizer, mstate)
    if self.print_variables:
      logging.info("params = %s", tstate.optimizer.target)
      logging.info("state = %s", tstate.state)

    # Load a pre-trained model or restore it from checkpoint.
    if self.workdir or self.load_dir:
      restore_checkpoints = self.restore_checkpoints
    else:
      restore_checkpoints = False

    start_step = 0
    if restore_checkpoints:
      tstate = self.restore_checkpoint(tstate)
      start_step = int(tstate.optimizer.state.step)

    # Log info on trainable parameters (before replicating them).
    self._write_parameter_info(tstate)
    # raise ValueError("That's all folks!")

    # Replicate the training state across local devices.
    if self.replicate_mode:
      tstate = jax_utils.replicate(tstate)

    return (tstate, start_step, imodel, prngs)

  def restore_checkpoint(self, train_state: TrainState) -> TrainState:
    """Load a pre-trained model or restore it from a checkpoint."""

    # Figure out if we have an existing checkpoint.
    if not self.workdir:
      logging.info("No working directory specified.")
      existing_checkpoint = False
    elif not gfile.exists(self.workdir):
      logging.info("No existing checkpoint directory %s", self.workdir)
      existing_checkpoint = False
    elif not gfile.isdir(self.workdir):
      raise ValueError(f"workdir {self.workdir} must be a directory.")
    else:
      ckpath = checkpoints.latest_checkpoint(self.workdir, "checkpoint_")
      if ckpath:
        logging.info("Found existing checkpoint in %s", self.workdir)
        existing_checkpoint = True
      else:
        logging.info("No existing checkpoint in %s", self.workdir)
        existing_checkpoint = False

    # If any checkpoints exist in workdir, then use those first.
    # This will ensure that the task will restore properly if it's preempted.
    if existing_checkpoint:
      logging.info("Restoring model from last checkpoint %s:", self.workdir)
      load_dir = self.workdir
    elif self.load_dir:
      logging.info("Loading pre-trained model from %s:", self.load_dir)
      load_dir = self.load_dir
    else:
      logging.warning("Unable to load model.")
      return train_state
    loaded_train_state = checkpoints.restore_checkpoint(load_dir, train_state)
    step = int(loaded_train_state.optimizer.state.step)
    self.post_load_checkpoint_fn(load_dir, step)

    if self.restore_state_variables:
      # Restore complete state.
      logging.info("Restoring all variables and state.")
      train_state = loaded_train_state
      del loaded_train_state
    else:
      # Restore trainable variables, but not other state.
      logging.info("Only restoring trainable parameters.")
      train_state = TrainState(loaded_train_state.optimizer, train_state.state)
      del loaded_train_state

    return train_state

  def save_checkpoint(self, tstate: TrainState, step: int,
                      param_summary: Optional[MetricsSummary]):
    """Save a checkpoint with the model state.

    Args:
      tstate: The training state.
      step: The current step number.
      param_summary: Optional metrics summary to write parameter statistics.
    """

    logging.info("Saving checkpoint in directory %s", self.workdir)
    if self.replicate_mode:
      save_state = jax_utils.unreplicate(tstate)
    else:
      save_state = tstate
    checkpoints.save_checkpoint(self.workdir, save_state, step)

    # While we're at it, record distributions of trainable parameters.
    if param_summary is not None:
      logging.info("Recording parameter distributions.")
      params_dict = jax.device_get(
          _flatten_dict_string_keys(save_state.optimizer.target))
      param_distribs = self._compute_parameter_distributions(params_dict)
      param_summary.add(param_distribs)

  def create_training_task(self, mode: str, imodel: nn.Module, prngs: PRNGKeys,
                           writers: Dict[str, MetricWriter]) -> TrainingTask:
    """Create a new TrainingTask for the given mode.

    Args:
      mode: The mode for the task, e.g. "train", "test", "generate".
      imodel: The model object from initialize_model.
      prngs: The PRNGKeys from initialize_model.
      writers: A dictionary of summary writers.

    Returns:
      A TrainingTask object.
    """

    logging.info("Training loop: creating task for mode %s", mode)
    if self.use_separate_metric_directories:
      prefix = ""
    else:
      prefix = mode

    if mode == "train":
      ds = self.get_training_dataset_iterator
    elif mode == "test":
      ds = self.get_test_dataset_iterator
    else:
      ds = None

    # We summarize metrics over multiple training steps.
    # These types control how the summary is computed.
    metric_summary_ops = {
        "step_time": "mean",
        "learning_rate": "last",
        **imodel.metrics_summary_operations(aggregate_over="steps")
    }
    summary = self.metrics_summary_factory(metric_summary_ops)
    extra_summary = self.metrics_summary_factory({})
    summary_writer = self._get_summary_writer(mode, writers)

    return TrainingTask(
        mode=mode,
        dataset=ds,
        step_function=self._compile_step_function(mode),
        prng_keys=prngs,
        summary=summary,
        extra_summary=extra_summary,
        summary_writer=summary_writer,
        summary_prefix=prefix,
        # --- options ---
        replicate_mode=self.replicate_mode,
        print_input_every_steps=self.print_input_every_steps,
        pretty_print_input_function=self.pretty_print_input_function,
        process_summaries_function=self.process_summaries_function,
        extra_summaries_function=self.extra_summaries_fn)

  def train(self):
    """Runs the training and evaluation loop."""

    # The master process saves checkpoints and summaries to disk.
    is_master_process = jax.process_index() == 0
    if self.workdir:
      save_checkpoints = self.save_checkpoints
    else:
      save_checkpoints = False

    # --- Create and initialize the model. ---
    (tstate, start_step, imodel, prngs) = self.initialize_model()

    # Log experiment hyper-parameters.
    writers = {}
    train_writer = self._get_summary_writer("train", writers)
    if start_step == 0:
      self._write_config(train_writer)

    # Additional summary objects.
    param_summary = self.metrics_summary_factory({})  # Parameter statistics.

    # --- Create task objects for test, train, and generate. ---
    tasks = {}
    train_task = self.create_training_task("train", imodel, prngs, writers)
    tasks["train"] = train_task

    if (self.get_test_dataset_iterator is not None and
        self.test_every_steps != 0):
      test_task = self.create_training_task("test", imodel, prngs, writers)
      tasks["test"] = test_task
      if self.generate_every_steps != 0:
        gen_task = self.create_training_task("generate", imodel, prngs,
                                             writers)
        tasks["generate"] = gen_task

    # Register any additional actions.
    register_interstep_callbacks()

    # Main Training Loop
    # --------------------------------------------------------------------------
    logging.info("==== Training loop: starting main loop ====")
    with metric_writers.ensure_flushes(*writers.values()):
      for step in range(start_step, self.num_steps):
        # Log status every so often to monitor progress.
        if should_run(step, self.status_every_steps):
          logging.info("Step: %d", step)

        # Train.
        train_x = train_task.get_next_input()
        (tstate, _) = train_task.run_step(tstate, train_x, step)
        run_interstep_callbacks("train", step)
        del train_x

        # Test.
        if should_run(step, self.test_every_steps):
          if self.num_test_steps > 1:
            logging.info("Test cycle: %d iterations.", self.num_test_steps)
          for sub_step in range(0, self.num_test_steps):
            test_x = test_task.get_next_input()

            # TODO(delesley): This is an ugly hack to run generate steps.
            # Run a generate step using test data.
            # Generate is run just *before* the last test iteration.
            if ((sub_step == self.num_test_steps - 1) and
                should_run(step, self.generate_every_steps)):
              logging.info("Generate cycle.")
              (tstate, _) = gen_task.run_step(tstate, test_x, step)
              run_interstep_callbacks("generate", step)

            (tstate, _) = test_task.run_step(tstate, test_x, step,
                                             sub_step=sub_step)
            run_interstep_callbacks("test", step, sub_step)
          del test_x

        # --- Save checkpoints on the master host. ---
        is_last_step = (step == self.num_steps - 1)
        checkpoint_current_step = (
            save_checkpoints and
            (should_run(step, self.checkpoint_every_steps) or is_last_step))
        if checkpoint_current_step:
          if is_master_process:
            self.save_checkpoint(tstate, step, param_summary)
          self.post_save_checkpoint_fn(self.workdir, step)

        # --- Flush summaries to disk. ---
        if should_run(step, self.log_every_steps):
          for tsk in tasks.values():
            tsk.flush(step)
          param_summary.write(train_writer, step, prefix="params")

    logging.info("Training Finished.")
    if self.replicate_mode:
      tstate = jax_utils.unreplicate(tstate)
    if self.print_variables:
      logging.info("params = %s", tstate.optimizer.target)
      logging.info("state = %s", tstate.state)

  def _compile_step_function(self, mode: str) -> StepFunction:
    """Compile a step function (training or test)."""

    # Create a model object, and a step function that is a closure over the
    # object.  Flax modules are supposed to be "stateless", in that all state
    # is contained the TrainState object that is passed as an input parameter.
    # However, creating the model object may involve allocating expensive
    # data structures, or launching processes, and should only be done once.
    model = self.model_definition(mode=mode)
    if mode == "train":
      step_fn = functools.partial(self.train_step, model)
    else:
      step_fn = functools.partial(self.other_step, model)

    if self.replicate_mode:
      assert not self.trace_debug_mode
      logging.info("Compiling mode %s with pmap.", mode)
      p_fn = jax.pmap(step_fn, donate_argnums=(0,), axis_name="batch")
    elif self.trace_debug_mode:
      logging.info("Compiling mode %s with trace_debug.", mode)
      p_fn = step_fn
    else:
      logging.info("Compiling mode %s with jit.", mode)
      p_fn = jax.jit(step_fn, donate_argnums=(0,))
    return p_fn

  def _get_summary_writer(self, mode: str,
                          writers: Dict[str, MetricWriter]) -> MetricWriter:
    """Create a summary writer for the given mode.

    Args:
      mode: the mode for the summaries, e.g. "test", "train"
      writers: a dictionary which caches previously-created writers.

    Returns:
      A writer for the given mode.
    """

    if self.use_separate_metric_directories:
      # Create a separate writer & directory for each mode.
      w_mode = mode
      summary_dir = os.path.join(self.workdir, mode)
    else:
      # Create a single default writer for all modes.
      w_mode = "train"
      summary_dir = self.workdir

    if w_mode in writers:
      # Return previously created and cached writer.
      logging.info("Returning cached summary writer (%s) for mode %s",
                   w_mode, mode)
      return writers[w_mode]

    if not self.workdir:
      # No working directory, so log only.
      logging.info("Creating logging writer (%s) for mode %s", w_mode, mode)
      writer = metric_writers.LoggingWriter()
    else:
      # Create a new writer for workdir.
      # Only the master will actually write summaries to workdir.
      logging.info("Creating summary writer (%s) for mode %s in directory %s",
                   w_mode, mode, summary_dir)
      is_master = jax.process_index() == 0
      gfile.makedirs(summary_dir)
      writer = metric_writers.create_default_writer(summary_dir,
                                                    just_logging=not is_master)
    writers[w_mode] = writer
    return writer

  def _write_config(self, writer):
    """Write the configuration file to the working directory."""

    is_master = jax.process_index() == 0
    config_str = gin.operative_config_str()
    logging.info("Gin config: \n%s", config_str)

    # Write configuration to workdir.
    if is_master and self.workdir:
      config_file_name = os.path.join(self.workdir, "config.gin")
      with gfile.GFile(config_file_name, "w") as f:
        f.write(config_str)

    # Write config string text to tensorboard.
    writer.write_texts(0, {"config": gin.markdown(config_str)})

  def _write_parameter_info(self, tstate: TrainState):
    """Write information on state and trainable parameters to the log."""

    # Write information on parameters to log file.
    params_dict = _flatten_dict_string_keys(tstate.optimizer.target)
    total_nparams = 0
    for (k, v) in params_dict.items():
      nparams = np.prod(v.shape)
      total_nparams += nparams
      logging.info("parameter: %s, shape %s, size %d", k, v.shape, nparams)
    logging.info("Total parameters: %d", total_nparams)

    # Write information on state variables to log file.
    state_dict = _flatten_dict_string_keys(tstate.state)
    state_size = 0
    total_state = 0
    for (k, v) in state_dict.items():
      if hasattr(v, "shape"):
        state_size = np.prod(v.shape)
        total_state += state_size
        logging.info("state: %s, shape %s, size %d", k, v.shape, state_size)
      else:
        # Some other stuff may be stored in the state.
        logging.info("state: %s [unknown]", k)
    logging.info("Total state size: %d", total_state)

  def _compute_parameter_distributions(self, params_dict):
    """Compute info on distributions of parameters."""

    scalar_params_dict = {}
    for (k, v) in params_dict.items():
      # Convert from bfloat16, which crashes when serializing a NaN.
      v = np.asarray(v, dtype=jnp.float32)
      scalar_params_dict[k + "_mean"] = np.mean(v)
      scalar_params_dict[k + "_stddev"] = np.std(v)
      # scalar_params_dict[k + "_min"] = np.min(v)
      # scalar_params_dict[k + "_max"] = np.max(v)
    return scalar_params_dict


def _flatten_dict_string_keys(params):
  """Flattens a nested dictionary to have string keys and '/' separators."""
  return {"/".join(k): v for k, v in flatten_dict(unfreeze(params)).items()}