after restart
Browse files
events.out.tfevents.1628300146.t1v-n-1a0a7c50-w-0.1616525.3.v2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc846e02789f45a7b1ada21713b56d4496b6dfa15c056619f2a5b46eb22ecdc0
|
3 |
+
size 52250633
|
flax_model_backup350k.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7a1b3c986f8f90bfb67978326d950efc15e64ad9b6d5f4884bcde1013a65968
|
3 |
+
size 1100762015
|
run_streaming.sh
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
--model_type="t5" \
|
4 |
--config_name="./" \
|
5 |
--tokenizer_name="./" \
|
|
|
6 |
--dataset_name="pere/norwegian_colossal_corpus_v2_short100k" \
|
7 |
--max_seq_length="512" \
|
8 |
--weight_decay="0.01" \
|
|
|
3 |
--model_type="t5" \
|
4 |
--config_name="./" \
|
5 |
--tokenizer_name="./" \
|
6 |
+
--model_name_or_path="./" \
|
7 |
--dataset_name="pere/norwegian_colossal_corpus_v2_short100k" \
|
8 |
--max_seq_length="512" \
|
9 |
--weight_decay="0.01" \
|
run_t5_mlm_flax_streaming.py
CHANGED
@@ -552,16 +552,17 @@ if __name__ == "__main__":
|
|
552 |
rng = jax.random.PRNGKey(training_args.seed)
|
553 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
554 |
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
|
|
565 |
|
566 |
|
567 |
# Data collator
|
|
|
552 |
rng = jax.random.PRNGKey(training_args.seed)
|
553 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
554 |
|
555 |
+
#Pere changed 13 august
|
556 |
+
#model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
557 |
+
|
558 |
+
if model_args.model_name_or_path:
|
559 |
+
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
560 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
561 |
+
)
|
562 |
+
else:
|
563 |
+
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
564 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
565 |
+
)
|
566 |
|
567 |
|
568 |
# Data collator
|