Spaces:
Runtime error
Runtime error
# Apache Software License 2.0 | |
# | |
# Copyright (c) ZenML GmbH 2023. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
from typing_extensions import Annotated | |
from huggingface_hub import create_branch, login, HfApi | |
from zenml import step, log_artifact_metadata | |
from zenml.client import Client | |
from zenml.logger import get_logger | |
# Initialize logger | |
logger = get_logger(__name__) | |
def deploy_to_huggingface( | |
repo_name: str, | |
) -> Annotated[str, "huggingface_url"]: | |
""" | |
This step deploy the model to huggingface. | |
Args: | |
repo_name: The name of the repo to create/use on huggingface. | |
""" | |
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ### | |
secret = Client().get_secret("huggingface_creds") | |
assert secret, "No secret found with name 'huggingface_creds'. Please create one that includes your `username` and `token`." | |
token = secret.secret_values["token"] | |
api = HfApi(token=token) | |
hf_repo = api.create_repo(repo_id=repo_name, repo_type="space", space_sdk="gradio", exist_ok=True) | |
zenml_repo_root = Client().root | |
if not zenml_repo_root: | |
logger.warning( | |
"You're running the `deploy_to_huggingface` step outside of a ZenML repo. " | |
"Since the deployment step to huggingface is all about pushing the repo to huggingface, " | |
"this step will not work outside of a ZenML repo where the gradio folder is present." | |
) | |
raise | |
url = api.upload_folder( | |
folder_path=zenml_repo_root, repo_id=hf_repo.repo_id, repo_type="space", | |
) | |
repo_commits = api.list_repo_commits( | |
repo_id=hf_repo.repo_id, | |
repo_type="space", | |
) | |
log_artifact_metadata( | |
artifact_name="huggingface_url", | |
metadata={ | |
"repo_id": hf_repo.repo_id, | |
"revision": repo_commits[0].commit_id, | |
}, | |
) | |
logger.info(f"Model updated: {url}") | |
### YOUR CODE ENDS HERE ### | |
return url | |