Spaces:
Running
Running
Merge pull request #118 from borisdayma/feat-optim
Browse files- src/dalle_mini/data.py +11 -12
- tools/inference/inference_pipeline.ipynb +46 -19
- tools/train/train.py +129 -101
src/dalle_mini/data.py
CHANGED
@@ -161,7 +161,7 @@ class Dataset:
|
|
161 |
):
|
162 |
"""
|
163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
164 |
-
Shuffle batches if
|
165 |
"""
|
166 |
steps_per_epoch = len(dataset) // batch_size
|
167 |
|
@@ -182,19 +182,20 @@ class Dataset:
|
|
182 |
yield batch
|
183 |
|
184 |
def _dataloader_datasets_streaming(
|
185 |
-
dataset: Dataset, batch_size: int, epoch: int
|
186 |
):
|
187 |
-
# epoch is only use for multi-host
|
188 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
189 |
batch = {k: [] for k in keys}
|
190 |
-
first_loop = True
|
191 |
-
while self.multi_hosts or first_loop:
|
192 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
193 |
-
# at the same time and
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
197 |
dataset.set_epoch(epoch)
|
|
|
198 |
for item in dataset:
|
199 |
for k, v in item.items():
|
200 |
batch[k].append(v)
|
@@ -213,9 +214,7 @@ class Dataset:
|
|
213 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
214 |
|
215 |
if self.streaming:
|
216 |
-
|
217 |
-
ds.set_epoch(epoch)
|
218 |
-
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
219 |
else:
|
220 |
if split == "train":
|
221 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
|
|
161 |
):
|
162 |
"""
|
163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
164 |
+
Shuffle batches if rng is set.
|
165 |
"""
|
166 |
steps_per_epoch = len(dataset) // batch_size
|
167 |
|
|
|
182 |
yield batch
|
183 |
|
184 |
def _dataloader_datasets_streaming(
|
185 |
+
dataset: Dataset, split: str, batch_size: int, epoch: int
|
186 |
):
|
|
|
187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
188 |
batch = {k: [] for k in keys}
|
189 |
+
first_loop = True # stop after one loop in some cases
|
190 |
+
while (self.multi_hosts and split == "train") or first_loop:
|
191 |
# in multi-host, we run forever (no epoch) as hosts need to stop
|
192 |
+
# at the same time and training data may not be split equally
|
193 |
+
# For validation data we put the entire set on each host as we could lose
|
194 |
+
# too many samples on pods
|
195 |
+
if epoch is not None:
|
196 |
+
# reshuffle training data at each epoch (not applicable with validation set)
|
197 |
dataset.set_epoch(epoch)
|
198 |
+
epoch += 1
|
199 |
for item in dataset:
|
200 |
for k, v in item.items():
|
201 |
batch[k].append(v)
|
|
|
214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
215 |
|
216 |
if self.streaming:
|
217 |
+
return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
|
|
|
|
|
218 |
else:
|
219 |
if split == "train":
|
220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -70,15 +70,15 @@
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
-
"DALLE_MODEL =
|
74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
77 |
-
"VQGAN_REPO =
|
78 |
-
"VQGAN_COMMIT_ID =
|
79 |
"\n",
|
80 |
"# CLIP model\n",
|
81 |
-
"CLIP_REPO =
|
82 |
"CLIP_COMMIT_ID = None"
|
83 |
]
|
84 |
},
|
@@ -121,18 +121,28 @@
|
|
121 |
"import wandb\n",
|
122 |
"\n",
|
123 |
"# Load dalle-mini\n",
|
124 |
-
"if
|
125 |
" # wandb artifact\n",
|
126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
127 |
" # we only download required files (no need for opt_state which is large)\n",
|
128 |
-
" model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
" for f in model_files:\n",
|
130 |
-
" artifact.get_path(f).download(
|
131 |
-
" model = DalleBart.from_pretrained(
|
132 |
-
" tokenizer = AutoTokenizer.from_pretrained(
|
133 |
"else:\n",
|
134 |
" # local folder or 🤗 Hub\n",
|
135 |
-
" model = DalleBart.from_pretrained(
|
|
|
|
|
136 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
137 |
"\n",
|
138 |
"# Load VQGAN\n",
|
@@ -191,7 +201,7 @@
|
|
191 |
"from functools import partial\n",
|
192 |
"\n",
|
193 |
"# model inference\n",
|
194 |
-
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n",
|
195 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
196 |
" return model.generate(\n",
|
197 |
" **tokenized_prompt,\n",
|
@@ -203,11 +213,13 @@
|
|
203 |
" top_p=top_p\n",
|
204 |
" )\n",
|
205 |
"\n",
|
|
|
206 |
"# decode images\n",
|
207 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
208 |
"def p_decode(indices, params):\n",
|
209 |
" return vqgan.decode_code(indices, params=params)\n",
|
210 |
"\n",
|
|
|
211 |
"# score images\n",
|
212 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
213 |
"def p_clip(inputs, params):\n",
|
@@ -235,7 +247,7 @@
|
|
235 |
"import random\n",
|
236 |
"\n",
|
237 |
"# create a random key\n",
|
238 |
-
"seed = random.randint(0, 2**32-1)\n",
|
239 |
"key = jax.random.PRNGKey(seed)"
|
240 |
]
|
241 |
},
|
@@ -287,7 +299,7 @@
|
|
287 |
},
|
288 |
"outputs": [],
|
289 |
"source": [
|
290 |
-
"prompt =
|
291 |
]
|
292 |
},
|
293 |
{
|
@@ -323,7 +335,13 @@
|
|
323 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
324 |
"\n",
|
325 |
"# tokenize\n",
|
326 |
-
"tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
"tokenized_prompt"
|
328 |
]
|
329 |
},
|
@@ -408,12 +426,14 @@
|
|
408 |
" # get a new key\n",
|
409 |
" key, subkey = jax.random.split(key)\n",
|
410 |
" # generate images\n",
|
411 |
-
" encoded_images = p_generate(
|
|
|
|
|
412 |
" # remove BOS\n",
|
413 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
414 |
" # decode images\n",
|
415 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
416 |
-
" decoded_images = decoded_images.clip(0
|
417 |
" for img in decoded_images:\n",
|
418 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
419 |
]
|
@@ -436,7 +456,14 @@
|
|
436 |
"outputs": [],
|
437 |
"source": [
|
438 |
"# get clip scores\n",
|
439 |
-
"clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
441 |
"logits = logits.squeeze().flatten()"
|
442 |
]
|
@@ -458,10 +485,10 @@
|
|
458 |
},
|
459 |
"outputs": [],
|
460 |
"source": [
|
461 |
-
"print(f
|
462 |
"for idx in logits.argsort()[::-1]:\n",
|
463 |
" display(images[idx])\n",
|
464 |
-
" print(f
|
465 |
]
|
466 |
}
|
467 |
],
|
|
|
70 |
"# Model references\n",
|
71 |
"\n",
|
72 |
"# dalle-mini\n",
|
73 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
|
74 |
"DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
|
75 |
"\n",
|
76 |
"# VQGAN model\n",
|
77 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
78 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
79 |
"\n",
|
80 |
"# CLIP model\n",
|
81 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
|
82 |
"CLIP_COMMIT_ID = None"
|
83 |
]
|
84 |
},
|
|
|
121 |
"import wandb\n",
|
122 |
"\n",
|
123 |
"# Load dalle-mini\n",
|
124 |
+
"if \":\" in DALLE_MODEL:\n",
|
125 |
" # wandb artifact\n",
|
126 |
" artifact = wandb.Api().artifact(DALLE_MODEL)\n",
|
127 |
" # we only download required files (no need for opt_state which is large)\n",
|
128 |
+
" model_files = [\n",
|
129 |
+
" \"config.json\",\n",
|
130 |
+
" \"flax_model.msgpack\",\n",
|
131 |
+
" \"merges.txt\",\n",
|
132 |
+
" \"special_tokens_map.json\",\n",
|
133 |
+
" \"tokenizer.json\",\n",
|
134 |
+
" \"tokenizer_config.json\",\n",
|
135 |
+
" \"vocab.json\",\n",
|
136 |
+
" ]\n",
|
137 |
" for f in model_files:\n",
|
138 |
+
" artifact.get_path(f).download(\"model\")\n",
|
139 |
+
" model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
|
140 |
+
" tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
|
141 |
"else:\n",
|
142 |
" # local folder or 🤗 Hub\n",
|
143 |
+
" model = DalleBart.from_pretrained(\n",
|
144 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
145 |
+
" )\n",
|
146 |
" tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
|
147 |
"\n",
|
148 |
"# Load VQGAN\n",
|
|
|
201 |
"from functools import partial\n",
|
202 |
"\n",
|
203 |
"# model inference\n",
|
204 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
|
205 |
"def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
|
206 |
" return model.generate(\n",
|
207 |
" **tokenized_prompt,\n",
|
|
|
213 |
" top_p=top_p\n",
|
214 |
" )\n",
|
215 |
"\n",
|
216 |
+
"\n",
|
217 |
"# decode images\n",
|
218 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
219 |
"def p_decode(indices, params):\n",
|
220 |
" return vqgan.decode_code(indices, params=params)\n",
|
221 |
"\n",
|
222 |
+
"\n",
|
223 |
"# score images\n",
|
224 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
225 |
"def p_clip(inputs, params):\n",
|
|
|
247 |
"import random\n",
|
248 |
"\n",
|
249 |
"# create a random key\n",
|
250 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
251 |
"key = jax.random.PRNGKey(seed)"
|
252 |
]
|
253 |
},
|
|
|
299 |
},
|
300 |
"outputs": [],
|
301 |
"source": [
|
302 |
+
"prompt = \"a red T-shirt\""
|
303 |
]
|
304 |
},
|
305 |
{
|
|
|
335 |
"repeated_prompts = [processed_prompt] * jax.device_count()\n",
|
336 |
"\n",
|
337 |
"# tokenize\n",
|
338 |
+
"tokenized_prompt = tokenizer(\n",
|
339 |
+
" repeated_prompts,\n",
|
340 |
+
" return_tensors=\"jax\",\n",
|
341 |
+
" padding=\"max_length\",\n",
|
342 |
+
" truncation=True,\n",
|
343 |
+
" max_length=128,\n",
|
344 |
+
").data\n",
|
345 |
"tokenized_prompt"
|
346 |
]
|
347 |
},
|
|
|
426 |
" # get a new key\n",
|
427 |
" key, subkey = jax.random.split(key)\n",
|
428 |
" # generate images\n",
|
429 |
+
" encoded_images = p_generate(\n",
|
430 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
|
431 |
+
" )\n",
|
432 |
" # remove BOS\n",
|
433 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
434 |
" # decode images\n",
|
435 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
436 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
437 |
" for img in decoded_images:\n",
|
438 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
439 |
]
|
|
|
456 |
"outputs": [],
|
457 |
"source": [
|
458 |
"# get clip scores\n",
|
459 |
+
"clip_inputs = processor(\n",
|
460 |
+
" text=[prompt] * jax.device_count(),\n",
|
461 |
+
" images=images,\n",
|
462 |
+
" return_tensors=\"np\",\n",
|
463 |
+
" padding=\"max_length\",\n",
|
464 |
+
" max_length=77,\n",
|
465 |
+
" truncation=True,\n",
|
466 |
+
").data\n",
|
467 |
"logits = p_clip(shard(clip_inputs), clip_params)\n",
|
468 |
"logits = logits.squeeze().flatten()"
|
469 |
]
|
|
|
485 |
},
|
486 |
"outputs": [],
|
487 |
"source": [
|
488 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
489 |
"for idx in logits.argsort()[::-1]:\n",
|
490 |
" display(images[idx])\n",
|
491 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
492 |
]
|
493 |
}
|
494 |
],
|
tools/train/train.py
CHANGED
@@ -65,7 +65,7 @@ class ModelArguments:
|
|
65 |
config_name: Optional[str] = field(
|
66 |
default=None,
|
67 |
metadata={
|
68 |
-
"help": "Pretrained config name or path if not the same as
|
69 |
},
|
70 |
)
|
71 |
tokenizer_name: Optional[str] = field(
|
@@ -77,7 +77,7 @@ class ModelArguments:
|
|
77 |
dtype: Optional[str] = field(
|
78 |
default="float32",
|
79 |
metadata={
|
80 |
-
"help": "Floating-point format in which the
|
81 |
},
|
82 |
)
|
83 |
|
@@ -106,11 +106,15 @@ class DataTrainingArguments:
|
|
106 |
)
|
107 |
train_file: Optional[str] = field(
|
108 |
default=None,
|
109 |
-
metadata={
|
|
|
|
|
110 |
)
|
111 |
validation_file: Optional[str] = field(
|
112 |
default=None,
|
113 |
-
metadata={
|
|
|
|
|
114 |
)
|
115 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
116 |
streaming: Optional[bool] = field(
|
@@ -132,15 +136,13 @@ class DataTrainingArguments:
|
|
132 |
max_train_samples: Optional[int] = field(
|
133 |
default=None,
|
134 |
metadata={
|
135 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples
|
136 |
-
"value if set."
|
137 |
},
|
138 |
)
|
139 |
max_eval_samples: Optional[int] = field(
|
140 |
default=None,
|
141 |
metadata={
|
142 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples
|
143 |
-
"value if set."
|
144 |
},
|
145 |
)
|
146 |
preprocessing_num_workers: Optional[int] = field(
|
@@ -191,42 +193,40 @@ class TrainingArguments:
|
|
191 |
|
192 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
193 |
do_eval: bool = field(
|
194 |
-
default=False, metadata={"help": "Whether to run eval on the
|
195 |
)
|
196 |
|
197 |
per_device_train_batch_size: int = field(
|
198 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
199 |
)
|
200 |
per_device_eval_batch_size: int = field(
|
201 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
202 |
)
|
203 |
|
204 |
gradient_accumulation_steps: int = field(
|
205 |
default=1,
|
206 |
metadata={
|
207 |
-
"help": "Number of updates steps to accumulate before performing
|
208 |
},
|
209 |
)
|
210 |
|
211 |
learning_rate: float = field(
|
212 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
213 |
)
|
214 |
-
|
215 |
-
default=
|
216 |
-
metadata={
|
217 |
-
|
218 |
-
|
219 |
-
default=False,
|
220 |
-
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
221 |
-
)
|
222 |
-
weight_decay: float = field(
|
223 |
-
default=None, metadata={"help": "Weight decay if we apply some."}
|
224 |
)
|
225 |
-
|
226 |
-
|
|
|
|
|
227 |
)
|
228 |
-
|
229 |
-
default=0.999,
|
|
|
230 |
)
|
231 |
adam_epsilon: float = field(
|
232 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
@@ -234,9 +234,47 @@ class TrainingArguments:
|
|
234 |
max_grad_norm: float = field(
|
235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
236 |
)
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
default=False,
|
239 |
-
metadata={
|
|
|
|
|
240 |
)
|
241 |
|
242 |
num_train_epochs: float = field(
|
@@ -267,18 +305,18 @@ class TrainingArguments:
|
|
267 |
},
|
268 |
)
|
269 |
|
270 |
-
push_to_hub: bool = field(
|
271 |
-
default=False,
|
272 |
-
metadata={
|
273 |
-
"help": "Whether or not to upload the trained model to the model hub after training."
|
274 |
-
},
|
275 |
-
)
|
276 |
-
|
277 |
resume_from_checkpoint: Optional[str] = field(
|
278 |
default=None,
|
279 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
280 |
)
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
class TrainState(train_state.TrainState):
|
284 |
dropout_rng: jnp.ndarray = None
|
@@ -309,33 +347,6 @@ class TrainState(train_state.TrainState):
|
|
309 |
)
|
310 |
|
311 |
|
312 |
-
def create_learning_rate_fn(
|
313 |
-
num_warmup_steps: int,
|
314 |
-
learning_rate: float,
|
315 |
-
use_decay: bool,
|
316 |
-
num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
|
317 |
-
) -> Callable[[int], jnp.array]:
|
318 |
-
"""Returns a linear warmup, linear_decay learning rate function."""
|
319 |
-
if use_decay:
|
320 |
-
assert (
|
321 |
-
num_train_steps is not None
|
322 |
-
), "Learning rate with decay requires number of training steps"
|
323 |
-
warmup_fn = optax.linear_schedule(
|
324 |
-
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
325 |
-
)
|
326 |
-
if not use_decay:
|
327 |
-
return warmup_fn
|
328 |
-
decay_fn = optax.linear_schedule(
|
329 |
-
init_value=learning_rate,
|
330 |
-
end_value=0,
|
331 |
-
transition_steps=num_train_steps - num_warmup_steps,
|
332 |
-
)
|
333 |
-
schedule_fn = optax.join_schedules(
|
334 |
-
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
335 |
-
)
|
336 |
-
return schedule_fn
|
337 |
-
|
338 |
-
|
339 |
class MetricsLogger:
|
340 |
def __init__(self, state):
|
341 |
self.step = state.step
|
@@ -529,12 +540,37 @@ def main():
|
|
529 |
num_params = model.num_params
|
530 |
|
531 |
# Create learning rate schedule
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
|
539 |
# We use Optax's "masking" functionality to not apply weight decay
|
540 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
@@ -558,29 +594,22 @@ def main():
|
|
558 |
return traverse_util.unflatten_dict(flat_mask)
|
559 |
|
560 |
# create adam optimizer
|
561 |
-
if training_args.
|
562 |
-
# We use the default parameters here to initialize adafactor,
|
563 |
-
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
564 |
-
optimizer = optax.adafactor(
|
565 |
-
learning_rate=learning_rate_fn,
|
566 |
-
weight_decay_rate=training_args.weight_decay,
|
567 |
-
weight_decay_mask=decay_mask_fn,
|
568 |
-
clipping_threshold=training_args.max_grad_norm,
|
569 |
-
)
|
570 |
-
elif training_args.distributed_shampoo:
|
571 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
572 |
# Notes:
|
573 |
-
# - mask for weight decay is not implemented
|
574 |
optimizer = distributed_shampoo(
|
575 |
learning_rate_fn,
|
576 |
-
block_size=
|
577 |
-
beta1=
|
578 |
-
beta2=
|
579 |
diagonal_epsilon=1e-10,
|
580 |
matrix_epsilon=1e-8,
|
581 |
-
weight_decay=
|
582 |
-
|
583 |
-
|
|
|
|
|
584 |
statistics_compute_steps=1,
|
585 |
best_effort_shape_interpretation=True,
|
586 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
@@ -589,23 +618,32 @@ def main():
|
|
589 |
batch_axis_name="batch",
|
590 |
inverse_failure_threshold=0.1,
|
591 |
moving_average_for_momentum=True,
|
592 |
-
skip_preconditioning_dim_size_gt=
|
593 |
clip_by_scaled_gradient_norm=None,
|
594 |
precision=jax.lax.Precision.HIGHEST,
|
595 |
-
best_effort_memory_usage_reduction=
|
596 |
)
|
597 |
|
598 |
-
|
599 |
optimizer = optax.adamw(
|
600 |
learning_rate=learning_rate_fn,
|
601 |
-
b1=training_args.
|
602 |
-
b2=training_args.
|
603 |
eps=training_args.adam_epsilon,
|
604 |
weight_decay=training_args.weight_decay
|
605 |
if training_args.weight_decay is not None
|
606 |
else 0.0,
|
607 |
mask=decay_mask_fn,
|
608 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
|
610 |
# add gradient accumulation
|
611 |
if training_args.gradient_accumulation_steps > 1:
|
@@ -821,16 +859,6 @@ def main():
|
|
821 |
|
822 |
wandb.run.log_artifact(artifact)
|
823 |
|
824 |
-
# save to the hub
|
825 |
-
if training_args.push_to_hub:
|
826 |
-
model.save_pretrained(
|
827 |
-
training_args.output_dir,
|
828 |
-
params=params,
|
829 |
-
push_to_hub=training_args.push_to_hub,
|
830 |
-
commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
|
831 |
-
temp_dir=True, # avoid issues with being in a repository
|
832 |
-
)
|
833 |
-
|
834 |
# init variables
|
835 |
last_time = time.perf_counter()
|
836 |
train_metrics = None
|
@@ -841,7 +869,7 @@ def main():
|
|
841 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
842 |
|
843 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
844 |
-
train_loader = dataset.dataloader("train", train_batch_size)
|
845 |
# train
|
846 |
for batch in tqdm(
|
847 |
train_loader,
|
|
|
65 |
config_name: Optional[str] = field(
|
66 |
default=None,
|
67 |
metadata={
|
68 |
+
"help": "Pretrained config name or path if not the same as model_name_or_path"
|
69 |
},
|
70 |
)
|
71 |
tokenizer_name: Optional[str] = field(
|
|
|
77 |
dtype: Optional[str] = field(
|
78 |
default="float32",
|
79 |
metadata={
|
80 |
+
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
81 |
},
|
82 |
)
|
83 |
|
|
|
106 |
)
|
107 |
train_file: Optional[str] = field(
|
108 |
default=None,
|
109 |
+
metadata={
|
110 |
+
"help": "The input training data file (glob & braceexpand acceptable)."
|
111 |
+
},
|
112 |
)
|
113 |
validation_file: Optional[str] = field(
|
114 |
default=None,
|
115 |
+
metadata={
|
116 |
+
"help": "An optional input evaluation data file (glob & braceexpand acceptable)."
|
117 |
+
},
|
118 |
)
|
119 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
120 |
streaming: Optional[bool] = field(
|
|
|
136 |
max_train_samples: Optional[int] = field(
|
137 |
default=None,
|
138 |
metadata={
|
139 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples."
|
|
|
140 |
},
|
141 |
)
|
142 |
max_eval_samples: Optional[int] = field(
|
143 |
default=None,
|
144 |
metadata={
|
145 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
|
|
|
146 |
},
|
147 |
)
|
148 |
preprocessing_num_workers: Optional[int] = field(
|
|
|
193 |
|
194 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
195 |
do_eval: bool = field(
|
196 |
+
default=False, metadata={"help": "Whether to run eval on the validation set."}
|
197 |
)
|
198 |
|
199 |
per_device_train_batch_size: int = field(
|
200 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
201 |
)
|
202 |
per_device_eval_batch_size: int = field(
|
203 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
|
204 |
)
|
205 |
|
206 |
gradient_accumulation_steps: int = field(
|
207 |
default=1,
|
208 |
metadata={
|
209 |
+
"help": "Number of updates steps to accumulate before performing an update pass."
|
210 |
},
|
211 |
)
|
212 |
|
213 |
learning_rate: float = field(
|
214 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
215 |
)
|
216 |
+
optim: str = field(
|
217 |
+
default="distributed_shampoo",
|
218 |
+
metadata={
|
219 |
+
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
220 |
+
},
|
|
|
|
|
|
|
|
|
|
|
221 |
)
|
222 |
+
weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
|
223 |
+
beta1: float = field(
|
224 |
+
default=0.9,
|
225 |
+
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
226 |
)
|
227 |
+
beta2: float = field(
|
228 |
+
default=0.999,
|
229 |
+
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
|
230 |
)
|
231 |
adam_epsilon: float = field(
|
232 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
|
234 |
max_grad_norm: float = field(
|
235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
236 |
)
|
237 |
+
block_size: int = field(
|
238 |
+
default=1024,
|
239 |
+
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
240 |
+
)
|
241 |
+
preconditioning_compute_steps: int = field(
|
242 |
+
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
243 |
+
)
|
244 |
+
skip_preconditioning_dim_size_gt: int = field(
|
245 |
+
default=4096,
|
246 |
+
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
247 |
+
)
|
248 |
+
optim_quantized: bool = field(
|
249 |
+
default=False,
|
250 |
+
metadata={
|
251 |
+
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
252 |
+
},
|
253 |
+
)
|
254 |
+
|
255 |
+
lr_decay: str = field(
|
256 |
+
default=None,
|
257 |
+
metadata={
|
258 |
+
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
|
259 |
+
},
|
260 |
+
)
|
261 |
+
lr_transition_steps: int = field(
|
262 |
+
default=None,
|
263 |
+
metadata={
|
264 |
+
"help": "Number of transition steps associated with learning rate decay when using exponential decay."
|
265 |
+
},
|
266 |
+
)
|
267 |
+
lr_decay_rate: float = field(
|
268 |
+
default=None,
|
269 |
+
metadata={
|
270 |
+
"help": "Decay rate associated with learning rate when using exponential decay."
|
271 |
+
},
|
272 |
+
)
|
273 |
+
lr_staircase: bool = field(
|
274 |
default=False,
|
275 |
+
metadata={
|
276 |
+
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
277 |
+
},
|
278 |
)
|
279 |
|
280 |
num_train_epochs: float = field(
|
|
|
305 |
},
|
306 |
)
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
resume_from_checkpoint: Optional[str] = field(
|
309 |
default=None,
|
310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
311 |
)
|
312 |
|
313 |
+
def __post_init__(self):
|
314 |
+
assert self.optim in [
|
315 |
+
"distributed_shampoo",
|
316 |
+
"adam",
|
317 |
+
"adafactor",
|
318 |
+
], f"Selected optimizer not supported: {self.optim}"
|
319 |
+
|
320 |
|
321 |
class TrainState(train_state.TrainState):
|
322 |
dropout_rng: jnp.ndarray = None
|
|
|
347 |
)
|
348 |
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
class MetricsLogger:
|
351 |
def __init__(self, state):
|
352 |
self.step = state.step
|
|
|
540 |
num_params = model.num_params
|
541 |
|
542 |
# Create learning rate schedule
|
543 |
+
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
544 |
+
"""Create the learning rate function."""
|
545 |
+
warmup_fn = optax.linear_schedule(
|
546 |
+
init_value=0.0,
|
547 |
+
end_value=training_args.learning_rate,
|
548 |
+
transition_steps=training_args.warmup_steps,
|
549 |
+
)
|
550 |
+
if training_args.lr_decay is None:
|
551 |
+
return warmup_fn
|
552 |
+
elif training_args.lr_decay == "linear":
|
553 |
+
assert (
|
554 |
+
num_train_steps is not None
|
555 |
+
), "linear decay requires knowing the dataset length"
|
556 |
+
decay_fn = optax.linear_schedule(
|
557 |
+
init_value=training_args.learning_rate,
|
558 |
+
end_value=0,
|
559 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
560 |
+
)
|
561 |
+
elif training_args.lr_decay == "exponential":
|
562 |
+
decay_fn = optax.exponential_decay(
|
563 |
+
init_value=training_args.learning_rate,
|
564 |
+
transition_steps=training_args.lr_transition_steps,
|
565 |
+
decay_rate=training_args.lr_decay_rate,
|
566 |
+
staircase=training_args.lr_staircase,
|
567 |
+
)
|
568 |
+
schedule_fn = optax.join_schedules(
|
569 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
570 |
+
)
|
571 |
+
return schedule_fn
|
572 |
+
|
573 |
+
learning_rate_fn = create_learning_rate_fn()
|
574 |
|
575 |
# We use Optax's "masking" functionality to not apply weight decay
|
576 |
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
|
594 |
return traverse_util.unflatten_dict(flat_mask)
|
595 |
|
596 |
# create adam optimizer
|
597 |
+
if training_args.optim == "distributed_shampoo":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
599 |
# Notes:
|
600 |
+
# - mask for weight decay is not implemented
|
601 |
optimizer = distributed_shampoo(
|
602 |
learning_rate_fn,
|
603 |
+
block_size=training_args.block_size,
|
604 |
+
beta1=training_args.beta1,
|
605 |
+
beta2=training_args.beta2,
|
606 |
diagonal_epsilon=1e-10,
|
607 |
matrix_epsilon=1e-8,
|
608 |
+
weight_decay=training_args.weight_decay
|
609 |
+
if training_args.weight_decay is not None
|
610 |
+
else 0.0,
|
611 |
+
start_preconditioning_step=training_args.warmup_steps,
|
612 |
+
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
613 |
statistics_compute_steps=1,
|
614 |
best_effort_shape_interpretation=True,
|
615 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
|
618 |
batch_axis_name="batch",
|
619 |
inverse_failure_threshold=0.1,
|
620 |
moving_average_for_momentum=True,
|
621 |
+
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
622 |
clip_by_scaled_gradient_norm=None,
|
623 |
precision=jax.lax.Precision.HIGHEST,
|
624 |
+
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
625 |
)
|
626 |
|
627 |
+
elif training_args.optim == "adam":
|
628 |
optimizer = optax.adamw(
|
629 |
learning_rate=learning_rate_fn,
|
630 |
+
b1=training_args.beta1,
|
631 |
+
b2=training_args.beta2,
|
632 |
eps=training_args.adam_epsilon,
|
633 |
weight_decay=training_args.weight_decay
|
634 |
if training_args.weight_decay is not None
|
635 |
else 0.0,
|
636 |
mask=decay_mask_fn,
|
637 |
)
|
638 |
+
elif training_args.optim == "adafactor":
|
639 |
+
# We use the default parameters here to initialize adafactor,
|
640 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
641 |
+
optimizer = optax.adafactor(
|
642 |
+
learning_rate=learning_rate_fn,
|
643 |
+
weight_decay_rate=training_args.weight_decay,
|
644 |
+
weight_decay_mask=decay_mask_fn,
|
645 |
+
clipping_threshold=training_args.max_grad_norm,
|
646 |
+
)
|
647 |
|
648 |
# add gradient accumulation
|
649 |
if training_args.gradient_accumulation_steps > 1:
|
|
|
859 |
|
860 |
wandb.run.log_artifact(artifact)
|
861 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
# init variables
|
863 |
last_time = time.perf_counter()
|
864 |
train_metrics = None
|
|
|
869 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
870 |
|
871 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
872 |
+
train_loader = dataset.dataloader("train", train_batch_size, epoch)
|
873 |
# train
|
874 |
for batch in tqdm(
|
875 |
train_loader,
|