hysts HF staff commited on
Commit
d0ad885
1 Parent(s): fff06c1

Update for inference

Browse files
Files changed (7) hide show
  1. app.py +0 -84
  2. app_inference.py +89 -128
  3. app_training.py +0 -135
  4. app_upload.py +0 -106
  5. trainer.py +0 -166
  6. uploader.py +0 -44
  7. utils.py +0 -65
app.py DELETED
@@ -1,84 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
- from subprocess import getoutput
7
-
8
- import gradio as gr
9
- import torch
10
-
11
- from app_inference import create_inference_demo
12
- from app_training import create_training_demo
13
- from app_upload import create_upload_demo
14
- from inference import InferencePipeline
15
- from trainer import Trainer
16
-
17
- TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/) UI'
18
-
19
- ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
20
- SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
21
- GPU_DATA = getoutput('nvidia-smi')
22
- SHARED_UI_WARNING = f'''## Attention - Training doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
23
-
24
- <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
25
- '''
26
-
27
- if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
28
- SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
29
- else:
30
- SETTINGS = 'Settings'
31
-
32
- INVALID_GPU_WARNING = f'''## Attention - the specified GPU is invalid. Training may not work. Make sure you have selected a `T4 GPU` for this task.'''
33
-
34
- CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
35
- <center>
36
- You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
37
- You can use "T4 small/medium" to run this demo.
38
- </center>
39
- '''
40
-
41
- HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
42
- <center>
43
- You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
44
- You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
45
- </center>
46
- '''
47
-
48
- HF_TOKEN = os.getenv('HF_TOKEN')
49
-
50
-
51
- def show_warning(warning_text: str) -> gr.Blocks:
52
- with gr.Blocks() as demo:
53
- with gr.Box():
54
- gr.Markdown(warning_text)
55
- return demo
56
-
57
-
58
- pipe = InferencePipeline(HF_TOKEN)
59
- trainer = Trainer(HF_TOKEN)
60
-
61
- with gr.Blocks(css='style.css') as demo:
62
- if SPACE_ID == ORIGINAL_SPACE_ID:
63
- show_warning(SHARED_UI_WARNING)
64
- elif not torch.cuda.is_available():
65
- show_warning(CUDA_NOT_AVAILABLE_WARNING)
66
- elif (not 'T4' in GPU_DATA):
67
- show_warning(INVALID_GPU_WARNING)
68
-
69
- gr.Markdown(TITLE)
70
- with gr.Tabs():
71
- with gr.TabItem('Train'):
72
- create_training_demo(trainer, pipe)
73
- with gr.TabItem('Run'):
74
- create_inference_demo(pipe, HF_TOKEN)
75
- with gr.TabItem('Upload'):
76
- gr.Markdown('''
77
- - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
78
- ''')
79
- create_upload_demo(HF_TOKEN)
80
-
81
- if not HF_TOKEN:
82
- show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
83
-
84
- demo.queue(max_size=1).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_inference.py CHANGED
@@ -2,19 +2,13 @@
2
 
3
  from __future__ import annotations
4
 
5
- import enum
6
 
7
  import gradio as gr
8
  from huggingface_hub import HfApi
9
 
10
- from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
  from inference import InferencePipeline
12
- from utils import find_exp_dirs
13
-
14
-
15
- class ModelSource(enum.Enum):
16
- HUB_LIB = UploadTarget.MODEL_LIBRARY.value
17
- LOCAL = 'Local'
18
 
19
 
20
  class InferenceUtil:
@@ -30,20 +24,6 @@ class InferenceUtil:
30
  return gr.update(choices=choices,
31
  value=choices[0] if choices else None)
32
 
