ShaoTengLiu commited on
Commit
e5176ce
·
1 Parent(s): 2d81855

add UI from TAV

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +57 -0
  3. LICENSE +21 -0
  4. Video-P2P-Demo/.DS_Store +0 -0
  5. Video-P2P-Demo/README.md +27 -0
  6. {configs → Video-P2P-Demo/configs}/.DS_Store +0 -0
  7. {configs → Video-P2P-Demo/configs}/man-motor-tune.yaml +0 -0
  8. {configs → Video-P2P-Demo/configs}/rabbit-jump-p2p.yaml +0 -0
  9. {configs → Video-P2P-Demo/configs}/rabbit-jump-tune.yaml +0 -0
  10. {data → Video-P2P-Demo/data}/.DS_Store +0 -0
  11. {data → Video-P2P-Demo/data}/motorbike/.DS_Store +0 -0
  12. {data → Video-P2P-Demo/data}/motorbike/1.jpg +0 -0
  13. {data → Video-P2P-Demo/data}/motorbike/2.jpg +0 -0
  14. {data → Video-P2P-Demo/data}/motorbike/3.jpg +0 -0
  15. {data → Video-P2P-Demo/data}/motorbike/4.jpg +0 -0
  16. {data → Video-P2P-Demo/data}/motorbike/5.jpg +0 -0
  17. {data → Video-P2P-Demo/data}/motorbike/6.jpg +0 -0
  18. {data → Video-P2P-Demo/data}/motorbike/7.jpg +0 -0
  19. {data → Video-P2P-Demo/data}/motorbike/8.jpg +0 -0
  20. {data → Video-P2P-Demo/data}/rabbit/1.jpg +0 -0
  21. {data → Video-P2P-Demo/data}/rabbit/2.jpg +0 -0
  22. {data → Video-P2P-Demo/data}/rabbit/3.jpg +0 -0
  23. {data → Video-P2P-Demo/data}/rabbit/4.jpg +0 -0
  24. {data → Video-P2P-Demo/data}/rabbit/5.jpg +0 -0
  25. {data → Video-P2P-Demo/data}/rabbit/6.jpg +0 -0
  26. {data → Video-P2P-Demo/data}/rabbit/7.jpg +0 -0
  27. {data → Video-P2P-Demo/data}/rabbit/8.jpg +0 -0
  28. ptp_utils.py → Video-P2P-Demo/ptp_utils.py +0 -0
  29. Video-P2P-Demo/requirements.txt +15 -0
  30. run_tuning.py → Video-P2P-Demo/run_tuning.py +0 -0
  31. run_videop2p.py → Video-P2P-Demo/run_videop2p.py +0 -0
  32. script.sh → Video-P2P-Demo/script.sh +0 -0
  33. seq_aligner.py → Video-P2P-Demo/seq_aligner.py +0 -0
  34. {tuneavideo → Video-P2P-Demo/tuneavideo}/data/dataset.py +0 -0
  35. {tuneavideo → Video-P2P-Demo/tuneavideo}/models/attention.py +0 -0
  36. {tuneavideo → Video-P2P-Demo/tuneavideo}/models/resnet.py +0 -0
  37. {tuneavideo → Video-P2P-Demo/tuneavideo}/models/unet.py +0 -0
  38. {tuneavideo → Video-P2P-Demo/tuneavideo}/models/unet_blocks.py +0 -0
  39. {tuneavideo → Video-P2P-Demo/tuneavideo}/pipelines/pipeline_tuneavideo.py +0 -0
  40. {tuneavideo → Video-P2P-Demo/tuneavideo}/util.py +0 -0
  41. app.py +84 -0
  42. app_inference.py +170 -0
  43. app_training.py +135 -0
  44. app_upload.py +106 -0
  45. constants.py +10 -0
  46. inference.py +109 -0
  47. packages.txt +1 -0
  48. patch +15 -0
  49. style.css +3 -0
  50. trainer.py +166 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # ffmpeg \
