Spaces:
Running
Running
feat: simplify parameters
Browse files- dev/seq2seq/run_seq2seq_flax.py +20 -47
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -151,7 +151,7 @@ class DataTrainingArguments:
|
|
151 |
"than this will be truncated, sequences shorter will be padded."
|
152 |
},
|
153 |
)
|
154 |
-
|
155 |
default=False,
|
156 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
157 |
)
|
@@ -170,18 +170,16 @@ class DataTrainingArguments:
|
|
170 |
},
|
171 |
)
|
172 |
preprocessing_num_workers: Optional[int] = field(
|
173 |
-
default=80, # ensure we have the same datasets cached data and avoid using too much space
|
174 |
-
metadata={"help": "The number of processes to use for the preprocessing."},
|
175 |
-
)
|
176 |
-
source_prefix: Optional[str] = field(
|
177 |
default=None,
|
178 |
metadata={
|
179 |
-
"help": "
|
180 |
},
|
181 |
)
|
182 |
overwrite_cache: bool = field(
|
183 |
default=False,
|
184 |
-
metadata={
|
|
|
|
|
185 |
)
|
186 |
log_interval: Optional[int] = field(
|
187 |
default=40,
|
@@ -189,41 +187,16 @@ class DataTrainingArguments:
|
|
189 |
)
|
190 |
log_model: bool = field(
|
191 |
default=False,
|
192 |
-
metadata={"help": "
|
193 |
)
|
194 |
save_model_steps: Optional[int] = field(
|
195 |
-
default=5000,
|
196 |
-
metadata={
|
197 |
-
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
198 |
-
},
|
199 |
)
|
200 |
|
201 |
def __post_init__(self):
|
202 |
if self.dataset_repo_or_path is None:
|
203 |
raise ValueError("Need a dataset repository or path.")
|
204 |
-
if self.train_file is None or self.validation_file is None:
|
205 |
-
raise ValueError("Need training/validation file.")
|
206 |
-
else:
|
207 |
-
if self.train_file is not None:
|
208 |
-
extension = self.train_file.split(".")[-1]
|
209 |
-
assert extension in [
|
210 |
-
"tsv",
|
211 |
-
"csv",
|
212 |
-
"json",
|
213 |
-
"jsonl",
|
214 |
-
], "`train_file` should be a tsv, csv or json file."
|
215 |
-
if self.validation_file is not None:
|
216 |
-
extension = self.validation_file.split(".")[-1]
|
217 |
-
assert extension in [
|
218 |
-
"tsv",
|
219 |
-
"csv",
|
220 |
-
"json",
|
221 |
-
"jsonl",
|
222 |
-
], "`validation_file` should be a tsv, csv or json file."
|
223 |
-
if self.streaming and (self.len_train is None or self.len_eval is None):
|
224 |
-
raise ValueError(
|
225 |
-
"Streaming requires providing length of training and validation datasets"
|
226 |
-
)
|
227 |
|
228 |
|
229 |
class TrainState(train_state.TrainState):
|
@@ -291,7 +264,7 @@ def create_learning_rate_fn(
|
|
291 |
num_train_epochs: int,
|
292 |
num_warmup_steps: int,
|
293 |
learning_rate: float,
|
294 |
-
|
295 |
) -> Callable[[int], jnp.array]:
|
296 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
297 |
steps_per_epoch = train_ds_size // train_batch_size
|
@@ -299,7 +272,7 @@ def create_learning_rate_fn(
|
|
299 |
warmup_fn = optax.linear_schedule(
|
300 |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
301 |
)
|
302 |
-
if
|
303 |
return warmup_fn
|
304 |
decay_fn = optax.linear_schedule(
|
305 |
init_value=learning_rate,
|
@@ -372,10 +345,13 @@ def main():
|
|
372 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
373 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
374 |
#
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
|
|
|
|
|
|
379 |
dataset = load_dataset(
|
380 |
data_args.dataset_repo_or_path,
|
381 |
data_files=data_files,
|
@@ -449,8 +425,6 @@ def main():
|
|
449 |
print(f"TPUs: {jax.device_count()}")
|
450 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
451 |
|
452 |
-
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
453 |
-
|
454 |
# Preprocessing the datasets.
|
455 |
# We need to tokenize inputs and targets.
|
456 |
|
@@ -475,7 +449,6 @@ def main():
|
|
475 |
|
476 |
def preprocess_function(examples):
|
477 |
inputs = examples[text_column]
|
478 |
-
inputs = [prefix + inp for inp in inputs] if prefix else inputs
|
479 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
480 |
model_inputs = tokenizer(
|
481 |
inputs,
|
@@ -617,7 +590,7 @@ def main():
|
|
617 |
training_args.num_train_epochs,
|
618 |
training_args.warmup_steps,
|
619 |
training_args.learning_rate,
|
620 |
-
data_args.
|
621 |
)
|
622 |
|
623 |
# We use Optax's "masking" functionality to not apply weight decay
|
@@ -625,8 +598,6 @@ def main():
|
|
625 |
# mask boolean with the same structure as the parameters.
|
626 |
# The mask is True for parameters that should be decayed.
|
627 |
# Note that this mask is specifically adapted for FlaxBart.
|
628 |
-
# For FlaxT5, one should correct the layer norm parameter naming
|
629 |
-
# accordingly - see `run_t5_mlm_flax.py` e.g.
|
630 |
def decay_mask_fn(params):
|
631 |
flat_params = traverse_util.flatten_dict(params)
|
632 |
layer_norm_params = [
|
@@ -649,6 +620,8 @@ def main():
|
|
649 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
650 |
optimizer = optax.adafactor(
|
651 |
learning_rate=learning_rate_fn,
|
|
|
|
|
652 |
)
|
653 |
else:
|
654 |
optimizer = optax.adamw(
|
|
|
151 |
"than this will be truncated, sequences shorter will be padded."
|
152 |
},
|
153 |
)
|
154 |
+
use_decay: bool = field(
|
155 |
default=False,
|
156 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
157 |
)
|
|
|
170 |
},
|
171 |
)
|
172 |
preprocessing_num_workers: Optional[int] = field(
|
|
|
|
|
|
|
|
|
173 |
default=None,
|
174 |
metadata={
|
175 |
+
"help": "The number of processes to use for the preprocessing. Not used in streaming mode."
|
176 |
},
|
177 |
)
|
178 |
overwrite_cache: bool = field(
|
179 |
default=False,
|
180 |
+
metadata={
|
181 |
+
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
182 |
+
},
|
183 |
)
|
184 |
log_interval: Optional[int] = field(
|
185 |
default=40,
|
|
|
187 |
)
|
188 |
log_model: bool = field(
|
189 |
default=False,
|
190 |
+
metadata={"help": "Log frequency for model"},
|
191 |
)
|
192 |
save_model_steps: Optional[int] = field(
|
193 |
+
default=5000,
|
194 |
+
metadata={"help": "For saving/logging the model more frequently"},
|
|
|
|
|
195 |
)
|
196 |
|
197 |
def __post_init__(self):
|
198 |
if self.dataset_repo_or_path is None:
|
199 |
raise ValueError("Need a dataset repository or path.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
|
202 |
class TrainState(train_state.TrainState):
|
|
|
264 |
num_train_epochs: int,
|
265 |
num_warmup_steps: int,
|
266 |
learning_rate: float,
|
267 |
+
use_decay: bool,
|
268 |
) -> Callable[[int], jnp.array]:
|
269 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
270 |
steps_per_epoch = train_ds_size // train_batch_size
|
|
|
272 |
warmup_fn = optax.linear_schedule(
|
273 |
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
274 |
)
|
275 |
+
if not use_decay:
|
276 |
return warmup_fn
|
277 |
decay_fn = optax.linear_schedule(
|
278 |
init_value=learning_rate,
|
|
|
345 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
346 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
347 |
#
|
348 |
+
if data_args.train_file is not None or data_args.validation_file is not None:
|
349 |
+
data_files = {
|
350 |
+
"train": data_args.train_file,
|
351 |
+
"validation": data_args.validation_file,
|
352 |
+
}
|
353 |
+
else:
|
354 |
+
data_files = None
|
355 |
dataset = load_dataset(
|
356 |
data_args.dataset_repo_or_path,
|
357 |
data_files=data_files,
|
|
|
425 |
print(f"TPUs: {jax.device_count()}")
|
426 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
427 |
|
|
|
|
|
428 |
# Preprocessing the datasets.
|
429 |
# We need to tokenize inputs and targets.
|
430 |
|
|
|
449 |
|
450 |
def preprocess_function(examples):
|
451 |
inputs = examples[text_column]
|
|
|
452 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
453 |
model_inputs = tokenizer(
|
454 |
inputs,
|
|
|
590 |
training_args.num_train_epochs,
|
591 |
training_args.warmup_steps,
|
592 |
training_args.learning_rate,
|
593 |
+
data_args.use_decay,
|
594 |
)
|
595 |
|
596 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
598 |
# mask boolean with the same structure as the parameters.
|
599 |
# The mask is True for parameters that should be decayed.
|
600 |
# Note that this mask is specifically adapted for FlaxBart.
|
|
|
|
|
601 |
def decay_mask_fn(params):
|
602 |
flat_params = traverse_util.flatten_dict(params)
|
603 |
layer_norm_params = [
|
|
|
620 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
621 |
optimizer = optax.adafactor(
|
622 |
learning_rate=learning_rate_fn,
|
623 |
+
weight_decay_rate=training_args.weight_decay,
|
624 |
+
weight_decay_mask=decay_mask_fn
|
625 |
)
|
626 |
else:
|
627 |
optimizer = optax.adamw(
|