33
- @staticmethod
34
- def load_local_model_list() -> dict:
35
- choices = find_exp_dirs()
36
- return gr.update(choices=choices,
37
- value=choices[0] if choices else None)
38
-
39
- def reload_model_list(self, model_source: str) -> dict:
40
- if model_source == ModelSource.HUB_LIB.value:
41
- return self.load_hub_model_list()
42
- elif model_source == ModelSource.LOCAL.value:
43
- return self.load_local_model_list()
44
- else:
45
- raise ValueError
46
-
47
  def load_model_info(self, model_id: str) -> tuple[str, str]:
48
  try:
49
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
@@ -53,118 +33,99 @@ class InferenceUtil:
53
  training_prompt = getattr(card.data, 'training_prompt', '')
54
  return base_model, training_prompt
55
 
56
- def reload_model_list_and_update_model_info(
57
- self, model_source: str) -> tuple[dict, str, str]:
58
- model_list_update = self.reload_model_list(model_source)
59
  model_list = model_list_update['choices']
60
  model_info = self.load_model_info(model_list[0] if model_list else '')
61
  return model_list_update, *model_info
62
 
63
 
64
- def create_inference_demo(pipe: InferencePipeline,
65
- hf_token: str | None = None) -> gr.Blocks:
66
- app = InferenceUtil(hf_token)
67
-
68
- with gr.Blocks() as demo:
69
- with gr.Row():
70
- with gr.Column():
71
- with gr.Box():
72
- model_source = gr.Radio(
73
- label='Model Source',
74
- choices=[_.value for _ in ModelSource],
75
- value=ModelSource.HUB_LIB.value)
76
- reload_button = gr.Button('Reload Model List')
77
- model_id = gr.Dropdown(label='Model ID',
78
- choices=None,
79
- value=None)
80
- with gr.Accordion(
81
- label=
82
- 'Model info (Base model and prompt used for training)',
83
- open=False):
84
- with gr.Row():
85
- base_model_used_for_training = gr.Text(
86
- label='Base model', interactive=False)
87
- prompt_used_for_training = gr.Text(
88
- label='Training prompt', interactive=False)
89
- prompt = gr.Textbox(
90
- label='Prompt',
91
- max_lines=1,
92
- placeholder='Example: "A panda is surfing"')
93
- video_length = gr.Slider(label='Video length',
94
- minimum=4,
95
- maximum=12,
96
- step=1,
97
- value=8)
98
- fps = gr.Slider(label='FPS',
99
- minimum=1,
100
- maximum=12,
101
- step=1,
102
- value=1)
103
- seed = gr.Slider(label='Seed',
104
- minimum=0,
105
- maximum=100000,
106
- step=1,
107
- value=0)
108
- with gr.Accordion('Other Parameters', open=False):
109
- num_steps = gr.Slider(label='Number of Steps',
110
- minimum=0,
111
- maximum=100,
112
- step=1,
113
- value=50)
114
- guidance_scale = gr.Slider(label='CFG Scale',
115
- minimum=0,
116
- maximum=50,
117
- step=0.1,
118
- value=7.5)
119
-
120
- run_button = gr.Button('Generate')
121
-
122
- gr.Markdown('''
123
- - After training, you can press "Reload Model List" button to load your trained model names.
124
- - It takes a few minutes to download model first.
125
- - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
126
- ''')
127
- with gr.Column():
128
- result = gr.Video(label='Result')
129
-
130
- model_source.change(fn=app.reload_model_list_and_update_model_info,
131
- inputs=model_source,
132
- outputs=[
133
- model_id,
134
- base_model_used_for_training,
135
- prompt_used_for_training,
136
- ])
137
- reload_button.click(fn=app.reload_model_list_and_update_model_info,
138
- inputs=model_source,
139
- outputs=[
140
- model_id,
141
- base_model_used_for_training,
142
- prompt_used_for_training,
143
- ])
144
- model_id.change(fn=app.load_model_info,
145
- inputs=model_id,
146
  outputs=[
 
147
  base_model_used_for_training,
148
  prompt_used_for_training,
149
  ])
