Spaces:
Running
Running
feat: restore weights on CPU
Browse files- src/dalle_mini/model/modeling.py +255 -3
- tools/train/train.py +10 -8
src/dalle_mini/model/modeling.py
CHANGED
@@ -15,16 +15,30 @@
|
|
15 |
""" DalleBart model. """
|
16 |
|
17 |
import math
|
|
|
18 |
from functools import partial
|
19 |
-
from
|
|
|
20 |
|
21 |
import flax.linen as nn
|
22 |
import jax
|
23 |
import jax.numpy as jnp
|
|
|
24 |
from flax.core.frozen_dict import unfreeze
|
25 |
from flax.linen import make_causal_mask
|
26 |
-
from flax.
|
|
|
|
|
27 |
from jax.random import PRNGKey
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
from transformers.modeling_flax_outputs import (
|
29 |
FlaxCausalLMOutputWithCrossAttentions,
|
30 |
FlaxSeq2SeqLMOutput,
|
@@ -300,7 +314,8 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
300 |
- added num_params property
|
301 |
- config_class replaced to DalleBartConfig
|
302 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
303 |
-
- init weights on CPU
|
|
|
304 |
"""
|
305 |
|
306 |
config_class = DalleBartConfig
|
@@ -359,6 +374,243 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
|
359 |
).values()
|
360 |
return sum(list(num_params))
|
361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
|
363 |
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
364 |
"""
|
|
|
15 |
""" DalleBart model. """
|
16 |
|
17 |
import math
|
18 |
+
import os
|
19 |
from functools import partial
|
20 |
+
from pickle import UnpicklingError
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
|
23 |
import flax.linen as nn
|
24 |
import jax
|
25 |
import jax.numpy as jnp
|
26 |
+
import msgpack.exceptions
|
27 |
from flax.core.frozen_dict import unfreeze
|
28 |
from flax.linen import make_causal_mask
|
29 |
+
from flax.serialization import from_bytes
|
30 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
31 |
+
from jax import lax
|
32 |
from jax.random import PRNGKey
|
33 |
+
from transformers.configuration_utils import PretrainedConfig
|
34 |
+
from transformers.file_utils import (
|
35 |
+
FLAX_WEIGHTS_NAME,
|
36 |
+
WEIGHTS_NAME,
|
37 |
+
cached_path,
|
38 |
+
hf_bucket_url,
|
39 |
+
is_offline_mode,
|
40 |
+
is_remote_url,
|
41 |
+
)
|
42 |
from transformers.modeling_flax_outputs import (
|
43 |
FlaxCausalLMOutputWithCrossAttentions,
|
44 |
FlaxSeq2SeqLMOutput,
|
|
|
314 |
- added num_params property
|
315 |
- config_class replaced to DalleBartConfig
|
316 |
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
317 |
+
- init weights on CPU with `load_on_cpu`
|
318 |
+
- restore weights on CPU with custom `from_pretrained`
|
319 |
"""
|
320 |
|
321 |
config_class = DalleBartConfig
|
|
|
374 |
).values()
|
375 |
return sum(list(num_params))
|
376 |
|
377 |
+
@classmethod
|
378 |
+
def from_pretrained(
|
379 |
+
cls,
|
380 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
381 |
+
dtype: jnp.dtype = jnp.float32,
|
382 |
+
*model_args,
|
383 |
+
**kwargs,
|
384 |
+
):
|
385 |
+
config = kwargs.pop("config", None)
|
386 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
387 |
+
from_pt = kwargs.pop("from_pt", False)
|
388 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
389 |
+
force_download = kwargs.pop("force_download", False)
|
390 |
+
resume_download = kwargs.pop("resume_download", False)
|
391 |
+
proxies = kwargs.pop("proxies", None)
|
392 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
393 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
394 |
+
revision = kwargs.pop("revision", None)
|
395 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
396 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
397 |
+
|
398 |
+
user_agent = {
|
399 |
+
"file_type": "model",
|
400 |
+
"framework": "flax",
|
401 |
+
"from_auto_class": from_auto_class,
|
402 |
+
}
|
403 |
+
if from_pipeline is not None:
|
404 |
+
user_agent["using_pipeline"] = from_pipeline
|
405 |
+
|
406 |
+
if is_offline_mode() and not local_files_only:
|
407 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
408 |
+
local_files_only = True
|
409 |
+
|
410 |
+
# Load config if we don't provide a configuration
|
411 |
+
if not isinstance(config, PretrainedConfig):
|
412 |
+
config_path = (
|
413 |
+
config if config is not None else pretrained_model_name_or_path
|
414 |
+
)
|
415 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
416 |
+
config_path,
|
417 |
+
cache_dir=cache_dir,
|
418 |
+
return_unused_kwargs=True,
|
419 |
+
force_download=force_download,
|
420 |
+
resume_download=resume_download,
|
421 |
+
proxies=proxies,
|
422 |
+
local_files_only=local_files_only,
|
423 |
+
use_auth_token=use_auth_token,
|
424 |
+
revision=revision,
|
425 |
+
_from_auto=from_auto_class,
|
426 |
+
_from_pipeline=from_pipeline,
|
427 |
+
**kwargs,
|
428 |
+
)
|
429 |
+
else:
|
430 |
+
model_kwargs = kwargs
|
431 |
+
|
432 |
+
# Add the dtype to model_kwargs
|
433 |
+
model_kwargs["dtype"] = dtype
|
434 |
+
|
435 |
+
# Load model
|
436 |
+
if pretrained_model_name_or_path is not None:
|
437 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
438 |
+
if from_pt and os.path.isfile(
|
439 |
+
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
440 |
+
):
|
441 |
+
# Load from a PyTorch checkpoint
|
442 |
+
archive_file = os.path.join(
|
443 |
+
pretrained_model_name_or_path, WEIGHTS_NAME
|
444 |
+
)
|
445 |
+
elif os.path.isfile(
|
446 |
+
os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
447 |
+
):
|
448 |
+
# Load from a Flax checkpoint
|
449 |
+
archive_file = os.path.join(
|
450 |
+
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
|
451 |
+
)
|
452 |
+
else:
|
453 |
+
raise EnvironmentError(
|
454 |
+
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
|
455 |
+
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
456 |
+
)
|
457 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
|
458 |
+
pretrained_model_name_or_path
|
459 |
+
):
|
460 |
+
archive_file = pretrained_model_name_or_path
|
461 |
+
else:
|
462 |
+
archive_file = hf_bucket_url(
|
463 |
+
pretrained_model_name_or_path,
|
464 |
+
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
465 |
+
revision=revision,
|
466 |
+
)
|
467 |
+
|
468 |
+
# redirect to the cache, if necessary
|
469 |
+
try:
|
470 |
+
resolved_archive_file = cached_path(
|
471 |
+
archive_file,
|
472 |
+
cache_dir=cache_dir,
|
473 |
+
force_download=force_download,
|
474 |
+
proxies=proxies,
|
475 |
+
resume_download=resume_download,
|
476 |
+
local_files_only=local_files_only,
|
477 |
+
use_auth_token=use_auth_token,
|
478 |
+
user_agent=user_agent,
|
479 |
+
)
|
480 |
+
except EnvironmentError as err:
|
481 |
+
logger.error(err)
|
482 |
+
msg = (
|
483 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
484 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
485 |
+
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
486 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
487 |
+
)
|
488 |
+
raise EnvironmentError(msg)
|
489 |
+
|
490 |
+
if resolved_archive_file == archive_file:
|
491 |
+
logger.info(f"loading weights file {archive_file}")
|
492 |
+
else:
|
493 |
+
logger.info(
|
494 |
+
f"loading weights file {archive_file} from cache at {resolved_archive_file}"
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
resolved_archive_file = None
|
498 |
+
|
499 |
+
# init random models
|
500 |
+
model = cls(config, *model_args, **model_kwargs)
|
501 |
+
|
502 |
+
with open(resolved_archive_file, "rb") as state_f:
|
503 |
+
try:
|
504 |
+
state = from_bytes(cls, state_f.read())
|
505 |
+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
506 |
+
try:
|
507 |
+
with open(resolved_archive_file) as f:
|
508 |
+
if f.read().startswith("version"):
|
509 |
+
raise OSError(
|
510 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
511 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
512 |
+
"you cloned."
|
513 |
+
)
|
514 |
+
else:
|
515 |
+
raise ValueError from e
|
516 |
+
except (UnicodeDecodeError, ValueError):
|
517 |
+
raise EnvironmentError(
|
518 |
+
f"Unable to convert {archive_file} to Flax deserializable object. "
|
519 |
+
)
|
520 |
+
|
521 |
+
# if model is base model only use model_prefix key
|
522 |
+
if (
|
523 |
+
cls.base_model_prefix not in dict(model.params)
|
524 |
+
and cls.base_model_prefix in state
|
525 |
+
):
|
526 |
+
state = state[cls.base_model_prefix]
|
527 |
+
|
528 |
+
# if model is head model and we are loading weights from base model
|
529 |
+
# we initialize new params dict with base_model_prefix
|
530 |
+
if (
|
531 |
+
cls.base_model_prefix in dict(model.params)
|
532 |
+
and cls.base_model_prefix not in state
|
533 |
+
):
|
534 |
+
state = {cls.base_model_prefix: state}
|
535 |
+
|
536 |
+
# flatten dicts
|
537 |
+
state = flatten_dict(state)
|
538 |
+
|
539 |
+
random_state = flatten_dict(unfreeze(model.params))
|
540 |
+
|
541 |
+
missing_keys = model.required_params - set(state.keys())
|
542 |
+
unexpected_keys = set(state.keys()) - model.required_params
|
543 |
+
|
544 |
+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
545 |
+
# matching the weights in the model.
|
546 |
+
mismatched_keys = []
|
547 |
+
for key in state.keys():
|
548 |
+
if key in random_state and state[key].shape != random_state[key].shape:
|
549 |
+
if ignore_mismatched_sizes:
|
550 |
+
mismatched_keys.append(
|
551 |
+
(key, state[key].shape, random_state[key].shape)
|
552 |
+
)
|
553 |
+
state[key] = random_state[key]
|
554 |
+
else:
|
555 |
+
raise ValueError(
|
556 |
+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
557 |
+
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
|
558 |
+
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
|
559 |
+
"model."
|
560 |
+
)
|
561 |
+
|
562 |
+
# add missing keys as random parameters
|
563 |
+
for missing_key in missing_keys:
|
564 |
+
state[missing_key] = random_state[missing_key]
|
565 |
+
|
566 |
+
# remove unexpected keys to not be saved again
|
567 |
+
for unexpected_key in unexpected_keys:
|
568 |
+
del state[unexpected_key]
|
569 |
+
|
570 |
+
if len(unexpected_keys) > 0:
|
571 |
+
logger.warning(
|
572 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
573 |
+
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
574 |
+
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
575 |
+
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
576 |
+
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
577 |
+
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
578 |
+
)
|
579 |
+
else:
|
580 |
+
logger.info(
|
581 |
+
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
|
582 |
+
)
|
583 |
+
|
584 |
+
if len(missing_keys) > 0:
|
585 |
+
logger.warning(
|
586 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
587 |
+
f"and are newly initialized: {missing_keys}\n"
|
588 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
589 |
+
)
|
590 |
+
elif len(mismatched_keys) == 0:
|
591 |
+
logger.info(
|
592 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
593 |
+
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
594 |
+
f"you can already use {model.__class__.__name__} for predictions without further training."
|
595 |
+
)
|
596 |
+
if len(mismatched_keys) > 0:
|
597 |
+
mismatched_warning = "\n".join(
|
598 |
+
[
|
599 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
600 |
+
for key, shape1, shape2 in mismatched_keys
|
601 |
+
]
|
602 |
+
)
|
603 |
+
logger.warning(
|
604 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
605 |
+
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
606 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
607 |
+
)
|
608 |
+
|
609 |
+
# set correct parameters
|
610 |
+
model.params = unflatten_dict(state)
|
611 |
+
|
612 |
+
return model
|
613 |
+
|
614 |
|
615 |
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
616 |
"""
|
tools/train/train.py
CHANGED
@@ -249,6 +249,9 @@ class TrainingArguments:
|
|
249 |
"help": "Number of updates steps to accumulate before performing an update pass."
|
250 |
},
|
251 |
)
|
|
|
|
|
|
|
252 |
|
253 |
learning_rate: float = field(
|
254 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
@@ -515,10 +518,8 @@ def main():
|
|
515 |
load_on_cpu=True,
|
516 |
)
|
517 |
|
518 |
-
#
|
519 |
-
|
520 |
-
model_args.tokenizer_name, use_fast=True
|
521 |
-
)
|
522 |
|
523 |
# get PartitionSpec for model params (required to be a dict)
|
524 |
param_spec = set_partitions(model.params)
|
@@ -526,14 +527,15 @@ def main():
|
|
526 |
# convert params to frozen dict
|
527 |
model._params = freeze(model.params)
|
528 |
|
|
|
|
|
|
|
|
|
|
|
529 |
# Preprocessing the datasets.
|
530 |
# We need to normalize and tokenize inputs and targets.
|
531 |
-
|
532 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
533 |
|
534 |
-
# no dropout (hardcoded)
|
535 |
-
model.config.dropout = 0.0
|
536 |
-
|
537 |
# Initialize our training
|
538 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
539 |
|
|
|
249 |
"help": "Number of updates steps to accumulate before performing an update pass."
|
250 |
},
|
251 |
)
|
252 |
+
gradient_checkpointing: bool = field(
|
253 |
+
default=False, metadata={"help": "Use gradient checkpointing."}
|
254 |
+
)
|
255 |
|
256 |
learning_rate: float = field(
|
257 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
|
|
518 |
load_on_cpu=True,
|
519 |
)
|
520 |
|
521 |
+
# update model config per training args
|
522 |
+
model.config.gradient_checkpointing = training_args.gradient_checkpointing
|
|
|
|
|
523 |
|
524 |
# get PartitionSpec for model params (required to be a dict)
|
525 |
param_spec = set_partitions(model.params)
|
|
|
527 |
# convert params to frozen dict
|
528 |
model._params = freeze(model.params)
|
529 |
|
530 |
+
# Load tokenizer
|
531 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
532 |
+
model_args.tokenizer_name, use_fast=True
|
533 |
+
)
|
534 |
+
|
535 |
# Preprocessing the datasets.
|
536 |
# We need to normalize and tokenize inputs and targets.
|
|
|
537 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
538 |
|
|
|
|
|
|
|
539 |
# Initialize our training
|
540 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
541 |
|