HugoVoxx commited on
Commit
a5ccd04
1 Parent(s): 9fa58f4

Upload 5 files

Browse files
aglib/meliad/metrics_summary.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Class to handle summarizing of metrics over multiple training steps."""
16
+
17
+ import abc
18
+ from typing import Any, Dict, Mapping, Optional, Tuple, Union
19
+ from absl import logging
20
+ from clu import metric_writers
21
+ import gin
22
+ import jax
23
+ from jax import numpy as jnp
24
+ import numpy as np
25
+
26
+
27
+ Array = Union[jnp.ndarray, np.ndarray]
28
+
29
+
30
+ class Aggregator(abc.ABC): # Superclass for type checks
31
+
32
+ @abc.abstractmethod
33
+ def add(self, value: Any):
34
+ pass
35
+
36
+ @abc.abstractmethod
37
+ def is_valid(self) -> bool:
38
+ pass
39
+
40
+ @abc.abstractmethod
41
+ def to_value(self):
42
+ pass
43
+
44
+
45
+ class _MeanAggregator(Aggregator):
46
+ """Maintains the mean of incoming values."""
47
+ mean: float = 0.0
48
+ weight: float = 0.0
49
+
50
+ def add(self, new_value: Any):
51
+ """Aggregates a new value into the mean."""
52
+ if np.ndim(new_value) == 0: # is a scalar; works with int, float, Array
53
+ val, weight = new_value, 1.0 # assuming weight 1 by default
54
+ else:
55
+ val, weight = new_value
56
+ if weight < 0.0:
57
+ raise ValueError("Adding value with negative weight.")
58
+ total_weight = self.weight + weight
59
+ if total_weight != 0.0 and weight > 0.0:
60
+ delta = (val - self.mean) * weight / total_weight
61
+ self.mean += delta
62
+ self.weight = total_weight
63
+
64
+ def is_valid(self) -> bool:
65
+ return self.weight > 0.0
66
+
67
+ def to_value(self):
68
+ assert self.weight > 0.0
69
+ return self.mean
70
+
71
+
72
+ class _SumAggregator(_MeanAggregator):
73
+ # We aggregate sum and mean in the same way as a tuple of the form:
74
+ # (weighted mean, total weights). "sum" can then be computed by
75
+ # multiplying the two values.
76
+
77
+ def is_valid(self) -> bool:
78
+ return True
79
+
80
+ def to_value(self):
81
+ return self.mean * self.weight
82
+
83
+
84
+ class _LastAggregator(Aggregator):
85
+ """Remembers the last value given."""
86
+ last_value: Optional[float] = None
87
+
88
+ def add(self, new_value: Any):
89
+ self.last_value = new_value
90
+
91
+ def is_valid(self) -> bool:
92
+ return self.last_value is not None
93
+
94
+ def to_value(self):
95
+ assert self.last_value is not None
96
+ return self.last_value
97
+
98
+
99
+ @gin.configurable
100
+ class MetricsSummary:
101
+ """Summarizes a set of a metrics over multiple training steps."""
102
+
103
+ def __init__(self,
104
+ metric_types: Mapping[str, str],
105
+ upscale_images: bool = True,
106
+ remove_outliers: bool = False):
107
+ """Creates a MetricSummarizer.
108
+
109
+ Args:
110
+ metric_types: Map from metrics to the type of summary. Types are:
111
+ "mean" = Compute the cumulative moving average.
112
+ "sum" = Compute the sum.
113
+ "last" = No summary, just return the last value.
114
+ upscale_images: Upscale small images for easier viewing.
115
+ remove_outliers: Remove outliers from histograms.
116
+ """
117
+ self.metric_dict = {} # type: Dict[str, Aggregator]
118
+ self.text_dict = {}
119
+ self.metric_types = metric_types
120
+ self.upscale_images = upscale_images
121
+ self.remove_outliers = remove_outliers
122
+ self.constructor_map = {
123
+ "mean": _MeanAggregator,
124
+ "sum": _SumAggregator,
125
+ "last": _LastAggregator,
126
+ }
127
+ logging.debug("Registered metrics: %r", metric_types)
128
+
129
+ def current_metric_dict(self) -> Mapping[str, Aggregator]:
130
+ return self.metric_dict
131
+
132
+ def _is_image(self, image: Array) -> bool:
133
+ if image.ndim != 4:
134
+ return False
135
+ # Greyscale or RGB image.
136
+ return image.shape[-1] == 1 or image.shape[-1] == 3
137
+
138
+ def _upscale_image(self, image: Array) -> Array:
139
+ """Upscale small images to more pixels, for easier viewing."""
140
+ if not self.upscale_images:
141
+ return image
142
+ assert image.ndim == 4 # (num_images, ysize, xsize, num_channels)
143
+ ys = image.shape[1]
144
+ xs = image.shape[2]
145
+ if xs > 512 or ys > 512:
146
+ return image # No scaling.
147
+ elif xs > 256 or ys > 256:
148
+ scale = 2
149
+ else:
150
+ scale = 4
151
+ yidx = np.arange(ys * scale) // scale
152
+ xidx = np.arange(xs * scale) // scale
153
+ scaled_image = image[:, yidx, :, :][:, :, xidx, :]
154
+ return scaled_image
155
+
156
+ def _remove_outliers(self, v, std_range: float = 4):
157
+ if not self.remove_outliers:
158
+ return v
159
+ v_mean = np.mean(v)
160
+ v_std = np.std(v)
161
+ return np.where(np.abs(v) > (v_std * std_range), v_mean, v)
162
+
163
+ @staticmethod
164
+ def merge_replicated_metrics(device_metrics: Mapping[str, Any],
165
+ metric_types: Mapping[str, str]):
166
+ """Merge metrics across devices by psum over "batch" axis.
167
+
168
+ Args:
169
+ device_metrics: dictionary of device metrics.
170
+ metric_types: map from the metric name to { "mean", "sum" }
171
+
172
+ Returns:
173
+ A dictionary of metrics.
174
+ """
175
+ logging.info("Merging metrics across devices %r: ",
176
+ [(k, metric_types[k] if k in metric_types else None)
177
+ for k in device_metrics.keys()])
178
+
179
+ def aggregate_sum(value: Array) -> Array:
180
+ assert not isinstance(value, tuple), (
181
+ "Weighted sums are not supported when aggregating over devices.")
182
+ return jax.lax.psum(value, axis_name="batch")
183
+
184
+ def aggregate_mean(value: Array, weight: Array) -> Tuple[Array, Array]:
185
+ weighted_value = value * weight
186
+ weighted_value = jax.lax.psum(weighted_value, axis_name="batch")
187
+ weight = jax.lax.psum(weight, axis_name="batch")
188
+ return weighted_value / (weight + 1.0e-6), weight
189
+
190
+ aggregated_metrics = dict(device_metrics)
191
+ for k, value in aggregated_metrics.items():
192
+ if k not in metric_types:
193
+ # If no metric type is given, metric remains untouched.
194
+ continue
195
+ if metric_types[k] == "sum":
196
+ aggregated_metrics[k] = aggregate_sum(value)
197
+ elif metric_types[k] == "mean":
198
+ if not isinstance(aggregated_metrics[k], tuple):
199
+ logging.info("Metric '%s' has no weight; assuming 1.0.", k)
200
+ value = (value, jnp.array(1.0))
201
+ aggregated_metrics[k] = aggregate_mean(*value)
202
+ else:
203
+ raise ValueError("Can only aggregate 'sum' and 'mean' over devices. "
204
+ f"Got {metric_types[k]}.")
205
+ return aggregated_metrics
206
+
207
+ def _new_aggregator(self, key) -> Aggregator:
208
+ if key in self.metric_types:
209
+ return self.constructor_map[self.metric_types[key]]()
210
+ else:
211
+ # TODO(mrabe): The default to last_value is not obvious. Force all metric
212
+ # types to be given explicitly.
213
+ logging.debug("No metric type for accumulator: %s", key)
214
+ return _LastAggregator()
215
+
216
+ def add(self, metrics: Mapping[str, Any]):
217
+ """Add metrics from the current training step to the summary.
218
+
219
+ Args:
220
+ metrics: Dictionary of metrics.
221
+ """
222
+ for k, new_value in metrics.items():
223
+ if k not in self.metric_dict:
224
+ self.metric_dict[k] = self._new_aggregator(k)
225
+ self.metric_dict[k].add(new_value)
226
+
227
+ def add_text(self, text_metrics: Mapping[str, str]):
228
+ """Add text metrics from the current step to the summary."""
229
+ for (k, v) in text_metrics.items():
230
+ self.text_dict[k] = str(v)
231
+
232
+ def empty(self):
233
+ """Return true if there are no summaries to write."""
234
+ return not (self.metric_dict or self.text_dict)
235
+
236
+ def clear(self):
237
+ """Clear acculumated summaries."""
238
+ self.metric_dict = {}
239
+ self.text_dict = {}
240
+
241
+ def write(self, writer: metric_writers.MetricWriter, step: int, prefix: str):
242
+ """Write metrics using summary_writer, and clear all summaries."""
243
+ if self.empty():
244
+ return
245
+
246
+ # Special logic for organizing metrics under tensorboard.
247
+ # Tensorboard has top-level groups, but doesn't have subgroups.
248
+ # Scalars are put into separate top-level groups for easier viewing.
249
+ # e.g. all scalars in "train", "test", etc.
250
+ # For images, each set of images should be a different top-level group,
251
+ # otherwise all images will get tossed into a single group under,
252
+ # e.g. "generate".
253
+ if prefix:
254
+ s_prefix = prefix + "/"
255
+ i_prefix = prefix + "_"
256
+ else:
257
+ # Each prefix is stored in a separate subdirectory already.
258
+ s_prefix = ""
259
+ i_prefix = ""
260
+
261
+ # Split metrics into different types.
262
+ scalars = {}
263
+ images = {}
264
+ histograms = {}
265
+ text_dict = {}
266
+
267
+ # Sort metrics into scalars, images, text, and histograms.
268
+ for k, aggregator in self.metric_dict.items():
269
+ if not isinstance(aggregator, Aggregator):
270
+ raise ValueError("Internal error: metric_dict should contain only "
271
+ "_Aggregator objects; contained %s" % aggregator)
272
+ if not aggregator.is_valid():
273
+ raise ValueError(f"No valid value for metric {k}.")
274
+
275
+ v = aggregator.to_value()
276
+
277
+ s_key = s_prefix + k
278
+ i_key = i_prefix + k
279
+
280
+ finite_mask = np.isfinite(v)
281
+ if not np.all(finite_mask):
282
+ logging.warning("Item %s contains non-finite elements.", k)
283
+ v = np.where(finite_mask, v, np.zeros_like(v))
284
+ if v is None:
285
+ logging.warning("Invalid value for %s", k)
286
+ elif np.ndim(v) == 0:
287
+ scalars[s_key] = v
288
+ elif self._is_image(v):
289
+ images[i_key] = self._upscale_image(v)
290
+ else:
291
+ histograms[s_key] = self._remove_outliers(v)
292
+
293
+ # Handle text data.
294
+ for (k, v) in self.text_dict.items():
295
+ s_key = s_prefix + k
296
+ text_dict[s_key] = v
297
+
298
+ # Write metrics.
299
+ if scalars:
300
+ writer.write_scalars(step, scalars)
301
+ if images:
302
+ writer.write_images(step, images)
303
+ if histograms:
304
+ writer.write_histograms(step, histograms)
305
+ if text_dict:
306
+ writer.write_texts(step, text_dict)
307
+
308
+ # Clear accumulated summaries.
309
+ self.clear()
aglib/meliad/optimizer_config.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Gin configurable optimizer definitions.
16
+ """
17
+
18
+ from typing import Any, Optional
19
+
20
+ from absl import logging
21
+ from flax import optim
22
+ from flax import struct
23
+ import gin
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+
28
+ OptimizerDef = Any
29
+
30
+
31
+ @struct.dataclass
32
+ class OptimizerConfig:
33
+ """Base class for optimizer configurations."""
34
+
35
+ learning_rate: float = 0.01 # All optimizers have a learning rate.
36
+
37
+ def create_optimizer_def(self) -> OptimizerDef:
38
+ raise ValueError("Not implemented.")
39
+
40
+
41
+ @gin.configurable
42
+ @struct.dataclass
43
+ class AdamConfig(OptimizerConfig):
44
+ """Creates and configures the Adam optimizer."""
45
+
46
+ # Adam does not use parameter scale, and thus requires a smaller lrate.
47
+ # This will be multiplied by the learning rate schedule.
48
+ learning_rate: float = 0.05
49
+
50
+ beta1: float = 0.9 # For moving average of gradient.
51
+ beta2: float = 0.98 # For moving average of gradient magnitude.
52
+ weight_decay_rate: float = 0.0 # Relative to learning rate.
53
+
54
+ def create_optimizer_def(self) -> optim.OptimizerDef:
55
+ logging.info("Using Adam Optimizer. lr=%f, b1=%f, b2=%f",
56
+ self.learning_rate, self.beta1, self.beta2)
57
+ return optim.Adam(beta1=self.beta1,
58
+ beta2=self.beta2,
59
+ weight_decay=self.weight_decay_rate)
60
+
61
+
62
+ @gin.configurable
63
+ @struct.dataclass
64
+ class FlaxAdafactorConfig(OptimizerConfig):
65
+ """Creates and configures the Adafactor optimizer."""
66
+
67
+ # Adafactor scales gradients according to parameter scale.
68
+ # This will be multiplied by the learning rate schedule.
69
+ learning_rate: float = 1.0
70
+ beta1: Optional[float] = 0.9 # Enables momentum with extra memory cost.
71
+
72
+ def create_optimizer_def(self) -> optim.OptimizerDef:
73
+ # Use wd_lr_exponent to get weight_decay relative to learning rate.
74
+ logging.info("Using Flax Adafactor Optimizer. lr=%f, b1=%f",
75
+ self.learning_rate, self.beta1)
76
+ return optim.Adafactor(beta1=self.beta1)
77
+
78
+
79
+
80
+
81
+ # ----------------------------------------------------------------------------
82
+ # Learning rate schedules for use with any optimizer.
83
+ #
84
+ # In keeping with the Chinchilla model: https://arxiv.org/abs/2203.15556.
85
+ # A learning rate schedule is a function that decays the learning rate from
86
+ # step zero to max_steps. The desired maximum number of steps must be set at
87
+ # the start of training.
88
+ # ----------------------------------------------------------------------------
89
+
90
+
91
+ @gin.configurable
92
+ def lr_constant(step: jnp.ndarray, max_steps: int,
93
+ learning_rate: float = 0.01) -> jnp.ndarray:
94
+ """Returns constant_lr on each step.
95
+
96
+ Args:
97
+ step: The current training step (unused).
98
+ max_steps: Unused.
99
+ learning_rate: The constant learning rate to use.
100
+
101
+ Returns:
102
+ The learning rate for the current step.
103
+ """
104
+ del step
105
+ del max_steps
106
+ return jnp.asarray(learning_rate, dtype=jnp.float32)
107
+
108
+
109
+ @gin.configurable
110
+ def lr_rsqrt_decay_std(step: jnp.ndarray, max_steps: int,
111
+ max_lr: Optional[float] = None) -> jnp.ndarray:
112
+ """Inverse square root decay function: LR = 1/sqrt(step).
113
+
114
+ Provided for compatibility. No min_lr, and it ignores max_steps.
115
+ Should be used with warmup: pass step = max(step, warmup_steps).
116
+ Maximum learning rate is 1/sqrt(warmup_steps) ~= 0.03 for 1000 warmup steps.
117
+
118
+ Args:
119
+ step: The current training step.
120
+ max_steps: Unused.
121
+ max_lr: If specified, learning rate will be clipped to the maximum value.
122
+
123
+ Returns:
124
+ The learning rate for the current step.
125
+ """
126
+ # This function implements standard rsqrt decay as used in the memorizing
127
+ # and block-recurrent transformer papers, (https://arxiv.org/abs/2203.08913,
128
+ # https://arxiv.org/abs/2203.07852) which does not decay to a specified
129
+ # minimum learning rate over max_steps.
130
+ del max_steps
131
+
132
+ # Avoid divide by zero; force at least 100 warmup steps and a max LR of 0.1.
133
+ step = jnp.maximum(step, 100.0)
134
+ lrate = 1.0 / jnp.sqrt(step)
135
+ if max_lr is not None:
136
+ lrate = jnp.minimum(lrate, max_lr) # Clip to max_lr
137
+ return lrate
138
+
139
+
140
+ @gin.configurable
141
+ def lr_rsqrt_decay(step: jnp.ndarray, max_steps: int,
142
+ max_lr: float = 0.05,
143
+ min_lr: float = 0.001) -> jnp.ndarray:
144
+ """Inverse sqrt decay from max_lr to min_lr over max_steps.
145
+
146
+ This function implements rsqrt decay, but adjusts the decay rate so that
147
+ min_lr is reached at max_steps.
148
+
149
+ Note: with a warmup period, the maximum LR produced by the schedule is:
150
+ min_lr / sqrt(warmup_steps / max_steps), which may be less than max_lr.
151
+ e.g. if min_lr is 0.001, then the maximum LR will be 0.01 for
152
+ warmup_steps=1000 and max_steps=100_000.
153
+
154
+ Args:
155
+ step: The current training step.
156
+ max_steps: The step value at the end of training.
157
+ max_lr: LR will be clipped to max at the start of training.
158
+ min_lr: LR to output at max_steps.
159
+
160
+ Returns:
161
+ The learning rate for the current step.
162
+ """
163
+ assert max_lr > min_lr
164
+
165
+ # Avoid divide by zero; force at least 100 warmup steps and a max LR of 0.1.
166
+ step = jnp.maximum(step, 100.0)
167
+ lrate = min_lr / jnp.sqrt(step / float(max_steps))
168
+ lrate = jnp.minimum(lrate, max_lr) # Clip to max_lr
169
+ return lrate
170
+
171
+
172
+ @gin.configurable
173
+ def lr_exponential_decay(step: jnp.ndarray, max_steps: int,
174
+ max_lr: float = 0.01,
175
+ min_lr: float = 0.001) -> jnp.ndarray:
176
+ """Exponential decay from max_lr to min_lr over max_steps.
177
+
178
+ Continues to decay at the same rate after max_steps.
179
+
180
+ Args:
181
+ step: The current training step.
182
+ max_steps: The step value at the end of training.
183
+ max_lr: LR to output at step 0.
184
+ min_lr: LR to output at max_steps.
185
+
186
+ Returns:
187
+ The learning rate for the current step.
188
+ """
189
+ assert max_lr > min_lr
190
+
191
+ lrate = max_lr * jnp.power(min_lr / max_lr, step / float(max_steps))
192
+ return lrate
193
+
194
+
195
+ @gin.configurable
196
+ def lr_linear_decay(step: jnp.ndarray, max_steps: int,
197
+ max_lr: float = 0.01,
198
+ min_lr: float = 0.001,
199
+ decay_after: bool = True) -> jnp.ndarray:
200
+ """Linear decay from max_lr to min_lr over max_steps.
201
+
202
+ If decay_after, then LR will continue to decay exponentially by a factor
203
+ of 2 every max_steps after the linear decay.
204
+
205
+ Args:
206
+ step: The current training step.
207
+ max_steps: The step value at the end of training.
208
+ max_lr: LR to output at step 0.
209
+ min_lr: LR to output at max_steps.
210
+ decay_after: If true, do exponential decay after the linear decay,
211
+ by a factor of 2 every max_steps.
212
+
213
+ Returns:
214
+ The learning rate for the current step.
215
+ """
216
+ assert max_lr > min_lr
217
+
218
+ lrate = min_lr + (max_lr - min_lr) * ((max_steps - step) / max_steps)
219
+ lrate = jnp.maximum(lrate, min_lr)
220
+
221
+ if decay_after:
222
+ exp_lrate = lr_exponential_decay(step, max_steps,
223
+ max_lr=2*min_lr, min_lr=min_lr)
224
+ lrate = jnp.where(step < max_steps, lrate, exp_lrate)
225
+
226
+ return lrate
227
+
228
+
229
+ @gin.configurable
230
+ def lr_cosine_decay(step: jnp.ndarray, max_steps: int,
231
+ max_lr: float = 0.01,
232
+ min_lr: float = 0.001,
233
+ decay_after: bool = True,
234
+ spike_steps: int = 0,
235
+ spike_lr: float = 0.0) -> jnp.ndarray:
236
+ """Cosine decay function from max_lr to min_lr over max_steps.
237
+
238
+ Used in the Chinchilla model: https://arxiv.org/abs/2203.15556.
239
+
240
+ If decay_after, then LR will continue to decay exponentially by a factor
241
+ of 2 every max_steps after the original ramp.
242
+
243
+ If spike_steps > 0, there will be an initial linear decay from spike_lr
244
+ down to max_lr over the first spike_steps steps. This implements a brief
245
+ period of higher LR early in training, similar to the curve for rsqrt_decay.
246
+ The model can generally tolerate a high LR early in training, and make a
247
+ lot of progress very quickly. Try spike_steps=10_000, spike_lr = 0.04.
248
+
249
+ Args:
250
+ step: The current training step.
251
+ max_steps: The number of training steps to decay over.
252
+ max_lr: The maximum learning rate at the start of training.
253
+ min_lr: The minimum learning rate at the end of training.
254
+ decay_after: If true, do exponential decay after the cosine day,
255
+ by a factor of 2 every max_steps.
256
+ spike_steps: The number of steps for the initial spike.
257
+ spike_lr: The maximum LR during the initial spike.
258
+
259
+ Returns:
260
+ The learning rate for the current step.
261
+ """
262
+ assert max_lr > min_lr
263
+
264
+ pi = float(np.pi)
265
+ step_ramp = jnp.minimum(step, max_steps) / max_steps # ramp: 0 to 1.0.
266
+
267
+ lrate = (1 + jnp.cos(pi * step_ramp)) * 0.5 # ranges from 1 to 0.
268
+ lrate = min_lr + lrate * (max_lr - min_lr)
269
+
270
+ if spike_steps > 0 and spike_lr > 0.0:
271
+ assert spike_lr > max_lr
272
+ spike_lrate = spike_lr * ((spike_steps - step) / spike_steps)
273
+ lrate = jnp.maximum(lrate, spike_lrate)
274
+
275
+ if decay_after:
276
+ exp_lrate = lr_exponential_decay(step, max_steps,
277
+ max_lr=2*min_lr, min_lr=min_lr)
278
+ lrate = jnp.where(step < max_steps, lrate, exp_lrate)
279
+
280
+ return lrate
281
+
aglib/meliad/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py>=1.0.0
2
+ clu>=0.0.7
3
+ gin-config>=0.5.0
4
+ flax>=0.5.0
5
+ jax>=0.3.13
6
+ optax>=0.1.2
7
+ numpy>=1.22.4
8
+ sentencepiece>=0.1.96
9
+ seqio>=0.0.7
10
+ tensorflow>=2.9.1
11
+ tensorflow-datasets>=4.5.2
aglib/meliad/training_loop.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Generic JAX training loop for experiments."""
16
+
17
+ import functools
18
+ import os
19
+ from typing import (Any, Callable, Dict, Optional, Sequence, Tuple)
20
+
21
+ from absl import logging
22
+ from clu import metric_writers
23
+ import flax
24
+ from flax import jax_utils
25
+ from flax import linen as nn
26
+ from flax import struct
27
+ from flax.training import checkpoints
28
+ import gin
29
+ import jax
30
+ import jax.numpy as jnp
31
+ import metrics_summary
32
+ import optimizer_config as opt_config
33
+ import training_task
34
+ import numpy as np
35
+ import tensorflow.compat.v2 as tf
36
+
37
+
38
+ PRNGKeys = training_task.PRNGKeys
39
+ TrainState = training_task.TrainState
40
+ TrainingTask = training_task.TrainingTask
41
+ StepFunction = training_task.StepFunction
42
+ Metrics = training_task.Metrics
43
+ MetricWriter = metric_writers.MetricWriter
44
+ MetricsSummary = metrics_summary.MetricsSummary
45
+
46
+
47
+ gfile = tf.io.gfile
48
+ unfreeze = flax.core.unfreeze
49
+ flatten_dict = flax.traverse_util.flatten_dict
50
+ should_run = training_task.should_run
51
+
52
+
53
+ # TODO(cstaats): Use a Protocol to specify that it must be possible to call
54
+ # the function with parameters (step: int, mode: str). This won't be feasible
55
+ # until we start using Python 3.8 or later.
56
+ StepModeCallable = Callable[..., None]
57
+
58
+
59
+ # This variable should *only* be set from register_interstep_callbacks.
60
+ _interstep_callbacks: Optional[Tuple[StepModeCallable, ...]] = None
61
+
62
+
63
+ @gin.configurable
64
+ def register_interstep_callbacks(**kwargs: StepModeCallable) -> None:
65
+ """Populates _interstep_callbacks from gin.
66
+
67
+ This function should be called exactly ONCE and that call should happen AFTER
68
+ flag initialization (and more specifically, after gin parsing). And the caller
69
+ should NOT specify any arguments.
70
+
71
+ In gin configurations, a callback can be specified with an arbitrary name
72
+ like so:
73
+
74
+ register_interstep_callbacks.my_callback_name = @my_callback_function
75
+
76
+ Multiple callbacks can be registered without overriding each other as long as
77
+ they all have different names. Conversely, if you *want* to override a
78
+ callback, you need to give that callback the same name.
79
+
80
+ Args:
81
+ **kwargs: Specified by gin. Each argument should be a function (callable)
82
+ that can be called as my_function(step, mode), where step is an int and
83
+ mode is a str.
84
+
85
+ Raises:
86
+ ValueError: Raised on the second (and any subsequent) function call.
87
+ """
88
+ global _interstep_callbacks
89
+ logging.info("registering functions: %s", kwargs.keys())
90
+ if _interstep_callbacks is not None:
91
+ raise ValueError("register_interstep_callbacks may only be called once.")
92
+ _interstep_callbacks = tuple(kwargs.values())
93
+
94
+
95
+ def clear_interstep_callbacks():
96
+ """Clear all registered callbacks, so that new ones can be registered."""
97
+ global _interstep_callbacks
98
+ _interstep_callbacks = None
99
+
100
+
101
+ def run_interstep_callbacks(mode: str, step: int, sub_step: int = 0):
102
+ """Run the registered callbacks.
103
+
104
+ Args:
105
+ mode: mode of the task to execute callbacks for.
106
+ step: training step number.
107
+ sub_step: For tasks that execute multiple iterations within a step.
108
+ E.g. a test cycle that runs multiple testing steps.
109
+ """
110
+ for func in _interstep_callbacks:
111
+ func(sub_step or step, mode)
112
+
113
+
114
+ @gin.configurable
115
+ @struct.dataclass
116
+ class Trainer:
117
+ """Implements a JAX training loop."""
118
+
119
+ # Returns a Flax module for the model.
120
+ # Takes a single argument mode, which can be "test", "train", or "generate".
121
+ model_definition: Any = gin.REQUIRED
122
+
123
+ # Iterator over trainining data.
124
+ get_training_dataset_iterator: Callable[[], Any] = gin.REQUIRED
125
+
126
+ # Iterator over test data.
127
+ get_test_dataset_iterator: Optional[Callable[[], Any]] = None
128
+
129
+ workdir: str = "" # Working directory for checkpoints.
130
+ load_dir: str = "" # Optional directory to load model.
131
+ num_steps: int = 100000 # Number of steps to train.
132
+ status_every_steps: int = 10 # Log step number every N steps.
133
+ log_every_steps: int = 100 # Log scalar data every N steps.
134
+ test_every_steps: int = 10 # Test model every N steps.
135
+ num_test_steps: int = 1 # Number of iterations to test.
136
+ generate_every_steps: int = 1000 # Generate examples every N steps.
137
+ print_input_every_steps: int = 1000 # Print example data every N steps.
138
+
139
+ save_checkpoints: bool = True # Save training checkpoints
140
+ checkpoint_every_steps: int = 5000 # Save checkpoints every N steps.
141
+ restore_checkpoints: bool = True # Restore from previous checkpoint.
142
+ restore_state_variables: bool = True # Restore TrainState.state from chkpt.
143
+
144
+ # Record metrics for "train", "test", etc. in separate directories.
145
+ # Otherwise they will be saved with separate prefixes.
146
+ use_separate_metric_directories: bool = True
147
+
148
+ # Optimizer options.
149
+ optimizer_factory: opt_config.OptimizerConfig = gin.REQUIRED
150
+ learning_rate_schedule: Callable[[jnp.ndarray, int], jnp.ndarray] = (
151
+ opt_config.lr_cosine_decay)
152
+
153
+ # Maximum steps for the LR schedule. Zero means use num_steps.
154
+ max_scheduled_steps: int = 0
155
+ warmup_steps: int = 1000 # Number of warmup steps.
156
+ learning_rate_multiplier: float = 1.0 # Used to scale the learning rate.
157
+
158
+ random_seed: int = 42 # Initial random seed.
159
+
160
+ # Names of random number generators used by the model.
161
+ rng_key_names: Optional[Sequence[str]] = ("dropout",)
162
+
163
+ # Debug options.
164
+ replicate_mode: bool = True # pmap over multiple replicas.
165
+ trace_debug_mode: bool = False # Run in eager mode to trace results.
166
+ print_variables: bool = False # Dump parameters/variables to stdout.
167
+
168
+ # Function to compute additional summary information.
169
+ # Takes a MetricsSummary object and a mode string (e.g. "test") as arguments,
170
+ # returns a MetricsSummary object.
171
+ process_summaries_function: Optional[Callable[[Any, str], Any]] = None
172
+
173
+ # Function to pretty print the input for each training step.
174
+ pretty_print_input_function: Optional[Callable[[Any], Any]] = None
175
+
176
+ # Classes to use for summarizing metrics.
177
+ metrics_summary_factory: Any = metrics_summary.MetricsSummary
178
+ extra_summaries_fn: training_task.ExtraSummariesFunction = (
179
+ lambda mode, step: dict())
180
+
181
+ post_save_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None
182
+ post_load_checkpoint_fn: Callable[[str, int], None] = lambda mode, step: None
183
+
184
+ def learning_rate_schedule_fn(self, step):
185
+ """Returns the learning rate for the given step."""
186
+
187
+ # There are four components to the learning rate.
188
+ #
189
+ # The base_lrate is defined by the optimizer, and different optimizers have
190
+ # different relative rates, e.g. Adafactor requires a higher LR than Adam.
191
+ # By default, the base_lrate is 1.0 for Adafactor.
192
+ #
193
+ # The base_lrate is then multiplied by the learning rate decay schedule,
194
+ # which typically starts at a maximum value and decays over time.
195
+ # Each schedule can be individually configured, e.g. from 0.01 to 0.001.
196
+ # The max_scheduled_steps parameter controls the decay rate of the schedule.
197
+ #
198
+ # Finally, the LR is scaled by the learning_rate_multiplier, which provides
199
+ # an easy way to scale the LR for hyperparameter tuning in a way that is
200
+ # independent of the choice of schedule or optimizer. The default is 1.0.
201
+ #
202
+ # During the warmp period, the learning rate ramps up linearly from zero.
203
+
204
+ step = jnp.asarray(step, dtype=jnp.float32)
205
+ if self.max_scheduled_steps == 0:
206
+ max_steps = self.num_steps
207
+ else:
208
+ max_steps = self.max_scheduled_steps
209
+
210
+ base_lrate = float(self.optimizer_factory.learning_rate)
211
+ lr_multiplier = float(self.learning_rate_multiplier)
212
+
213
+ # Linear increase in learning rate up to warmup_steps.
214
+ warmup_steps = float(self.warmup_steps)
215
+ lr_warmup_ramp = jnp.minimum(step, warmup_steps) / warmup_steps
216
+
217
+ # Hold step at a constant value during the warmup period.
218
+ # Required for some schedules, like rsqrt_decay.
219
+ step = jnp.maximum(step, warmup_steps)
220
+
221
+ # Get the scheduled learning rate.
222
+ lrate = self.learning_rate_schedule(step, max_steps)
223
+
224
+ # Multiply lrate by the base, warmup and multiplier factors.
225
+ lrate = lrate * base_lrate * lr_warmup_ramp * lr_multiplier
226
+ return jnp.asarray(lrate, dtype=jnp.float32)
227
+
228
+ def _init_rngs(self, rngs: PRNGKeys, step: int) -> PRNGKeys:
229
+ # Get a new random number generator for each step
230
+ rngs = jax.random.fold_in(rngs, step)
231
+ rngs = jax.random.split(rngs, len(self.rng_key_names))
232
+ rngs = {key: rngs[i] for i, key in enumerate(self.rng_key_names)}
233
+ return rngs
234
+
235
+ def train_step(self, model: nn.Module, tstate: TrainState, x: Any,
236
+ rngs: PRNGKeys) -> Tuple[TrainState, Metrics]:
237
+ """Perform a training step, pmapped over multiple devices.
238
+
239
+ Args:
240
+ model: The model to use for the step function.
241
+ tstate: Values for state variables, and the optimizer.
242
+ x: A batch of inputs to train on.
243
+ rngs: PRNGKey (possibly replicated).
244
+
245
+ Returns:
246
+ Tuple of (new_tstate, metrics: dictionary of scalar values)
247
+ """
248
+
249
+ mutable_keys = [k for (k, _) in tstate.state.items()]
250
+ step = tstate.optimizer.state.step
251
+ rngs = self._init_rngs(rngs, step)
252
+
253
+ # Refactor the model as a loss function from trainable params to loss, so
254
+ # that we can differentiate with jax and get {d}loss/{d}params.
255
+ # Inputs and non-trainable params are bound within the closure.
256
+ # model:: x, { state_params } -> (loss, metrics), { new_state_params }
257
+ # loss_fn:: params -> (loss, (metrics, new_state))
258
+ def loss_fn(params):
259
+ """Loss function."""
260
+ (loss, mets), nstate = model.apply({"params": params, **tstate.state},
261
+ x,
262
+ rngs=rngs,
263
+ mutable=mutable_keys)
264
+ return loss, (mets, nstate)
265
+
266
+ # grad_fn:: params -> ((loss, (aux, nstate)), param_gradients)
267
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
268
+
269
+ # Run forward and backward pass.
270
+ (loss, (metrics, new_state)), param_grads = grad_fn(tstate.optimizer.target)
271
+ del loss # loss is only recorded if it is part of the metrics
272
+ if self.replicate_mode:
273
+ param_grads = jax.lax.pmean(param_grads, axis_name="batch")
274
+ lrate = self.learning_rate_schedule_fn(step)
275
+ new_optimizer = tstate.optimizer.apply_gradient(
276
+ param_grads, learning_rate=lrate)
277
+
278
+ # Metrics are summary values that will be logged.
279
+ if self.replicate_mode:
280
+ # Merge metrics (take mean/sum etc.) over replicas on-device.
281
+ summary_class = self.metrics_summary_factory
282
+ metrics = summary_class.merge_replicated_metrics(
283
+ metrics, model.metrics_summary_operations(aggregate_over="devices"))
284
+
285
+ metrics["learning_rate"] = lrate
286
+ return (TrainState(new_optimizer, new_state), metrics)
287
+
288
+ def other_step(self, model: nn.Module, tstate: TrainState, x: Any,
289
+ rngs: PRNGKeys) -> Tuple[TrainState, Metrics]:
290
+ """Perform a test or generate step, pmapped over multiple devices.
291
+
292
+ Args:
293
+ model: The model to use for the step function.
294
+ tstate: Values for state variables, and the optimizer.
295
+ x: A batch of inputs to train on.
296
+ rngs: PRNGKey (possibly replicated).
297
+
298
+ Returns:
299
+ Tuple of (new_tstate, metrics: dictionary of scalar values)
300
+ """
301
+
302
+ mutable_keys = [k for (k, _) in tstate.state.items()]
303
+ step = tstate.optimizer.state.step
304
+ rngs = self._init_rngs(rngs, step)
305
+
306
+ params = tstate.optimizer.target
307
+ (loss, metrics), new_state = model.apply({"params": params, **tstate.state},
308
+ x,
309
+ rngs=rngs,
310
+ mutable=mutable_keys)
311
+ del loss # loss is only recorded if it is part of the metrics
312
+
313
+ # Metrics are summary values that will be logged.
314
+ if self.replicate_mode:
315
+ # Merge metrics (take mean/sum etc.) over replicas on-device.
316
+ summary_class = self.metrics_summary_factory
317
+ metrics = summary_class.merge_replicated_metrics(
318
+ metrics, model.metrics_summary_operations(aggregate_over="devices"))
319
+
320
+ return (TrainState(tstate.optimizer, new_state), metrics)
321
+
322
+ def initialize_model(self) -> Tuple[TrainState, int, nn.Module, PRNGKeys]:
323
+ """Initialize the model and/or load it from a checkpoint.
324
+
325
+ Returns:
326
+ (tstate: TrainState, -- The parameters and state for the the model.
327
+ start_step: int, -- The step number, when restoring from checkpoint.
328
+ imodel: nn.Module, -- A model object (created with mode "init").
329
+ rngs: PRNGkeys) -- Initial random numbers.
330
+ """
331
+
332
+ # Set up random number generators.
333
+ # ---------------------------------
334
+ logging.info("==== Training loop: initializing model ====")
335
+ logging.info("Process %d of %d", jax.process_index(), jax.process_count())
336
+ logging.info("Local device count = %d", jax.local_device_count())
337
+ logging.info("Number of replicas = %d",
338
+ jax.process_count() * jax.local_device_count())
339
+ logging.info("Using random number seed %d", self.random_seed)
340
+
341
+ prng = jax.random.PRNGKey(self.random_seed)
342
+ prng, init_rng = jax.random.split(prng)
343
+
344
+ # Grab rngs, which provide different random numbers for each replica.
345
+ if self.replicate_mode:
346
+ prngs = jax.random.split(prng, jax.local_device_count())
347
+ else:
348
+ prngs = prng
349
+ del prng
350
+
351
+ # Create a dictionary of prng keys for initialization.
352
+ rng_key_names_init = list(self.rng_key_names) + ["params"]
353
+ init_rngs = jax.random.split(init_rng, len(rng_key_names_init))
354
+ init_rngs = {key: init_rngs[i] for i, key in enumerate(rng_key_names_init)}
355
+ del init_rng
356
+
357
+ # Build Model
358
+ # -------------------------------------------------------------------------
359
+ logging.info("Initializing the model.")
360
+
361
+ # Create a model, which will be used to initialize trainable parameters.
362
+ imodel = self.model_definition(mode="init")
363
+
364
+ # The init function will lazily initialize the model, given a fake input.
365
+ # It returns initialized variables, without doing a fwd pass.
366
+ model_init_fn = jax.jit(imodel.init)
367
+ variables = model_init_fn(init_rngs, imodel.get_fake_input())
368
+
369
+ # Split variables into trainable and non-trainable sets.
370
+ mstate, params = variables.pop("params")
371
+ del variables # Delete to avoid wasting resources.
372
+
373
+ # Create an optimizer for params.
374
+ optimizer_def = self.optimizer_factory.create_optimizer_def()
375
+ optimizer = optimizer_def.create(params)
376
+
377
+ # tstate holds the full training state of the model.
378
+ tstate = TrainState(optimizer, mstate)
379
+ if self.print_variables:
380
+ logging.info("params = %s", tstate.optimizer.target)
381
+ logging.info("state = %s", tstate.state)
382
+
383
+ # Load a pre-trained model or restore it from checkpoint.
384
+ if self.workdir or self.load_dir:
385
+ restore_checkpoints = self.restore_checkpoints
386
+ else:
387
+ restore_checkpoints = False
388
+
389
+ start_step = 0
390
+ if restore_checkpoints:
391
+ tstate = self.restore_checkpoint(tstate)
392
+ start_step = int(tstate.optimizer.state.step)
393
+
394
+ # Log info on trainable parameters (before replicating them).
395
+ self._write_parameter_info(tstate)
396
+ # raise ValueError("That's all folks!")
397
+
398
+ # Replicate the training state across local devices.
399
+ if self.replicate_mode:
400
+ tstate = jax_utils.replicate(tstate)
401
+
402
+ return (tstate, start_step, imodel, prngs)
403
+
404
+ def restore_checkpoint(self, train_state: TrainState) -> TrainState:
405
+ """Load a pre-trained model or restore it from a checkpoint."""
406
+
407
+ # Figure out if we have an existing checkpoint.
408
+ if not self.workdir:
409
+ logging.info("No working directory specified.")
410
+ existing_checkpoint = False
411
+ elif not gfile.exists(self.workdir):
412
+ logging.info("No existing checkpoint directory %s", self.workdir)
413
+ existing_checkpoint = False
414
+ elif not gfile.isdir(self.workdir):
415
+ raise ValueError(f"workdir {self.workdir} must be a directory.")
416
+ else:
417
+ ckpath = checkpoints.latest_checkpoint(self.workdir, "checkpoint_")
418
+ if ckpath:
419
+ logging.info("Found existing checkpoint in %s", self.workdir)
420
+ existing_checkpoint = True
421
+ else:
422
+ logging.info("No existing checkpoint in %s", self.workdir)
423
+ existing_checkpoint = False
424
+
425
+ # If any checkpoints exist in workdir, then use those first.
426
+ # This will ensure that the task will restore properly if it's preempted.
427
+ if existing_checkpoint:
428
+ logging.info("Restoring model from last checkpoint %s:", self.workdir)
429
+ load_dir = self.workdir
430
+ elif self.load_dir:
431
+ logging.info("Loading pre-trained model from %s:", self.load_dir)
432
+ load_dir = self.load_dir
433
+ else:
434
+ logging.warning("Unable to load model.")
435
+ return train_state
436
+ loaded_train_state = checkpoints.restore_checkpoint(load_dir, train_state)
437
+ step = int(loaded_train_state.optimizer.state.step)
438
+ self.post_load_checkpoint_fn(load_dir, step)
439
+
440
+ if self.restore_state_variables:
441
+ # Restore complete state.
442
+ logging.info("Restoring all variables and state.")
443
+ train_state = loaded_train_state
444
+ del loaded_train_state
445
+ else:
446
+ # Restore trainable variables, but not other state.
447
+ logging.info("Only restoring trainable parameters.")
448
+ train_state = TrainState(loaded_train_state.optimizer, train_state.state)
449
+ del loaded_train_state
450
+
451
+ return train_state
452
+
453
+ def save_checkpoint(self, tstate: TrainState, step: int,
454
+ param_summary: Optional[MetricsSummary]):
455
+ """Save a checkpoint with the model state.
456
+
457
+ Args:
458
+ tstate: The training state.
459
+ step: The current step number.
460
+ param_summary: Optional metrics summary to write parameter statistics.
461
+ """
462
+
463
+ logging.info("Saving checkpoint in directory %s", self.workdir)
464
+ if self.replicate_mode:
465
+ save_state = jax_utils.unreplicate(tstate)
466
+ else:
467
+ save_state = tstate
468
+ checkpoints.save_checkpoint(self.workdir, save_state, step)
469
+
470
+ # While we're at it, record distributions of trainable parameters.
471
+ if param_summary is not None:
472
+ logging.info("Recording parameter distributions.")
473
+ params_dict = jax.device_get(
474
+ _flatten_dict_string_keys(save_state.optimizer.target))
475
+ param_distribs = self._compute_parameter_distributions(params_dict)
476
+ param_summary.add(param_distribs)
477
+
478
+ def create_training_task(self, mode: str, imodel: nn.Module, prngs: PRNGKeys,
479
+ writers: Dict[str, MetricWriter]) -> TrainingTask:
480
+ """Create a new TrainingTask for the given mode.
481
+
482
+ Args:
483
+ mode: The mode for the task, e.g. "train", "test", "generate".
484
+ imodel: The model object from initialize_model.
485
+ prngs: The PRNGKeys from initialize_model.
486
+ writers: A dictionary of summary writers.
487
+
488
+ Returns:
489
+ A TrainingTask object.
490
+ """
491
+
492
+ logging.info("Training loop: creating task for mode %s", mode)
493
+ if self.use_separate_metric_directories:
494
+ prefix = ""
495
+ else:
496
+ prefix = mode
497
+
498
+ if mode == "train":
499
+ ds = self.get_training_dataset_iterator
500
+ elif mode == "test":
501
+ ds = self.get_test_dataset_iterator
502
+ else:
503
+ ds = None
504
+
505
+ # We summarize metrics over multiple training steps.
506
+ # These types control how the summary is computed.
507
+ metric_summary_ops = {
508
+ "step_time": "mean",
509
+ "learning_rate": "last",
510
+ **imodel.metrics_summary_operations(aggregate_over="steps")
511
+ }
512
+ summary = self.metrics_summary_factory(metric_summary_ops)
513
+ extra_summary = self.metrics_summary_factory({})
514
+ summary_writer = self._get_summary_writer(mode, writers)
515
+
516
+ return TrainingTask(
517
+ mode=mode,
518
+ dataset=ds,
519
+ step_function=self._compile_step_function(mode),
520
+ prng_keys=prngs,
521
+ summary=summary,
522
+ extra_summary=extra_summary,
523
+ summary_writer=summary_writer,
524
+ summary_prefix=prefix,
525
+ # --- options ---
526
+ replicate_mode=self.replicate_mode,
527
+ print_input_every_steps=self.print_input_every_steps,
528
+ pretty_print_input_function=self.pretty_print_input_function,
529
+ process_summaries_function=self.process_summaries_function,
530
+ extra_summaries_function=self.extra_summaries_fn)
531
+
532
+ def train(self):
533
+ """Runs the training and evaluation loop."""
534
+
535
+ # The master process saves checkpoints and summaries to disk.
536
+ is_master_process = jax.process_index() == 0
537
+ if self.workdir:
538
+ save_checkpoints = self.save_checkpoints
539
+ else:
540
+ save_checkpoints = False
541
+
542
+ # --- Create and initialize the model. ---
543
+ (tstate, start_step, imodel, prngs) = self.initialize_model()
544
+
545
+ # Log experiment hyper-parameters.
546
+ writers = {}
547
+ train_writer = self._get_summary_writer("train", writers)
548
+ if start_step == 0:
549
+ self._write_config(train_writer)
550
+
551
+ # Additional summary objects.
552
+ param_summary = self.metrics_summary_factory({}) # Parameter statistics.
553
+
554
+ # --- Create task objects for test, train, and generate. ---
555
+ tasks = {}
556
+ train_task = self.create_training_task("train", imodel, prngs, writers)
557
+ tasks["train"] = train_task
558
+
559
+ if (self.get_test_dataset_iterator is not None and
560
+ self.test_every_steps != 0):
561
+ test_task = self.create_training_task("test", imodel, prngs, writers)
562
+ tasks["test"] = test_task
563
+ if self.generate_every_steps != 0:
564
+ gen_task = self.create_training_task("generate", imodel, prngs,
565
+ writers)
566
+ tasks["generate"] = gen_task
567
+
568
+ # Register any additional actions.
569
+ register_interstep_callbacks()
570
+
571
+ # Main Training Loop
572
+ # --------------------------------------------------------------------------
573
+ logging.info("==== Training loop: starting main loop ====")
574
+ with metric_writers.ensure_flushes(*writers.values()):
575
+ for step in range(start_step, self.num_steps):
576
+ # Log status every so often to monitor progress.
577
+ if should_run(step, self.status_every_steps):
578
+ logging.info("Step: %d", step)
579
+
580
+ # Train.
581
+ train_x = train_task.get_next_input()
582
+ (tstate, _) = train_task.run_step(tstate, train_x, step)
583
+ run_interstep_callbacks("train", step)
584
+ del train_x
585
+
586
+ # Test.
587
+ if should_run(step, self.test_every_steps):
588
+ if self.num_test_steps > 1:
589
+ logging.info("Test cycle: %d iterations.", self.num_test_steps)
590
+ for sub_step in range(0, self.num_test_steps):
591
+ test_x = test_task.get_next_input()
592
+
593
+ # TODO(delesley): This is an ugly hack to run generate steps.
594
+ # Run a generate step using test data.
595
+ # Generate is run just *before* the last test iteration.
596
+ if ((sub_step == self.num_test_steps - 1) and
597
+ should_run(step, self.generate_every_steps)):
598
+ logging.info("Generate cycle.")
599
+ (tstate, _) = gen_task.run_step(tstate, test_x, step)
600
+ run_interstep_callbacks("generate", step)
601
+
602
+ (tstate, _) = test_task.run_step(tstate, test_x, step,
603
+ sub_step=sub_step)
604
+ run_interstep_callbacks("test", step, sub_step)
605
+ del test_x
606
+
607
+ # --- Save checkpoints on the master host. ---
608
+ is_last_step = (step == self.num_steps - 1)
609
+ checkpoint_current_step = (
610
+ save_checkpoints and
611
+ (should_run(step, self.checkpoint_every_steps) or is_last_step))
612
+ if checkpoint_current_step:
613
+ if is_master_process:
614
+ self.save_checkpoint(tstate, step, param_summary)
615
+ self.post_save_checkpoint_fn(self.workdir, step)
616
+
617
+ # --- Flush summaries to disk. ---
618
+ if should_run(step, self.log_every_steps):
619
+ for tsk in tasks.values():
620
+ tsk.flush(step)
621
+ param_summary.write(train_writer, step, prefix="params")
622
+
623
+ logging.info("Training Finished.")
624
+ if self.replicate_mode:
625
+ tstate = jax_utils.unreplicate(tstate)
626
+ if self.print_variables:
627
+ logging.info("params = %s", tstate.optimizer.target)
628
+ logging.info("state = %s", tstate.state)
629
+
630
+ def _compile_step_function(self, mode: str) -> StepFunction:
631
+ """Compile a step function (training or test)."""
632
+
633
+ # Create a model object, and a step function that is a closure over the
634
+ # object. Flax modules are supposed to be "stateless", in that all state
635
+ # is contained the TrainState object that is passed as an input parameter.
636
+ # However, creating the model object may involve allocating expensive
637
+ # data structures, or launching processes, and should only be done once.
638
+ model = self.model_definition(mode=mode)
639
+ if mode == "train":
640
+ step_fn = functools.partial(self.train_step, model)
641
+ else:
642
+ step_fn = functools.partial(self.other_step, model)
643
+
644
+ if self.replicate_mode:
645
+ assert not self.trace_debug_mode
646
+ logging.info("Compiling mode %s with pmap.", mode)
647
+ p_fn = jax.pmap(step_fn, donate_argnums=(0,), axis_name="batch")
648
+ elif self.trace_debug_mode:
649
+ logging.info("Compiling mode %s with trace_debug.", mode)
650
+ p_fn = step_fn
651
+ else:
652
+ logging.info("Compiling mode %s with jit.", mode)
653
+ p_fn = jax.jit(step_fn, donate_argnums=(0,))
654
+ return p_fn
655
+
656
+ def _get_summary_writer(self, mode: str,
657
+ writers: Dict[str, MetricWriter]) -> MetricWriter:
658
+ """Create a summary writer for the given mode.
659
+
660
+ Args:
661
+ mode: the mode for the summaries, e.g. "test", "train"
662
+ writers: a dictionary which caches previously-created writers.
663
+
664
+ Returns:
665
+ A writer for the given mode.
666
+ """
667
+
668
+ if self.use_separate_metric_directories:
669
+ # Create a separate writer & directory for each mode.
670
+ w_mode = mode
671
+ summary_dir = os.path.join(self.workdir, mode)
672
+ else:
673
+ # Create a single default writer for all modes.
674
+ w_mode = "train"
675
+ summary_dir = self.workdir
676
+
677
+ if w_mode in writers:
678
+ # Return previously created and cached writer.
679
+ logging.info("Returning cached summary writer (%s) for mode %s",
680
+ w_mode, mode)
681
+ return writers[w_mode]
682
+
683
+ if not self.workdir:
684
+ # No working directory, so log only.
685
+ logging.info("Creating logging writer (%s) for mode %s", w_mode, mode)
686
+ writer = metric_writers.LoggingWriter()
687
+ else:
688
+ # Create a new writer for workdir.
689
+ # Only the master will actually write summaries to workdir.
690
+ logging.info("Creating summary writer (%s) for mode %s in directory %s",
691
+ w_mode, mode, summary_dir)
692
+ is_master = jax.process_index() == 0
693
+ gfile.makedirs(summary_dir)
694
+ writer = metric_writers.create_default_writer(summary_dir,
695
+ just_logging=not is_master)
696
+ writers[w_mode] = writer
697
+ return writer
698
+
699
+ def _write_config(self, writer):
700
+ """Write the configuration file to the working directory."""
701
+
702
+ is_master = jax.process_index() == 0
703
+ config_str = gin.operative_config_str()
704
+ logging.info("Gin config: \n%s", config_str)
705
+
706
+ # Write configuration to workdir.
707
+ if is_master and self.workdir:
708
+ config_file_name = os.path.join(self.workdir, "config.gin")
709
+ with gfile.GFile(config_file_name, "w") as f:
710
+ f.write(config_str)
711
+
712
+ # Write config string text to tensorboard.
713
+ writer.write_texts(0, {"config": gin.markdown(config_str)})
714
+
715
+ def _write_parameter_info(self, tstate: TrainState):
716
+ """Write information on state and trainable parameters to the log."""
717
+
718
+ # Write information on parameters to log file.
719
+ params_dict = _flatten_dict_string_keys(tstate.optimizer.target)
720
+ total_nparams = 0
721
+ for (k, v) in params_dict.items():
722
+ nparams = np.prod(v.shape)
723
+ total_nparams += nparams
724
+ logging.info("parameter: %s, shape %s, size %d", k, v.shape, nparams)
725
+ logging.info("Total parameters: %d", total_nparams)
726
+
727
+ # Write information on state variables to log file.
728
+ state_dict = _flatten_dict_string_keys(tstate.state)
729
+ state_size = 0
730
+ total_state = 0
731
+ for (k, v) in state_dict.items():
732
+ if hasattr(v, "shape"):
733
+ state_size = np.prod(v.shape)
734
+ total_state += state_size
735
+ logging.info("state: %s, shape %s, size %d", k, v.shape, state_size)
736
+ else:
737
+ # Some other stuff may be stored in the state.
738
+ logging.info("state: %s [unknown]", k)
739
+ logging.info("Total state size: %d", total_state)
740
+
741
+ def _compute_parameter_distributions(self, params_dict):
742
+ """Compute info on distributions of parameters."""
743
+
744
+ scalar_params_dict = {}
745
+ for (k, v) in params_dict.items():
746
+ # Convert from bfloat16, which crashes when serializing a NaN.
747
+ v = np.asarray(v, dtype=jnp.float32)
748
+ scalar_params_dict[k + "_mean"] = np.mean(v)
749
+ scalar_params_dict[k + "_stddev"] = np.std(v)
750
+ # scalar_params_dict[k + "_min"] = np.min(v)
751
+ # scalar_params_dict[k + "_max"] = np.max(v)
752
+ return scalar_params_dict
753
+
754
+
755
+ def _flatten_dict_string_keys(params):
756
+ """Flattens a nested dictionary to have string keys and '/' separators."""
757
+ return {"/".join(k): v for k, v in flatten_dict(unfreeze(params)).items()}
aglib/meliad/training_task.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TrainingTask encapsulates the state associated with model step."""
16
+
17
+ import time
18
+ from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Tuple)
19
+
20
+ from absl import logging
21
+ from clu import metric_writers
22
+ from flax import optim
23
+ from flax import struct
24
+ import jax
25
+ import metrics_summary
26
+ import numpy as np
27
+
28
+
29
+ @struct.dataclass
30
+ class TrainState:
31
+ optimizer: optim.Optimizer # Trainable parameters.
32
+ state: Any # Other state, e.g. XL cache or memory.
33
+
34
+
35
+ PRNGKeys = Any
36
+ Metrics = Dict[str, Any]
37
+ MetricsSummary = metrics_summary.MetricsSummary
38
+
39
+ Dataset = Callable[[], Iterator[Any]]
40
+ StepFunction = Callable[[TrainState, Any, Any], Tuple[TrainState, Metrics]]
41
+ PrettyPrintInputFunction = Optional[Callable[[Any], str]]
42
+ ProcessSummariesFunction = Optional[Callable[[Any, str], Any]]
43
+ ExtraSummariesFunction = Optional[Callable[[str, int], Mapping[str, Any]]]
44
+
45
+
46
+ def should_run(step: int, every_steps: int) -> bool:
47
+ """Returns true if a periodic action should be run."""
48
+ return (step > 0) and (every_steps > 0) and (step % every_steps == 0)
49
+
50
+
51
+ class TrainingTask:
52
+ """A TrainingTask encapsulates the state associated with a training task.
53
+
54
+ Examples of tasks include training steps, test or validation runs,
55
+ or inference (generation). State includes the input pipeline, and
56
+ summary information that is averaged over multiple steps.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ *, # Pass arguments by keyword only.
62
+ mode: str,
63
+ dataset: Dataset,
64
+ step_function: StepFunction,
65
+ prng_keys: PRNGKeys,
66
+ summary: MetricsSummary,
67
+ extra_summary: MetricsSummary,
68
+ summary_writer: metric_writers.MetricWriter,
69
+ summary_prefix: str = "",
70
+ # --- Options from TrainingLoop ---
71
+ replicate_mode: bool = True,
72
+ print_input_every_steps: int = 0,
73
+ pretty_print_input_function: PrettyPrintInputFunction = None,
74
+ process_summaries_function: ProcessSummariesFunction = None,
75
+ extra_summaries_function: Optional[ExtraSummariesFunction] = None):
76
+ # Local state.
77
+ self.mode = mode
78
+ self.dataset = dataset
79
+ self.step_function = step_function
80
+ self.prng_keys = prng_keys
81
+ self.summary = summary
82
+ self.extra_summary = extra_summary
83
+ self.summary_writer = summary_writer
84
+ self.summary_prefix = summary_prefix
85
+
86
+ # Options carried over from TrainingLoop.
87
+ self.replicate_mode = replicate_mode
88
+ self.print_input_every_steps = print_input_every_steps
89
+ self.pretty_print_input_fn = pretty_print_input_function
90
+ self.process_summaries_fn = process_summaries_function
91
+ self.extra_summaries_fn = extra_summaries_function
92
+
93
+ # Local state.
94
+ if self.dataset is not None:
95
+ self.ds_iterator = self.dataset()
96
+ self.epoch = 0
97
+
98
+ def _get_metrics(self, device_metrics: Metrics) -> Metrics:
99
+ """Read a dictionary of metrics from device."""
100
+ if self.replicate_mode:
101
+ # x[0] gets the metric from device 0 -- the first replica.
102
+ # We assume that merge_replicated_metrics has already combined the
103
+ # metrics from multiple devices.
104
+ device_metrics = jax.tree_map(lambda x: x[0], device_metrics)
105
+ metrics_np = jax.device_get(device_metrics) # Get numpy arrays.
106
+ return metrics_np
107
+
108
+ def get_next_input(self) -> Any:
109
+ """Grab the next input from the data pipeline."""
110
+ if self.dataset is None:
111
+ logging.warning("No dataset for mode %s", self.mode)
112
+ return None
113
+
114
+ try:
115
+ x = next(self.ds_iterator)
116
+ except StopIteration:
117
+ logging.info("End of epoch %d for mode %s.", self.epoch, self.mode)
118
+ self.ds_iterator = self.dataset()
119
+ x = next(self.ds_iterator)
120
+ self.epoch += 1
121
+ return x
122
+
123
+ def run_step(self, tstate: TrainState, x: Any,
124
+ step: int, sub_step: int = 0) -> Tuple[TrainState, Metrics]:
125
+ """Run the model for a single step.
126
+
127
+ Args:
128
+ tstate: The current model state.
129
+ x: The input for the model -- from get_next_input.
130
+ step: The training step number.
131
+ sub_step: For tasks that run multiple iterations within a step.
132
+ E.g. A test cycle will call run_step multiple times to cover the test
133
+ set. The step counter will not increment, but sub_step will.
134
+
135
+ Returns:
136
+ An updated model state.
137
+ """
138
+
139
+ start_time = time.perf_counter()
140
+
141
+ # Split a batch of inputs among local replicas.
142
+ if self.replicate_mode:
143
+ x = split_batch_dimension(x, jax.local_device_count())
144
+
145
+ # Pretty-print the input to the summary and log file every so often.
146
+ if (sub_step == 0 and self.pretty_print_input_fn is not None and
147
+ should_run(step, self.print_input_every_steps)):
148
+ x_first = jax.tree_map(lambda x: x[0], x) if self.replicate_mode else x
149
+ x_strs = self.pretty_print_input_fn(x_first)
150
+ logging.info("[%d] Input (%s) = %s", step, self.mode, x_strs)
151
+ self.summary.add_text({"input": x_strs})
152
+
153
+ # Run the step function on the input.
154
+ with jax.profiler.StepTraceAnnotation(self.mode, step_num=step):
155
+ (tstate, metrics) = self.step_function(tstate, x, self.prng_keys)
156
+
157
+ # Read metrics from device.
158
+ metrics_np = self._get_metrics(metrics)
159
+ end_time = time.perf_counter()
160
+ metrics_np["step_time"] = end_time - start_time
161
+ if "epoch" not in metrics_np.keys():
162
+ metrics_np["epoch"] = self.epoch
163
+
164
+ # Add metrics to the current summary.
165
+ self.summary.add(metrics_np)
166
+ return (tstate, metrics_np)
167
+
168
+ def flush(self, step: int):
169
+ """Flush accumulated metric summaries to disk."""
170
+
171
+ if self.summary_writer is None:
172
+ self.summary.clear() # Clear summary if we can't write it.
173
+ return
174
+
175
+ if self.summary.empty():
176
+ return
177
+
178
+ # Do post-processing of the summaries.
179
+ if self.process_summaries_fn is not None:
180
+ self.summary = self.process_summaries_fn(self.summary, self.mode) # pylint: disable=not-callable
181
+
182
+ # Write and clear summary data.
183
+ logging.info("Writing summaries for mode %s.", self.mode)
184
+ self.summary.write(self.summary_writer, step, prefix=self.summary_prefix)
185
+
186
+ # Add extra summaries that are not computed by the step function.
187
+ if self.extra_summaries_fn is not None:
188
+ self.extra_summary.add(self.extra_summaries_fn(self.mode, step))
189
+ self.extra_summary.write(self.summary_writer, step, prefix="")
190
+
191
+
192
+ def split_batch_dimension(inputs: Any, num_replicas: int) -> Any:
193
+ """Splits the leading batch dimension.
194
+
195
+ Given inputs of shape [num_replicas * batch_size, ...], it will reshape
196
+ them to [num_replicas, batch_size, ...]. This operation is intended to be
197
+ used right before calling pmap, which will eliminate the num_replicas
198
+ dimension.
199
+
200
+ Args:
201
+ inputs: Tuple of inputs to split.
202
+ num_replicas: Number of replicas.
203
+
204
+ Returns:
205
+ inputs with extra batch dimension.
206
+ """
207
+
208
+ def split_batch_dim(x):
209
+ assert x.ndim > 0
210
+ if (x.shape[0] % num_replicas) != 0:
211
+ raise ValueError(f"Can't split {x.shape} into {num_replicas} replicas.")
212
+ batch_size = x.shape[0] // num_replicas
213
+ split_shape = [num_replicas, batch_size] + list(x.shape[1:])
214
+ return np.reshape(x, split_shape)
215
+
216
+ return jax.tree_map(split_batch_dim, inputs)