Spaces:
Running
Running
# Copyright 2022 Google. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Generic JAX training loop for experiments.""" | |
import functools | |
import os | |
from typing import (Any, Callable, Dict, Optional, Sequence, Tuple) | |
from absl import logging | |
from clu import metric_writers | |
import flax | |
from flax import jax_utils | |
from flax import linen as nn | |
from flax import struct | |
from flax.training import checkpoints | |
import gin | |
import jax | |
import jax.numpy as jnp | |
import metrics_summary | |
import optimizer_config as opt_config | |
import training_task | |
import numpy as np | |
import tensorflow.compat.v2 as tf | |
PRNGKeys = training_task.PRNGKeys | |
TrainState = training_task.TrainState | |
TrainingTask = training_task.TrainingTask | |
StepFunction = training_task.StepFunction | |
Metrics = training_task.Metrics | |
MetricWriter = metric_writers.MetricWriter | |
MetricsSummary = metrics_summary.MetricsSummary | |
gfile = tf.io.gfile | |
unfreeze = flax.core.unfreeze | |
flatten_dict = flax.traverse_util.flatten_dict | |
should_run = training_task.should_run | |
# TODO(cstaats): Use a Protocol to specify that it must be possible to call | |
# the function with parameters (step: int, mode: str). This won't be feasible | |
# until we start using Python 3.8 or later. | |
StepModeCallable = Callable[..., None] | |
# This variable should *only* be set from register_interstep_callbacks. | |
_interstep_callbacks: Optional[Tuple[StepModeCallable, ...]] = None | |
def register_interstep_callbacks(**kwargs: StepModeCallable) -> None: | |
"""Populates _interstep_callbacks from gin. | |
This function should be called exactly ONCE and that call should happen AFTER | |
flag initialization (and more specifically, after gin parsing). And the caller | |
should NOT specify any arguments. | |
In gin configurations, a callback can be specified with an arbitrary name | |
like so: | |
register_interstep_callbacks.my_callback_name = @my_callback_function | |
Multiple callbacks can be registered without overriding each other as long as | |
they all have different names. Conversely, if you *want* to override a | |
callback, you need to give that callback the same name. | |
Args: | |
**kwargs: Specified by gin. Each argument should be a function (callable) | |
that can be called as my_function(step, mode), where step is an int and | |
mode is a str. | |
Raises: | |
ValueError: Raised on the second (and any subsequent) function call. | |
""" | |
global _interstep_callbacks | |
logging.info("registering functions: %s", kwargs.keys()) | |
if _interstep_callbacks is not None: | |
raise ValueError("register_interstep_callbacks may only be called once.") | |
_interstep_callbacks = tuple(kwargs.values()) | |
def clear_interstep_callbacks(): | |
"""Clear all registered callbacks, so that new ones can be registered.""" | |
global _interstep_callbacks | |
_interstep_callbacks = None | |
def run_interstep_callbacks(mode: str, step: int, sub_step: int = 0): | |
"""Run the registered callbacks. | |
Args: | |
mode: mode of the task to execute callbacks for. | |
step: training step number. | |
sub_step: For tasks that execute multiple iterations within a step. | |
E.g. a test cycle that runs multiple testing steps. | |
""" | |
for func in _interstep_callbacks: | |
func(sub_step or step, mode) | |
class Trainer: | |
"""Implements a JAX training loop.""" | |
# Returns a Flax module for the model. | |
# Takes a single argument mode, which can be "test", "train", or "generate". | |
model_definition: Any = gin.REQUIRED | |
# Iterator over trainining data. | |
get_training_dataset_iterator: Callable[[], Any] = gin.REQUIRED | |
# Iterator over test data. | |
get_test_dataset_iterator: Optional[Callable[[], Any]] = None | |
workdir: str = "" # Working directory for checkpoints. | |
load_dir: str = "" # Optional directory to load model. | |
num_steps: int = 100000 # Number of steps to train. | |
status_every_steps: int = 10 # Log step number every N steps. | |
log_every_steps: int = 100 # Log scalar data every N steps. | |
test_every_steps: int = 10 # Test model every N steps. | |
num_test_steps: int = 1 # Number of iterations to test. | |
generate_every_steps: int = 1000 # Generate examples every N steps. | |
print_input_every_steps: int = 1000 # Print example data every N steps. | |
save_checkpoints: bool = True # Save training checkpoints | |
checkpoint_every_steps: int = 5000 # Save checkpoints every N steps. | |
restore_checkpoints: bool = True # Restore from previous checkpoint. | |
restore_state_variables: bool = True # Restore TrainState.state from chkpt. | |
# Record metrics for "train", "test", etc. in separate directories. | |
# Otherwise they will be saved with separate prefixes. | |
use_separate_metric_directories: bool = True | |
# Optimizer options. | |
optimizer_factory: opt_config.OptimizerConfig = gin.REQUIRED | |
learning_rate_schedule: Callable[[jnp.ndarray, int], jnp.ndarray] = ( | |
opt_config.lr_cosine_decay) | |
# Maximum steps for the LR schedule. Zero means use num_steps. | |
max_scheduled_steps: int = 0 | |
warmup_steps: int = 1000 # Number of warmup steps. | |
learning_rate_multiplier: float = 1.0 # Used to scale the learning rate. | |
random_seed: int = 42 # Initial random seed. | |
# Names of random number generators used by the model. | |
rng_key_names: Optional[Sequence[str]] = ("dropout",) | |
# Debug options. | |
replicate_mode: bool = True # pmap over multiple replicas. | |
trace_debug_mode: bool = False # Run in eager mode to trace results. | |
print_variables: bool = False # Dump parameters/variables to stdout. | |
# Function to compute additional summary information. | |
# Takes a MetricsSummary object and a mode string (e.g. "test") as arguments, | |
# returns a MetricsSummary object. | |
process_summaries_function: Optional[Callable[[Any, str], Any]] = None | |
# Function to pretty print the input for each training step. | |
pretty_print_input_function: Optional[Callable[[Any], Any]] = None | |
# Classes to use for summarizing metrics. | |
metrics_summary_factory: Any = metrics_summary.MetricsSummary | |
extra_summaries_fn: training_task.ExtraSummariesFunction = ( | |
lambda mode, step: dict()) | |
post_save_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None | |
post_load_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None | |
def learning_rate_schedule_fn(self, step): | |
"""Returns the learning rate for the given step.""" | |
# There are four components to the learning rate. | |
# | |
# The base_lrate is defined by the optimizer, and different optimizers have | |
# different relative rates, e.g. Adafactor requires a higher LR than Adam. | |
# By default, the base_lrate is 1.0 for Adafactor. | |
# | |
# The base_lrate is then multiplied by the learning rate decay schedule, | |
# which typically starts at a maximum value and decays over time. | |
# Each schedule can be individually configured, e.g. from 0.01 to 0.001. | |
# The max_scheduled_steps parameter controls the decay rate of the schedule. | |
# | |
# Finally, the LR is scaled by the learning_rate_multiplier, which provides | |
# an easy way to scale the LR for hyperparameter tuning in a way that is | |
# independent of the choice of schedule or optimizer. The default is 1.0. | |
# | |
# During the warmp period, the learning rate ramps up linearly from zero. | |
step = jnp.asarray(step, dtype=jnp.float32) | |
if self.max_scheduled_steps == 0: | |
max_steps = self.num_steps | |
else: | |
max_steps = self.max_scheduled_steps | |
base_lrate = float(self.optimizer_factory.learning_rate) | |
lr_multiplier = float(self.learning_rate_multiplier) | |
# Linear increase in learning rate up to warmup_steps. | |
warmup_steps = float(self.warmup_steps) | |
lr_warmup_ramp = jnp.minimum(step, warmup_steps) / warmup_steps | |
# Hold step at a constant value during the warmup period. | |
# Required for some schedules, like rsqrt_decay. | |
step = jnp.maximum(step, warmup_steps) | |
# Get the scheduled learning rate. | |
lrate = self.learning_rate_schedule(step, max_steps) | |
# Multiply lrate by the base, warmup and multiplier factors. | |
lrate = lrate * base_lrate * lr_warmup_ramp * lr_multiplier | |
return jnp.asarray(lrate, dtype=jnp.float32) | |
def _init_rngs(self, rngs: PRNGKeys, step: int) -> PRNGKeys: | |
# Get a new random number generator for each step | |
rngs = jax.random.fold_in(rngs, step) | |
rngs = jax.random.split(rngs, len(self.rng_key_names)) | |
rngs = {key: rngs[i] for i, key in enumerate(self.rng_key_names)} | |
return rngs | |
def train_step(self, model: nn.Module, tstate: TrainState, x: Any, | |
rngs: PRNGKeys) -> Tuple[TrainState, Metrics]: | |
"""Perform a training step, pmapped over multiple devices. | |
Args: | |
model: The model to use for the step function. | |
tstate: Values for state variables, and the optimizer. | |
x: A batch of inputs to train on. | |
rngs: PRNGKey (possibly replicated). | |
Returns: | |
Tuple of (new_tstate, metrics: dictionary of scalar values) | |
""" | |
mutable_keys = [k for (k, _) in tstate.state.items()] | |
step = tstate.optimizer.state.step | |
rngs = self._init_rngs(rngs, step) | |
# Refactor the model as a loss function from trainable params to loss, so | |
# that we can differentiate with jax and get {d}loss/{d}params. | |
# Inputs and non-trainable params are bound within the closure. | |
# model:: x, { state_params } -> (loss, metrics), { new_state_params } | |
# loss_fn:: params -> (loss, (metrics, new_state)) | |
def loss_fn(params): | |
"""Loss function.""" | |
(loss, mets), nstate = model.apply({"params": params, **tstate.state}, | |
x, | |
rngs=rngs, | |
mutable=mutable_keys) | |
return loss, (mets, nstate) | |
# grad_fn:: params -> ((loss, (aux, nstate)), param_gradients) | |
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | |
# Run forward and backward pass. | |
(loss, (metrics, new_state)), param_grads = grad_fn(tstate.optimizer.target) | |
del loss # loss is only recorded if it is part of the metrics | |
if self.replicate_mode: | |
param_grads = jax.lax.pmean(param_grads, axis_name="batch") | |
lrate = self.learning_rate_schedule_fn(step) | |
new_optimizer = tstate.optimizer.apply_gradient( | |
param_grads, learning_rate=lrate) | |
# Metrics are summary values that will be logged. | |
if self.replicate_mode: | |
# Merge metrics (take mean/sum etc.) over replicas on-device. | |
summary_class = self.metrics_summary_factory | |
metrics = summary_class.merge_replicated_metrics( | |
metrics, model.metrics_summary_operations(aggregate_over="devices")) | |
metrics["learning_rate"] = lrate | |
return (TrainState(new_optimizer, new_state), metrics) | |
def other_step(self, model: nn.Module, tstate: TrainState, x: Any, | |
rngs: PRNGKeys) -> Tuple[TrainState, Metrics]: | |
"""Perform a test or generate step, pmapped over multiple devices. | |
Args: | |
model: The model to use for the step function. | |
tstate: Values for state variables, and the optimizer. | |
x: A batch of inputs to train on. | |
rngs: PRNGKey (possibly replicated). | |
Returns: | |
Tuple of (new_tstate, metrics: dictionary of scalar values) | |
""" | |
mutable_keys = [k for (k, _) in tstate.state.items()] | |
step = tstate.optimizer.state.step | |
rngs = self._init_rngs(rngs, step) | |
params = tstate.optimizer.target | |
(loss, metrics), new_state = model.apply({"params": params, **tstate.state}, | |
x, | |
rngs=rngs, | |
mutable=mutable_keys) | |
del loss # loss is only recorded if it is part of the metrics | |
# Metrics are summary values that will be logged. | |
if self.replicate_mode: | |
# Merge metrics (take mean/sum etc.) over replicas on-device. | |
summary_class = self.metrics_summary_factory | |
metrics = summary_class.merge_replicated_metrics( | |
metrics, model.metrics_summary_operations(aggregate_over="devices")) | |
return (TrainState(tstate.optimizer, new_state), metrics) | |
def initialize_model(self) -> Tuple[TrainState, int, nn.Module, PRNGKeys]: | |
"""Initialize the model and/or load it from a checkpoint. | |
Returns: | |
(tstate: TrainState, -- The parameters and state for the the model. | |
start_step: int, -- The step number, when restoring from checkpoint. | |
imodel: nn.Module, -- A model object (created with mode "init"). | |
rngs: PRNGkeys) -- Initial random numbers. | |
""" | |
# Set up random number generators. | |
# --------------------------------- | |
logging.info("==== Training loop: initializing model ====") | |
logging.info("Process %d of %d", jax.process_index(), jax.process_count()) | |
logging.info("Local device count = %d", jax.local_device_count()) | |
logging.info("Number of replicas = %d", | |
jax.process_count() * jax.local_device_count()) | |
logging.info("Using random number seed %d", self.random_seed) | |
prng = jax.random.PRNGKey(self.random_seed) | |
prng, init_rng = jax.random.split(prng) | |
# Grab rngs, which provide different random numbers for each replica. | |
if self.replicate_mode: | |
prngs = jax.random.split(prng, jax.local_device_count()) | |
else: | |
prngs = prng | |
del prng | |
# Create a dictionary of prng keys for initialization. | |
rng_key_names_init = list(self.rng_key_names) + ["params"] | |
init_rngs = jax.random.split(init_rng, len(rng_key_names_init)) | |
init_rngs = {key: init_rngs[i] for i, key in enumerate(rng_key_names_init)} | |
del init_rng | |
# Build Model | |
# ------------------------------------------------------------------------- | |
logging.info("Initializing the model.") | |
# Create a model, which will be used to initialize trainable parameters. | |
imodel = self.model_definition(mode="init") | |
# The init function will lazily initialize the model, given a fake input. | |
# It returns initialized variables, without doing a fwd pass. | |
model_init_fn = jax.jit(imodel.init) | |
variables = model_init_fn(init_rngs, imodel.get_fake_input()) | |
# Split variables into trainable and non-trainable sets. | |
mstate, params = variables.pop("params") | |
del variables # Delete to avoid wasting resources. | |
# Create an optimizer for params. | |
optimizer_def = self.optimizer_factory.create_optimizer_def() | |
optimizer = optimizer_def.create(params) | |
# tstate holds the full training state of the model. | |
tstate = TrainState(optimizer, mstate) | |
if self.print_variables: | |
logging.info("params = %s", tstate.optimizer.target) | |
logging.info("state = %s", tstate.state) | |
# Load a pre-trained model or restore it from checkpoint. | |
if self.workdir or self.load_dir: | |
restore_checkpoints = self.restore_checkpoints | |
else: | |
restore_checkpoints = False | |
start_step = 0 | |
if restore_checkpoints: | |
tstate = self.restore_checkpoint(tstate) | |
start_step = int(tstate.optimizer.state.step) | |
# Log info on trainable parameters (before replicating them). | |
self._write_parameter_info(tstate) | |
# raise ValueError("That's all folks!") | |
# Replicate the training state across local devices. | |
if self.replicate_mode: | |
tstate = jax_utils.replicate(tstate) | |
return (tstate, start_step, imodel, prngs) | |
def restore_checkpoint(self, train_state: TrainState) -> TrainState: | |
"""Load a pre-trained model or restore it from a checkpoint.""" | |
# Figure out if we have an existing checkpoint. | |
if not self.workdir: | |
logging.info("No working directory specified.") | |
existing_checkpoint = False | |
elif not gfile.exists(self.workdir): | |
logging.info("No existing checkpoint directory %s", self.workdir) | |
existing_checkpoint = False | |
elif not gfile.isdir(self.workdir): | |
raise ValueError(f"workdir {self.workdir} must be a directory.") | |
else: | |
ckpath = checkpoints.latest_checkpoint(self.workdir, "checkpoint_") | |
if ckpath: | |
logging.info("Found existing checkpoint in %s", self.workdir) | |
existing_checkpoint = True | |
else: | |
logging.info("No existing checkpoint in %s", self.workdir) | |
existing_checkpoint = False | |
# If any checkpoints exist in workdir, then use those first. | |
# This will ensure that the task will restore properly if it's preempted. | |
if existing_checkpoint: | |
logging.info("Restoring model from last checkpoint %s:", self.workdir) | |
load_dir = self.workdir | |
elif self.load_dir: | |
logging.info("Loading pre-trained model from %s:", self.load_dir) | |
load_dir = self.load_dir | |
else: | |
logging.warning("Unable to load model.") | |
return train_state | |
loaded_train_state = checkpoints.restore_checkpoint(load_dir, train_state) | |
step = int(loaded_train_state.optimizer.state.step) | |
self.post_load_checkpoint_fn(load_dir, step) | |
if self.restore_state_variables: | |
# Restore complete state. | |
logging.info("Restoring all variables and state.") | |
train_state = loaded_train_state | |
del loaded_train_state | |
else: | |
# Restore trainable variables, but not other state. | |
logging.info("Only restoring trainable parameters.") | |
train_state = TrainState(loaded_train_state.optimizer, train_state.state) | |
del loaded_train_state | |
return train_state | |
def save_checkpoint(self, tstate: TrainState, step: int, | |
param_summary: Optional[MetricsSummary]): | |
"""Save a checkpoint with the model state. | |
Args: | |
tstate: The training state. | |
step: The current step number. | |
param_summary: Optional metrics summary to write parameter statistics. | |
""" | |
logging.info("Saving checkpoint in directory %s", self.workdir) | |
if self.replicate_mode: | |
save_state = jax_utils.unreplicate(tstate) | |
else: | |
save_state = tstate | |
checkpoints.save_checkpoint(self.workdir, save_state, step) | |
# While we're at it, record distributions of trainable parameters. | |
if param_summary is not None: | |
logging.info("Recording parameter distributions.") | |
params_dict = jax.device_get( | |
_flatten_dict_string_keys(save_state.optimizer.target)) | |
param_distribs = self._compute_parameter_distributions(params_dict) | |
param_summary.add(param_distribs) | |
def create_training_task(self, mode: str, imodel: nn.Module, prngs: PRNGKeys, | |
writers: Dict[str, MetricWriter]) -> TrainingTask: | |
"""Create a new TrainingTask for the given mode. | |
Args: | |
mode: The mode for the task, e.g. "train", "test", "generate". | |
imodel: The model object from initialize_model. | |
prngs: The PRNGKeys from initialize_model. | |
writers: A dictionary of summary writers. | |
Returns: | |
A TrainingTask object. | |
""" | |
logging.info("Training loop: creating task for mode %s", mode) | |
if self.use_separate_metric_directories: | |
prefix = "" | |
else: | |
prefix = mode | |
if mode == "train": | |
ds = self.get_training_dataset_iterator | |
elif mode == "test": | |
ds = self.get_test_dataset_iterator | |
else: | |
ds = None | |
# We summarize metrics over multiple training steps. | |
# These types control how the summary is computed. | |
metric_summary_ops = { | |
"step_time": "mean", | |
"learning_rate": "last", | |
**imodel.metrics_summary_operations(aggregate_over="steps") | |
} | |
summary = self.metrics_summary_factory(metric_summary_ops) | |
extra_summary = self.metrics_summary_factory({}) | |
summary_writer = self._get_summary_writer(mode, writers) | |
return TrainingTask( | |
mode=mode, | |
dataset=ds, | |
step_function=self._compile_step_function(mode), | |
prng_keys=prngs, | |
summary=summary, | |
extra_summary=extra_summary, | |
summary_writer=summary_writer, | |
summary_prefix=prefix, | |
# --- options --- | |
replicate_mode=self.replicate_mode, | |
print_input_every_steps=self.print_input_every_steps, | |
pretty_print_input_function=self.pretty_print_input_function, | |
process_summaries_function=self.process_summaries_function, | |
extra_summaries_function=self.extra_summaries_fn) | |
def train(self): | |
"""Runs the training and evaluation loop.""" | |
# The master process saves checkpoints and summaries to disk. | |
is_master_process = jax.process_index() == 0 | |
if self.workdir: | |
save_checkpoints = self.save_checkpoints | |
else: | |
save_checkpoints = False | |
# --- Create and initialize the model. --- | |
(tstate, start_step, imodel, prngs) = self.initialize_model() | |
# Log experiment hyper-parameters. | |
writers = {} | |
train_writer = self._get_summary_writer("train", writers) | |
if start_step == 0: | |
self._write_config(train_writer) | |
# Additional summary objects. | |
param_summary = self.metrics_summary_factory({}) # Parameter statistics. | |
# --- Create task objects for test, train, and generate. --- | |
tasks = {} | |
train_task = self.create_training_task("train", imodel, prngs, writers) | |
tasks["train"] = train_task | |
if (self.get_test_dataset_iterator is not None and | |
self.test_every_steps != 0): | |
test_task = self.create_training_task("test", imodel, prngs, writers) | |
tasks["test"] = test_task | |
if self.generate_every_steps != 0: | |
gen_task = self.create_training_task("generate", imodel, prngs, | |
writers) | |
tasks["generate"] = gen_task | |
# Register any additional actions. | |
register_interstep_callbacks() | |
# Main Training Loop | |
# -------------------------------------------------------------------------- | |
logging.info("==== Training loop: starting main loop ====") | |
with metric_writers.ensure_flushes(*writers.values()): | |
for step in range(start_step, self.num_steps): | |
# Log status every so often to monitor progress. | |
if should_run(step, self.status_every_steps): | |
logging.info("Step: %d", step) | |
# Train. | |
train_x = train_task.get_next_input() | |
(tstate, _) = train_task.run_step(tstate, train_x, step) | |
run_interstep_callbacks("train", step) | |
del train_x | |
# Test. | |
if should_run(step, self.test_every_steps): | |
if self.num_test_steps > 1: | |
logging.info("Test cycle: %d iterations.", self.num_test_steps) | |
for sub_step in range(0, self.num_test_steps): | |
test_x = test_task.get_next_input() | |
# TODO(delesley): This is an ugly hack to run generate steps. | |
# Run a generate step using test data. | |
# Generate is run just *before* the last test iteration. | |
if ((sub_step == self.num_test_steps - 1) and | |
should_run(step, self.generate_every_steps)): | |
logging.info("Generate cycle.") | |
(tstate, _) = gen_task.run_step(tstate, test_x, step) | |
run_interstep_callbacks("generate", step) | |
(tstate, _) = test_task.run_step(tstate, test_x, step, | |
sub_step=sub_step) | |
run_interstep_callbacks("test", step, sub_step) | |
del test_x | |
# --- Save checkpoints on the master host. --- | |
is_last_step = (step == self.num_steps - 1) | |
checkpoint_current_step = ( | |
save_checkpoints and | |
(should_run(step, self.checkpoint_every_steps) or is_last_step)) | |
if checkpoint_current_step: | |
if is_master_process: | |
self.save_checkpoint(tstate, step, param_summary) | |
self.post_save_checkpoint_fn(self.workdir, step) | |
# --- Flush summaries to disk. --- | |
if should_run(step, self.log_every_steps): | |
for tsk in tasks.values(): | |
tsk.flush(step) | |
param_summary.write(train_writer, step, prefix="params") | |
logging.info("Training Finished.") | |
if self.replicate_mode: | |
tstate = jax_utils.unreplicate(tstate) | |
if self.print_variables: | |
logging.info("params = %s", tstate.optimizer.target) | |
logging.info("state = %s", tstate.state) | |
def _compile_step_function(self, mode: str) -> StepFunction: | |
"""Compile a step function (training or test).""" | |
# Create a model object, and a step function that is a closure over the | |
# object. Flax modules are supposed to be "stateless", in that all state | |
# is contained the TrainState object that is passed as an input parameter. | |
# However, creating the model object may involve allocating expensive | |
# data structures, or launching processes, and should only be done once. | |
model = self.model_definition(mode=mode) | |
if mode == "train": | |
step_fn = functools.partial(self.train_step, model) | |
else: | |
step_fn = functools.partial(self.other_step, model) | |
if self.replicate_mode: | |
assert not self.trace_debug_mode | |
logging.info("Compiling mode %s with pmap.", mode) | |
p_fn = jax.pmap(step_fn, donate_argnums=(0,), axis_name="batch") | |
elif self.trace_debug_mode: | |
logging.info("Compiling mode %s with trace_debug.", mode) | |
p_fn = step_fn | |
else: | |
logging.info("Compiling mode %s with jit.", mode) | |
p_fn = jax.jit(step_fn, donate_argnums=(0,)) | |
return p_fn | |
def _get_summary_writer(self, mode: str, | |
writers: Dict[str, MetricWriter]) -> MetricWriter: | |
"""Create a summary writer for the given mode. | |
Args: | |
mode: the mode for the summaries, e.g. "test", "train" | |
writers: a dictionary which caches previously-created writers. | |
Returns: | |
A writer for the given mode. | |
""" | |
if self.use_separate_metric_directories: | |
# Create a separate writer & directory for each mode. | |
w_mode = mode | |
summary_dir = os.path.join(self.workdir, mode) | |
else: | |
# Create a single default writer for all modes. | |
w_mode = "train" | |
summary_dir = self.workdir | |
if w_mode in writers: | |
# Return previously created and cached writer. | |
logging.info("Returning cached summary writer (%s) for mode %s", | |
w_mode, mode) | |
return writers[w_mode] | |
if not self.workdir: | |
# No working directory, so log only. | |
logging.info("Creating logging writer (%s) for mode %s", w_mode, mode) | |
writer = metric_writers.LoggingWriter() | |
else: | |
# Create a new writer for workdir. | |
# Only the master will actually write summaries to workdir. | |
logging.info("Creating summary writer (%s) for mode %s in directory %s", | |
w_mode, mode, summary_dir) | |
is_master = jax.process_index() == 0 | |
gfile.makedirs(summary_dir) | |
writer = metric_writers.create_default_writer(summary_dir, | |
just_logging=not is_master) | |
writers[w_mode] = writer | |
return writer | |
def _write_config(self, writer): | |
"""Write the configuration file to the working directory.""" | |
is_master = jax.process_index() == 0 | |
config_str = gin.operative_config_str() | |
logging.info("Gin config: \n%s", config_str) | |
# Write configuration to workdir. | |
if is_master and self.workdir: | |
config_file_name = os.path.join(self.workdir, "config.gin") | |
with gfile.GFile(config_file_name, "w") as f: | |
f.write(config_str) | |
# Write config string text to tensorboard. | |
writer.write_texts(0, {"config": gin.markdown(config_str)}) | |
def _write_parameter_info(self, tstate: TrainState): | |
"""Write information on state and trainable parameters to the log.""" | |
# Write information on parameters to log file. | |
params_dict = _flatten_dict_string_keys(tstate.optimizer.target) | |
total_nparams = 0 | |
for (k, v) in params_dict.items(): | |
nparams = np.prod(v.shape) | |
total_nparams += nparams | |
logging.info("parameter: %s, shape %s, size %d", k, v.shape, nparams) | |
logging.info("Total parameters: %d", total_nparams) | |
# Write information on state variables to log file. | |
state_dict = _flatten_dict_string_keys(tstate.state) | |
state_size = 0 | |
total_state = 0 | |
for (k, v) in state_dict.items(): | |
if hasattr(v, "shape"): | |
state_size = np.prod(v.shape) | |
total_state += state_size | |
logging.info("state: %s, shape %s, size %d", k, v.shape, state_size) | |
else: | |
# Some other stuff may be stored in the state. | |
logging.info("state: %s [unknown]", k) | |
logging.info("Total state size: %d", total_state) | |
def _compute_parameter_distributions(self, params_dict): | |
"""Compute info on distributions of parameters.""" | |
scalar_params_dict = {} | |
for (k, v) in params_dict.items(): | |
# Convert from bfloat16, which crashes when serializing a NaN. | |
v = np.asarray(v, dtype=jnp.float32) | |
scalar_params_dict[k + "_mean"] = np.mean(v) | |
scalar_params_dict[k + "_stddev"] = np.std(v) | |
# scalar_params_dict[k + "_min"] = np.min(v) | |
# scalar_params_dict[k + "_max"] = np.max(v) | |
return scalar_params_dict | |
def _flatten_dict_string_keys(params): | |
"""Flattens a nested dictionary to have string keys and '/' separators.""" | |
return {"/".join(k): v for k, v in flatten_dict(unfreeze(params)).items()} | |