Spaces:
Runtime error
Runtime error
ShaoTengLiu
commited on
Commit
·
f527f9c
1
Parent(s):
7fef50a
update two buttons
Browse files- app_training.py +10 -1
- trainer.py +2 -1
app_training.py
CHANGED
@@ -40,6 +40,15 @@ def create_training_demo(trainer: Trainer,
|
|
40 |
value='512',
|
41 |
label='Resolution',
|
42 |
visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
input_token = gr.Text(label='Hugging Face Write Token',
|
45 |
placeholder='',
|
@@ -153,7 +162,7 @@ def create_training_demo(trainer: Trainer,
|
|
153 |
gradient_accumulation, seed, fp16, use_8bit_adam,
|
154 |
checkpointing_steps, validation_epochs, upload_to_hub,
|
155 |
use_private_repo, delete_existing_repo, upload_to,
|
156 |
-
remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
|
157 |
],
|
158 |
outputs=output_message)
|
159 |
return demo
|
|
|
40 |
value='512',
|
41 |
label='Resolution',
|
42 |
visible=False)
|
43 |
+
with gr.Row():
|
44 |
+
tuned_model = gr.Text(
|
45 |
+
label='Path to tuned model',
|
46 |
+
value='xxx/xxx,
|
47 |
+
max_lines=1)
|
48 |
+
resolution = gr.Dropdown(choices=['512', '768'],
|
49 |
+
value='512',
|
50 |
+
label='Resolution',
|
51 |
+
visible=False)
|
52 |
|
53 |
input_token = gr.Text(label='Hugging Face Write Token',
|
54 |
placeholder='',
|
|
|
162 |
gradient_accumulation, seed, fp16, use_8bit_adam,
|
163 |
checkpointing_steps, validation_epochs, upload_to_hub,
|
164 |
use_private_repo, delete_existing_repo, upload_to,
|
165 |
+
remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2, tuned_model
|
166 |
],
|
167 |
outputs=output_message)
|
168 |
return demo
|
trainer.py
CHANGED
@@ -207,6 +207,7 @@ class Trainer:
|
|
207 |
blend_word_2: str,
|
208 |
eq_params_1: str,
|
209 |
eq_params_2: str,
|
|
|
210 |
) -> str:
|
211 |
# if SPACE_ID == ORIGINAL_SPACE_ID:
|
212 |
# raise gr.Error(
|
@@ -239,7 +240,7 @@ class Trainer:
|
|
239 |
self.hf_token if self.hf_token else input_token)
|
240 |
|
241 |
config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
|
242 |
-
config.pretrained_model_path = self.download_base_model(
|
243 |
config.output_dir = output_dir.as_posix()
|
244 |
config.train_data.video_path = training_video.name # type: ignore
|
245 |
config.train_data.prompt = training_prompt
|
|
|
207 |
blend_word_2: str,
|
208 |
eq_params_1: str,
|
209 |
eq_params_2: str,
|
210 |
+
tuned_model: str = None,
|
211 |
) -> str:
|
212 |
# if SPACE_ID == ORIGINAL_SPACE_ID:
|
213 |
# raise gr.Error(
|
|
|
240 |
self.hf_token if self.hf_token else input_token)
|
241 |
|
242 |
config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
|
243 |
+
config.pretrained_model_path = self.download_base_model(tuned_model)
|
244 |
config.output_dir = output_dir.as_posix()
|
245 |
config.train_data.video_path = training_video.name # type: ignore
|
246 |
config.train_data.prompt = training_prompt
|