# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TrainingTask encapsulates the state associated with model step.""" import time from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Tuple) from absl import logging from clu import metric_writers from flax import optim from flax import struct import jax import metrics_summary import numpy as np @struct.dataclass class TrainState: optimizer: optim.Optimizer # Trainable parameters. state: Any # Other state, e.g. XL cache or memory. PRNGKeys = Any Metrics = Dict[str, Any] MetricsSummary = metrics_summary.MetricsSummary Dataset = Callable[[], Iterator[Any]] StepFunction = Callable[[TrainState, Any, Any], Tuple[TrainState, Metrics]] PrettyPrintInputFunction = Optional[Callable[[Any], str]] ProcessSummariesFunction = Optional[Callable[[Any, str], Any]] ExtraSummariesFunction = Optional[Callable[[str, int], Mapping[str, Any]]] def should_run(step: int, every_steps: int) -> bool: """Returns true if a periodic action should be run.""" return (step > 0) and (every_steps > 0) and (step % every_steps == 0) class TrainingTask: """A TrainingTask encapsulates the state associated with a training task. Examples of tasks include training steps, test or validation runs, or inference (generation). State includes the input pipeline, and summary information that is averaged over multiple steps. """ def __init__( self, *, # Pass arguments by keyword only. mode: str, dataset: Dataset, step_function: StepFunction, prng_keys: PRNGKeys, summary: MetricsSummary, extra_summary: MetricsSummary, summary_writer: metric_writers.MetricWriter, summary_prefix: str = "", # --- Options from TrainingLoop --- replicate_mode: bool = True, print_input_every_steps: int = 0, pretty_print_input_function: PrettyPrintInputFunction = None, process_summaries_function: ProcessSummariesFunction = None, extra_summaries_function: Optional[ExtraSummariesFunction] = None): # Local state. self.mode = mode self.dataset = dataset self.step_function = step_function self.prng_keys = prng_keys self.summary = summary self.extra_summary = extra_summary self.summary_writer = summary_writer self.summary_prefix = summary_prefix # Options carried over from TrainingLoop. self.replicate_mode = replicate_mode self.print_input_every_steps = print_input_every_steps self.pretty_print_input_fn = pretty_print_input_function self.process_summaries_fn = process_summaries_function self.extra_summaries_fn = extra_summaries_function # Local state. if self.dataset is not None: self.ds_iterator = self.dataset() self.epoch = 0 def _get_metrics(self, device_metrics: Metrics) -> Metrics: """Read a dictionary of metrics from device.""" if self.replicate_mode: # x[0] gets the metric from device 0 -- the first replica. # We assume that merge_replicated_metrics has already combined the # metrics from multiple devices. device_metrics = jax.tree_map(lambda x: x[0], device_metrics) metrics_np = jax.device_get(device_metrics) # Get numpy arrays. return metrics_np def get_next_input(self) -> Any: """Grab the next input from the data pipeline.""" if self.dataset is None: logging.warning("No dataset for mode %s", self.mode) return None try: x = next(self.ds_iterator) except StopIteration: logging.info("End of epoch %d for mode %s.", self.epoch, self.mode) self.ds_iterator = self.dataset() x = next(self.ds_iterator) self.epoch += 1 return x def run_step(self, tstate: TrainState, x: Any, step: int, sub_step: int = 0) -> Tuple[TrainState, Metrics]: """Run the model for a single step. Args: tstate: The current model state. x: The input for the model -- from get_next_input. step: The training step number. sub_step: For tasks that run multiple iterations within a step. E.g. A test cycle will call run_step multiple times to cover the test set. The step counter will not increment, but sub_step will. Returns: An updated model state. """ start_time = time.perf_counter() # Split a batch of inputs among local replicas. if self.replicate_mode: x = split_batch_dimension(x, jax.local_device_count()) # Pretty-print the input to the summary and log file every so often. if (sub_step == 0 and self.pretty_print_input_fn is not None and should_run(step, self.print_input_every_steps)): x_first = jax.tree_map(lambda x: x[0], x) if self.replicate_mode else x x_strs = self.pretty_print_input_fn(x_first) logging.info("[%d] Input (%s) = %s", step, self.mode, x_strs) self.summary.add_text({"input": x_strs}) # Run the step function on the input. with jax.profiler.StepTraceAnnotation(self.mode, step_num=step): (tstate, metrics) = self.step_function(tstate, x, self.prng_keys) # Read metrics from device. metrics_np = self._get_metrics(metrics) end_time = time.perf_counter() metrics_np["step_time"] = end_time - start_time if "epoch" not in metrics_np.keys(): metrics_np["epoch"] = self.epoch # Add metrics to the current summary. self.summary.add(metrics_np) return (tstate, metrics_np) def flush(self, step: int): """Flush accumulated metric summaries to disk.""" if self.summary_writer is None: self.summary.clear() # Clear summary if we can't write it. return if self.summary.empty(): return # Do post-processing of the summaries. if self.process_summaries_fn is not None: self.summary = self.process_summaries_fn(self.summary, self.mode) # pylint: disable=not-callable # Write and clear summary data. logging.info("Writing summaries for mode %s.", self.mode) self.summary.write(self.summary_writer, step, prefix=self.summary_prefix) # Add extra summaries that are not computed by the step function. if self.extra_summaries_fn is not None: self.extra_summary.add(self.extra_summaries_fn(self.mode, step)) self.extra_summary.write(self.summary_writer, step, prefix="") def split_batch_dimension(inputs: Any, num_replicas: int) -> Any: """Splits the leading batch dimension. Given inputs of shape [num_replicas * batch_size, ...], it will reshape them to [num_replicas, batch_size, ...]. This operation is intended to be used right before calling pmap, which will eliminate the num_replicas dimension. Args: inputs: Tuple of inputs to split. num_replicas: Number of replicas. Returns: inputs with extra batch dimension. """ def split_batch_dim(x): assert x.ndim > 0 if (x.shape[0] % num_replicas) != 0: raise ValueError(f"Can't split {x.shape} into {num_replicas} replicas.") batch_size = x.shape[0] // num_replicas split_shape = [num_replicas, batch_size] + list(x.shape[1:]) return np.reshape(x, split_shape) return jax.tree_map(split_batch_dim, inputs)