150
- inputs = [
151
- model_id,
152
- prompt,
153
- video_length,
154
- fps,
155
- seed,
156
- num_steps,
157
- guidance_scale,
158
- ]
159
- prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
160
- run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
161
- return demo
162
-
163
-
164
- if __name__ == '__main__':
165
- import os
166
-
167
- hf_token = os.getenv('HF_TOKEN')
168
- pipe = InferencePipeline(hf_token)
169
- demo = create_inference_demo(pipe, hf_token)
170
- demo.queue(max_size=10).launch(share=False)
 
2
 
3
  from __future__ import annotations
4
 
5
+ import os
6
 
7
  import gradio as gr
8
  from huggingface_hub import HfApi
9
 
10
+ from constants import MODEL_LIBRARY_ORG_NAME
11
  from inference import InferencePipeline
 
 
 
 
 
 
12
 
13
 
14
  class InferenceUtil:
 
24
  return gr.update(choices=choices,
25
  value=choices[0] if choices else None)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def load_model_info(self, model_id: str) -> tuple[str, str]:
28
  try:
29
  card = InferencePipeline.get_model_card(model_id, self.hf_token)
 
33
  training_prompt = getattr(card.data, 'training_prompt', '')
34
  return base_model, training_prompt
35
 
36
+ def reload_model_list_and_update_model_info(self) -> tuple[dict, str, str]:
37
+ model_list_update = self.load_hub_model_list()
 
38
  model_list = model_list_update['choices']
39
  model_info = self.load_model_info(model_list[0] if model_list else '')
40
  return model_list_update, *model_info
41
 
42
 
43
+ TITLE = '# [Tune-A-Video](https://tuneavideo.github.io/)'
44
+ HF_TOKEN = os.getenv('HF_TOKEN')
45
+ pipe = InferencePipeline(HF_TOKEN)
46
+ app = InferenceUtil(HF_TOKEN)
47
+
48
+ with gr.Blocks(css='style.css') as demo:
49
+ gr.Markdown(TITLE)
50
+
51
+ with gr.Row():
52
+ with gr.Column():
53
+ with gr.Box():
54
+ reload_button = gr.Button('Reload Model List')
55
+ model_id = gr.Dropdown(label='Model ID',
56
+ choices=None,
57
+ value=None)
58
+ with gr.Accordion(
59
+ label=
60
+ 'Model info (Base model and prompt used for training)',
61
+ open=False):
62
+ with gr.Row():
63
+ base_model_used_for_training = gr.Text(
64
+ label='Base model', interactive=False)
65
+ prompt_used_for_training = gr.Text(
66
+ label='Training prompt', interactive=False)
67
+ prompt = gr.Textbox(label='Prompt',
68
+ max_lines=1,
69
+ placeholder='Example: "A panda is surfing"')
70
+ video_length = gr.Slider(label='Video length',
71
+ minimum=4,
72
+ maximum=12,
73
+ step=1,
74
+ value=8)
75
+ fps = gr.Slider(label='FPS',
76
+ minimum=1,
77
+ maximum=12,
78
+ step=1,
79
+ value=1)
80
+ seed = gr.Slider(label='Seed',
81
+ minimum=0,
82
+ maximum=100000,
83
+ step=1,
84
+ value=0)
85
+ with gr.Accordion('Other Parameters', open=False):
86
+ num_steps = gr.Slider(label='Number of Steps',
87
+ minimum=0,
88
+ maximum=100,
89
+ step=1,
90
+ value=50)
91
+ guidance_scale = gr.Slider(label='CFG Scale',
92
+ minimum=0,
93
+ maximum=50,
94
+ step=0.1,
95
+ value=7.5)
96
+
97
+ run_button = gr.Button('Generate')
98
+
99
+ gr.Markdown('''
100
+ - It takes a few minutes to download model first.
101
+ - Expected time to generate an 8-frame video: 70 seconds with T4, 24 seconds with A10G, (10 seconds with A100)
102
+ ''')
103
+ with gr.Column():
104
+ result = gr.Video(label='Result')
105
+
106
+ reload_button.click(fn=app.reload_model_list_and_update_model_info,
107
+ inputs=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  outputs=[
109
+ model_id,
110
  base_model_used_for_training,
111
  prompt_used_for_training,
112
  ])
