GPT-2-LoRA-HealthCare / importHuggingFaceHubModel.py
PascalZhan's picture
Create importHuggingFaceHubModel.py
be6d5bf verified
from huggingface_hub import ModelHubMixin, snapshot_download
import os
from huggingface_hub.utils import (
get_tf_version,
is_graphviz_available,
is_pydot_available,
is_tf_available,
yaml_dump,
)
def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin":
r"""
Instantiate a pretrained Keras model from a pre-trained model from the Hub.
The model is expected to be in `SavedModel` format.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the `model id` of a pretrained model hosted inside a
model repo on huggingface.co. Valid model ids can be located
at the root-level, like `bert-base-uncased`, or namespaced
under a user or organization name, like
`dbmdz/bert-base-german-cased`.
- You can add `revision` by appending `@` at the end of model_id
simply like this: `dbmdz/bert-base-german-cased@main` Revision
is the specific model version to use. It can be a branch name,
a tag name, or a commit id, since we use a git-based system
for storing models and other artifacts on huggingface.co, so
`revision` can be any identifier allowed by git.
- A path to a `directory` containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g.,
`./my_model_directory/`.
- `None` if you are both providing the configuration and state
dictionary (resp. with keyword arguments `config` and
`state_dict`).
force_download (`bool`, *optional*, defaults to `False`):
Whether to force the (re-)download of the model weights and
configuration files, overriding the cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether to delete incompletely received files. Will attempt to
resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The
proxies are used on each request.
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If
`True`, will use the token generated when running `transformers-cli
login` (stored in `~/.huggingface`).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model
configuration should be cached if the standard cache should not be
used.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only look at local files (i.e., do not try to download
the model).
model_kwargs (`Dict`, *optional*):
model_kwargs will be passed to the model during initialization
<Tip>
Passing `token=True` is required when you want to use a private
model.
</Tip>
"""
return KerasModelHubMixin.from_pretrained(*args, **kwargs)
class KerasModelHubMixin(ModelHubMixin):
"""
Implementation of [`ModelHubMixin`] to provide model Hub upload/download
capabilities to Keras models.
```python
>>> import tensorflow as tf
>>> from huggingface_hub import KerasModelHubMixin
>>> class MyModel(tf.keras.Model, KerasModelHubMixin):
... def __init__(self, **kwargs):
... super().__init__()
... self.config = kwargs.pop("config", None)
... self.dummy_inputs = ...
... self.layer = ...
... def call(self, *args):
... return ...
>>> # Initialize and compile the model as you normally would
>>> model = MyModel()
>>> model.compile(...)
>>> # Build the graph by training it or passing dummy inputs
>>> _ = model(model.dummy_inputs)
>>> # Save model weights to local directory
>>> model.save_pretrained("my-awesome-model")
>>> # Push model weights to the Hub
>>> model.push_to_hub("my-awesome-model")
>>> # Download and initialize weights from the Hub
>>> model = MyModel.from_pretrained("username/super-cool-model")
```
"""
@classmethod
def _from_pretrained(
cls,
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
token,
**model_kwargs,
):
"""Here we just call [`from_pretrained_keras`] function so both the mixin and
functional APIs stay in sync.
TODO - Some args above aren't used since we are calling
snapshot_download instead of hf_hub_download.
"""
if is_tf_available():
import tensorflow as tf
else:
raise ImportError("Called a TensorFlow-specific function but could not import it.")
# TODO - Figure out what to do about these config values. Config is not going to be needed to load model
cfg = model_kwargs.pop("config", None)
# Root is either a local filepath matching model_id or a cached snapshot
if not os.path.isdir(model_id):
storage_folder = snapshot_download(
repo_id=model_id,
revision=revision,
cache_dir=cache_dir,
library_name="keras",
library_version=get_tf_version(),
)
else:
storage_folder = model_id
# Search the first file which ends with .keras
files = os.listdir(storage_folder)
modelFileName = None
for file in files :
if file.endswith(".keras"):
modelFileName = file
break
if modelFileName==None:
raise ValueError("Repesitory does not have model ends with .keras")
modelPath = storage_folder + '/' + modelFileName
model = tf.keras.models.load_model(modelPath, **model_kwargs)
# For now, we add a new attribute, config, to store the config loaded from the hub/a local dir.
model.config = cfg
return model