11
+ ffmpeg \
12
+ x264 \
13
+ # python build dependencies \
14
+ build-essential \
15
+ libssl-dev \
16
+ zlib1g-dev \
17
+ libbz2-dev \
18
+ libreadline-dev \
19
+ libsqlite3-dev \
20
+ libncursesw5-dev \
21
+ xz-utils \
22
+ tk-dev \
23
+ libxml2-dev \
24
+ libxmlsec1-dev \
25
+ libffi-dev \
26
+ liblzma-dev && \
27
+ apt-get clean && \
28
+ rm -rf /var/lib/apt/lists/*
29
+
30
+ RUN useradd -m -u 1000 user
31
+ USER user
32
+ ENV HOME=/home/user \
33
+ PATH=/home/user/.local/bin:${PATH}
34
+ WORKDIR ${HOME}/app
35
+
36
+ RUN curl https://pyenv.run | bash
37
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
38
+ ENV PYTHON_VERSION=3.10.9
39
+ RUN pyenv install ${PYTHON_VERSION} && \
40
+ pyenv global ${PYTHON_VERSION} && \
41
+ pyenv rehash && \
42
+ pip install --no-cache-dir -U pip setuptools wheel
43
+
44
+ RUN pip install --no-cache-dir -U torch==1.13.1 torchvision==0.14.1
45
+ COPY --chown=1000 requirements.txt /tmp/requirements.txt
46
+ RUN pip install --no-cache-dir -U -r /tmp/requirements.txt
47
+
48
+ COPY --chown=1000 . ${HOME}/app
49
+ RUN cd Tune-A-Video && patch -p1 < ../patch
50
+ ENV PYTHONPATH=${HOME}/app \
51
+ PYTHONUNBUFFERED=1 \
52
+ GRADIO_ALLOW_FLAGGING=never \
53
+ GRADIO_NUM_PORTS=1 \
54
+ GRADIO_SERVER_NAME=0.0.0.0 \
55
+ GRADIO_THEME=huggingface \
56
+ SYSTEM=spaces
57
+ CMD ["python", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Video-P2P-Demo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Video-P2P-Demo/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Video-P2P Demo
3
+ emoji: 🐶
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Video-P2P
12
+
13
+ ## Setup
14
+
15
+ All required packages are listed in the requirements file.
16
+ The code was tested on a Tesla V100 32GB but should work on other cards with at least **16GB** VRAM.
17
+
18
+ ## Quickstart
19
+
20
+ ``` bash
21
+ bash script.sh
22
+ ```
23
+
24
+ ## References
25
+ * prompt-to-prompt: https://github.com/google/prompt-to-prompt
26
+ * Tune-A-Video: https://github.com/showlab/Tune-A-Video
27
+ * diffusers: https://github.com/huggingface/diffusers
{configs → Video-P2P-Demo/configs}/.DS_Store RENAMED
File without changes
{configs → Video-P2P-Demo/configs}/man-motor-tune.yaml RENAMED
File without changes
{configs → Video-P2P-Demo/configs}/rabbit-jump-p2p.yaml RENAMED
File without changes
{configs → Video-P2P-Demo/configs}/rabbit-jump-tune.yaml RENAMED
File without changes
{data → Video-P2P-Demo/data}/.DS_Store RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/.DS_Store RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/1.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/2.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/3.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/4.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/5.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/6.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/7.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/motorbike/8.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/1.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/2.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/3.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/4.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/5.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/6.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/7.jpg RENAMED
File without changes
{data → Video-P2P-Demo/data}/rabbit/8.jpg RENAMED
File without changes
ptp_utils.py → Video-P2P-Demo/ptp_utils.py RENAMED
File without changes
Video-P2P-Demo/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.13.1
3
+ diffusers[torch]==0.11.1
4
+ transformers>=4.25.1
5
+ bitsandbytes==0.35.4
6
+ decord==0.6.0
7
+ accelerate
8
+ tensorboard
9
+ modelcards
10
+ omegaconf
11
+ einops
12
+ imageio
13
+ ftfy
14
+ opencv-python
15
+ ipywidgets
run_tuning.py → Video-P2P-Demo/run_tuning.py RENAMED
File without changes
run_videop2p.py → Video-P2P-Demo/run_videop2p.py RENAMED
File without changes
script.sh → Video-P2P-Demo/script.sh RENAMED
File without changes
seq_aligner.py → Video-P2P-Demo/seq_aligner.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/data/dataset.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/models/attention.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/models/resnet.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/models/unet.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/models/unet_blocks.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/pipelines/pipeline_tuneavideo.py RENAMED
File without changes
{tuneavideo → Video-P2P-Demo/tuneavideo}/util.py RENAMED
File without changes
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from subprocess import getoutput
7
+
8
+ import gradio as gr
9
+ import torch
10
+
11
+ # from app_inference import create_inference_demo
12
+ from app_training import create_training_demo
13
+ # from app_upload import create_upload_demo
14
+ from inference import InferencePipeline
15
+ from trainer import Trainer
16
+
17
+ TITLE = '# [Video-P2P](https://video-p2p.github.io/) UI'
18
+
19
+ ORIGINAL_SPACE_ID = 'Shaldon/Video-P2P-Training-UI'
20
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
21
+ GPU_DATA = getoutput('nvidia-smi')
22
+ SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
+
24
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
25
+ '''
26
+
27
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
28
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
+ else:
30
+ SETTINGS = 'Settings'
31
+
32
+ INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
33
+
34
+ CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
35
+ <center>
36
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
37
+ You can use "T4 small/medium" to run this demo.
38
+ </center>
39
+ '''
40
+
41
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
42
+ <center>
43
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
44
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
45
+ </center>
46
+ '''
47
+
48
+ HF_TOKEN = os.getenv('HF_TOKEN')
49
+
50
+
51
+ def show_warning(warning_text: str) -> gr.Blocks:
52
+ with gr.Blocks() as demo:
53
+ with gr.Box():
54
+ gr.Markdown(warning_text)
55
+ return demo
56
+
57
+
58
+ pipe = InferencePipeline(HF_TOKEN)
59
+ trainer = Trainer(HF_TOKEN)
60
+
61
+ with gr.Blocks(css='style.css') as demo:
62
+ if SPACE_ID == ORIGINAL_SPACE_ID:
63
+ show_warning(SHARED_UI_WARNING)
64
+ elif not torch.cuda.is_available():
65
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
+ elif (not 'T4' in GPU_DATA):
67
+ show_warning(INVALID_GPU_WARNING)
68
+
69
+ gr.Markdown(TITLE)
70
+ with gr.Tabs():
71
+ with gr.TabItem('Train'):
72
+ create_training_demo(trainer, pipe)
73
+ # with gr.TabItem('Run'):
74
+ # create_inference_demo(pipe, HF_TOKEN)
75
+ # with gr.TabItem('Upload'):
76
+ # gr.Markdown('''
77
+ # - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
78
+ # ''')
79
+ # create_upload_demo(HF_TOKEN)
80
+
81
+ if not HF_TOKEN:
82
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
83
+
84
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from inference import InferencePipeline
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelSource(enum.Enum):
16
+ HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
+ LOCAL = 'Local'
18
+
19
+
20
+ class InferenceUtil:
21
+ def __init__(self, hf_token: str | None):
22
+ self.hf_token = hf_token
23
+
24
+ def load_hub_model_list(self) -> dict:
25
+ api = HfApi(token=self.hf_token)
26
+ choices = [
27
+ info.modelId
28
+ for info in api.list_models(author=MODEL_LIBRARY_ORG_NAME)
29
+ ]
30
+ return gr.update(choices=choices,
31
+ value=choices[0] if choices else None)
32
+
33
+ @staticmethod
34
+ def load_local_model_list() -> dict:
35
+ choices = find_exp_dirs()
36
+ return gr.update(choices=choices,
37
+ value=choices[0] if choices else None)
38
+
39
+ def reload_model_list(self, model_source: str) -> dict:
40
+ if model_source == ModelSource.HUB_LIB.value:
41
+ return self.load_hub_model_list()
42
+ elif model_source == ModelSource.LOCAL.value:
43
+ return self.load_local_model_list()
44
+ else:
45
+ raise ValueError
46
+
47
+ def load_model_info(self, model_id: str) -> tuple[str, str]:
48
+ try:
49
+ card = InferencePipeline.get_model_card(model_id, self.hf_token)
50
+ except Exception:
51
+ return '', ''
52
+ base_model = getattr(card.data, 'base_model', '')
53
+ training_prompt = getattr(card.data, 'training_prompt', '')
54
+ return base_model, training_prompt
55
+
56
+ def reload_model_list_and_update_model_info(
57
+ self, model_source: str) -> tuple[dict, str, str]:
58
+ model_list_update = self.reload_model_list(model_source)
59
+ model_list = model_list_update['choices']
60
+ model_info = self.load_model_info(model_list[0] if model_list else '')
61
+ return model_list_update, *model_info
62
+
63
+
64
+ def create_inference_demo(pipe: InferencePipeline,
65
+ hf_token: str | None = None) -> gr.Blocks:
66
+ app = InferenceUtil(hf_token)
67
+
68
+ with gr.Blocks() as demo:
69
+ with gr.Row():
70
+ with gr.Column():
71
+ with gr.Box():
72
+ model_source = gr.Radio(
73
+ label='Model Source',
74
+ choices=[_.value for _ in ModelSource],
75
+ value=ModelSource.HUB_LIB.value)
76
+ reload_button = gr.Button('Reload Model List')
77
+ model_id = gr.Dropdown(label='Model ID',
78
+ choices=None,
79
+ value=None)
80
+ with gr.Accordion(
81
+ label=
82
+ 'Model info (Base model and prompt used for training)',
83
+ open=False):
84
+ with gr.Row():
85
+ base_model_used_for_training = gr.Text(
86
+ label='Base model', interactive=False)
87
+ prompt_used_for_training = gr.Text(
88
+ label='Training prompt', interactive=False)
89
+ prompt = gr.Textbox(
90
+ label='Prompt',
91
+ max_lines=1,
92
+ placeholder='Example: "A panda is surfing"')
93
+ video_length = gr.Slider(label='Video length',
94
+ minimum=4,
95
+ maximum=12,
96
+ step=1,
97
+ value=8)
98
+ fps = gr.Slider(label='FPS',
99
+ minimum=1,
100
+ maximum=12,
101
+ step=1,
102
+ value=1)
103
+ seed = gr.Slider(label='Seed',
104
+ minimum=0,
105
+ maximum=100000,
106
+ step=1,
107
+ value=0)
108
+ with gr.Accordion('Other Parameters', open=False):
109
+ num_steps = gr.Slider(label='Number of Steps',
110
+ minimum=0,
111
+ maximum=100,
112
+ step=1,
113
+ value=50)
114
+ guidance_scale = gr.Slider(label='CFG Scale',
115
+ minimum=0,
116
+ maximum=50,
117
+ step=0.1,
118
+ value=7.5)
119
+
120
+ run_button = gr.Button('Generate')
121
+
122
+ gr.Markdown('''
123
+ - After training, you can press "Reload Model List" button to load your trained model names.
124
+ - It takes a few minutes to download model first.
125
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
126
+ ''')
127
+ with gr.Column():
128
+ result = gr.Video(label='Result')
129
+
130
+ model_source.change(fn=app.reload_model_list_and_update_model_info,
131
+ inputs=model_source,
132
+ outputs=[
133
+ model_id,
134
+ base_model_used_for_training,
135
+ prompt_used_for_training,
136
+ ])
137
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
138
+ inputs=model_source,
139
+ outputs=[
140
+ model_id,
141
+ base_model_used_for_training,
142
+ prompt_used_for_training,
143
+ ])
144
+ model_id.change(fn=app.load_model_info,
145
+ inputs=model_id,
146
+ outputs=[
147
+ base_model_used_for_training,
148
+ prompt_used_for_training,
149
+ ])
150
+ inputs = [
151
+ model_id,
152
+ prompt,
153
+ video_length,
154
+ fps,
155
+ seed,
156
+ num_steps,
157
+ guidance_scale,
158
+ ]
159
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
160
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
161
+ return demo
162
+
163
+
164
+ if __name__ == '__main__':
165
+ import os
166
+
167
+ hf_token = os.getenv('HF_TOKEN')
168
+ pipe = InferencePipeline(hf_token)
169
+ demo = create_inference_demo(pipe, hf_token)
170
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ hf_token = os.getenv('HF_TOKEN')
17
+ with gr.Blocks() as demo:
18
+ with gr.Row():
19
+ with gr.Column():
20
+ with gr.Box():
21
+ gr.Markdown('Training Data')
22
+ training_video = gr.File(label='Training video')
23
+ training_prompt = gr.Textbox(
24
+ label='Training prompt',
25
+ max_lines=1,
26
+ placeholder='A rabbit is jumping on the grass')
27
+ gr.Markdown('''
28
+ - Upload a video and write a `Training Prompt` that describes the video.
29
+ ''')
30
+
31
+ with gr.Column():
32
+ with gr.Box():
33
+ gr.Markdown('Training Parameters')
34
+ with gr.Row():
35
+ base_model = gr.Text(
36
+ label='Base Model',
37
+ value='CompVis/stable-diffusion-v1-5',
38
+ max_lines=1)
39
+ resolution = gr.Dropdown(choices=['512', '768'],
40
+ value='512',
41
+ label='Resolution',
42
+ visible=False)
43
+
44
+ input_token = gr.Text(label='Hugging Face Write Token',
45
+ placeholder='',
46
+ visible=False if hf_token else True)
47
+ with gr.Accordion('Advanced settings', open=False):
48
+ num_training_steps = gr.Number(
49
+ label='Number of Training Steps',
50
+ value=300,
51
+ precision=0)
52
+ learning_rate = gr.Number(label='Learning Rate',
53
+ value=0.000035)
54
+ gradient_accumulation = gr.Number(
55
+ label='Number of Gradient Accumulation',
56
+ value=1,
57
+ precision=0)
58
+ seed = gr.Slider(label='Seed',
59
+ minimum=0,
60
+ maximum=100000,
61
+ step=1,
62
+ randomize=True,
63
+ value=0)
64
+ fp16 = gr.Checkbox(label='FP16', value=True)
65
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
66
+ value=False)
67
+ checkpointing_steps = gr.Number(
68
+ label='Checkpointing Steps',
69
+ value=1000,
70
+ precision=0)
71
+ validation_epochs = gr.Number(
72
+ label='Validation Epochs', value=100, precision=0)
73
+ gr.Markdown('''
74
+ - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
75
+ - Expected time to train a model for 300 steps: ~20 minutes with T4
76
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
77
+ ''')
78
+
79
+ with gr.Row():
80
+ with gr.Column():
81
+ gr.Markdown('Output Model')
82
+ output_model_name = gr.Text(label='Name of your model',
83
+ placeholder='The surfer man',
84
+ max_lines=1)
85
+ validation_prompt = gr.Text(
86
+ label='Validation Prompt',
87
+ placeholder=
88
+ 'prompt to test the model, e.g: a dog is surfing')
89
+ with gr.Column():
90
+ gr.Markdown('Upload Settings')
91
+ with gr.Row():
92
+ upload_to_hub = gr.Checkbox(label='Upload model to Hub',
93
+ value=True)
94
+ use_private_repo = gr.Checkbox(label='Private', value=True)
95
+ delete_existing_repo = gr.Checkbox(
96
+ label='Delete existing repo of the same name',
97
+ value=False)
98
+ upload_to = gr.Radio(
99
+ label='Upload to',
100
+ choices=[_.value for _ in UploadTarget],
101
+ value=UploadTarget.MODEL_LIBRARY.value)
102
+
103
+ remove_gpu_after_training = gr.Checkbox(
104
+ label='Remove GPU after training',
105
+ value=False,
106
+ interactive=bool(os.getenv('SPACE_ID')),
107
+ visible=False)
108
+ run_button = gr.Button('Start Training')
109
+
110
+ with gr.Box():
111
+ gr.Markdown('Output message')
112
+ output_message = gr.Markdown()
113
+
114
+ if pipe is not None:
115
+ run_button.click(fn=pipe.clear)
116
+ run_button.click(
117
+ fn=trainer.run,
118
+ inputs=[
119
+ training_video, training_prompt, output_model_name,
120
+ delete_existing_repo, validation_prompt, base_model,
121
+ resolution, num_training_steps, learning_rate,
122
+ gradient_accumulation, seed, fp16, use_8bit_adam,
123
+ checkpointing_steps, validation_epochs, upload_to_hub,
124
+ use_private_repo, delete_existing_repo, upload_to,
125
+ remove_gpu_after_training, input_token
126
+ ],
127
+ outputs=output_message)
128
+ return demo
129
+
130
+
131
+ if __name__ == '__main__':
132
+ hf_token = os.getenv('HF_TOKEN')
133
+ trainer = Trainer(hf_token)
134
+ demo = create_training_demo(trainer)
135
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class ModelUploader(Uploader):
16
+ def upload_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ input_token: str | None = None,
24
+ ) -> str:
25
+ if not folder_path:
26
+ raise ValueError
27
+ if not repo_name:
28
+ repo_name = pathlib.Path(folder_path).name
29
+ repo_name = slugify.slugify(repo_name)
30
+
31
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
32
+ organization = ''
33
+ elif upload_to == UploadTarget.MODEL_LIBRARY.value:
34
+ organization = MODEL_LIBRARY_ORG_NAME
35
+ else:
36
+ raise ValueError
37
+
38
+ return self.upload(folder_path,
39
+ repo_name,
40
+ organization=organization,
41
+ private=private,
42
+ delete_existing_repo=delete_existing_repo,
43
+ input_token=input_token)
44
+
45
+
46
+ def load_local_model_list() -> dict:
47
+ choices = find_exp_dirs()
48
+ return gr.update(choices=choices, value=choices[0] if choices else None)
49
+
50
+
51
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
52
+ uploader = ModelUploader(hf_token)
53
+ model_dirs = find_exp_dirs()
54
+
55
+ with gr.Blocks() as demo:
56
+ with gr.Box():
57
+ gr.Markdown('Local Models')
58
+ reload_button = gr.Button('Reload Model List')
59
+ model_dir = gr.Dropdown(
60
+ label='Model names',
61
+ choices=model_dirs,
62
+ value=model_dirs[0] if model_dirs else None)
63
+ with gr.Box():
64
+ gr.Markdown('Upload Settings')
65
+ with gr.Row():
66
+ use_private_repo = gr.Checkbox(label='Private', value=True)
67
+ delete_existing_repo = gr.Checkbox(
68
+ label='Delete existing repo of the same name', value=False)
69
+ upload_to = gr.Radio(label='Upload to',
70
+ choices=[_.value for _ in UploadTarget],
71
+ value=UploadTarget.MODEL_LIBRARY.value)
72
+ model_name = gr.Textbox(label='Model Name')
73
+ input_token = gr.Text(label='Hugging Face Write Token',
74
+ placeholder='',
75
+ visible=False if hf_token else True)
76
+ upload_button = gr.Button('Upload')
77
+ gr.Markdown(f'''
78
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
79
+ ''')
80
+ with gr.Box():
81
+ gr.Markdown('Output message')
82
+ output_message = gr.Markdown()
83
+
84
+ reload_button.click(fn=load_local_model_list,
85
+ inputs=None,
86
+ outputs=model_dir)
87
+ upload_button.click(fn=uploader.upload_model,
88
+ inputs=[
89
+ model_dir,
90
+ model_name,
91
+ upload_to,
92
+ use_private_repo,
93
+ delete_existing_repo,
94
+ input_token,
95
+ ],
96
+ outputs=output_message)
97
+
98
+ return demo
99
+
100
+
101
+ if __name__ == '__main__':
102
+ import os
103
+
104
+ hf_token = os.getenv('HF_TOKEN')
105
+ demo = create_upload_demo(hf_token)
106
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ MODEL_LIBRARY = 'Tune-A-Video Library'
7
+
8
+
9
+ MODEL_LIBRARY_ORG_NAME = 'Tune-A-Video-library'
10
+ SAMPLE_MODEL_REPO = 'Tune-A-Video-library/a-man-is-surfing'
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+ import sys
6
+ import tempfile
7
+
8
+ import gradio as gr
9
+ import imageio
10
+ import PIL.Image
11
+ import torch
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from einops import rearrange
14
+ from huggingface_hub import ModelCard
15
+
16
+ sys.path.append('Tune-A-Video')
17
+
18
+ from tuneavideo.models.unet import UNet3DConditionModel
19
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
20
+
21
+
22
+ class InferencePipeline:
23
+ def __init__(self, hf_token: str | None = None):
24
+ self.hf_token = hf_token
25
+ self.pipe = None
26
+ self.device = torch.device(
27
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
28
+ self.model_id = None
29
+
30
+ def clear(self) -> None:
31
+ self.model_id = None
32
+ del self.pipe
33
+ self.pipe = None
34
+ torch.cuda.empty_cache()
35
+ gc.collect()
36
+
37
+ @staticmethod
38
+ def check_if_model_is_local(model_id: str) -> bool:
39
+ return pathlib.Path(model_id).exists()
40
+
41
+ @staticmethod
42
+ def get_model_card(model_id: str,
43
+ hf_token: str | None = None) -> ModelCard:
44
+ if InferencePipeline.check_if_model_is_local(model_id):
45
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
46
+ else:
47
+ card_path = model_id
48
+ return ModelCard.load(card_path, token=hf_token)
49
+
50
+ @staticmethod
51
+ def get_base_model_info(model_id: str, hf_token: str | None = None) -> str:
52
+ card = InferencePipeline.get_model_card(model_id, hf_token)
53
+ return card.data.base_model
54
+
55
+ def load_pipe(self, model_id: str) -> None:
56
+ if model_id == self.model_id:
57
+ return
58
+ base_model_id = self.get_base_model_info(model_id, self.hf_token)
59
+ unet = UNet3DConditionModel.from_pretrained(
60
+ model_id,
61
+ subfolder='unet',
62
+ torch_dtype=torch.float16,
63
+ use_auth_token=self.hf_token)
64
+ pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
65
+ unet=unet,
66
+ torch_dtype=torch.float16,
67
+ use_auth_token=self.hf_token)
68
+ pipe = pipe.to(self.device)
69
+ if is_xformers_available():
70
+ pipe.unet.enable_xformers_memory_efficient_attention()
71
+ self.pipe = pipe
72
+ self.model_id = model_id # type: ignore
73
+
74
+ def run(
75
+ self,
76
+ model_id: str,
77
+ prompt: str,
78
+ video_length: int,
79
+ fps: int,
80
+ seed: int,
81
+ n_steps: int,
82
+ guidance_scale: float,
83
+ ) -> PIL.Image.Image:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+
87
+ self.load_pipe(model_id)
88
+
89
+ generator = torch.Generator(device=self.device).manual_seed(seed)
90
+ out = self.pipe(
91
+ prompt,
92
+ video_length=video_length,
93
+ width=512,
94
+ height=512,
95
+ num_inference_steps=n_steps,
96
+ guidance_scale=guidance_scale,
97
+ generator=generator,
98
+ ) # type: ignore
99
+
100
+ frames = rearrange(out.videos[0], 'c t h w -> t h w c')
101
+ frames = (frames * 255).to(torch.uint8).numpy()
102
+
103
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
104
+ writer = imageio.get_writer(out_file.name, fps=fps)
105
+ for frame in frames:
106
+ writer.append_data(frame)
107
+ writer.close()
108
+
109
+ return out_file.name
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
patch ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/train_tuneavideo.py b/train_tuneavideo.py
2
+ index 66d51b2..86b2a5d 100644
3
+ --- a/train_tuneavideo.py
4
+ +++ b/train_tuneavideo.py
5
+ @@ -94,8 +94,8 @@ def main(
6
+
7
+ # Handle the output folder creation
8
+ if accelerator.is_main_process:
9
+ - now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
10
+ - output_dir = os.path.join(output_dir, now)
11
+ + #now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
12
+ + #output_dir = os.path.join(output_dir, now)
13
+ os.makedirs(output_dir, exist_ok=True)
14
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
15
+
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
trainer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+ from omegaconf import OmegaConf
16
+
17
+ from app_upload import ModelUploader
18
+ from utils import save_model_card
19
+
20
+ sys.path.append('Video-P2P-Demo')
21
+
22
+ # URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
+ ORIGINAL_SPACE_ID = 'Shaldon/Video-P2P-Training-UI'
24
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
25
+
26
+
27
+ class Trainer:
28
+ def __init__(self, hf_token: str | None = None):
29
+ self.hf_token = hf_token
30
+ self.model_uploader = ModelUploader(hf_token)
31
+
32
+ self.checkpoint_dir = pathlib.Path('checkpoints')
33
+ self.checkpoint_dir.mkdir(exist_ok=True)
34
+
35
+ def download_base_model(self, base_model_id: str) -> str:
36
+ model_dir = self.checkpoint_dir / base_model_id
37
+ if not model_dir.exists():
38
+ org_name = base_model_id.split('/')[0]
39
+ org_dir = self.checkpoint_dir / org_name
40
+ org_dir.mkdir(exist_ok=True)
41
+ subprocess.run(shlex.split(
42
+ f'git clone https://huggingface.co/{base_model_id}'),
43
+ cwd=org_dir)
44
+ return model_dir.as_posix()
45
+
46
+ # def join_model_library_org(self, token: str) -> None:
47
+ # subprocess.run(
48
+ # shlex.split(
49
+ # f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
50
+ # ))
51
+
52
+ def run(
53
+ self,
54
+ training_video: str,
55
+ training_prompt: str,
56
+ output_model_name: str,
57
+ overwrite_existing_model: bool,
58
+ validation_prompt: str,
59
+ base_model: str,
60
+ resolution_s: str,
61
+ n_steps: int,
62
+ learning_rate: float,
63
+ gradient_accumulation: int,
64
+ seed: int,
65
+ fp16: bool,
66
+ use_8bit_adam: bool,
67
+ checkpointing_steps: int,
68
+ validation_epochs: int,
69
+ upload_to_hub: bool,
70
+ use_private_repo: bool,
71
+ delete_existing_repo: bool,
72
+ upload_to: str,
73
+ remove_gpu_after_training: bool,
74
+ input_token: str,
75
+ ) -> str:
76
+ if SPACE_ID == ORIGINAL_SPACE_ID:
77
+ raise gr.Error(
78
+ 'This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU'
79
+ )
80
+ if not torch.cuda.is_available():
81
+ raise gr.Error('CUDA is not available.')
82
+ if training_video is None:
83
+ raise gr.Error('You need to upload a video.')
84
+ if not training_prompt:
85
+ raise gr.Error('The training prompt is missing.')
86
+ if not validation_prompt:
87
+ raise gr.Error('The validation prompt is missing.')
88
+
89
+ resolution = int(resolution_s)
90
+
91
+ if not output_model_name:
92
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
93
+ output_model_name = f'video-p2p-{timestamp}'
94
+ output_model_name = slugify.slugify(output_model_name)
95
+
96
+ repo_dir = pathlib.Path(__file__).parent
97
+ output_dir = repo_dir / 'experiments' / output_model_name
98
+ if overwrite_existing_model or upload_to_hub:
99
+ shutil.rmtree(output_dir, ignore_errors=True)
100
+ output_dir.mkdir(parents=True)
101
+
102
+ # if upload_to_hub:
103
+ # self.join_model_library_org(
104
+ # self.hf_token if self.hf_token else input_token)
105
+
106
+ config = OmegaConf.load('Video-P2P-Demo/configs/rabbit-jump-tune.yaml')
107
+ config.pretrained_model_path = self.download_base_model(base_model)
108
+ config.output_dir = output_dir.as_posix()
109
+ config.train_data.video_path = training_video.name # type: ignore
110
+ config.train_data.prompt = training_prompt
111
+ config.train_data.n_sample_frames = 8
112
+ config.train_data.width = resolution
113
+ config.train_data.height = resolution
114
+ config.train_data.sample_start_idx = 0
115
+ config.train_data.sample_frame_rate = 1
116
+ config.validation_data.prompts = [validation_prompt]
117
+ config.validation_data.video_length = 8
118
+ config.validation_data.width = resolution
119
+ config.validation_data.height = resolution
120
+ config.validation_data.num_inference_steps = 50
121
+ config.validation_data.guidance_scale = 7.5
122
+ config.learning_rate = learning_rate
123
+ config.gradient_accumulation_steps = gradient_accumulation
124
+ config.train_batch_size = 1
125
+ config.max_train_steps = n_steps
126
+ config.checkpointing_steps = checkpointing_steps
127
+ config.validation_steps = validation_epochs
128
+ config.seed = seed
129
+ config.mixed_precision = 'fp16' if fp16 else ''
130
+ config.use_8bit_adam = use_8bit_adam
131
+
132
+ config_path = output_dir / 'config.yaml'
133
+ with open(config_path, 'w') as f:
134
+ OmegaConf.save(config, f)
135
+
136
+ command = f'accelerate launch Video-P2P-Demo/train_tuneavideo.py --config {config_path}'
137
+ subprocess.run(shlex.split(command))
138
+ save_model_card(save_dir=output_dir,
139
+ base_model=base_model,
140
+ training_prompt=training_prompt,
141
+ test_prompt=validation_prompt,
142
+ test_image_dir='samples')
143
+
144
+ message = 'Training completed!'
145
+ print(message)
146
+
147
+ if upload_to_hub:
148
+ upload_message = self.model_uploader.upload_model(
149
+ folder_path=output_dir.as_posix(),
150
+ repo_name=output_model_name,
151
+ upload_to=upload_to,
152
+ private=use_private_repo,
153
+ delete_existing_repo=delete_existing_repo,
154
+ input_token=input_token)
155
+ print(upload_message)
156
+ message = message + '\n' + upload_message
157
+
158
+ if remove_gpu_after_training:
159
+ space_id = os.getenv('SPACE_ID')
160
+ if space_id:
161
+ api = HfApi(
162
+ token=self.hf_token if self.hf_token else input_token)
163
+ api.request_space_hardware(repo_id=space_id,
164
+ hardware='cpu-basic')
165
+
166
+ return message