Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
from pathlib import Path | |
from typing import Iterable, List | |
import gradio as gr | |
import kagglehub | |
from gradio_logsview.logsview import Log, LogsView, LogsViewRunner | |
from huggingface_hub import HfApi | |
KAGGLE_JSON = os.environ.get("KAGGLE_JSON") | |
KAGGLE_JSON_PATH = Path("~/.kaggle/kaggle.json").expanduser().resolve() | |
if KAGGLE_JSON_PATH.exists(): | |
print(f"Found existing kaggle.json file at {KAGGLE_JSON_PATH}") | |
elif KAGGLE_JSON is not None: | |
print( | |
"KAGGLE_JSON is set as secret. Will be able to be authenticated when downloading files from Kaggle." | |
) | |
KAGGLE_JSON_PATH.parent.mkdir(parents=True, exist_ok=True) | |
KAGGLE_JSON_PATH.write_text(KAGGLE_JSON) | |
else: | |
print( | |
f"No kaggle.json file found at {KAGGLE_JSON_PATH}. You will not be able to download private/gated files from Kaggle." | |
) | |
MARKDOWN_DESCRIPTION = """ | |
# Keggla-importer GUI | |
The fastest way to import a model from KaggleHub to the Hugging Face Hub π₯ | |
Specify a Kaggle handle and a Hugging Face Write Token to import a model from KaggleHub to the Hugging Face Hub. | |
To find the Kaggle handle from a web UI, click on the "download dropdown" and copy the handle from the code snippet. | |
Example: `"keras/gemma/keras/gemma_instruct_2b_en"`. | |
""" | |
if KAGGLE_JSON_PATH.exists(): | |
MARKDOWN_DESCRIPTION += """ | |
**Note**: a `kaggle.json` file exists in the home directory. This means the Space will be able to download **SOME** private/gated files from Kaggle. | |
To access other models, please duplicate this Space to a private Space and set the `KAGGLE_JSON` environment variable with the content of the `kaggle.json` | |
you've downloaded from your Kaggle user account. | |
""" | |
def import_model( | |
kaggle_model: str, repo_name: str, token: gr.OAuthToken | None | |
) -> Iterable[List[Log]]: | |
runner = LogsViewRunner() | |
if not kaggle_model: | |
yield runner.log("Kaggle model is required.", level="ERROR") | |
raise gr.Error("Kaggle model is required.") | |
if not repo_name: | |
repo_name = kaggle_model.split("/")[-1] | |
if not token: | |
yield runner.log("You must sign in with HF before proceeding.", level="ERROR") | |
raise gr.Error("Authentication is required.") | |
api = HfApi(token=token.token) | |
yield runner.log(f"Creating HF repo {repo_name}") | |
repo_url = api.create_repo(repo_name, exist_ok=True) | |
yield runner.log(f"Created HF repo: {repo_url}") | |
repo_id = repo_url.repo_id | |
model_id = api.model_info(repo_id) | |
if len(model_id.siblings) > 1: | |
yield runner.log( | |
f"Model repo {repo_id} is not empty. Please delete it or set a different repo name.", | |
level="ERROR", | |
) | |
return | |
yield runner.log(f"Downloading model {kaggle_model} from Kaggle.") | |
yield from runner.run_python(kagglehub.model_download, handle=kaggle_model) | |
if runner.exit_code != 0: | |
yield runner.log("Failed to download model from Kaggle.", level="ERROR") | |
api.delete_repo(repo_id=repo_id) | |
return | |
cache_path = kagglehub.model_download(kaggle_model) # should be instant | |
yield runner.log(f"Model successfully downloaded from Kaggle to {cache_path}.") | |
yield runner.log(f"Uploading model to HF repo {repo_id}.") | |
yield from runner.run_python( | |
api.upload_folder, repo_id=repo_id, folder_path=cache_path | |
) | |
if runner.exit_code != 0: | |
yield runner.log("Failed to upload model to HF repo.", level="ERROR") | |
api.delete_repo(repo_id=repo_id) | |
return | |
yield runner.log(f"Model successfully uploaded to HF: {repo_url}.") | |
yield runner.log(f"Deleting local cache from {cache_path}.") | |
shutil.rmtree(cache_path) | |
yield runner.log("Done!") | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN_DESCRIPTION) | |
with gr.Row(): | |
kaggle_model = gr.Textbox( | |
lines=1, | |
label="Kaggle Model*", | |
placeholder="keras/codegemma/keras/code_gemma_7b_en", | |
) | |
repo_name = gr.Textbox( | |
lines=1, | |
label="Repo name", | |
placeholder="Optional. Will infer from Kaggle Model if empty.", | |
) | |
gr.LoginButton(min_width=250) | |
button = gr.Button("Import", variant="primary") | |
logs = LogsView(label="Terminal output") | |
button.click(fn=import_model, inputs=[kaggle_model, repo_name], outputs=[logs]) | |
demo.queue(default_concurrency_limit=1).launch() | |