|
from huggingface_hub import ModelHubMixin, snapshot_download |
|
import os |
|
from huggingface_hub.utils import ( |
|
get_tf_version, |
|
is_graphviz_available, |
|
is_pydot_available, |
|
is_tf_available, |
|
yaml_dump, |
|
) |
|
|
|
def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": |
|
r""" |
|
Instantiate a pretrained Keras model from a pre-trained model from the Hub. |
|
The model is expected to be in `SavedModel` format. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str` or `os.PathLike`): |
|
Can be either: |
|
- A string, the `model id` of a pretrained model hosted inside a |
|
model repo on huggingface.co. Valid model ids can be located |
|
at the root-level, like `bert-base-uncased`, or namespaced |
|
under a user or organization name, like |
|
`dbmdz/bert-base-german-cased`. |
|
- You can add `revision` by appending `@` at the end of model_id |
|
simply like this: `dbmdz/bert-base-german-cased@main` Revision |
|
is the specific model version to use. It can be a branch name, |
|
a tag name, or a commit id, since we use a git-based system |
|
for storing models and other artifacts on huggingface.co, so |
|
`revision` can be any identifier allowed by git. |
|
- A path to a `directory` containing model weights saved using |
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., |
|
`./my_model_directory/`. |
|
- `None` if you are both providing the configuration and state |
|
dictionary (resp. with keyword arguments `config` and |
|
`state_dict`). |
|
force_download (`bool`, *optional*, defaults to `False`): |
|
Whether to force the (re-)download of the model weights and |
|
configuration files, overriding the cached versions if they exist. |
|
resume_download (`bool`, *optional*, defaults to `False`): |
|
Whether to delete incompletely received files. Will attempt to |
|
resume the download if such a file exists. |
|
proxies (`Dict[str, str]`, *optional*): |
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., |
|
`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The |
|
proxies are used on each request. |
|
token (`str` or `bool`, *optional*): |
|
The token to use as HTTP bearer authorization for remote files. If |
|
`True`, will use the token generated when running `transformers-cli |
|
login` (stored in `~/.huggingface`). |
|
cache_dir (`Union[str, os.PathLike]`, *optional*): |
|
Path to a directory in which a downloaded pretrained model |
|
configuration should be cached if the standard cache should not be |
|
used. |
|
local_files_only(`bool`, *optional*, defaults to `False`): |
|
Whether to only look at local files (i.e., do not try to download |
|
the model). |
|
model_kwargs (`Dict`, *optional*): |
|
model_kwargs will be passed to the model during initialization |
|
|
|
<Tip> |
|
|
|
Passing `token=True` is required when you want to use a private |
|
model. |
|
|
|
</Tip> |
|
""" |
|
return KerasModelHubMixin.from_pretrained(*args, **kwargs) |
|
|
|
|
|
class KerasModelHubMixin(ModelHubMixin): |
|
""" |
|
Implementation of [`ModelHubMixin`] to provide model Hub upload/download |
|
capabilities to Keras models. |
|
|
|
|
|
```python |
|
>>> import tensorflow as tf |
|
>>> from huggingface_hub import KerasModelHubMixin |
|
|
|
|
|
>>> class MyModel(tf.keras.Model, KerasModelHubMixin): |
|
... def __init__(self, **kwargs): |
|
... super().__init__() |
|
... self.config = kwargs.pop("config", None) |
|
... self.dummy_inputs = ... |
|
... self.layer = ... |
|
|
|
... def call(self, *args): |
|
... return ... |
|
|
|
|
|
>>> # Initialize and compile the model as you normally would |
|
>>> model = MyModel() |
|
>>> model.compile(...) |
|
>>> # Build the graph by training it or passing dummy inputs |
|
>>> _ = model(model.dummy_inputs) |
|
>>> # Save model weights to local directory |
|
>>> model.save_pretrained("my-awesome-model") |
|
>>> # Push model weights to the Hub |
|
>>> model.push_to_hub("my-awesome-model") |
|
>>> # Download and initialize weights from the Hub |
|
>>> model = MyModel.from_pretrained("username/super-cool-model") |
|
``` |
|
""" |
|
|
|
@classmethod |
|
def _from_pretrained( |
|
cls, |
|
model_id, |
|
revision, |
|
cache_dir, |
|
force_download, |
|
proxies, |
|
resume_download, |
|
local_files_only, |
|
token, |
|
**model_kwargs, |
|
): |
|
"""Here we just call [`from_pretrained_keras`] function so both the mixin and |
|
functional APIs stay in sync. |
|
|
|
TODO - Some args above aren't used since we are calling |
|
snapshot_download instead of hf_hub_download. |
|
""" |
|
if is_tf_available(): |
|
import tensorflow as tf |
|
else: |
|
raise ImportError("Called a TensorFlow-specific function but could not import it.") |
|
|
|
|
|
cfg = model_kwargs.pop("config", None) |
|
|
|
|
|
if not os.path.isdir(model_id): |
|
storage_folder = snapshot_download( |
|
repo_id=model_id, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
library_name="keras", |
|
library_version=get_tf_version(), |
|
) |
|
else: |
|
storage_folder = model_id |
|
|
|
|
|
files = os.listdir(storage_folder) |
|
modelFileName = None |
|
for file in files : |
|
if file.endswith(".keras"): |
|
modelFileName = file |
|
break |
|
|
|
if modelFileName==None: |
|
raise ValueError("Repesitory does not have model ends with .keras") |
|
|
|
modelPath = storage_folder + '/' + modelFileName |
|
|
|
model = tf.keras.models.load_model(modelPath, **model_kwargs) |
|
|
|
|
|
model.config = cfg |
|
|
|
return model |
|
|
|
|