Spaces:
Running
Running
feat(train): different rng per node
Browse files- tools/train/train.py +2 -0
tools/train/train.py
CHANGED
@@ -727,6 +727,8 @@ def main():
|
|
727 |
# Define gradient update step fn
|
728 |
def train_step(state, batch, delta_time):
|
729 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
|
730 |
|
731 |
def compute_loss(params, minibatch):
|
732 |
labels = minibatch.pop("labels")
|
|
|
727 |
# Define gradient update step fn
|
728 |
def train_step(state, batch, delta_time):
|
729 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
730 |
+
# use a different rng per node
|
731 |
+
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
732 |
|
733 |
def compute_loss(params, minibatch):
|
734 |
labels = minibatch.pop("labels")
|