Spaces:
Running
Running
Merge pull request #127 from borisdayma/pjit-t5x
Browse filesfeat(train): pjit optimization and distributed shampoo support
- src/dalle_mini/data.py +6 -40
- src/dalle_mini/model/modeling.py +1 -1
- tools/train/distributed_shampoo.py +427 -61
- tools/train/train.py +215 -106
src/dalle_mini/data.py
CHANGED
@@ -152,24 +152,15 @@ class Dataset:
|
|
152 |
),
|
153 |
)
|
154 |
|
155 |
-
def dataloader(
|
156 |
-
self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
|
157 |
-
):
|
158 |
-
num_devices = jax.local_device_count()
|
159 |
-
|
160 |
def _dataloader_datasets_non_streaming(
|
161 |
dataset: Dataset,
|
162 |
-
per_device_batch_size: int,
|
163 |
-
gradient_accumulation_steps: int,
|
164 |
rng: jax.random.PRNGKey = None,
|
165 |
):
|
166 |
"""
|
167 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
168 |
Shuffle batches if rng is set.
|
169 |
"""
|
170 |
-
batch_size = (
|
171 |
-
per_device_batch_size * num_devices * gradient_accumulation_steps
|
172 |
-
)
|
173 |
steps_per_epoch = len(dataset) // batch_size
|
174 |
|
175 |
if rng is not None:
|
@@ -185,18 +176,10 @@ class Dataset:
|
|
185 |
for idx in batch_idx:
|
186 |
batch = dataset[idx]
|
187 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
188 |
-
if gradient_accumulation_steps is not None:
|
189 |
-
batch = jax.tree_map(
|
190 |
-
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
191 |
-
batch,
|
192 |
-
)
|
193 |
yield batch
|
194 |
|
195 |
def _dataloader_datasets_streaming(
|
196 |
dataset: Dataset,
|
197 |
-
split: str,
|
198 |
-
per_device_batch_size: int,
|
199 |
-
gradient_accumulation_steps: int,
|
200 |
epoch: int,
|
201 |
):
|
202 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
@@ -208,28 +191,15 @@ class Dataset:
|
|
208 |
# For validation data we put the entire set on each host as we could lose
|
209 |
# too many samples on pods
|
210 |
if epoch is not None:
|
211 |
-
|
|
|
212 |
dataset.set_epoch(epoch)
|
213 |
epoch += 1
|
214 |
for item in dataset:
|
215 |
for k, v in item.items():
|
216 |
batch[k].append(v)
|
217 |
-
|
218 |
-
# (40, 3, 3) -> shard 8 x (5, 3, 3)
|
219 |
-
# (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
|
220 |
-
if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
|
221 |
-
gradient_accumulation_steps
|
222 |
-
if gradient_accumulation_steps is not None
|
223 |
-
else 1
|
224 |
-
):
|
225 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
226 |
-
if gradient_accumulation_steps is not None:
|
227 |
-
batch = jax.tree_map(
|
228 |
-
lambda x: x.reshape(
|
229 |
-
(-1, per_device_batch_size) + x.shape[1:]
|
230 |
-
),
|
231 |
-
batch,
|
232 |
-
)
|
233 |
yield batch
|
234 |
batch = {k: [] for k in keys}
|
235 |
first_loop = False
|
@@ -242,15 +212,11 @@ class Dataset:
|
|
242 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
243 |
|
244 |
if self.streaming:
|
245 |
-
return _dataloader_datasets_streaming(
|
246 |
-
ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
|
247 |
-
)
|
248 |
else:
|
249 |
if split == "train":
|
250 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
251 |
-
return _dataloader_datasets_non_streaming(
|
252 |
-
ds, per_device_batch_size, gradient_accumulation_steps, input_rng
|
253 |
-
)
|
254 |
|
255 |
@property
|
256 |
def length(self):
|
|
|
152 |
),
|
153 |
)
|
154 |
|
155 |
+
def dataloader(self, split, batch_size, epoch=None):
|
|
|
|
|
|
|
|
|
156 |
def _dataloader_datasets_non_streaming(
|
157 |
dataset: Dataset,
|
|
|
|
|
158 |
rng: jax.random.PRNGKey = None,
|
159 |
):
|
160 |
"""
|
161 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
162 |
Shuffle batches if rng is set.
|
163 |
"""
|
|
|
|
|
|
|
164 |
steps_per_epoch = len(dataset) // batch_size
|
165 |
|
166 |
if rng is not None:
|
|
|
176 |
for idx in batch_idx:
|
177 |
batch = dataset[idx]
|
178 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
179 |
yield batch
|
180 |
|
181 |
def _dataloader_datasets_streaming(
|
182 |
dataset: Dataset,
|
|
|
|
|
|
|
183 |
epoch: int,
|
184 |
):
|
185 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
|
|
191 |
# For validation data we put the entire set on each host as we could lose
|
192 |
# too many samples on pods
|
193 |
if epoch is not None:
|
194 |
+
assert split == "train"
|
195 |
+
# reshuffle training data at each epoch
|
196 |
dataset.set_epoch(epoch)
|
197 |
epoch += 1
|
198 |
for item in dataset:
|
199 |
for k, v in item.items():
|
200 |
batch[k].append(v)
|
201 |
+
if len(batch[keys[0]]) == batch_size:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
yield batch
|
204 |
batch = {k: [] for k in keys}
|
205 |
first_loop = False
|
|
|
212 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
213 |
|
214 |
if self.streaming:
|
215 |
+
return _dataloader_datasets_streaming(ds, epoch)
|
|
|
|
|
216 |
else:
|
217 |
if split == "train":
|
218 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
219 |
+
return _dataloader_datasets_non_streaming(ds, input_rng)
|
|
|
|
|
220 |
|
221 |
@property
|
222 |
def length(self):
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -312,7 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
312 |
seed: int = 0,
|
313 |
dtype: jnp.dtype = jnp.float32,
|
314 |
abstract_init: bool = False,
|
315 |
-
load_on_cpu: bool =
|
316 |
**kwargs,
|
317 |
):
|
318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
|
312 |
seed: int = 0,
|
313 |
dtype: jnp.dtype = jnp.float32,
|
314 |
abstract_init: bool = False,
|
315 |
+
load_on_cpu: bool = False,
|
316 |
**kwargs,
|
317 |
):
|
318 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
tools/train/distributed_shampoo.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
"""File copied from https://github.com/google-research/google-research/edit/master/scalable_shampoo/optax/distributed_shampoo.py"""
|
2 |
-
|
3 |
# coding=utf-8
|
4 |
-
# Copyright
|
5 |
#
|
6 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
# you may not use this file except in compliance with the License.
|
@@ -147,6 +145,12 @@ class QuantizedValue:
|
|
147 |
return val
|
148 |
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
# Per parameter optimizer state used in data-parallel training.
|
151 |
class ParameterStats(NamedTuple):
|
152 |
"""State associated to each parameter of the model being trained."""
|
@@ -156,6 +160,7 @@ class ParameterStats(NamedTuple):
|
|
156 |
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
|
157 |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
158 |
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
|
|
159 |
|
160 |
|
161 |
# For training extremely large model; We keep a global state with a concatenated
|
@@ -166,6 +171,7 @@ class ParameterStats(NamedTuple):
|
|
166 |
class GlobalShardedParameterStats:
|
167 |
statistics: chex.Array # Statistics
|
168 |
preconditioners: chex.Array # Preconditioners
|
|
|
169 |
|
170 |
|
171 |
# These are per-parameter local states; All statistics here mirror the parameter
|
@@ -177,12 +183,34 @@ class LocalShardedParameterStats:
|
|
177 |
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
|
178 |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
179 |
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
|
|
180 |
index_start: np.int32 = struct.field(
|
181 |
pytree_node=False
|
182 |
) # Index into global statistics array
|
183 |
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
|
184 |
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
class ShardedShampooStats(NamedTuple):
|
187 |
"""Shampoo state in sharded mode."""
|
188 |
|
@@ -195,6 +223,12 @@ class ShampooState(NamedTuple):
|
|
195 |
stats: Any
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
class GraftingType(enum.IntEnum):
|
199 |
SGD = 1
|
200 |
ADAGRAD = 2
|
@@ -292,6 +326,8 @@ def matrix_inverse_pth_root(
|
|
292 |
matrix^(-1/p)
|
293 |
"""
|
294 |
|
|
|
|
|
295 |
# We use float32 for the matrix inverse pth root.
|
296 |
# Switch to f64 if you have hardware that supports it.
|
297 |
matrix_size = matrix.shape[0]
|
@@ -615,6 +651,7 @@ def _convert_to_parameter_stats(global_stats, local_stat):
|
|
615 |
new_preconditioners,
|
616 |
local_stat.diagonal_momentum,
|
617 |
local_stat.momentum,
|
|
|
618 |
)
|
619 |
|
620 |
|
@@ -624,11 +661,40 @@ def _convert_from_parameter_stats(parameter_stats, local_stats):
|
|
624 |
parameter_stats.diagonal_statistics,
|
625 |
parameter_stats.diagonal_momentum,
|
626 |
parameter_stats.momentum,
|
|
|
627 |
local_stats.index_start,
|
628 |
local_stats.sizes,
|
629 |
)
|
630 |
|
631 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
def batch(x, num_devices):
|
633 |
"""Batch `x` so that so that leading axis is num_devices."""
|
634 |
n = len(x)
|
@@ -670,7 +736,8 @@ def distributed_shampoo(
|
|
670 |
batch_axis_name=None,
|
671 |
### Only set following 3 params in pjit/spmd mode.
|
672 |
### WARNING: Experimental
|
673 |
-
|
|
|
674 |
num_devices_for_pjit=None,
|
675 |
shard_optimizer_states=False,
|
676 |
###
|
@@ -730,7 +797,8 @@ def distributed_shampoo(
|
|
730 |
exponent_override: Override the exponent used in matrix inverse.
|
731 |
batch_axis_name: labeled axis over pmap for data-parallel training the
|
732 |
optimizer used for.
|
733 |
-
|
|
|
734 |
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
735 |
shard_optimizer_states: Shard optimizer states to save memory in model
|
736 |
parallel training.
|
@@ -830,6 +898,11 @@ def distributed_shampoo(
|
|
830 |
)
|
831 |
|
832 |
def sharded_init_fn(params):
|
|
|
|
|
|
|
|
|
|
|
833 |
params_flat, treedef = jax.tree_flatten(params)
|
834 |
# Find max size to pad to.
|
835 |
max_size = 0
|
@@ -845,6 +918,7 @@ def distributed_shampoo(
|
|
845 |
padded_statistics = []
|
846 |
padded_preconditioners = []
|
847 |
local_stats_flat = []
|
|
|
848 |
for param in params_flat:
|
849 |
preconditioner = Preconditioner(
|
850 |
param, block_size, best_effort_shape_interpretation
|
@@ -862,6 +936,12 @@ def distributed_shampoo(
|
|
862 |
preconditioners = [jnp.eye(max_size) for s in shapes]
|
863 |
padded_statistics.extend(statistics)
|
864 |
padded_preconditioners.extend(preconditioners)
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
|
866 |
diagonal_statistics = []
|
867 |
if graft_type != GraftingType.SGD:
|
@@ -871,6 +951,7 @@ def distributed_shampoo(
|
|
871 |
_quantize_diagonal_statistics(diagonal_statistics),
|
872 |
_quantize_momentum(jnp.zeros_like(param)),
|
873 |
_quantize_momentum(jnp.zeros_like(param)),
|
|
|
874 |
index_start,
|
875 |
sizes,
|
876 |
)
|
@@ -888,14 +969,238 @@ def distributed_shampoo(
|
|
888 |
padded_preconditioners.extend(
|
889 |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
|
890 |
)
|
|
|
891 |
global_stats = GlobalShardedParameterStats(
|
892 |
-
jnp.stack(padded_statistics),
|
|
|
|
|
893 |
)
|
894 |
return ShampooState(
|
895 |
count=jnp.zeros([], jnp.int32),
|
896 |
stats=ShardedShampooStats(global_stats, local_stats),
|
897 |
)
|
898 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
899 |
def sharded_update_fn(grads, state, params):
|
900 |
"""Transform the input gradient and update all statistics in sharded mode.
|
901 |
|
@@ -923,20 +1228,6 @@ def distributed_shampoo(
|
|
923 |
params_flat,
|
924 |
)
|
925 |
|
926 |
-
exponents = []
|
927 |
-
for stat, param in zip(new_stats_flat, params_flat):
|
928 |
-
num_statistics = len(stat.statistics)
|
929 |
-
if num_statistics > 0:
|
930 |
-
preconditioner = Preconditioner(
|
931 |
-
param, block_size, best_effort_shape_interpretation
|
932 |
-
)
|
933 |
-
exponent = (
|
934 |
-
preconditioner.exponent_for_preconditioner()
|
935 |
-
if exponent_override == 0
|
936 |
-
else exponent_override
|
937 |
-
)
|
938 |
-
exponents.extend([exponent] * num_statistics)
|
939 |
-
|
940 |
outputs = jax.tree_multimap(
|
941 |
lambda g, s, p: _transform_grad(g, s, p, state.count),
|
942 |
grads_flat,
|
@@ -951,7 +1242,6 @@ def distributed_shampoo(
|
|
951 |
_convert_from_parameter_stats(new_stat, local_stat)
|
952 |
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
|
953 |
]
|
954 |
-
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
|
955 |
|
956 |
max_size = global_stats.statistics.shape[1]
|
957 |
new_padded_statistics = []
|
@@ -974,22 +1264,16 @@ def distributed_shampoo(
|
|
974 |
for _ in range(to_pad)
|
975 |
]
|
976 |
)
|
977 |
-
exponents.extend([1 for _ in range(to_pad)])
|
978 |
new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
mi_pth_root = functools.partial(
|
983 |
-
matrix_inverse_pth_root,
|
984 |
-
ridge_epsilon=matrix_epsilon,
|
985 |
-
precision=precision,
|
986 |
-
)
|
987 |
-
preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
|
988 |
-
return preconditioners, errors
|
989 |
|
990 |
def _internal_inverse_pth_root_all():
|
991 |
-
preconditioners, errors =
|
992 |
-
new_stacked_padded_statistics,
|
|
|
|
|
993 |
)
|
994 |
return preconditioners, errors
|
995 |
|
@@ -1000,13 +1284,18 @@ def distributed_shampoo(
|
|
1000 |
# shaped tensors. Note statistics will be ignored as we are passing in
|
1001 |
# a large init value for error.
|
1002 |
preconditioners_init = new_stacked_padded_statistics
|
1003 |
-
|
|
|
1004 |
init_state = [preconditioners_init, errors_init]
|
1005 |
perform_step = state.count % preconditioning_compute_steps == 0
|
1006 |
new_preconditioners, errors = efficient_cond(
|
1007 |
perform_step, _internal_inverse_pth_root_all, init_state
|
1008 |
)
|
1009 |
|
|
|
|
|
|
|
|
|
1010 |
errors = errors.reshape((-1, 1, 1))
|
1011 |
predicate = jnp.logical_or(
|
1012 |
jnp.isnan(errors), errors >= inverse_failure_threshold
|
@@ -1017,7 +1306,9 @@ def distributed_shampoo(
|
|
1017 |
+ (1.0 - predicate) * new_preconditioners
|
1018 |
)
|
1019 |
new_global_stats = GlobalShardedParameterStats(
|
1020 |
-
new_stacked_padded_statistics,
|
|
|
|
|
1021 |
)
|
1022 |
new_shampoo_state = ShampooState(
|
1023 |
count=state.count + 1,
|
@@ -1048,6 +1339,7 @@ def distributed_shampoo(
|
|
1048 |
_maybe_quantize_preconditioners(preconditioners),
|
1049 |
_quantize_momentum(jnp.zeros_like(param)),
|
1050 |
_quantize_momentum(jnp.zeros_like(param)),
|
|
|
1051 |
)
|
1052 |
|
1053 |
return ShampooState(
|
@@ -1092,6 +1384,7 @@ def distributed_shampoo(
|
|
1092 |
state.preconditioners,
|
1093 |
state.diagonal_momentum,
|
1094 |
state.momentum,
|
|
|
1095 |
)
|
1096 |
|
1097 |
def _matrix_inverse_pth_root_vmap(xs, ps):
|
@@ -1115,33 +1408,27 @@ def distributed_shampoo(
|
|
1115 |
|
1116 |
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
|
1117 |
|
1118 |
-
def _matrix_inverse_pth_root_pjit(xs, ps):
|
1119 |
-
mesh_axis_names_tuple = tuple(mesh_axis_names)
|
1120 |
# Partition the concatenated statistics matrix across all cores.
|
1121 |
-
|
1122 |
-
|
1123 |
-
|
1124 |
-
|
1125 |
-
|
1126 |
-
),
|
1127 |
-
)(xs, ps)
|
1128 |
# Run matrix inverse pth root on each shard.
|
1129 |
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
|
1130 |
partitioned_xs, partitioned_ps
|
1131 |
)
|
|
|
|
|
|
|
|
|
|
|
1132 |
# Recombine the outputs at each core.
|
1133 |
-
preconditioners
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
mesh_axis_names_tuple,
|
1138 |
-
),
|
1139 |
-
pjit.PartitionSpec(
|
1140 |
-
mesh_axis_names_tuple,
|
1141 |
-
),
|
1142 |
-
),
|
1143 |
-
out_axis_resources=(None, None),
|
1144 |
-
)(partitioned_preconditioners, partitioned_errors)
|
1145 |
return preconditioners, errors
|
1146 |
|
1147 |
def _pmap_compute_preconditioners(
|
@@ -1223,31 +1510,54 @@ def distributed_shampoo(
|
|
1223 |
)
|
1224 |
|
1225 |
new_preconditioners_flat = []
|
|
|
1226 |
for p, shape, prev_p, error in zip(
|
1227 |
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
|
1228 |
):
|
1229 |
new_preconditioners_flat.append(
|
1230 |
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
|
1231 |
)
|
|
|
1232 |
|
1233 |
assert len(states) == len(num_statistics_per_state)
|
1234 |
assert len(new_preconditioners_flat) == num_statistics
|
|
|
1235 |
|
1236 |
# Add back empty preconditioners so we that we can set the optimizer state.
|
1237 |
preconditioners_for_states = []
|
1238 |
idx = 0
|
|
|
1239 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1240 |
if num_statistics == 0:
|
1241 |
preconditioners_for_states.append([])
|
|
|
1242 |
else:
|
1243 |
preconditioners_for_state = new_preconditioners_flat[
|
1244 |
idx : idx + num_statistics
|
1245 |
]
|
1246 |
assert len(state.statistics) == len(preconditioners_for_state)
|
1247 |
preconditioners_for_states.append(preconditioners_for_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1248 |
idx += num_statistics
|
1249 |
new_states = []
|
1250 |
-
for state, new_preconditioners in zip(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1251 |
new_states.append(
|
1252 |
ParameterStats(
|
1253 |
state.diagonal_statistics,
|
@@ -1255,6 +1565,7 @@ def distributed_shampoo(
|
|
1255 |
new_preconditioners,
|
1256 |
state.diagonal_momentum,
|
1257 |
state.momentum,
|
|
|
1258 |
)
|
1259 |
)
|
1260 |
|
@@ -1413,6 +1724,7 @@ def distributed_shampoo(
|
|
1413 |
new_quantized_preconditioners_flat = []
|
1414 |
new_quantized_diagonals_flat = []
|
1415 |
new_quantized_bucket_sizes_flat = []
|
|
|
1416 |
for p, d, b, shape, prev_p, error in zip(
|
1417 |
quantized_preconditioners_flat,
|
1418 |
quantized_diagonals_flat,
|
@@ -1432,6 +1744,7 @@ def distributed_shampoo(
|
|
1432 |
new_quantized_bucket_sizes_flat.append(
|
1433 |
_select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
|
1434 |
)
|
|
|
1435 |
|
1436 |
assert len(states) == len(num_statistics_per_state)
|
1437 |
assert len(new_quantized_preconditioners_flat) == num_statistics
|
@@ -1440,10 +1753,12 @@ def distributed_shampoo(
|
|
1440 |
|
1441 |
# Add back empty preconditioners so we that we can set the optimizer state.
|
1442 |
preconditioners_for_states = []
|
|
|
1443 |
idx = 0
|
1444 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1445 |
if num_statistics == 0:
|
1446 |
preconditioners_for_states.append([])
|
|
|
1447 |
else:
|
1448 |
quantized_preconditioners_for_state = (
|
1449 |
new_quantized_preconditioners_flat[idx : idx + num_statistics]
|
@@ -1454,10 +1769,14 @@ def distributed_shampoo(
|
|
1454 |
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
|
1455 |
idx : idx + num_statistics
|
1456 |
]
|
|
|
|
|
|
|
1457 |
|
1458 |
assert len(state.statistics) == len(quantized_preconditioners_for_state)
|
1459 |
assert len(state.statistics) == len(quantized_diagonals_for_state)
|
1460 |
assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
|
|
|
1461 |
|
1462 |
quantized_preconditioners = []
|
1463 |
for qv, qd, qb in zip(
|
@@ -1469,9 +1788,21 @@ def distributed_shampoo(
|
|
1469 |
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
|
1470 |
)
|
1471 |
preconditioners_for_states.append(quantized_preconditioners)
|
|
|
1472 |
idx += num_statistics
|
1473 |
new_states = []
|
1474 |
-
for state, new_preconditioners in zip(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1475 |
new_states.append(
|
1476 |
ParameterStats(
|
1477 |
state.diagonal_statistics,
|
@@ -1479,6 +1810,7 @@ def distributed_shampoo(
|
|
1479 |
new_preconditioners,
|
1480 |
state.diagonal_momentum,
|
1481 |
state.momentum,
|
|
|
1482 |
)
|
1483 |
)
|
1484 |
|
@@ -1560,31 +1892,53 @@ def distributed_shampoo(
|
|
1560 |
)
|
1561 |
|
1562 |
new_preconditioners_flat = []
|
|
|
1563 |
for p, shape, prev_p, error in zip(
|
1564 |
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
|
1565 |
):
|
1566 |
new_preconditioners_flat.append(
|
1567 |
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
|
1568 |
)
|
|
|
1569 |
|
1570 |
assert len(states) == len(num_statistics_per_state)
|
1571 |
assert len(new_preconditioners_flat) == num_statistics
|
1572 |
|
1573 |
# Add back empty preconditioners so we that we can set the optimizer state.
|
1574 |
preconditioners_for_states = []
|
|
|
1575 |
idx = 0
|
1576 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1577 |
if num_statistics == 0:
|
1578 |
preconditioners_for_states.append([])
|
|
|
1579 |
else:
|
1580 |
preconditioners_for_state = new_preconditioners_flat[
|
1581 |
idx : idx + num_statistics
|
1582 |
]
|
1583 |
assert len(state.statistics) == len(preconditioners_for_state)
|
1584 |
preconditioners_for_states.append(preconditioners_for_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
1585 |
idx += num_statistics
|
|
|
1586 |
new_states = []
|
1587 |
-
for state, new_preconditioners in zip(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1588 |
new_states.append(
|
1589 |
ParameterStats(
|
1590 |
state.diagonal_statistics,
|
@@ -1592,6 +1946,7 @@ def distributed_shampoo(
|
|
1592 |
new_preconditioners,
|
1593 |
state.diagonal_momentum,
|
1594 |
state.momentum,
|
|
|
1595 |
)
|
1596 |
)
|
1597 |
|
@@ -1778,7 +2133,9 @@ def distributed_shampoo(
|
|
1778 |
state.preconditioners,
|
1779 |
_quantize_momentum(grafting_update_with_wd_momentum),
|
1780 |
_quantize_momentum(shampoo_update_with_wd_momentum),
|
|
|
1781 |
)
|
|
|
1782 |
return transformed_update, param_stats
|
1783 |
|
1784 |
def update_fn(grads, state, params):
|
@@ -1821,6 +2178,15 @@ def distributed_shampoo(
|
|
1821 |
return updates, new_state
|
1822 |
|
1823 |
if shard_optimizer_states:
|
1824 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1825 |
else:
|
1826 |
return optax.GradientTransformation(init_fn, update_fn)
|
|
|
|
|
|
|
1 |
# coding=utf-8
|
2 |
+
# Copyright 2022 The Google Research Authors.
|
3 |
#
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
# you may not use this file except in compliance with the License.
|
|
|
145 |
return val
|
146 |
|
147 |
|
148 |
+
@struct.dataclass
|
149 |
+
class TrainingMetrics:
|
150 |
+
inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
|
151 |
+
# TODO(rohananil): Add more important metrics to track during training.
|
152 |
+
|
153 |
+
|
154 |
# Per parameter optimizer state used in data-parallel training.
|
155 |
class ParameterStats(NamedTuple):
|
156 |
"""State associated to each parameter of the model being trained."""
|
|
|
160 |
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
|
161 |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
162 |
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
163 |
+
training_metrics: TrainingMetrics # Metrics (optional for training).
|
164 |
|
165 |
|
166 |
# For training extremely large model; We keep a global state with a concatenated
|
|
|
171 |
class GlobalShardedParameterStats:
|
172 |
statistics: chex.Array # Statistics
|
173 |
preconditioners: chex.Array # Preconditioners
|
174 |
+
exponents: chex.Array # exponents
|
175 |
|
176 |
|
177 |
# These are per-parameter local states; All statistics here mirror the parameter
|
|
|
183 |
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
|
184 |
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
185 |
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
186 |
+
training_metrics: TrainingMetrics # Metrics (optional for training).
|
187 |
index_start: np.int32 = struct.field(
|
188 |
pytree_node=False
|
189 |
) # Index into global statistics array
|
190 |
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
|
191 |
|
192 |
|
193 |
+
def init_training_metrics(num_statistics):
|
194 |
+
if num_statistics:
|
195 |
+
return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
|
196 |
+
else:
|
197 |
+
return TrainingMetrics([])
|
198 |
+
|
199 |
+
|
200 |
+
def init_training_metrics_shapes(num_statistics):
|
201 |
+
if num_statistics:
|
202 |
+
return TrainingMetrics([[num_statistics], jnp.float32])
|
203 |
+
else:
|
204 |
+
return TrainingMetrics([None, jnp.float32])
|
205 |
+
|
206 |
+
|
207 |
+
def init_training_metrics_pspec(num_statistics):
|
208 |
+
if num_statistics:
|
209 |
+
return TrainingMetrics(pjit.PartitionSpec())
|
210 |
+
else:
|
211 |
+
return TrainingMetrics(None)
|
212 |
+
|
213 |
+
|
214 |
class ShardedShampooStats(NamedTuple):
|
215 |
"""Shampoo state in sharded mode."""
|
216 |
|
|
|
223 |
stats: Any
|
224 |
|
225 |
|
226 |
+
class InitFnState(NamedTuple):
|
227 |
+
init_fn: Any
|
228 |
+
pspec_fn: Any
|
229 |
+
shape_and_dtype_fn: Any
|
230 |
+
|
231 |
+
|
232 |
class GraftingType(enum.IntEnum):
|
233 |
SGD = 1
|
234 |
ADAGRAD = 2
|
|
|
326 |
matrix^(-1/p)
|
327 |
"""
|
328 |
|
329 |
+
assert matrix.shape[0] == matrix.shape[1]
|
330 |
+
|
331 |
# We use float32 for the matrix inverse pth root.
|
332 |
# Switch to f64 if you have hardware that supports it.
|
333 |
matrix_size = matrix.shape[0]
|
|
|
651 |
new_preconditioners,
|
652 |
local_stat.diagonal_momentum,
|
653 |
local_stat.momentum,
|
654 |
+
local_stat.training_metrics,
|
655 |
)
|
656 |
|
657 |
|
|
|
661 |
parameter_stats.diagonal_statistics,
|
662 |
parameter_stats.diagonal_momentum,
|
663 |
parameter_stats.momentum,
|
664 |
+
parameter_stats.training_metrics,
|
665 |
local_stats.index_start,
|
666 |
local_stats.sizes,
|
667 |
)
|
668 |
|
669 |
|
670 |
+
def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
|
671 |
+
"""Adds errors back into local statistics."""
|
672 |
+
new_local_stats = []
|
673 |
+
for local_stat in local_stats:
|
674 |
+
index_start = int(local_stat.index_start)
|
675 |
+
index_end = int(len(local_stat.sizes)) + index_start
|
676 |
+
per_stat_error = errors[index_start:index_end]
|
677 |
+
if local_stat.sizes:
|
678 |
+
per_stat_error = jnp.where(
|
679 |
+
jnp.logical_and(
|
680 |
+
per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
|
681 |
+
),
|
682 |
+
per_stat_error,
|
683 |
+
local_stat.training_metrics.inverse_pth_root_errors,
|
684 |
+
)
|
685 |
+
new_local_stats.append(
|
686 |
+
LocalShardedParameterStats(
|
687 |
+
local_stat.diagonal_statistics,
|
688 |
+
local_stat.diagonal_momentum,
|
689 |
+
local_stat.momentum,
|
690 |
+
TrainingMetrics(per_stat_error),
|
691 |
+
local_stat.index_start,
|
692 |
+
local_stat.sizes,
|
693 |
+
)
|
694 |
+
)
|
695 |
+
return new_local_stats
|
696 |
+
|
697 |
+
|
698 |
def batch(x, num_devices):
|
699 |
"""Batch `x` so that so that leading axis is num_devices."""
|
700 |
n = len(x)
|
|
|
736 |
batch_axis_name=None,
|
737 |
### Only set following 3 params in pjit/spmd mode.
|
738 |
### WARNING: Experimental
|
739 |
+
statistics_partition_spec=None,
|
740 |
+
preconditioner_partition_spec=None,
|
741 |
num_devices_for_pjit=None,
|
742 |
shard_optimizer_states=False,
|
743 |
###
|
|
|
797 |
exponent_override: Override the exponent used in matrix inverse.
|
798 |
batch_axis_name: labeled axis over pmap for data-parallel training the
|
799 |
optimizer used for.
|
800 |
+
statistics_partition_spec: PartitionSpec to be used in sharded mode.
|
801 |
+
preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
|
802 |
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
803 |
shard_optimizer_states: Shard optimizer states to save memory in model
|
804 |
parallel training.
|
|
|
898 |
)
|
899 |
|
900 |
def sharded_init_fn(params):
|
901 |
+
"""Returns optimizer state (for PJIT mode).
|
902 |
+
|
903 |
+
Args:
|
904 |
+
params: the parameters that should be updated.
|
905 |
+
"""
|
906 |
params_flat, treedef = jax.tree_flatten(params)
|
907 |
# Find max size to pad to.
|
908 |
max_size = 0
|
|
|
918 |
padded_statistics = []
|
919 |
padded_preconditioners = []
|
920 |
local_stats_flat = []
|
921 |
+
exponents = []
|
922 |
for param in params_flat:
|
923 |
preconditioner = Preconditioner(
|
924 |
param, block_size, best_effort_shape_interpretation
|
|
|
936 |
preconditioners = [jnp.eye(max_size) for s in shapes]
|
937 |
padded_statistics.extend(statistics)
|
938 |
padded_preconditioners.extend(preconditioners)
|
939 |
+
exponent = (
|
940 |
+
preconditioner.exponent_for_preconditioner()
|
941 |
+
if exponent_override == 0
|
942 |
+
else exponent_override
|
943 |
+
)
|
944 |
+
exponents.extend([exponent] * len(shapes))
|
945 |
|
946 |
diagonal_statistics = []
|
947 |
if graft_type != GraftingType.SGD:
|
|
|
951 |
_quantize_diagonal_statistics(diagonal_statistics),
|
952 |
_quantize_momentum(jnp.zeros_like(param)),
|
953 |
_quantize_momentum(jnp.zeros_like(param)),
|
954 |
+
init_training_metrics(len(sizes)),
|
955 |
index_start,
|
956 |
sizes,
|
957 |
)
|
|
|
969 |
padded_preconditioners.extend(
|
970 |
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
|
971 |
)
|
972 |
+
exponents.extend([1 for _ in range(to_pad)])
|
973 |
global_stats = GlobalShardedParameterStats(
|
974 |
+
jnp.stack(padded_statistics),
|
975 |
+
jnp.stack(padded_preconditioners),
|
976 |
+
jnp.stack(exponents),
|
977 |
)
|
978 |
return ShampooState(
|
979 |
count=jnp.zeros([], jnp.int32),
|
980 |
stats=ShardedShampooStats(global_stats, local_stats),
|
981 |
)
|
982 |
|
983 |
+
def _max_statistics_size_from_params(params):
|
984 |
+
max_size = 0
|
985 |
+
for param in params:
|
986 |
+
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
|
987 |
+
preconditioner = Preconditioner(
|
988 |
+
param_clone, block_size, best_effort_shape_interpretation
|
989 |
+
)
|
990 |
+
if not _skip_preconditioning(param):
|
991 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
992 |
+
sizes = [s[0] for s in shapes]
|
993 |
+
max_size = max(max(sizes), max_size)
|
994 |
+
return max_size
|
995 |
+
|
996 |
+
def _remove_leading_sharding_annotation(pspec):
|
997 |
+
"""Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
|
998 |
+
# None and PSpec(None) are valid PSpecs.
|
999 |
+
if pspec and len(pspec) > 1:
|
1000 |
+
return pjit.PartitionSpec(*pspec[1:])
|
1001 |
+
else:
|
1002 |
+
return None
|
1003 |
+
|
1004 |
+
def sharded_init_partition_spec_fn(
|
1005 |
+
params, params_partition_spec, partition_spec_for_statistics
|
1006 |
+
):
|
1007 |
+
"""Returns a parallel state tree with PartitionSpec associated with state.
|
1008 |
+
|
1009 |
+
|
1010 |
+
Args:
|
1011 |
+
params: A pytree with params.
|
1012 |
+
params_partition_spec: A pytree with PartitionSpec for params.
|
1013 |
+
partition_spec_for_statistics: PartitionSpec for the statistics.
|
1014 |
+
"""
|
1015 |
+
# Parallel lists of spec, and params.
|
1016 |
+
param_pspec_flat, _ = jax.tree_flatten(
|
1017 |
+
params_partition_spec, is_leaf=lambda x: x is None
|
1018 |
+
)
|
1019 |
+
params_flat, treedef = jax.tree_flatten(params)
|
1020 |
+
assert param_pspec_flat
|
1021 |
+
assert params_flat
|
1022 |
+
# Step is replicated across cores.
|
1023 |
+
# None means cores.
|
1024 |
+
local_stats_flat = []
|
1025 |
+
num_statistics = 0
|
1026 |
+
for param, param_pspec in zip(params_flat, param_pspec_flat):
|
1027 |
+
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
|
1028 |
+
preconditioner = Preconditioner(
|
1029 |
+
param_clone, block_size, best_effort_shape_interpretation
|
1030 |
+
)
|
1031 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1032 |
+
sizes = []
|
1033 |
+
|
1034 |
+
index_start = num_statistics
|
1035 |
+
if not _skip_preconditioning(param):
|
1036 |
+
sizes = [s[0] for s in shapes]
|
1037 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1038 |
+
num_statistics += len(shapes)
|
1039 |
+
|
1040 |
+
diagonal_statistics_pspec = []
|
1041 |
+
diagonal_statistics_scale_pspec = []
|
1042 |
+
if graft_type != GraftingType.SGD:
|
1043 |
+
# Identically shaped param.
|
1044 |
+
diagonal_statistics_pspec = param_pspec
|
1045 |
+
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
|
1046 |
+
diagonal_statistics_scale_pspec = (
|
1047 |
+
_remove_leading_sharding_annotation(param_pspec)
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
m1_pspec = param_pspec
|
1051 |
+
m2_pspec = param_pspec
|
1052 |
+
|
1053 |
+
m1_scale_pspec = []
|
1054 |
+
m2_scale_pspec = []
|
1055 |
+
|
1056 |
+
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1057 |
+
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
|
1058 |
+
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
|
1059 |
+
|
1060 |
+
local_stats_flat.append(
|
1061 |
+
LocalShardedParameterStats(
|
1062 |
+
QuantizedValue(
|
1063 |
+
diagonal_statistics_pspec,
|
1064 |
+
[],
|
1065 |
+
diagonal_statistics_scale_pspec,
|
1066 |
+
quantized_dtype_for_diagonal_statistics_buffers(),
|
1067 |
+
False,
|
1068 |
+
list(param.shape),
|
1069 |
+
),
|
1070 |
+
QuantizedValue(
|
1071 |
+
m1_pspec,
|
1072 |
+
[],
|
1073 |
+
m1_scale_pspec,
|
1074 |
+
quantized_dtype_for_momentum_buffers(),
|
1075 |
+
False,
|
1076 |
+
list(param.shape),
|
1077 |
+
),
|
1078 |
+
QuantizedValue(
|
1079 |
+
m2_pspec,
|
1080 |
+
[],
|
1081 |
+
m2_scale_pspec,
|
1082 |
+
quantized_dtype_for_momentum_buffers(),
|
1083 |
+
False,
|
1084 |
+
list(param.shape),
|
1085 |
+
),
|
1086 |
+
init_training_metrics_pspec(len(sizes)),
|
1087 |
+
index_start,
|
1088 |
+
sizes,
|
1089 |
+
)
|
1090 |
+
)
|
1091 |
+
|
1092 |
+
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
1093 |
+
global_stats = GlobalShardedParameterStats(
|
1094 |
+
partition_spec_for_statistics,
|
1095 |
+
partition_spec_for_statistics,
|
1096 |
+
pjit.PartitionSpec(),
|
1097 |
+
)
|
1098 |
+
count_pspec = pjit.PartitionSpec()
|
1099 |
+
return ShampooState(
|
1100 |
+
count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
def sharded_init_shape_and_dtype_fn(params):
|
1104 |
+
"""Returns a parallel state tree with shape, dtype associated with state.
|
1105 |
+
|
1106 |
+
|
1107 |
+
Args:
|
1108 |
+
params: A pytree with params.
|
1109 |
+
"""
|
1110 |
+
# Parallel lists of spec, and params.
|
1111 |
+
params_flat, treedef = jax.tree_flatten(params)
|
1112 |
+
assert params_flat
|
1113 |
+
# Step is replicated across cores.
|
1114 |
+
# None means cores.
|
1115 |
+
local_stats_flat = []
|
1116 |
+
num_statistics = 0
|
1117 |
+
for param in params_flat:
|
1118 |
+
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
|
1119 |
+
preconditioner = Preconditioner(
|
1120 |
+
param_clone, block_size, best_effort_shape_interpretation
|
1121 |
+
)
|
1122 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1123 |
+
sizes = []
|
1124 |
+
|
1125 |
+
index_start = num_statistics
|
1126 |
+
if not _skip_preconditioning(param):
|
1127 |
+
sizes = [s[0] for s in shapes]
|
1128 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1129 |
+
num_statistics += len(shapes)
|
1130 |
+
|
1131 |
+
diagonal_statistics_shape_and_dtype = []
|
1132 |
+
diagonal_statistics_scale_shape_and_dtype = []
|
1133 |
+
if graft_type != GraftingType.SGD:
|
1134 |
+
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
|
1135 |
+
qdtype = quantized_dtype_for_diagonal_statistics_buffers()
|
1136 |
+
if qdtype != jnp.float32:
|
1137 |
+
diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
|
1138 |
+
diagonal_statistics_scale_shape_and_dtype = [
|
1139 |
+
list(param.shape)[1:],
|
1140 |
+
param.dtype,
|
1141 |
+
]
|
1142 |
+
|
1143 |
+
m1_shape_and_dtype = [list(param.shape), param.dtype]
|
1144 |
+
m2_shape_and_dtype = [list(param.shape), param.dtype]
|
1145 |
+
|
1146 |
+
m1_scale_shape_and_dtype = []
|
1147 |
+
m2_scale_shape_and_dtype = []
|
1148 |
+
|
1149 |
+
qdtype = quantized_dtype_for_momentum_buffers()
|
1150 |
+
if qdtype != jnp.float32:
|
1151 |
+
m1_shape_and_dtype = [list(param.shape), qdtype]
|
1152 |
+
m2_shape_and_dtype = [list(param.shape), qdtype]
|
1153 |
+
|
1154 |
+
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1155 |
+
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1156 |
+
|
1157 |
+
local_stats_flat.append(
|
1158 |
+
LocalShardedParameterStats(
|
1159 |
+
QuantizedValue(
|
1160 |
+
diagonal_statistics_shape_and_dtype,
|
1161 |
+
[],
|
1162 |
+
diagonal_statistics_scale_shape_and_dtype,
|
1163 |
+
quantized_dtype_for_diagonal_statistics_buffers(),
|
1164 |
+
False,
|
1165 |
+
list(param.shape),
|
1166 |
+
),
|
1167 |
+
QuantizedValue(
|
1168 |
+
m1_shape_and_dtype,
|
1169 |
+
[],
|
1170 |
+
m1_scale_shape_and_dtype,
|
1171 |
+
quantized_dtype_for_momentum_buffers(),
|
1172 |
+
False,
|
1173 |
+
list(param.shape),
|
1174 |
+
),
|
1175 |
+
QuantizedValue(
|
1176 |
+
m2_shape_and_dtype,
|
1177 |
+
[],
|
1178 |
+
m2_scale_shape_and_dtype,
|
1179 |
+
quantized_dtype_for_momentum_buffers(),
|
1180 |
+
False,
|
1181 |
+
list(param.shape),
|
1182 |
+
),
|
1183 |
+
init_training_metrics_shapes(len(sizes)),
|
1184 |
+
index_start,
|
1185 |
+
sizes,
|
1186 |
+
)
|
1187 |
+
)
|
1188 |
+
|
1189 |
+
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
1190 |
+
max_statistics_size = _max_statistics_size_from_params(params_flat)
|
1191 |
+
to_pad = -num_statistics % num_devices_for_pjit
|
1192 |
+
num_statistics += to_pad
|
1193 |
+
statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
|
1194 |
+
global_stats = GlobalShardedParameterStats(
|
1195 |
+
[statistics_shape, jnp.float32],
|
1196 |
+
[statistics_shape, jnp.float32],
|
1197 |
+
[[num_statistics], jnp.int32],
|
1198 |
+
)
|
1199 |
+
return ShampooState(
|
1200 |
+
count=[[], jnp.float32],
|
1201 |
+
stats=ShardedShampooStats(global_stats, local_stats),
|
1202 |
+
)
|
1203 |
+
|
1204 |
def sharded_update_fn(grads, state, params):
|
1205 |
"""Transform the input gradient and update all statistics in sharded mode.
|
1206 |
|
|
|
1228 |
params_flat,
|
1229 |
)
|
1230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1231 |
outputs = jax.tree_multimap(
|
1232 |
lambda g, s, p: _transform_grad(g, s, p, state.count),
|
1233 |
grads_flat,
|
|
|
1242 |
_convert_from_parameter_stats(new_stat, local_stat)
|
1243 |
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
|
1244 |
]
|
|
|
1245 |
|
1246 |
max_size = global_stats.statistics.shape[1]
|
1247 |
new_padded_statistics = []
|
|
|
1264 |
for _ in range(to_pad)
|
1265 |
]
|
1266 |
)
|
|
|
1267 |
new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
|
1268 |
+
new_stacked_padded_statistics = pjit.with_sharding_constraint(
|
1269 |
+
new_stacked_padded_statistics, statistics_partition_spec
|
1270 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1271 |
|
1272 |
def _internal_inverse_pth_root_all():
|
1273 |
+
preconditioners, errors = _matrix_inverse_pth_root_pjit(
|
1274 |
+
new_stacked_padded_statistics,
|
1275 |
+
global_stats.exponents,
|
1276 |
+
statistics_partition_spec,
|
1277 |
)
|
1278 |
return preconditioners, errors
|
1279 |
|
|
|
1284 |
# shaped tensors. Note statistics will be ignored as we are passing in
|
1285 |
# a large init value for error.
|
1286 |
preconditioners_init = new_stacked_padded_statistics
|
1287 |
+
n = new_stacked_padded_statistics.shape[0]
|
1288 |
+
errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
|
1289 |
init_state = [preconditioners_init, errors_init]
|
1290 |
perform_step = state.count % preconditioning_compute_steps == 0
|
1291 |
new_preconditioners, errors = efficient_cond(
|
1292 |
perform_step, _internal_inverse_pth_root_all, init_state
|
1293 |
)
|
1294 |
|
1295 |
+
new_local_stats_flat = _add_error_into_local_stats(
|
1296 |
+
new_local_stats_flat, errors, inverse_failure_threshold
|
1297 |
+
)
|
1298 |
+
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
|
1299 |
errors = errors.reshape((-1, 1, 1))
|
1300 |
predicate = jnp.logical_or(
|
1301 |
jnp.isnan(errors), errors >= inverse_failure_threshold
|
|
|
1306 |
+ (1.0 - predicate) * new_preconditioners
|
1307 |
)
|
1308 |
new_global_stats = GlobalShardedParameterStats(
|
1309 |
+
new_stacked_padded_statistics,
|
1310 |
+
new_conditional_preconditioners,
|
1311 |
+
global_stats.exponents,
|
1312 |
)
|
1313 |
new_shampoo_state = ShampooState(
|
1314 |
count=state.count + 1,
|
|
|
1339 |
_maybe_quantize_preconditioners(preconditioners),
|
1340 |
_quantize_momentum(jnp.zeros_like(param)),
|
1341 |
_quantize_momentum(jnp.zeros_like(param)),
|
1342 |
+
init_training_metrics(len(statistics)),
|
1343 |
)
|
1344 |
|
1345 |
return ShampooState(
|
|
|
1384 |
state.preconditioners,
|
1385 |
state.diagonal_momentum,
|
1386 |
state.momentum,
|
1387 |
+
state.training_metrics,
|
1388 |
)
|
1389 |
|
1390 |
def _matrix_inverse_pth_root_vmap(xs, ps):
|
|
|
1408 |
|
1409 |
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
|
1410 |
|
1411 |
+
def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
|
|
|
1412 |
# Partition the concatenated statistics matrix across all cores.
|
1413 |
+
pspec_for_partition = preconditioner_partition_spec
|
1414 |
+
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
|
1415 |
+
partitioned_ps = pjit.with_sharding_constraint(
|
1416 |
+
ps, pjit.PartitionSpec(preconditioner_partition_spec[0])
|
1417 |
+
)
|
|
|
|
|
1418 |
# Run matrix inverse pth root on each shard.
|
1419 |
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
|
1420 |
partitioned_xs, partitioned_ps
|
1421 |
)
|
1422 |
+
# Reshard output to have the same PSpec as input. This is required to avoid
|
1423 |
+
# vmap seeing the full set of statistics.
|
1424 |
+
partitioned_preconditioners = pjit.with_sharding_constraint(
|
1425 |
+
partitioned_preconditioners, pspec_for_partition
|
1426 |
+
)
|
1427 |
# Recombine the outputs at each core.
|
1428 |
+
preconditioners = pjit.with_sharding_constraint(
|
1429 |
+
partitioned_preconditioners, statistics_partition_spec
|
1430 |
+
)
|
1431 |
+
errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1432 |
return preconditioners, errors
|
1433 |
|
1434 |
def _pmap_compute_preconditioners(
|
|
|
1510 |
)
|
1511 |
|
1512 |
new_preconditioners_flat = []
|
1513 |
+
new_errors_flat = []
|
1514 |
for p, shape, prev_p, error in zip(
|
1515 |
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
|
1516 |
):
|
1517 |
new_preconditioners_flat.append(
|
1518 |
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
|
1519 |
)
|
1520 |
+
new_errors_flat.append(error)
|
1521 |
|
1522 |
assert len(states) == len(num_statistics_per_state)
|
1523 |
assert len(new_preconditioners_flat) == num_statistics
|
1524 |
+
assert len(new_errors_flat) == num_statistics
|
1525 |
|
1526 |
# Add back empty preconditioners so we that we can set the optimizer state.
|
1527 |
preconditioners_for_states = []
|
1528 |
idx = 0
|
1529 |
+
errors_for_states = []
|
1530 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1531 |
if num_statistics == 0:
|
1532 |
preconditioners_for_states.append([])
|
1533 |
+
errors_for_states.append([])
|
1534 |
else:
|
1535 |
preconditioners_for_state = new_preconditioners_flat[
|
1536 |
idx : idx + num_statistics
|
1537 |
]
|
1538 |
assert len(state.statistics) == len(preconditioners_for_state)
|
1539 |
preconditioners_for_states.append(preconditioners_for_state)
|
1540 |
+
|
1541 |
+
errors_for_state = jnp.stack(
|
1542 |
+
new_errors_flat[idx : idx + num_statistics]
|
1543 |
+
)
|
1544 |
+
assert len(state.statistics) == len(errors_for_state)
|
1545 |
+
errors_for_states.append(errors_for_state)
|
1546 |
+
|
1547 |
idx += num_statistics
|
1548 |
new_states = []
|
1549 |
+
for state, new_preconditioners, new_errors in zip(
|
1550 |
+
states, preconditioners_for_states, errors_for_states
|
1551 |
+
):
|
1552 |
+
if state.statistics:
|
1553 |
+
new_errors = jnp.where(
|
1554 |
+
jnp.logical_and(
|
1555 |
+
new_errors > 0.0, new_errors != inverse_failure_threshold
|
1556 |
+
),
|
1557 |
+
new_errors,
|
1558 |
+
state.training_metrics.inverse_pth_root_errors,
|
1559 |
+
)
|
1560 |
+
new_training_metrics = TrainingMetrics(new_errors)
|
1561 |
new_states.append(
|
1562 |
ParameterStats(
|
1563 |
state.diagonal_statistics,
|
|
|
1565 |
new_preconditioners,
|
1566 |
state.diagonal_momentum,
|
1567 |
state.momentum,
|
1568 |
+
new_training_metrics,
|
1569 |
)
|
1570 |
)
|
1571 |
|
|
|
1724 |
new_quantized_preconditioners_flat = []
|
1725 |
new_quantized_diagonals_flat = []
|
1726 |
new_quantized_bucket_sizes_flat = []
|
1727 |
+
new_errors_flat = []
|
1728 |
for p, d, b, shape, prev_p, error in zip(
|
1729 |
quantized_preconditioners_flat,
|
1730 |
quantized_diagonals_flat,
|
|
|
1744 |
new_quantized_bucket_sizes_flat.append(
|
1745 |
_select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
|
1746 |
)
|
1747 |
+
new_errors_flat.append(error)
|
1748 |
|
1749 |
assert len(states) == len(num_statistics_per_state)
|
1750 |
assert len(new_quantized_preconditioners_flat) == num_statistics
|
|
|
1753 |
|
1754 |
# Add back empty preconditioners so we that we can set the optimizer state.
|
1755 |
preconditioners_for_states = []
|
1756 |
+
errors_for_states = []
|
1757 |
idx = 0
|
1758 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1759 |
if num_statistics == 0:
|
1760 |
preconditioners_for_states.append([])
|
1761 |
+
errors_for_states.append([])
|
1762 |
else:
|
1763 |
quantized_preconditioners_for_state = (
|
1764 |
new_quantized_preconditioners_flat[idx : idx + num_statistics]
|
|
|
1769 |
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
|
1770 |
idx : idx + num_statistics
|
1771 |
]
|
1772 |
+
errors_for_state = jnp.stack(
|
1773 |
+
new_errors_flat[idx : idx + num_statistics]
|
1774 |
+
)
|
1775 |
|
1776 |
assert len(state.statistics) == len(quantized_preconditioners_for_state)
|
1777 |
assert len(state.statistics) == len(quantized_diagonals_for_state)
|
1778 |
assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
|
1779 |
+
assert len(state.statistics) == len(errors_for_state)
|
1780 |
|
1781 |
quantized_preconditioners = []
|
1782 |
for qv, qd, qb in zip(
|
|
|
1788 |
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
|
1789 |
)
|
1790 |
preconditioners_for_states.append(quantized_preconditioners)
|
1791 |
+
errors_for_states.append(errors_for_state)
|
1792 |
idx += num_statistics
|
1793 |
new_states = []
|
1794 |
+
for state, new_preconditioners, new_errors in zip(
|
1795 |
+
states, preconditioners_for_states, errors_for_states
|
1796 |
+
):
|
1797 |
+
if state.statistics:
|
1798 |
+
new_errors = jnp.where(
|
1799 |
+
jnp.logical_and(
|
1800 |
+
new_errors > 0.0, new_errors != inverse_failure_threshold
|
1801 |
+
),
|
1802 |
+
new_errors,
|
1803 |
+
state.training_metrics.inverse_pth_root_errors,
|
1804 |
+
)
|
1805 |
+
new_training_metrics = TrainingMetrics(new_errors)
|
1806 |
new_states.append(
|
1807 |
ParameterStats(
|
1808 |
state.diagonal_statistics,
|
|
|
1810 |
new_preconditioners,
|
1811 |
state.diagonal_momentum,
|
1812 |
state.momentum,
|
1813 |
+
new_training_metrics,
|
1814 |
)
|
1815 |
)
|
1816 |
|
|
|
1892 |
)
|
1893 |
|
1894 |
new_preconditioners_flat = []
|
1895 |
+
new_errors_flat = []
|
1896 |
for p, shape, prev_p, error in zip(
|
1897 |
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
|
1898 |
):
|
1899 |
new_preconditioners_flat.append(
|
1900 |
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
|
1901 |
)
|
1902 |
+
new_errors_flat.append(error)
|
1903 |
|
1904 |
assert len(states) == len(num_statistics_per_state)
|
1905 |
assert len(new_preconditioners_flat) == num_statistics
|
1906 |
|
1907 |
# Add back empty preconditioners so we that we can set the optimizer state.
|
1908 |
preconditioners_for_states = []
|
1909 |
+
errors_for_states = []
|
1910 |
idx = 0
|
1911 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1912 |
if num_statistics == 0:
|
1913 |
preconditioners_for_states.append([])
|
1914 |
+
errors_for_states.append([])
|
1915 |
else:
|
1916 |
preconditioners_for_state = new_preconditioners_flat[
|
1917 |
idx : idx + num_statistics
|
1918 |
]
|
1919 |
assert len(state.statistics) == len(preconditioners_for_state)
|
1920 |
preconditioners_for_states.append(preconditioners_for_state)
|
1921 |
+
|
1922 |
+
errors_for_state = jnp.stack(
|
1923 |
+
new_errors_flat[idx : idx + num_statistics]
|
1924 |
+
)
|
1925 |
+
assert len(state.statistics) == len(errors_for_state)
|
1926 |
+
errors_for_states.append(errors_for_state)
|
1927 |
idx += num_statistics
|
1928 |
+
|
1929 |
new_states = []
|
1930 |
+
for state, new_preconditioners, new_errors in zip(
|
1931 |
+
states, preconditioners_for_states, errors_for_states
|
1932 |
+
):
|
1933 |
+
if state.statistics:
|
1934 |
+
new_errors = jnp.where(
|
1935 |
+
jnp.logical_and(
|
1936 |
+
new_errors > 0.0, new_errors != inverse_failure_threshold
|
1937 |
+
),
|
1938 |
+
new_errors,
|
1939 |
+
state.training_metrics.inverse_pth_root_errors,
|
1940 |
+
)
|
1941 |
+
new_training_metrics = TrainingMetrics(new_errors)
|
1942 |
new_states.append(
|
1943 |
ParameterStats(
|
1944 |
state.diagonal_statistics,
|
|
|
1946 |
new_preconditioners,
|
1947 |
state.diagonal_momentum,
|
1948 |
state.momentum,
|
1949 |
+
new_training_metrics,
|
1950 |
)
|
1951 |
)
|
1952 |
|
|
|
2133 |
state.preconditioners,
|
2134 |
_quantize_momentum(grafting_update_with_wd_momentum),
|
2135 |
_quantize_momentum(shampoo_update_with_wd_momentum),
|
2136 |
+
state.training_metrics,
|
2137 |
)
|
2138 |
+
|
2139 |
return transformed_update, param_stats
|
2140 |
|
2141 |
def update_fn(grads, state, params):
|
|
|
2178 |
return updates, new_state
|
2179 |
|
2180 |
if shard_optimizer_states:
|
2181 |
+
# Hijacks the init_fn signature so we can return an OptState with
|
2182 |
+
# appropriate init_fns.
|
2183 |
+
def _init_fns(unused_params):
|
2184 |
+
return InitFnState(
|
2185 |
+
init_fn=sharded_init_fn,
|
2186 |
+
pspec_fn=sharded_init_partition_spec_fn,
|
2187 |
+
shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
|
2188 |
+
)
|
2189 |
+
|
2190 |
+
return optax.GradientTransformation(_init_fns, sharded_update_fn)
|
2191 |
else:
|
2192 |
return optax.GradientTransformation(init_fn, update_fn)
|
tools/train/train.py
CHANGED
@@ -25,7 +25,7 @@ import sys
|
|
25 |
import time
|
26 |
from dataclasses import asdict, dataclass, field
|
27 |
from pathlib import Path
|
28 |
-
from typing import Callable, Optional
|
29 |
|
30 |
import datasets
|
31 |
import jax
|
@@ -36,12 +36,12 @@ import transformers
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
-
from flax.core.frozen_dict import freeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
-
from jax.experimental.pjit import pjit
|
45 |
from tqdm import tqdm
|
46 |
from transformers import HfArgumentParser
|
47 |
|
@@ -248,6 +248,10 @@ class TrainingArguments:
|
|
248 |
default=1024,
|
249 |
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
250 |
)
|
|
|
|
|
|
|
|
|
251 |
preconditioning_compute_steps: int = field(
|
252 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
253 |
)
|
@@ -478,6 +482,7 @@ def main():
|
|
478 |
artifact_dir,
|
479 |
dtype=getattr(jnp, model_args.dtype),
|
480 |
abstract_init=True,
|
|
|
481 |
)
|
482 |
|
483 |
# load tokenizer
|
@@ -501,12 +506,14 @@ def main():
|
|
501 |
seed=training_args.seed_model,
|
502 |
dtype=getattr(jnp, model_args.dtype),
|
503 |
abstract_init=True,
|
|
|
504 |
)
|
505 |
else:
|
506 |
model = DalleBart(
|
507 |
config,
|
508 |
seed=training_args.seed_model,
|
509 |
dtype=getattr(jnp, model_args.dtype),
|
|
|
510 |
)
|
511 |
|
512 |
# Load tokenizer
|
@@ -520,6 +527,12 @@ def main():
|
|
520 |
use_fast=True,
|
521 |
)
|
522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
# Preprocessing the datasets.
|
524 |
# We need to normalize and tokenize inputs and targets.
|
525 |
|
@@ -536,14 +549,14 @@ def main():
|
|
536 |
|
537 |
# Store some constant
|
538 |
num_epochs = training_args.num_train_epochs
|
539 |
-
# batch size
|
540 |
-
|
541 |
-
training_args.per_device_train_batch_size *
|
542 |
)
|
543 |
-
batch_size_per_node =
|
544 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
545 |
eval_batch_size = (
|
546 |
-
training_args.per_device_eval_batch_size *
|
547 |
)
|
548 |
len_train_dataset, len_eval_dataset = dataset.length
|
549 |
steps_per_epoch = (
|
@@ -599,14 +612,17 @@ def main():
|
|
599 |
beta2=training_args.beta2,
|
600 |
diagonal_epsilon=1e-10,
|
601 |
matrix_epsilon=1e-8,
|
602 |
-
start_preconditioning_step=training_args.
|
603 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
604 |
statistics_compute_steps=1,
|
605 |
best_effort_shape_interpretation=True,
|
606 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
607 |
nesterov=False,
|
608 |
exponent_override=0,
|
609 |
-
|
|
|
|
|
|
|
610 |
inverse_failure_threshold=0.1,
|
611 |
moving_average_for_momentum=True,
|
612 |
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
@@ -614,6 +630,13 @@ def main():
|
|
614 |
precision=jax.lax.Precision.HIGHEST,
|
615 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
616 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
617 |
|
618 |
elif training_args.optim == "adam":
|
619 |
optimizer = optax.adamw(
|
@@ -630,31 +653,45 @@ def main():
|
|
630 |
clipping_threshold=training_args.max_grad_norm,
|
631 |
)
|
632 |
|
633 |
-
# get
|
634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
|
636 |
-
|
637 |
-
|
|
|
638 |
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
# other variables such as count
|
647 |
-
return None
|
648 |
else:
|
649 |
-
# TODO: create spec for Distributed Shampoo
|
650 |
raise NotImplementedError
|
|
|
651 |
|
652 |
-
opt_state_spec =
|
653 |
-
opt_state_spec_per_leaf,
|
654 |
-
opt_state_shape,
|
655 |
-
# return None spec for empty elements
|
656 |
-
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
657 |
-
)
|
658 |
|
659 |
# create a mesh
|
660 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
@@ -674,51 +711,61 @@ def main():
|
|
674 |
tx=optimizer,
|
675 |
)
|
676 |
|
677 |
-
opt_state, attr_state = None, None
|
678 |
-
if training_args.resume_from_checkpoint is not None:
|
679 |
-
# restore opt_state
|
680 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
681 |
-
opt_state = from_bytes(opt_state_shape, f.read())
|
682 |
-
# need to freeze dict for pjit
|
683 |
-
opt_state = jax.tree_map(
|
684 |
-
lambda x: freeze(x) if isinstance(x, dict) else x,
|
685 |
-
opt_state,
|
686 |
-
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
687 |
-
)
|
688 |
-
# restore other attributes
|
689 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
690 |
-
attr_state = json.load(f)
|
691 |
-
|
692 |
# create training state
|
693 |
-
|
694 |
if training_args.resume_from_checkpoint is None:
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
701 |
else:
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
|
|
|
|
719 |
|
720 |
-
# free memory
|
721 |
-
del model._params
|
|
|
|
|
|
|
|
|
|
|
722 |
|
723 |
# label smoothed cross entropy
|
724 |
def loss_fn(logits, labels):
|
@@ -728,11 +775,24 @@ def main():
|
|
728 |
|
729 |
# Define gradient update step fn
|
730 |
def train_step(state, batch, delta_time):
|
731 |
-
|
732 |
-
#
|
733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
734 |
|
735 |
-
def compute_loss(params, minibatch):
|
|
|
|
|
736 |
labels = minibatch.pop("labels")
|
737 |
logits = state.apply_fn(
|
738 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
@@ -741,36 +801,75 @@ def main():
|
|
741 |
|
742 |
grad_fn = jax.value_and_grad(compute_loss)
|
743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
if training_args.gradient_accumulation_steps == 1:
|
745 |
-
minibatch = jax.tree_map(lambda x: x[0], batch)
|
746 |
-
loss, grads = grad_fn(state.params, minibatch)
|
747 |
-
else:
|
748 |
|
749 |
-
def
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
cumul_loss_grads,
|
754 |
-
grad_fn(state.params, minibatch),
|
755 |
-
)
|
756 |
|
757 |
-
|
|
|
|
|
|
|
758 |
0.0,
|
759 |
jax.tree_map(jnp.zeros_like, state.params),
|
760 |
)
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
)
|
770 |
|
|
|
|
|
771 |
state = state.apply_gradients(
|
772 |
grads=grads,
|
773 |
-
dropout_rng=
|
774 |
train_time=state.train_time + delta_time,
|
775 |
train_samples=state.train_samples + batch_size_per_step,
|
776 |
)
|
@@ -784,6 +883,7 @@ def main():
|
|
784 |
|
785 |
# Define eval fn
|
786 |
def eval_step(params, batch):
|
|
|
787 |
labels = batch.pop("labels")
|
788 |
logits = model(**batch, params=params, train=False)[0]
|
789 |
loss = loss_fn(logits, labels)
|
@@ -795,13 +895,13 @@ def main():
|
|
795 |
# Create parallel version of the train and eval step
|
796 |
p_train_step = pjit(
|
797 |
train_step,
|
798 |
-
in_axis_resources=(state_spec,
|
799 |
out_axis_resources=(state_spec, None),
|
800 |
donate_argnums=(0,),
|
801 |
)
|
802 |
p_eval_step = pjit(
|
803 |
eval_step,
|
804 |
-
in_axis_resources=(param_spec,
|
805 |
out_axis_resources=None,
|
806 |
)
|
807 |
|
@@ -842,9 +942,7 @@ def main():
|
|
842 |
# ======================== Evaluating ==============================
|
843 |
eval_metrics = []
|
844 |
if training_args.do_eval:
|
845 |
-
eval_loader = dataset.dataloader(
|
846 |
-
"eval", training_args.per_device_eval_batch_size
|
847 |
-
)
|
848 |
eval_steps = (
|
849 |
len_eval_dataset // eval_batch_size
|
850 |
if len_eval_dataset is not None
|
@@ -857,8 +955,8 @@ def main():
|
|
857 |
leave=False,
|
858 |
total=eval_steps,
|
859 |
):
|
860 |
-
#
|
861 |
-
metrics = p_eval_step(state.params, batch)
|
862 |
eval_metrics.append(metrics)
|
863 |
|
864 |
# normalize eval metrics
|
@@ -962,8 +1060,7 @@ def main():
|
|
962 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
963 |
train_loader = dataset.dataloader(
|
964 |
"train",
|
965 |
-
|
966 |
-
training_args.gradient_accumulation_steps,
|
967 |
epoch,
|
968 |
)
|
969 |
# train
|
@@ -974,15 +1071,27 @@ def main():
|
|
974 |
leave=False,
|
975 |
total=steps_per_epoch,
|
976 |
):
|
977 |
-
|
978 |
# calculate delta time (we have a lag of one step but it's ok)
|
979 |
new_time = time.perf_counter()
|
980 |
delta_time = new_time - last_time
|
981 |
last_time = new_time
|
982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
983 |
# train step
|
984 |
-
state, train_metrics = p_train_step(state, batch, delta_time)
|
985 |
-
step = state.step
|
986 |
|
987 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
988 |
all_metrics = metrics_logger.get_all_train_metrics(
|
|
|
25 |
import time
|
26 |
from dataclasses import asdict, dataclass, field
|
27 |
from pathlib import Path
|
28 |
+
from typing import Any, Callable, NamedTuple, Optional
|
29 |
|
30 |
import datasets
|
31 |
import jax
|
|
|
36 |
import wandb
|
37 |
from datasets import Dataset
|
38 |
from distributed_shampoo import GraftingType, distributed_shampoo
|
39 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
40 |
from flax.serialization import from_bytes, to_bytes
|
41 |
from flax.training import train_state
|
42 |
from flax.training.common_utils import onehot, stack_forest
|
43 |
from jax.experimental import PartitionSpec, maps
|
44 |
+
from jax.experimental.pjit import pjit, with_sharding_constraint
|
45 |
from tqdm import tqdm
|
46 |
from transformers import HfArgumentParser
|
47 |
|
|
|
248 |
default=1024,
|
249 |
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
250 |
)
|
251 |
+
start_preconditioning_step: int = field(
|
252 |
+
default=100,
|
253 |
+
metadata={"help": "Number of steps before starting to update preconditioner."},
|
254 |
+
)
|
255 |
preconditioning_compute_steps: int = field(
|
256 |
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
257 |
)
|
|
|
482 |
artifact_dir,
|
483 |
dtype=getattr(jnp, model_args.dtype),
|
484 |
abstract_init=True,
|
485 |
+
load_on_cpu=True,
|
486 |
)
|
487 |
|
488 |
# load tokenizer
|
|
|
506 |
seed=training_args.seed_model,
|
507 |
dtype=getattr(jnp, model_args.dtype),
|
508 |
abstract_init=True,
|
509 |
+
load_on_cpu=True,
|
510 |
)
|
511 |
else:
|
512 |
model = DalleBart(
|
513 |
config,
|
514 |
seed=training_args.seed_model,
|
515 |
dtype=getattr(jnp, model_args.dtype),
|
516 |
+
load_on_cpu=True,
|
517 |
)
|
518 |
|
519 |
# Load tokenizer
|
|
|
527 |
use_fast=True,
|
528 |
)
|
529 |
|
530 |
+
# get PartitionSpec for model params (required to be a dict)
|
531 |
+
param_spec = set_partitions(model.params)
|
532 |
+
|
533 |
+
# convert params to frozen dict
|
534 |
+
model._params = freeze(model.params)
|
535 |
+
|
536 |
# Preprocessing the datasets.
|
537 |
# We need to normalize and tokenize inputs and targets.
|
538 |
|
|
|
549 |
|
550 |
# Store some constant
|
551 |
num_epochs = training_args.num_train_epochs
|
552 |
+
# batch size
|
553 |
+
minibatch_size = (
|
554 |
+
training_args.per_device_train_batch_size * training_args.dp_devices
|
555 |
)
|
556 |
+
batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
|
557 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
558 |
eval_batch_size = (
|
559 |
+
training_args.per_device_eval_batch_size * training_args.dp_devices
|
560 |
)
|
561 |
len_train_dataset, len_eval_dataset = dataset.length
|
562 |
steps_per_epoch = (
|
|
|
612 |
beta2=training_args.beta2,
|
613 |
diagonal_epsilon=1e-10,
|
614 |
matrix_epsilon=1e-8,
|
615 |
+
start_preconditioning_step=training_args.start_preconditioning_step,
|
616 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
617 |
statistics_compute_steps=1,
|
618 |
best_effort_shape_interpretation=True,
|
619 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
620 |
nesterov=False,
|
621 |
exponent_override=0,
|
622 |
+
statistics_partition_spec=PartitionSpec(None, "batch", None),
|
623 |
+
preconditioner_partition_spec=PartitionSpec("batch", None, None),
|
624 |
+
num_devices_for_pjit=training_args.dp_devices,
|
625 |
+
shard_optimizer_states=True,
|
626 |
inverse_failure_threshold=0.1,
|
627 |
moving_average_for_momentum=True,
|
628 |
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
|
|
630 |
precision=jax.lax.Precision.HIGHEST,
|
631 |
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
632 |
)
|
633 |
+
# get the real optimizer and helper functions
|
634 |
+
update_fn = optimizer.update
|
635 |
+
optimizer = optimizer.init(model.params)
|
636 |
+
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
637 |
+
optimizer.pspec_fn, optimizer.shape_and_dtype_fn
|
638 |
+
)
|
639 |
+
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
|
640 |
|
641 |
elif training_args.optim == "adam":
|
642 |
optimizer = optax.adamw(
|
|
|
653 |
clipping_threshold=training_args.max_grad_norm,
|
654 |
)
|
655 |
|
656 |
+
# get PartitionSpec for optimizer state
|
657 |
+
def get_opt_state_spec_and_shape(param_spec):
|
658 |
+
if training_args.optim in ["adam", "adafactor"]:
|
659 |
+
# get opt_state shape without actual init
|
660 |
+
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
661 |
+
|
662 |
+
if training_args.optim == "adam":
|
663 |
+
|
664 |
+
def _opt_state_spec_per_leaf(x):
|
665 |
+
if isinstance(x, FrozenDict):
|
666 |
+
# variables with same structure as params
|
667 |
+
return param_spec
|
668 |
+
else:
|
669 |
+
# other variables such as count
|
670 |
+
return None
|
671 |
+
|
672 |
+
opt_state_spec = jax.tree_map(
|
673 |
+
_opt_state_spec_per_leaf,
|
674 |
+
opt_state_shape,
|
675 |
+
# return None spec for empty elements
|
676 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
677 |
+
)
|
678 |
|
679 |
+
elif training_args.optim == "adafactor":
|
680 |
+
# factorized state must be replicated (rank different than params)
|
681 |
+
opt_state_spec = None
|
682 |
|
683 |
+
elif training_args.optim == "distributed_shampoo":
|
684 |
+
opt_state_spec = opt_fn.pspec_fn(
|
685 |
+
params=model.params,
|
686 |
+
params_partition_spec=param_spec,
|
687 |
+
partition_spec_for_statistics=PartitionSpec(None, "batch", None),
|
688 |
+
)
|
689 |
+
opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
|
|
|
|
|
690 |
else:
|
|
|
691 |
raise NotImplementedError
|
692 |
+
return opt_state_spec, opt_state_shape
|
693 |
|
694 |
+
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
|
|
|
|
|
|
|
|
|
|
|
695 |
|
696 |
# create a mesh
|
697 |
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
|
|
711 |
tx=optimizer,
|
712 |
)
|
713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
# create training state
|
715 |
+
with maps.mesh(mesh.devices, mesh.axis_names):
|
716 |
if training_args.resume_from_checkpoint is None:
|
717 |
+
|
718 |
+
def init_state(params):
|
719 |
+
return TrainState.create(
|
720 |
+
apply_fn=model.__call__,
|
721 |
+
tx=optimizer,
|
722 |
+
params=params,
|
723 |
+
dropout_rng=dropout_rng,
|
724 |
+
)
|
725 |
+
|
726 |
+
state = pjit(
|
727 |
+
init_state,
|
728 |
+
in_axis_resources=(param_spec,),
|
729 |
+
out_axis_resources=state_spec,
|
730 |
+
donate_argnums=(0,),
|
731 |
+
)(model.params)
|
732 |
+
|
733 |
else:
|
734 |
+
# restore opt_state
|
735 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
736 |
+
opt_state = from_bytes(opt_state_shape, f.read())
|
737 |
+
|
738 |
+
# restore other attributes
|
739 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
740 |
+
attr_state = json.load(f)
|
741 |
+
|
742 |
+
def restore_state(params, opt_state):
|
743 |
+
return TrainState(
|
744 |
+
apply_fn=model.__call__,
|
745 |
+
tx=optimizer,
|
746 |
+
params=params,
|
747 |
+
opt_state=opt_state,
|
748 |
+
dropout_rng=dropout_rng,
|
749 |
+
**attr_state,
|
750 |
+
)
|
751 |
|
752 |
+
state = pjit(
|
753 |
+
restore_state,
|
754 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
755 |
+
out_axis_resources=state_spec,
|
756 |
+
donate_argnums=(0, 1),
|
757 |
+
)(model.params, opt_state)
|
758 |
+
|
759 |
+
# remove opt_state from CPU
|
760 |
+
del opt_state
|
761 |
|
762 |
+
# free memory
|
763 |
+
del model._params
|
764 |
+
|
765 |
+
# define batch specs
|
766 |
+
keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
|
767 |
+
batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
|
768 |
+
grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
|
769 |
|
770 |
# label smoothed cross entropy
|
771 |
def loss_fn(logits, labels):
|
|
|
775 |
|
776 |
# Define gradient update step fn
|
777 |
def train_step(state, batch, delta_time):
|
778 |
+
# batch is (gradient_accumulation_steps, minibatch_size, ...)
|
779 |
+
# check correct batch shape during compilation
|
780 |
+
assert batch["labels"].shape[0:3] == (
|
781 |
+
training_args.gradient_accumulation_steps,
|
782 |
+
training_args.dp_devices,
|
783 |
+
training_args.per_device_train_batch_size,
|
784 |
+
), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
|
785 |
+
|
786 |
+
# get a minibatch (one gradient accumulation slice)
|
787 |
+
def get_minibatch(batch, grad_idx):
|
788 |
+
return jax.tree_map(
|
789 |
+
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
|
790 |
+
batch,
|
791 |
+
)
|
792 |
|
793 |
+
def compute_loss(params, minibatch, dropout_rng):
|
794 |
+
# minibatch has dim (batch_size, ...)
|
795 |
+
minibatch = unfreeze(minibatch)
|
796 |
labels = minibatch.pop("labels")
|
797 |
logits = state.apply_fn(
|
798 |
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
|
|
801 |
|
802 |
grad_fn = jax.value_and_grad(compute_loss)
|
803 |
|
804 |
+
def loss_and_grad(grad_idx, dropout_rng):
|
805 |
+
# minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
|
806 |
+
minibatch = get_minibatch(batch, grad_idx)
|
807 |
+
# ensure batch is sharded over devices
|
808 |
+
minibatch = jax.tree_map(
|
809 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
|
810 |
+
)
|
811 |
+
# calculate loss and grads independently per dp_device
|
812 |
+
loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
|
813 |
+
state.params, minibatch, dropout_rng
|
814 |
+
)
|
815 |
+
# ensure they are sharded over devices
|
816 |
+
loss_grads = jax.tree_map(
|
817 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
818 |
+
loss_grads,
|
819 |
+
)
|
820 |
+
|
821 |
+
# average across all devices
|
822 |
+
loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
|
823 |
+
|
824 |
+
# return loss and grads
|
825 |
+
return loss_grads
|
826 |
+
|
827 |
+
# create a new rng
|
828 |
+
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
829 |
+
# use a different rng per node
|
830 |
+
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
831 |
+
|
832 |
if training_args.gradient_accumulation_steps == 1:
|
|
|
|
|
|
|
833 |
|
834 |
+
def batch_step(dropout_rng):
|
835 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
836 |
+
loss_grad = loss_and_grad(0, dropout_rng)
|
837 |
+
return loss_grad, new_dropout_rng
|
|
|
|
|
|
|
838 |
|
839 |
+
loss_grad, dropout_rng = batch_step(dropout_rng)
|
840 |
+
else:
|
841 |
+
# create initial state for per_minibatch_step loop
|
842 |
+
init_cumul_loss_grad = (
|
843 |
0.0,
|
844 |
jax.tree_map(jnp.zeros_like, state.params),
|
845 |
)
|
846 |
+
init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
|
847 |
+
|
848 |
+
# accumulate gradients
|
849 |
+
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
850 |
+
cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
|
851 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
852 |
+
loss_grad = loss_and_grad(grad_idx, dropout_rng)
|
853 |
+
cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
|
854 |
+
return cumul_loss_grad, new_dropout_rng
|
855 |
+
|
856 |
+
# loop over gradients
|
857 |
+
loss_grad, dropout_rng = jax.lax.fori_loop(
|
858 |
+
0,
|
859 |
+
training_args.gradient_accumulation_steps,
|
860 |
+
cumul_minibatch_step,
|
861 |
+
init_minibatch_step,
|
862 |
+
)
|
863 |
+
# sum -> mean
|
864 |
+
loss_grad = jax.tree_map(
|
865 |
+
lambda x: x / training_args.gradient_accumulation_steps, loss_grad
|
866 |
)
|
867 |
|
868 |
+
# update state
|
869 |
+
loss, grads = loss_grad
|
870 |
state = state.apply_gradients(
|
871 |
grads=grads,
|
872 |
+
dropout_rng=dropout_rng,
|
873 |
train_time=state.train_time + delta_time,
|
874 |
train_samples=state.train_samples + batch_size_per_step,
|
875 |
)
|
|
|
883 |
|
884 |
# Define eval fn
|
885 |
def eval_step(params, batch):
|
886 |
+
batch = unfreeze(batch)
|
887 |
labels = batch.pop("labels")
|
888 |
logits = model(**batch, params=params, train=False)[0]
|
889 |
loss = loss_fn(logits, labels)
|
|
|
895 |
# Create parallel version of the train and eval step
|
896 |
p_train_step = pjit(
|
897 |
train_step,
|
898 |
+
in_axis_resources=(state_spec, grad_batch_spec, None),
|
899 |
out_axis_resources=(state_spec, None),
|
900 |
donate_argnums=(0,),
|
901 |
)
|
902 |
p_eval_step = pjit(
|
903 |
eval_step,
|
904 |
+
in_axis_resources=(param_spec, batch_spec),
|
905 |
out_axis_resources=None,
|
906 |
)
|
907 |
|
|
|
942 |
# ======================== Evaluating ==============================
|
943 |
eval_metrics = []
|
944 |
if training_args.do_eval:
|
945 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size)
|
|
|
|
|
946 |
eval_steps = (
|
947 |
len_eval_dataset // eval_batch_size
|
948 |
if len_eval_dataset is not None
|
|
|
955 |
leave=False,
|
956 |
total=eval_steps,
|
957 |
):
|
958 |
+
# TODO: make this more efficient once training loop is fast
|
959 |
+
metrics = p_eval_step(state.params, freeze(batch))
|
960 |
eval_metrics.append(metrics)
|
961 |
|
962 |
# normalize eval metrics
|
|
|
1060 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
1061 |
train_loader = dataset.dataloader(
|
1062 |
"train",
|
1063 |
+
batch_size_per_node,
|
|
|
1064 |
epoch,
|
1065 |
)
|
1066 |
# train
|
|
|
1071 |
leave=False,
|
1072 |
total=steps_per_epoch,
|
1073 |
):
|
|
|
1074 |
# calculate delta time (we have a lag of one step but it's ok)
|
1075 |
new_time = time.perf_counter()
|
1076 |
delta_time = new_time - last_time
|
1077 |
last_time = new_time
|
1078 |
|
1079 |
+
# reshape data into (gradient_accumulation_steps, dp_devices, batch_per_dp, ...)
|
1080 |
+
batch = jax.tree_map(
|
1081 |
+
lambda x: x.reshape(
|
1082 |
+
(
|
1083 |
+
training_args.gradient_accumulation_steps,
|
1084 |
+
training_args.dp_devices,
|
1085 |
+
training_args.per_device_train_batch_size,
|
1086 |
+
)
|
1087 |
+
+ x.shape[1:]
|
1088 |
+
),
|
1089 |
+
batch,
|
1090 |
+
)
|
1091 |
+
|
1092 |
# train step
|
1093 |
+
state, train_metrics = p_train_step(state, freeze(batch), delta_time)
|
1094 |
+
step = int(state.step)
|
1095 |
|
1096 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
1097 |
all_metrics = metrics_logger.get_all_train_metrics(
|