113
+ model_id.change(fn=app.load_model_info,
114
+ inputs=model_id,
115
+ outputs=[
116
+ base_model_used_for_training,
117
+ prompt_used_for_training,
118
+ ])
119
+ inputs = [
120
+ model_id,
121
+ prompt,
122
+ video_length,
123
+ fps,
124
+ seed,
125
+ num_steps,
126
+ guidance_scale,
127
+ ]
128
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
129
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
130
+
131
+ demo.queue().launch()
 
 
app_training.py DELETED
@@ -1,135 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
-
7
- import gradio as gr
8
-
9
- from constants import MODEL_LIBRARY_ORG_NAME, SAMPLE_MODEL_REPO, UploadTarget
10
- from inference import InferencePipeline
11
- from trainer import Trainer
12
-
13
-
14
- def create_training_demo(trainer: Trainer,
15
- pipe: InferencePipeline | None = None) -> gr.Blocks:
16
- hf_token = os.getenv('HF_TOKEN')
17
- with gr.Blocks() as demo:
18
- with gr.Row():
19
- with gr.Column():
20
- with gr.Box():
21
- gr.Markdown('Training Data')
22
- training_video = gr.File(label='Training video')
23
- training_prompt = gr.Textbox(
24
- label='Training prompt',
25
- max_lines=1,
26
- placeholder='A man is surfing')
27
- gr.Markdown('''
28
- - Upload a video and write a `Training Prompt` that describes the video.
29
- ''')
30
-
31
- with gr.Column():
32
- with gr.Box():
33
- gr.Markdown('Training Parameters')
34
- with gr.Row():
35
- base_model = gr.Text(
36
- label='Base Model',
37
- value='CompVis/stable-diffusion-v1-4',
38
- max_lines=1)
39
- resolution = gr.Dropdown(choices=['512', '768'],
40
- value='512',
41
- label='Resolution',
42
- visible=False)
43
-
44
- input_token = gr.Text(label='Hugging Face Write Token',
45
- placeholder='',
46
- visible=False if hf_token else True)
47
- with gr.Accordion('Advanced settings', open=False):
48
- num_training_steps = gr.Number(
49
- label='Number of Training Steps',
50
- value=300,
51
- precision=0)
52
- learning_rate = gr.Number(label='Learning Rate',
53
- value=0.000035)
54
- gradient_accumulation = gr.Number(
55
- label='Number of Gradient Accumulation',
56
- value=1,
57
- precision=0)
58
- seed = gr.Slider(label='Seed',
59
- minimum=0,
60
- maximum=100000,
61
- step=1,
62
- randomize=True,
63
- value=0)
64
- fp16 = gr.Checkbox(label='FP16', value=True)
65
- use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
66
- value=False)
67
- checkpointing_steps = gr.Number(
68
- label='Checkpointing Steps',
69
- value=1000,
70
- precision=0)
71
- validation_epochs = gr.Number(
72
- label='Validation Epochs', value=100, precision=0)
73
- gr.Markdown('''
74
- - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
75
- - Expected time to train a model for 300 steps: ~20 minutes with T4
76
- - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
77
- ''')
78
-
79
- with gr.Row():
80
- with gr.Column():
81
- gr.Markdown('Output Model')
82
- output_model_name = gr.Text(label='Name of your model',
83
- placeholder='The surfer man',
84
- max_lines=1)
85
- validation_prompt = gr.Text(
86
- label='Validation Prompt',
87
- placeholder=
88
- 'prompt to test the model, e.g: a dog is surfing')
89
- with gr.Column():
90
- gr.Markdown('Upload Settings')
91
- with gr.Row():
92
- upload_to_hub = gr.Checkbox(label='Upload model to Hub',
93
- value=True)
94
- use_private_repo = gr.Checkbox(label='Private', value=True)
95
- delete_existing_repo = gr.Checkbox(
96
- label='Delete existing repo of the same name',
97
- value=False)
98
- upload_to = gr.Radio(
99
- label='Upload to',
100
- choices=[_.value for _ in UploadTarget],
101
- value=UploadTarget.MODEL_LIBRARY.value)
102
-
103
- remove_gpu_after_training = gr.Checkbox(
104
- label='Remove GPU after training',
105
- value=False,
106
- interactive=bool(os.getenv('SPACE_ID')),
107
- visible=False)
108
- run_button = gr.Button('Start Training')
109
-
110
- with gr.Box():
111
- gr.Markdown('Output message')
112
- output_message = gr.Markdown()
113
-
114
- if pipe is not None:
115
- run_button.click(fn=pipe.clear)
116
- run_button.click(
117
- fn=trainer.run,
118
- inputs=[
119
- training_video, training_prompt, output_model_name,
120
- delete_existing_repo, validation_prompt, base_model,
121
- resolution, num_training_steps, learning_rate,
122
- gradient_accumulation, seed, fp16, use_8bit_adam,
123
- checkpointing_steps, validation_epochs, upload_to_hub,
124
- use_private_repo, delete_existing_repo, upload_to,
125
- remove_gpu_after_training, input_token
126
- ],
127
- outputs=output_message)
128
- return demo
129
-
130
-
131
- if __name__ == '__main__':
132
- hf_token = os.getenv('HF_TOKEN')
133
- trainer = Trainer(hf_token)
134
- demo = create_training_demo(trainer)
135
- demo.queue(max_size=1).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_upload.py DELETED
@@ -1,106 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- from __future__ import annotations
4
-
5
- import pathlib
6
-
7
- import gradio as gr
8
- import slugify
9
-
10
- from constants import MODEL_LIBRARY_ORG_NAME, UploadTarget
11
- from uploader import Uploader
12
- from utils import find_exp_dirs
13
-
14
-
15
- class ModelUploader(Uploader):
16
- def upload_model(
17
- self,
18
- folder_path: str,
19
- repo_name: str,
20
- upload_to: str,
21
- private: bool,
22
- delete_existing_repo: bool,
23
- input_token: str | None = None,
24
- ) -> str:
25
- if not folder_path:
26
- raise ValueError
27
- if not repo_name:
28
- repo_name = pathlib.Path(folder_path).name
29
- repo_name = slugify.slugify(repo_name)
30
-
31
- if upload_to == UploadTarget.PERSONAL_PROFILE.value:
32
- organization = ''
33
- elif upload_to == UploadTarget.MODEL_LIBRARY.value:
34
- organization = MODEL_LIBRARY_ORG_NAME
35
- else:
36
- raise ValueError
37
-
38
- return self.upload(folder_path,
39
- repo_name,
40
- organization=organization,
41
- private=private,
42
- delete_existing_repo=delete_existing_repo,
43
- input_token=input_token)
44
-
45
-
46
- def load_local_model_list() -> dict:
47
- choices = find_exp_dirs()
48
- return gr.update(choices=choices, value=choices[0] if choices else None)
49
-
50
-
51
- def create_upload_demo(hf_token: str | None) -> gr.Blocks:
52
- uploader = ModelUploader(hf_token)
53
- model_dirs = find_exp_dirs()
54
-
55
- with gr.Blocks() as demo:
56
- with gr.Box():
57
- gr.Markdown('Local Models')
58
- reload_button = gr.Button('Reload Model List')
59
- model_dir = gr.Dropdown(
60
- label='Model names',
61
- choices=model_dirs,
62
- value=model_dirs[0] if model_dirs else None)
63
- with gr.Box():
64
- gr.Markdown('Upload Settings')
65
- with gr.Row():
66
- use_private_repo = gr.Checkbox(label='Private', value=True)
67
- delete_existing_repo = gr.Checkbox(
68
- label='Delete existing repo of the same name', value=False)
69
- upload_to = gr.Radio(label='Upload to',
70
- choices=[_.value for _ in UploadTarget],
71
- value=UploadTarget.MODEL_LIBRARY.value)
72
- model_name = gr.Textbox(label='Model Name')
73
- input_token = gr.Text(label='Hugging Face Write Token',
74
- placeholder='',
75
- visible=False if hf_token else True)
76
- upload_button = gr.Button('Upload')
77
- gr.Markdown(f'''
78
- - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{{your_username}}/{{model_name}}) or to the public [Tune-A-Video Library](https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}) (i.e. https://huggingface.co/{MODEL_LIBRARY_ORG_NAME}/{{model_name}}).
79
- ''')
80
- with gr.Box():
81
- gr.Markdown('Output message')
82
- output_message = gr.Markdown()
83
-
84
- reload_button.click(fn=load_local_model_list,
85
- inputs=None,
86
- outputs=model_dir)
87
- upload_button.click(fn=uploader.upload_model,
88
- inputs=[
89
- model_dir,
90
- model_name,
91
- upload_to,
92
- use_private_repo,
93
- delete_existing_repo,
94
- input_token,
95
- ],
96
- outputs=output_message)
97
-
98
- return demo
99
-
100
-
101
- if __name__ == '__main__':
102
- import os
103
-
104
- hf_token = os.getenv('HF_TOKEN')
105
- demo = create_upload_demo(hf_token)
106
- demo.queue(max_size=1).launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
trainer.py DELETED
@@ -1,166 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import datetime
4
- import os
5
- import pathlib
6
- import shlex
7
- import shutil
8
- import subprocess
9
- import sys
10
-
11
- import gradio as gr
12
- import slugify
13
- import torch
14
- from huggingface_hub import HfApi
15
- from omegaconf import OmegaConf
16
-
17
- from app_upload import ModelUploader
18
- from utils import save_model_card
19
-
20
- sys.path.append('Tune-A-Video')
21
-
22
- URL_TO_JOIN_MODEL_LIBRARY_ORG = 'https://huggingface.co/organizations/Tune-A-Video-library/share/YjTcaNJmKyeHFpMBioHhzBcTzCYddVErEk'
23
- ORIGINAL_SPACE_ID = 'Tune-A-Video-library/Tune-A-Video-Training-UI'
24
- SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
25
-
26
-
27
- class Trainer:
28
- def __init__(self, hf_token: str | None = None):
29
- self.hf_token = hf_token
30
- self.model_uploader = ModelUploader(hf_token)
31
-
32
- self.checkpoint_dir = pathlib.Path('checkpoints')
33
- self.checkpoint_dir.mkdir(exist_ok=True)
34
-
35
- def download_base_model(self, base_model_id: str) -> str:
36
- model_dir = self.checkpoint_dir / base_model_id
37
- if not model_dir.exists():
38
- org_name = base_model_id.split('/')[0]
39
- org_dir = self.checkpoint_dir / org_name
40
- org_dir.mkdir(exist_ok=True)
41
- subprocess.run(shlex.split(
42
- f'git clone https://huggingface.co/{base_model_id}'),
43
- cwd=org_dir)
44
- return model_dir.as_posix()
45
-
46
- def join_model_library_org(self, token: str) -> None:
47
- subprocess.run(
48
- shlex.split(
49
- f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
50
- ))
51
-
52
- def run(
53
- self,
54
- training_video: str,
55
- training_prompt: str,
56
- output_model_name: str,
57
- overwrite_existing_model: bool,
58
- validation_prompt: str,
59
- base_model: str,
60
- resolution_s: str,
61
- n_steps: int,
62
- learning_rate: float,
63
- gradient_accumulation: int,
64
- seed: int,
65
- fp16: bool,
66
- use_8bit_adam: bool,
67
- checkpointing_steps: int,
68
- validation_epochs: int,
69
- upload_to_hub: bool,
70
- use_private_repo: bool,
71
- delete_existing_repo: bool,
72
- upload_to: str,
73
- remove_gpu_after_training: bool,
74
- input_token: str,
75
- ) -> str:
76
- if SPACE_ID == ORIGINAL_SPACE_ID:
77
- raise gr.Error(
78
- 'This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU'
79
- )
80
- if not torch.cuda.is_available():
81
- raise gr.Error('CUDA is not available.')
82
- if training_video is None:
83
- raise gr.Error('You need to upload a video.')
84
- if not training_prompt:
85
- raise gr.Error('The training prompt is missing.')
86
- if not validation_prompt:
87
- raise gr.Error('The validation prompt is missing.')
88
-
89
- resolution = int(resolution_s)
90
-
91
- if not output_model_name:
92
- timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
93
- output_model_name = f'tune-a-video-{timestamp}'
94
- output_model_name = slugify.slugify(output_model_name)
95
-
96
- repo_dir = pathlib.Path(__file__).parent
97
- output_dir = repo_dir / 'experiments' / output_model_name
98
- if overwrite_existing_model or upload_to_hub:
99
- shutil.rmtree(output_dir, ignore_errors=True)
100
- output_dir.mkdir(parents=True)
101
-
102
- if upload_to_hub:
103
- self.join_model_library_org(
104
- self.hf_token if self.hf_token else input_token)
105
-
106
- config = OmegaConf.load('Tune-A-Video/configs/man-surfing.yaml')
107
- config.pretrained_model_path = self.download_base_model(base_model)
108
- config.output_dir = output_dir.as_posix()
109
- config.train_data.video_path = training_video.name # type: ignore
110
- config.train_data.prompt = training_prompt
111
- config.train_data.n_sample_frames = 8
112
- config.train_data.width = resolution
113
- config.train_data.height = resolution
114
- config.train_data.sample_start_idx = 0
115
- config.train_data.sample_frame_rate = 1
116
- config.validation_data.prompts = [validation_prompt]
117
- config.validation_data.video_length = 8
118
- config.validation_data.width = resolution
119
- config.validation_data.height = resolution
120
- config.validation_data.num_inference_steps = 50
121
- config.validation_data.guidance_scale = 7.5
122
- config.learning_rate = learning_rate
123
- config.gradient_accumulation_steps = gradient_accumulation
124
- config.train_batch_size = 1
125
- config.max_train_steps = n_steps
126
- config.checkpointing_steps = checkpointing_steps
127
- config.validation_steps = validation_epochs
128
- config.seed = seed
129
- config.mixed_precision = 'fp16' if fp16 else ''
130
- config.use_8bit_adam = use_8bit_adam
131
-
132
- config_path = output_dir / 'config.yaml'
133
- with open(config_path, 'w') as f:
134
- OmegaConf.save(config, f)
135
-
136
- command = f'accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}'
137
- subprocess.run(shlex.split(command))
138
- save_model_card(save_dir=output_dir,
139
- base_model=base_model,
140
- training_prompt=training_prompt,
141
- test_prompt=validation_prompt,
142
- test_image_dir='samples')
143
-
144
- message = 'Training completed!'
145
- print(message)
146
-
147
- if upload_to_hub:
148
- upload_message = self.model_uploader.upload_model(
149
- folder_path=output_dir.as_posix(),
150
- repo_name=output_model_name,
151
- upload_to=upload_to,
152
- private=use_private_repo,
153
- delete_existing_repo=delete_existing_repo,
154
- input_token=input_token)
155
- print(upload_message)
156
- message = message + '\n' + upload_message
157
-
158
- if remove_gpu_after_training:
159
- space_id = os.getenv('SPACE_ID')
160
- if space_id:
161
- api = HfApi(
162
- token=self.hf_token if self.hf_token else input_token)
163
- api.request_space_hardware(repo_id=space_id,
164
- hardware='cpu-basic')
165
-
166
- return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uploader.py DELETED
@@ -1,44 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from huggingface_hub import HfApi
4
-
5
-
6
- class Uploader:
7
- def __init__(self, hf_token: str | None):
8
- self.hf_token = hf_token
9
-
10
- def upload(self,
11
- folder_path: str,
12
- repo_name: str,
13
- organization: str = '',
14
- repo_type: str = 'model',
15
- private: bool = True,
16
- delete_existing_repo: bool = False,
17
- input_token: str | None = None) -> str:
18
-
19
- api = HfApi(token=self.hf_token if self.hf_token else input_token)
20
-
21
- if not folder_path:
22
- raise ValueError
23
- if not repo_name:
24
- raise ValueError
25
- if not organization:
26
- organization = api.whoami()['name']
27
-
28
- repo_id = f'{organization}/{repo_name}'
29
- if delete_existing_repo:
30
- try:
31
- api.delete_repo(repo_id, repo_type=repo_type)
32
- except Exception:
33
- pass
34
- try:
35
- api.create_repo(repo_id, repo_type=repo_type, private=private)
36
- api.upload_folder(repo_id=repo_id,
37
- folder_path=folder_path,
38
- path_in_repo='.',
39
- repo_type=repo_type)
40
- url = f'https://huggingface.co/{repo_id}'
41
- message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
42
- except Exception as e:
43
- message = str(e)
44
- return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,65 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import pathlib
4
-
5
-
6
- def find_exp_dirs() -> list[str]:
7
- repo_dir = pathlib.Path(__file__).parent
8
- exp_root_dir = repo_dir / 'experiments'
9
- if not exp_root_dir.exists():
10
- return []
11
- exp_dirs = sorted(exp_root_dir.glob('*'))
12
- exp_dirs = [
13
- exp_dir for exp_dir in exp_dirs
14
- if (exp_dir / 'model_index.json').exists()
15
- ]
16
- return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
17
-
18
-
19
- def save_model_card(
20
- save_dir: pathlib.Path,
21
- base_model: str,
22
- training_prompt: str,
23
- test_prompt: str = '',
24
- test_image_dir: str = '',
25
- ) -> None:
26
- image_str = ''
27
- if test_prompt and test_image_dir:
28
- image_paths = sorted((save_dir / test_image_dir).glob('*.gif'))
29
- if image_paths:
30
- image_path = image_paths[-1]
31
- rel_path = image_path.relative_to(save_dir)
32
- image_str = f'''## Samples
33
- Test prompt: {test_prompt}
34
-
35
- ![{image_path.stem}]({rel_path})'''
36
-
37
- model_card = f'''---
38
- license: creativeml-openrail-m
39
- base_model: {base_model}
40
- training_prompt: {training_prompt}
41
- tags:
42
- - stable-diffusion
43
- - stable-diffusion-diffusers
44
- - text-to-image
45
- - diffusers
46
- - text-to-video
47
- - tune-a-video
48
- inference: false
49
- ---
50
-
51
- # Tune-A-Video - {save_dir.name}
52
-
53
- ## Model description
54
- - Base model: [{base_model}](https://huggingface.co/{base_model})
55
- - Training prompt: {training_prompt}
56
-
57
- {image_str}
58
-
59
- ## Related papers:
60
- - [Tune-A-Video](https://arxiv.org/abs/2212.11565): One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation
61
- - [Stable-Diffusion](https://arxiv.org/abs/2112.10752): High-Resolution Image Synthesis with Latent Diffusion Models
62
- '''
63
-
64
- with open(save_dir / 'README.md', 'w') as f:
65
- f.write(model_card)