pere commited on
Commit
bc603e8
1 Parent(s): 6364611
Files changed (2) hide show
  1. run.sh +3 -3
  2. run_mlm_flax_stream.py +1 -1
run.sh CHANGED
@@ -1,13 +1,13 @@
1
  python run_mlm_flax_stream.py \
2
- --output_dir="../roberta-debug-32" \
3
  --model_name_or_path="xlm-roberta-base" \
4
  --config_name="./" \
5
  --tokenizer_name="./" \
6
  --dataset_name="NbAiLab/scandinavian" \
7
  --max_seq_length="512" \
8
  --weight_decay="0.01" \
9
- --per_device_train_batch_size="12" \
10
- --per_device_eval_batch_size="12" \
11
  --learning_rate="1e-4" \
12
  --warmup_steps="10000" \
13
  --overwrite_output_dir \
 
1
  python run_mlm_flax_stream.py \
2
+ --output_dir="../roberta-debug-pod-32" \
3
  --model_name_or_path="xlm-roberta-base" \
4
  --config_name="./" \
5
  --tokenizer_name="./" \
6
  --dataset_name="NbAiLab/scandinavian" \
7
  --max_seq_length="512" \
8
  --weight_decay="0.01" \
9
+ --per_device_train_batch_size="62" \
10
+ --per_device_eval_batch_size="62" \
11
  --learning_rate="1e-4" \
12
  --warmup_steps="10000" \
13
  --overwrite_output_dir \
run_mlm_flax_stream.py CHANGED
@@ -451,7 +451,7 @@ if __name__ == "__main__":
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
454
-
455
  print("***************************")
456
  print(f"Train Batch Size: {train_batch_size}")
457
  print("***************************")
 
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
454
+ breakpoint()
455
  print("***************************")
456
  print(f"Train Batch Size: {train_batch_size}")
457
  print("***************************")