Spaces:
Build error
Build error
from huggingface_hub import PyTorchModelHubMixin | |
from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME | |
from huggingface_hub.file_download import hf_hub_download | |
from .unifiedmodel import RRUM | |
import os | |
import torch | |
class YoutubeVideoSimilarityModel(RRUM, PyTorchModelHubMixin): | |
""" | |
Hugging Face `PyTorchModelHubMixin` wrapper for RegretsReporter `RRUM` model. | |
This allows loading, using, and saving the model from Hugging Face model hub | |
with default Hugging Face methods `from_pretrained` and `save_pretrained`. | |
""" | |
def _from_pretrained( | |
cls, | |
model_id, | |
revision, | |
cache_dir, | |
force_download, | |
proxies, | |
resume_download, | |
local_files_only, | |
use_auth_token, | |
map_location="cpu", | |
strict=False, | |
**model_kwargs, | |
): | |
map_location = torch.device(map_location) | |
if os.path.isdir(model_id): | |
print("Loading weights from local directory") | |
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) | |
else: | |
model_file = hf_hub_download( | |
repo_id=model_id, | |
filename=PYTORCH_WEIGHTS_NAME, | |
revision=revision, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
proxies=proxies, | |
resume_download=resume_download, | |
use_auth_token=use_auth_token, | |
local_files_only=local_files_only, | |
) | |
# convert Huggingface config to RRUM acceptable input parameters | |
if "config" in model_kwargs: | |
model_kwargs = {**model_kwargs["config"], **model_kwargs} | |
del model_kwargs["config"] | |
model = cls(**model_kwargs) | |
state_dict = torch.load(model_file, map_location=map_location) | |
model.load_state_dict(state_dict, strict=strict) | |
model.eval() | |
return model | |