Spaces:
Running
Running
feat: use_auth_token + seed for dataset and model
Browse files- dev/seq2seq/run_seq2seq_flax.py +35 -12
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -129,6 +129,12 @@ class DataTrainingArguments:
|
|
129 |
default=False,
|
130 |
metadata={"help": "Whether to stream the dataset."},
|
131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
max_source_length: Optional[int] = field(
|
133 |
default=128,
|
134 |
metadata={
|
@@ -256,9 +262,18 @@ class TrainingArguments:
|
|
256 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
257 |
)
|
258 |
|
259 |
-
|
260 |
default=42,
|
261 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
)
|
263 |
|
264 |
push_to_hub: bool = field(
|
@@ -304,7 +319,9 @@ class TrainState(train_state.TrainState):
|
|
304 |
|
305 |
|
306 |
def data_loader(
|
307 |
-
|
|
|
|
|
308 |
):
|
309 |
"""
|
310 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
@@ -312,7 +329,7 @@ def data_loader(
|
|
312 |
"""
|
313 |
steps_per_epoch = len(dataset) // batch_size
|
314 |
|
315 |
-
if
|
316 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
317 |
else:
|
318 |
batch_idx = jnp.arange(len(dataset))
|
@@ -432,6 +449,7 @@ def main():
|
|
432 |
data_args.dataset_repo_or_path,
|
433 |
data_files=data_files,
|
434 |
streaming=data_args.streaming,
|
|
|
435 |
)
|
436 |
|
437 |
# Set up wandb run
|
@@ -483,7 +501,7 @@ def main():
|
|
483 |
|
484 |
# Create a custom model and initialize it randomly
|
485 |
model = CustomFlaxBartForConditionalGeneration(
|
486 |
-
config, seed=training_args.
|
487 |
)
|
488 |
|
489 |
# Load tokenizer
|
@@ -561,7 +579,14 @@ def main():
|
|
561 |
else train_dataset.select(range(data_args.max_train_samples))
|
562 |
)
|
563 |
if data_args.streaming:
|
564 |
-
train_dataset = train_dataset.shuffle(1000, training_args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
if model.config.normalize_text:
|
566 |
train_dataset = (
|
567 |
train_dataset.map(normalize_text)
|
@@ -627,7 +652,7 @@ def main():
|
|
627 |
)
|
628 |
|
629 |
# Initialize our training
|
630 |
-
rng = jax.random.PRNGKey(training_args.
|
631 |
rng, dropout_rng = jax.random.split(rng)
|
632 |
|
633 |
# Store some constant
|
@@ -808,7 +833,7 @@ def main():
|
|
808 |
if data_args.streaming:
|
809 |
eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
|
810 |
else:
|
811 |
-
eval_loader = data_loader(
|
812 |
eval_steps = (
|
813 |
len_eval_dataset // eval_batch_size
|
814 |
if len_eval_dataset is not None
|
@@ -927,10 +952,8 @@ def main():
|
|
927 |
train_dataset.set_epoch(epoch) # shuffle dataset
|
928 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
929 |
else:
|
930 |
-
|
931 |
-
train_loader = data_loader(
|
932 |
-
input_rng, train_dataset, train_batch_size, shuffle=True
|
933 |
-
)
|
934 |
# train
|
935 |
for batch in tqdm(
|
936 |
train_loader,
|
|
|
129 |
default=False,
|
130 |
metadata={"help": "Whether to stream the dataset."},
|
131 |
)
|
132 |
+
use_auth_token: bool = field(
|
133 |
+
default=False,
|
134 |
+
metadata={
|
135 |
+
"help": "Whether to use the authentication token for private datasets."
|
136 |
+
},
|
137 |
+
)
|
138 |
max_source_length: Optional[int] = field(
|
139 |
default=128,
|
140 |
metadata={
|
|
|
262 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
263 |
)
|
264 |
|
265 |
+
seed_model: int = field(
|
266 |
default=42,
|
267 |
+
metadata={
|
268 |
+
"help": "Random seed for the model that will be set at the beginning of training."
|
269 |
+
},
|
270 |
+
)
|
271 |
+
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
|
272 |
+
seed_dataset: int = field(
|
273 |
+
default=None,
|
274 |
+
metadata={
|
275 |
+
"help": "Random seed for the dataset that will be set at the beginning of training."
|
276 |
+
},
|
277 |
)
|
278 |
|
279 |
push_to_hub: bool = field(
|
|
|
319 |
|
320 |
|
321 |
def data_loader(
|
322 |
+
dataset: Dataset,
|
323 |
+
batch_size: int,
|
324 |
+
rng: jax.random.PRNGKey = None,
|
325 |
):
|
326 |
"""
|
327 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
|
|
329 |
"""
|
330 |
steps_per_epoch = len(dataset) // batch_size
|
331 |
|
332 |
+
if rng is not None:
|
333 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
334 |
else:
|
335 |
batch_idx = jnp.arange(len(dataset))
|
|
|
449 |
data_args.dataset_repo_or_path,
|
450 |
data_files=data_files,
|
451 |
streaming=data_args.streaming,
|
452 |
+
use_auth_token=data_args.use_auth_token,
|
453 |
)
|
454 |
|
455 |
# Set up wandb run
|
|
|
501 |
|
502 |
# Create a custom model and initialize it randomly
|
503 |
model = CustomFlaxBartForConditionalGeneration(
|
504 |
+
config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype)
|
505 |
)
|
506 |
|
507 |
# Load tokenizer
|
|
|
579 |
else train_dataset.select(range(data_args.max_train_samples))
|
580 |
)
|
581 |
if data_args.streaming:
|
582 |
+
train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
|
583 |
+
else:
|
584 |
+
seed_dataset = (
|
585 |
+
training_args.seed_dataset
|
586 |
+
if training_args.seed_dataset is not None
|
587 |
+
else np.random.get_state()[1][0]
|
588 |
+
)
|
589 |
+
rng_dataset = jax.random.PRNGKey(seed_dataset)
|
590 |
if model.config.normalize_text:
|
591 |
train_dataset = (
|
592 |
train_dataset.map(normalize_text)
|
|
|
652 |
)
|
653 |
|
654 |
# Initialize our training
|
655 |
+
rng = jax.random.PRNGKey(training_args.seed_model)
|
656 |
rng, dropout_rng = jax.random.split(rng)
|
657 |
|
658 |
# Store some constant
|
|
|
833 |
if data_args.streaming:
|
834 |
eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
|
835 |
else:
|
836 |
+
eval_loader = data_loader(eval_dataset, eval_batch_size)
|
837 |
eval_steps = (
|
838 |
len_eval_dataset // eval_batch_size
|
839 |
if len_eval_dataset is not None
|
|
|
952 |
train_dataset.set_epoch(epoch) # shuffle dataset
|
953 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
954 |
else:
|
955 |
+
rng_dataset, input_rng = jax.random.split(rng_dataset)
|
956 |
+
train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
|
|
|
|
|
957 |
# train
|
958 |
for batch in tqdm(
|
959 |
train_loader,
|