Spaces:
Running
Running
feat: load data first
Browse files- tools/train/train.py +3 -3
tools/train/train.py
CHANGED
@@ -375,9 +375,6 @@ def main():
|
|
375 |
datasets.utils.logging.set_verbosity_error()
|
376 |
transformers.utils.logging.set_verbosity_error()
|
377 |
|
378 |
-
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
379 |
-
assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
|
380 |
-
|
381 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
382 |
logger.info(f"Training/evaluation parameters {training_args}")
|
383 |
|
@@ -388,6 +385,9 @@ def main():
|
|
388 |
do_eval=training_args.do_eval,
|
389 |
)
|
390 |
|
|
|
|
|
|
|
391 |
# Set up wandb run
|
392 |
if jax.process_index() == 0:
|
393 |
wandb.init(
|
|
|
375 |
datasets.utils.logging.set_verbosity_error()
|
376 |
transformers.utils.logging.set_verbosity_error()
|
377 |
|
|
|
|
|
|
|
378 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
379 |
logger.info(f"Training/evaluation parameters {training_args}")
|
380 |
|
|
|
385 |
do_eval=training_args.do_eval,
|
386 |
)
|
387 |
|
388 |
+
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
389 |
+
assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
|
390 |
+
|
391 |
# Set up wandb run
|
392 |
if jax.process_index() == 0:
|
393 |
wandb.init(
|