File size: 3,836 Bytes
5dfd462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python

from __future__ import annotations

import pathlib

import gradio as gr
import slugify

from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
from uploader import Uploader
from utils import find_exp_dirs


class ModelUploader(Uploader):
    def upload_model(
        self,
        folder_path: str,
        repo_name: str,
        upload_to: str,
        private: bool,
        delete_existing_repo: bool,
        input_token: str | None = None,
    ) -> str:
        if not folder_path:
            raise ValueError
        if not repo_name:
            repo_name = pathlib.Path(folder_path).name
        repo_name = slugify.slugify(repo_name)

        if upload_to == UploadTarget.PERSONAL_PROFILE.value:
            organization = ''
        elif upload_to == UploadTarget.MODEL_LIBRARY.value:
            organization = MODEL_LIBRARY_ORG_NAME
        else:
            raise ValueError

        return self.upload(folder_path,
                           repo_name,
                           organization=organization,
                           private=private,
                           delete_existing_repo=delete_existing_repo,
                           input_token=input_token)


def load_local_model_list() -> dict:
    choices = find_exp_dirs()
    return gr.update(choices=choices, value=choices[0] if choices else None)


def create_upload_demo(hf_token: str | None) -> gr.Blocks:
    uploader = ModelUploader(hf_token)
    model_dirs = find_exp_dirs()

    with gr.Blocks() as demo:
        with gr.Box():
            gr.Markdown('Local Models')
            reload_button = gr.Button('Reload Model List')
            model_dir = gr.Dropdown(
                label='Model names',
                choices=model_dirs,
                value=model_dirs[0] if model_dirs else None)
        with gr.Box():
            gr.Markdown('Upload Settings')
            with gr.Row():
                use_private_repo = gr.Checkbox(label='Private', value=True)
                delete_existing_repo = gr.Checkbox(
                    label='Delete existing repo of the same name', value=False)
            upload_to = gr.Radio(label='Upload to',
                                 choices=[_.value for _ in UploadTarget],
                                 value=UploadTarget.MODEL_LIBRARY.value)
            model_name = gr.Textbox(label='Model Name')
            input_token = gr.Text(label='Hugging Face Write Token',
                                  placeholder='',
                                  visible=False if hf_token else True)
        upload_button = gr.Button('Upload')
        gr.Markdown(f'''
            - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
            ''')
        with gr.Box():
            gr.Markdown('Output message')
            output_message = gr.Markdown()

        reload_button.click(fn=load_local_model_list,
                            inputs=None,
                            outputs=model_dir)
        upload_button.click(fn=uploader.upload_model,
                            inputs=[
                                model_dir,
                                model_name,
                                upload_to,
                                use_private_repo,
                                delete_existing_repo,
                                input_token,
                            ],
                            outputs=output_message)

    return demo


if __name__ == '__main__':
    import os

    hf_token = os.getenv('HF_TOKEN')
    demo = create_upload_demo(hf_token)
    demo.queue(max_size=1).launch(share=False)