boris commited on
Commit
00710bc
1 Parent(s): f254058

fix(train): grads spec

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -1
tools/train/train.py CHANGED
@@ -837,7 +837,9 @@ def main():
837
  loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
838
  # enforce sharding constraints to avoid OOM
839
  loss = with_sharding_constraint(loss, PartitionSpec("batch"))
840
- grads = with_sharding_constraint(grads, PartitionSpec("batch"))
 
 
841
  # calculate the mean over all devices
842
  loss = jnp.mean(loss)
843
  grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)
 
837
  loss, grads = jax.vmap(loss_grad_per_device, in_axes=0, out_axes=(0, 0))(batch)
838
  # enforce sharding constraints to avoid OOM
839
  loss = with_sharding_constraint(loss, PartitionSpec("batch"))
840
+ grads = jax.tree_map(
841
+ lambda x: with_sharding_constraint(x, PartitionSpec("batch")), grads
842
+ )
843
  # calculate the mean over all devices
844
  loss = jnp.mean(loss)
845
  grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)