Spaces:
Running
Running
fix(train): grads spec
Browse files- tools/train/train.py +3 -1
tools/train/train.py
CHANGED
@@ -837,7 +837,9 @@ def main():
|
|
837 |
loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
|
838 |
# enforce sharding constraints to avoid OOM
|
839 |
loss = with_sharding_constraint(loss, PartitionSpec("batch"))
|
840 |
-
grads =
|
|
|
|
|
841 |
# calculate the mean over all devices
|
842 |
loss = jnp.mean(loss)
|
843 |
grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
|
|
|
837 |
loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
|
838 |
# enforce sharding constraints to avoid OOM
|
839 |
loss = with_sharding_constraint(loss, PartitionSpec("batch"))
|
840 |
+
grads = jax.tree_map(
|
841 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")), grads
|
842 |
+
)
|
843 |
# calculate the mean over all devices
|
844 |
loss = jnp.mean(loss)
|
845 |
grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
|