# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generic JAX training loop for experiments.""" import functools import os from typing import (Any, Callable, Dict, Optional, Sequence, Tuple) from absl import logging from clu import metric_writers import flax from flax import jax_utils from flax import linen as nn from flax import struct from flax.training import checkpoints import gin import jax import jax.numpy as jnp import metrics_summary import optimizer_config as opt_config import training_task import numpy as np import tensorflow.compat.v2 as tf PRNGKeys = training_task.PRNGKeys TrainState = training_task.TrainState TrainingTask = training_task.TrainingTask StepFunction = training_task.StepFunction Metrics = training_task.Metrics MetricWriter = metric_writers.MetricWriter MetricsSummary = metrics_summary.MetricsSummary gfile = tf.io.gfile unfreeze = flax.core.unfreeze flatten_dict = flax.traverse_util.flatten_dict should_run = training_task.should_run # TODO(cstaats): Use a Protocol to specify that it must be possible to call # the function with parameters (step: int, mode: str). This won't be feasible # until we start using Python 3.8 or later. StepModeCallable = Callable[..., None] # This variable should *only* be set from register_interstep_callbacks. _interstep_callbacks: Optional[Tuple[StepModeCallable, ...]] = None @gin.configurable def register_interstep_callbacks(**kwargs: StepModeCallable) -> None: """Populates _interstep_callbacks from gin. This function should be called exactly ONCE and that call should happen AFTER flag initialization (and more specifically, after gin parsing). And the caller should NOT specify any arguments. In gin configurations, a callback can be specified with an arbitrary name like so: register_interstep_callbacks.my_callback_name = @my_callback_function Multiple callbacks can be registered without overriding each other as long as they all have different names. Conversely, if you *want* to override a callback, you need to give that callback the same name. Args: **kwargs: Specified by gin. Each argument should be a function (callable) that can be called as my_function(step, mode), where step is an int and mode is a str. Raises: ValueError: Raised on the second (and any subsequent) function call. """ global _interstep_callbacks logging.info("registering functions: %s", kwargs.keys()) if _interstep_callbacks is not None: raise ValueError("register_interstep_callbacks may only be called once.") _interstep_callbacks = tuple(kwargs.values()) def clear_interstep_callbacks(): """Clear all registered callbacks, so that new ones can be registered.""" global _interstep_callbacks _interstep_callbacks = None def run_interstep_callbacks(mode: str, step: int, sub_step: int = 0): """Run the registered callbacks. Args: mode: mode of the task to execute callbacks for. step: training step number. sub_step: For tasks that execute multiple iterations within a step. E.g. a test cycle that runs multiple testing steps. """ for func in _interstep_callbacks: func(sub_step or step, mode) @gin.configurable @struct.dataclass class Trainer: """Implements a JAX training loop.""" # Returns a Flax module for the model. # Takes a single argument mode, which can be "test", "train", or "generate". model_definition: Any = gin.REQUIRED # Iterator over trainining data. get_training_dataset_iterator: Callable[[], Any] = gin.REQUIRED # Iterator over test data. get_test_dataset_iterator: Optional[Callable[[], Any]] = None workdir: str = "" # Working directory for checkpoints. load_dir: str = "" # Optional directory to load model. num_steps: int = 100000 # Number of steps to train. status_every_steps: int = 10 # Log step number every N steps. log_every_steps: int = 100 # Log scalar data every N steps. test_every_steps: int = 10 # Test model every N steps. num_test_steps: int = 1 # Number of iterations to test. generate_every_steps: int = 1000 # Generate examples every N steps. print_input_every_steps: int = 1000 # Print example data every N steps. save_checkpoints: bool = True # Save training checkpoints checkpoint_every_steps: int = 5000 # Save checkpoints every N steps. restore_checkpoints: bool = True # Restore from previous checkpoint. restore_state_variables: bool = True # Restore TrainState.state from chkpt. # Record metrics for "train", "test", etc. in separate directories. # Otherwise they will be saved with separate prefixes. use_separate_metric_directories: bool = True # Optimizer options. optimizer_factory: opt_config.OptimizerConfig = gin.REQUIRED learning_rate_schedule: Callable[[jnp.ndarray, int], jnp.ndarray] = ( opt_config.lr_cosine_decay) # Maximum steps for the LR schedule. Zero means use num_steps. max_scheduled_steps: int = 0 warmup_steps: int = 1000 # Number of warmup steps. learning_rate_multiplier: float = 1.0 # Used to scale the learning rate. random_seed: int = 42 # Initial random seed. # Names of random number generators used by the model. rng_key_names: Optional[Sequence[str]] = ("dropout",) # Debug options. replicate_mode: bool = True # pmap over multiple replicas. trace_debug_mode: bool = False # Run in eager mode to trace results. print_variables: bool = False # Dump parameters/variables to stdout. # Function to compute additional summary information. # Takes a MetricsSummary object and a mode string (e.g. "test") as arguments, # returns a MetricsSummary object. process_summaries_function: Optional[Callable[[Any, str], Any]] = None # Function to pretty print the input for each training step. pretty_print_input_function: Optional[Callable[[Any], Any]] = None # Classes to use for summarizing metrics. metrics_summary_factory: Any = metrics_summary.MetricsSummary extra_summaries_fn: training_task.ExtraSummariesFunction = ( lambda mode, step: dict()) post_save_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None post_load_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None def learning_rate_schedule_fn(self, step): """Returns the learning rate for the given step.""" # There are four components to the learning rate. # # The base_lrate is defined by the optimizer, and different optimizers have # different relative rates, e.g. Adafactor requires a higher LR than Adam. # By default, the base_lrate is 1.0 for Adafactor. # # The base_lrate is then multiplied by the learning rate decay schedule, # which typically starts at a maximum value and decays over time. # Each schedule can be individually configured, e.g. from 0.01 to 0.001. # The max_scheduled_steps parameter controls the decay rate of the schedule. # # Finally, the LR is scaled by the learning_rate_multiplier, which provides # an easy way to scale the LR for hyperparameter tuning in a way that is # independent of the choice of schedule or optimizer. The default is 1.0. # # During the warmp period, the learning rate ramps up linearly from zero. step = jnp.asarray(step, dtype=jnp.float32) if self.max_scheduled_steps == 0: max_steps = self.num_steps else: max_steps = self.max_scheduled_steps base_lrate = float(self.optimizer_factory.learning_rate) lr_multiplier = float(self.learning_rate_multiplier) # Linear increase in learning rate up to warmup_steps. warmup_steps = float(self.warmup_steps) lr_warmup_ramp = jnp.minimum(step, warmup_steps) / warmup_steps # Hold step at a constant value during the warmup period. # Required for some schedules, like rsqrt_decay. step = jnp.maximum(step, warmup_steps) # Get the scheduled learning rate. lrate = self.learning_rate_schedule(step, max_steps) # Multiply lrate by the base, warmup and multiplier factors. lrate = lrate * base_lrate * lr_warmup_ramp * lr_multiplier return jnp.asarray(lrate, dtype=jnp.float32) def _init_rngs(self, rngs: PRNGKeys, step: int) -> PRNGKeys: # Get a new random number generator for each step rngs = jax.random.fold_in(rngs, step) rngs = jax.random.split(rngs, len(self.rng_key_names)) rngs = {key: rngs[i] for i, key in enumerate(self.rng_key_names)} return rngs def train_step(self, model: nn.Module, tstate: TrainState, x: Any, rngs: PRNGKeys) -> Tuple[TrainState, Metrics]: """Perform a training step, pmapped over multiple devices. Args: model: The model to use for the step function. tstate: Values for state variables, and the optimizer. x: A batch of inputs to train on. rngs: PRNGKey (possibly replicated). Returns: Tuple of (new_tstate, metrics: dictionary of scalar values) """ mutable_keys = [k for (k, _) in tstate.state.items()] step = tstate.optimizer.state.step rngs = self._init_rngs(rngs, step) # Refactor the model as a loss function from trainable params to loss, so # that we can differentiate with jax and get {d}loss/{d}params. # Inputs and non-trainable params are bound within the closure. # model:: x, { state_params } -> (loss, metrics), { new_state_params } # loss_fn:: params -> (loss, (metrics, new_state)) def loss_fn(params): """Loss function.""" (loss, mets), nstate = model.apply({"params": params, **tstate.state}, x, rngs=rngs, mutable=mutable_keys) return loss, (mets, nstate) # grad_fn:: params -> ((loss, (aux, nstate)), param_gradients) grad_fn = jax.value_and_grad(loss_fn, has_aux=True) # Run forward and backward pass. (loss, (metrics, new_state)), param_grads = grad_fn(tstate.optimizer.target) del loss # loss is only recorded if it is part of the metrics if self.replicate_mode: param_grads = jax.lax.pmean(param_grads, axis_name="batch") lrate = self.learning_rate_schedule_fn(step) new_optimizer = tstate.optimizer.apply_gradient( param_grads, learning_rate=lrate) # Metrics are summary values that will be logged. if self.replicate_mode: # Merge metrics (take mean/sum etc.) over replicas on-device. summary_class = self.metrics_summary_factory metrics = summary_class.merge_replicated_metrics( metrics, model.metrics_summary_operations(aggregate_over="devices")) metrics["learning_rate"] = lrate return (TrainState(new_optimizer, new_state), metrics) def other_step(self, model: nn.Module, tstate: TrainState, x: Any, rngs: PRNGKeys) -> Tuple[TrainState, Metrics]: """Perform a test or generate step, pmapped over multiple devices. Args: model: The model to use for the step function. tstate: Values for state variables, and the optimizer. x: A batch of inputs to train on. rngs: PRNGKey (possibly replicated). Returns: Tuple of (new_tstate, metrics: dictionary of scalar values) """ mutable_keys = [k for (k, _) in tstate.state.items()] step = tstate.optimizer.state.step rngs = self._init_rngs(rngs, step) params = tstate.optimizer.target (loss, metrics), new_state = model.apply({"params": params, **tstate.state}, x, rngs=rngs, mutable=mutable_keys) del loss # loss is only recorded if it is part of the metrics # Metrics are summary values that will be logged. if self.replicate_mode: # Merge metrics (take mean/sum etc.) over replicas on-device. summary_class = self.metrics_summary_factory metrics = summary_class.merge_replicated_metrics( metrics, model.metrics_summary_operations(aggregate_over="devices")) return (TrainState(tstate.optimizer, new_state), metrics) def initialize_model(self) -> Tuple[TrainState, int, nn.Module, PRNGKeys]: """Initialize the model and/or load it from a checkpoint. Returns: (tstate: TrainState, -- The parameters and state for the the model. start_step: int, -- The step number, when restoring from checkpoint. imodel: nn.Module, -- A model object (created with mode "init"). rngs: PRNGkeys) -- Initial random numbers. """ # Set up random number generators. # --------------------------------- logging.info("==== Training loop: initializing model ====") logging.info("Process %d of %d", jax.process_index(), jax.process_count()) logging.info("Local device count = %d", jax.local_device_count()) logging.info("Number of replicas = %d", jax.process_count() * jax.local_device_count()) logging.info("Using random number seed %d", self.random_seed) prng = jax.random.PRNGKey(self.random_seed) prng, init_rng = jax.random.split(prng) # Grab rngs, which provide different random numbers for each replica. if self.replicate_mode: prngs = jax.random.split(prng, jax.local_device_count()) else: prngs = prng del prng # Create a dictionary of prng keys for initialization. rng_key_names_init = list(self.rng_key_names) + ["params"] init_rngs = jax.random.split(init_rng, len(rng_key_names_init)) init_rngs = {key: init_rngs[i] for i, key in enumerate(rng_key_names_init)} del init_rng # Build Model # ------------------------------------------------------------------------- logging.info("Initializing the model.") # Create a model, which will be used to initialize trainable parameters. imodel = self.model_definition(mode="init") # The init function will lazily initialize the model, given a fake input. # It returns initialized variables, without doing a fwd pass. model_init_fn = jax.jit(imodel.init) variables = model_init_fn(init_rngs, imodel.get_fake_input()) # Split variables into trainable and non-trainable sets. mstate, params = variables.pop("params") del variables # Delete to avoid wasting resources. # Create an optimizer for params. optimizer_def = self.optimizer_factory.create_optimizer_def() optimizer = optimizer_def.create(params) # tstate holds the full training state of the model. tstate = TrainState(optimizer, mstate) if self.print_variables: logging.info("params = %s", tstate.optimizer.target) logging.info("state = %s", tstate.state) # Load a pre-trained model or restore it from checkpoint. if self.workdir or self.load_dir: restore_checkpoints = self.restore_checkpoints else: restore_checkpoints = False start_step = 0 if restore_checkpoints: tstate = self.restore_checkpoint(tstate) start_step = int(tstate.optimizer.state.step) # Log info on trainable parameters (before replicating them). self._write_parameter_info(tstate) # raise ValueError("That's all folks!") # Replicate the training state across local devices. if self.replicate_mode: tstate = jax_utils.replicate(tstate) return (tstate, start_step, imodel, prngs) def restore_checkpoint(self, train_state: TrainState) -> TrainState: """Load a pre-trained model or restore it from a checkpoint.""" # Figure out if we have an existing checkpoint. if not self.workdir: logging.info("No working directory specified.") existing_checkpoint = False elif not gfile.exists(self.workdir): logging.info("No existing checkpoint directory %s", self.workdir) existing_checkpoint = False elif not gfile.isdir(self.workdir): raise ValueError(f"workdir {self.workdir} must be a directory.") else: ckpath = checkpoints.latest_checkpoint(self.workdir, "checkpoint_") if ckpath: logging.info("Found existing checkpoint in %s", self.workdir) existing_checkpoint = True else: logging.info("No existing checkpoint in %s", self.workdir) existing_checkpoint = False # If any checkpoints exist in workdir, then use those first. # This will ensure that the task will restore properly if it's preempted. if existing_checkpoint: logging.info("Restoring model from last checkpoint %s:", self.workdir) load_dir = self.workdir elif self.load_dir: logging.info("Loading pre-trained model from %s:", self.load_dir) load_dir = self.load_dir else: logging.warning("Unable to load model.") return train_state loaded_train_state = checkpoints.restore_checkpoint(load_dir, train_state) step = int(loaded_train_state.optimizer.state.step) self.post_load_checkpoint_fn(load_dir, step) if self.restore_state_variables: # Restore complete state. logging.info("Restoring all variables and state.") train_state = loaded_train_state del loaded_train_state else: # Restore trainable variables, but not other state. logging.info("Only restoring trainable parameters.") train_state = TrainState(loaded_train_state.optimizer, train_state.state) del loaded_train_state return train_state def save_checkpoint(self, tstate: TrainState, step: int, param_summary: Optional[MetricsSummary]): """Save a checkpoint with the model state. Args: tstate: The training state. step: The current step number. param_summary: Optional metrics summary to write parameter statistics. """ logging.info("Saving checkpoint in directory %s", self.workdir) if self.replicate_mode: save_state = jax_utils.unreplicate(tstate) else: save_state = tstate checkpoints.save_checkpoint(self.workdir, save_state, step) # While we're at it, record distributions of trainable parameters. if param_summary is not None: logging.info("Recording parameter distributions.") params_dict = jax.device_get( _flatten_dict_string_keys(save_state.optimizer.target)) param_distribs = self._compute_parameter_distributions(params_dict) param_summary.add(param_distribs) def create_training_task(self, mode: str, imodel: nn.Module, prngs: PRNGKeys, writers: Dict[str, MetricWriter]) -> TrainingTask: """Create a new TrainingTask for the given mode. Args: mode: The mode for the task, e.g. "train", "test", "generate". imodel: The model object from initialize_model. prngs: The PRNGKeys from initialize_model. writers: A dictionary of summary writers. Returns: A TrainingTask object. """ logging.info("Training loop: creating task for mode %s", mode) if self.use_separate_metric_directories: prefix = "" else: prefix = mode if mode == "train": ds = self.get_training_dataset_iterator elif mode == "test": ds = self.get_test_dataset_iterator else: ds = None # We summarize metrics over multiple training steps. # These types control how the summary is computed. metric_summary_ops = { "step_time": "mean", "learning_rate": "last", **imodel.metrics_summary_operations(aggregate_over="steps") } summary = self.metrics_summary_factory(metric_summary_ops) extra_summary = self.metrics_summary_factory({}) summary_writer = self._get_summary_writer(mode, writers) return TrainingTask( mode=mode, dataset=ds, step_function=self._compile_step_function(mode), prng_keys=prngs, summary=summary, extra_summary=extra_summary, summary_writer=summary_writer, summary_prefix=prefix, # --- options --- replicate_mode=self.replicate_mode, print_input_every_steps=self.print_input_every_steps, pretty_print_input_function=self.pretty_print_input_function, process_summaries_function=self.process_summaries_function, extra_summaries_function=self.extra_summaries_fn) def train(self): """Runs the training and evaluation loop.""" # The master process saves checkpoints and summaries to disk. is_master_process = jax.process_index() == 0 if self.workdir: save_checkpoints = self.save_checkpoints else: save_checkpoints = False # --- Create and initialize the model. --- (tstate, start_step, imodel, prngs) = self.initialize_model() # Log experiment hyper-parameters. writers = {} train_writer = self._get_summary_writer("train", writers) if start_step == 0: self._write_config(train_writer) # Additional summary objects. param_summary = self.metrics_summary_factory({}) # Parameter statistics. # --- Create task objects for test, train, and generate. --- tasks = {} train_task = self.create_training_task("train", imodel, prngs, writers) tasks["train"] = train_task if (self.get_test_dataset_iterator is not None and self.test_every_steps != 0): test_task = self.create_training_task("test", imodel, prngs, writers) tasks["test"] = test_task if self.generate_every_steps != 0: gen_task = self.create_training_task("generate", imodel, prngs, writers) tasks["generate"] = gen_task # Register any additional actions. register_interstep_callbacks() # Main Training Loop # -------------------------------------------------------------------------- logging.info("==== Training loop: starting main loop ====") with metric_writers.ensure_flushes(*writers.values()): for step in range(start_step, self.num_steps): # Log status every so often to monitor progress. if should_run(step, self.status_every_steps): logging.info("Step: %d", step) # Train. train_x = train_task.get_next_input() (tstate, _) = train_task.run_step(tstate, train_x, step) run_interstep_callbacks("train", step) del train_x # Test. if should_run(step, self.test_every_steps): if self.num_test_steps > 1: logging.info("Test cycle: %d iterations.", self.num_test_steps) for sub_step in range(0, self.num_test_steps): test_x = test_task.get_next_input() # TODO(delesley): This is an ugly hack to run generate steps. # Run a generate step using test data. # Generate is run just *before* the last test iteration. if ((sub_step == self.num_test_steps - 1) and should_run(step, self.generate_every_steps)): logging.info("Generate cycle.") (tstate, _) = gen_task.run_step(tstate, test_x, step) run_interstep_callbacks("generate", step) (tstate, _) = test_task.run_step(tstate, test_x, step, sub_step=sub_step) run_interstep_callbacks("test", step, sub_step) del test_x # --- Save checkpoints on the master host. --- is_last_step = (step == self.num_steps - 1) checkpoint_current_step = ( save_checkpoints and (should_run(step, self.checkpoint_every_steps) or is_last_step)) if checkpoint_current_step: if is_master_process: self.save_checkpoint(tstate, step, param_summary) self.post_save_checkpoint_fn(self.workdir, step) # --- Flush summaries to disk. --- if should_run(step, self.log_every_steps): for tsk in tasks.values(): tsk.flush(step) param_summary.write(train_writer, step, prefix="params") logging.info("Training Finished.") if self.replicate_mode: tstate = jax_utils.unreplicate(tstate) if self.print_variables: logging.info("params = %s", tstate.optimizer.target) logging.info("state = %s", tstate.state) def _compile_step_function(self, mode: str) -> StepFunction: """Compile a step function (training or test).""" # Create a model object, and a step function that is a closure over the # object. Flax modules are supposed to be "stateless", in that all state # is contained the TrainState object that is passed as an input parameter. # However, creating the model object may involve allocating expensive # data structures, or launching processes, and should only be done once. model = self.model_definition(mode=mode) if mode == "train": step_fn = functools.partial(self.train_step, model) else: step_fn = functools.partial(self.other_step, model) if self.replicate_mode: assert not self.trace_debug_mode logging.info("Compiling mode %s with pmap.", mode) p_fn = jax.pmap(step_fn, donate_argnums=(0,), axis_name="batch") elif self.trace_debug_mode: logging.info("Compiling mode %s with trace_debug.", mode) p_fn = step_fn else: logging.info("Compiling mode %s with jit.", mode) p_fn = jax.jit(step_fn, donate_argnums=(0,)) return p_fn def _get_summary_writer(self, mode: str, writers: Dict[str, MetricWriter]) -> MetricWriter: """Create a summary writer for the given mode. Args: mode: the mode for the summaries, e.g. "test", "train" writers: a dictionary which caches previously-created writers. Returns: A writer for the given mode. """ if self.use_separate_metric_directories: # Create a separate writer & directory for each mode. w_mode = mode summary_dir = os.path.join(self.workdir, mode) else: # Create a single default writer for all modes. w_mode = "train" summary_dir = self.workdir if w_mode in writers: # Return previously created and cached writer. logging.info("Returning cached summary writer (%s) for mode %s", w_mode, mode) return writers[w_mode] if not self.workdir: # No working directory, so log only. logging.info("Creating logging writer (%s) for mode %s", w_mode, mode) writer = metric_writers.LoggingWriter() else: # Create a new writer for workdir. # Only the master will actually write summaries to workdir. logging.info("Creating summary writer (%s) for mode %s in directory %s", w_mode, mode, summary_dir) is_master = jax.process_index() == 0 gfile.makedirs(summary_dir) writer = metric_writers.create_default_writer(summary_dir, just_logging=not is_master) writers[w_mode] = writer return writer def _write_config(self, writer): """Write the configuration file to the working directory.""" is_master = jax.process_index() == 0 config_str = gin.operative_config_str() logging.info("Gin config: \n%s", config_str) # Write configuration to workdir. if is_master and self.workdir: config_file_name = os.path.join(self.workdir, "config.gin") with gfile.GFile(config_file_name, "w") as f: f.write(config_str) # Write config string text to tensorboard. writer.write_texts(0, {"config": gin.markdown(config_str)}) def _write_parameter_info(self, tstate: TrainState): """Write information on state and trainable parameters to the log.""" # Write information on parameters to log file. params_dict = _flatten_dict_string_keys(tstate.optimizer.target) total_nparams = 0 for (k, v) in params_dict.items(): nparams = np.prod(v.shape) total_nparams += nparams logging.info("parameter: %s, shape %s, size %d", k, v.shape, nparams) logging.info("Total parameters: %d", total_nparams) # Write information on state variables to log file. state_dict = _flatten_dict_string_keys(tstate.state) state_size = 0 total_state = 0 for (k, v) in state_dict.items(): if hasattr(v, "shape"): state_size = np.prod(v.shape) total_state += state_size logging.info("state: %s, shape %s, size %d", k, v.shape, state_size) else: # Some other stuff may be stored in the state. logging.info("state: %s [unknown]", k) logging.info("Total state size: %d", total_state) def _compute_parameter_distributions(self, params_dict): """Compute info on distributions of parameters.""" scalar_params_dict = {} for (k, v) in params_dict.items(): # Convert from bfloat16, which crashes when serializing a NaN. v = np.asarray(v, dtype=jnp.float32) scalar_params_dict[k + "_mean"] = np.mean(v) scalar_params_dict[k + "_stddev"] = np.std(v) # scalar_params_dict[k + "_min"] = np.min(v) # scalar_params_dict[k + "_max"] = np.max(v) return scalar_params_dict def _flatten_dict_string_keys(params): """Flattens a nested dictionary to have string keys and '/' separators.""" return {"/".join(k): v for k, v in flatten_dict(unfreeze(params)).items()}