Spaces:
Runtime error
Runtime error
ShaoTengLiu
commited on
Commit
·
bb8b1f0
1
Parent(s):
f56167c
update two buttons
Browse files- Video-P2P/run_videop2p.py +4 -4
- app_training.py +12 -0
- trainer.py +130 -1
Video-P2P/run_videop2p.py
CHANGED
@@ -104,8 +104,8 @@ def main(
|
|
104 |
mask_th = (.3, .3)
|
105 |
|
106 |
|
107 |
-
pretrained_model_path = pretrained_model_path
|
108 |
-
|
109 |
image_path = train_data['video_path']
|
110 |
prompt = train_data['prompt']
|
111 |
# prompts = [prompt, ]
|
@@ -142,8 +142,8 @@ def main(
|
|
142 |
pretrained_model_path,
|
143 |
subfolder="vae",
|
144 |
).to(device, dtype=weight_dtype)
|
145 |
-
unet = UNet3DConditionModel.from_pretrained_2d(
|
146 |
-
|
147 |
pretrained_model_path, subfolder="unet"
|
148 |
).to(device)
|
149 |
ldm_stable = TuneAVideoPipeline(
|
|
|
104 |
mask_th = (.3, .3)
|
105 |
|
106 |
|
107 |
+
# pretrained_model_path = pretrained_model_path
|
108 |
+
pretrained_model_path = output_dir
|
109 |
image_path = train_data['video_path']
|
110 |
prompt = train_data['prompt']
|
111 |
# prompts = [prompt, ]
|
|
|
142 |
pretrained_model_path,
|
143 |
subfolder="vae",
|
144 |
).to(device, dtype=weight_dtype)
|
145 |
+
# unet = UNet3DConditionModel.from_pretrained_2d(
|
146 |
+
unet = UNet3DConditionModel.from_pretrained(
|
147 |
pretrained_model_path, subfolder="unet"
|
148 |
).to(device)
|
149 |
ldm_stable = TuneAVideoPipeline(
|
app_training.py
CHANGED
@@ -142,6 +142,18 @@ def create_training_demo(trainer: Trainer,
|
|
142 |
remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
|
143 |
],
|
144 |
outputs=output_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return demo
|
146 |
|
147 |
|
|
|
142 |
remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
|
143 |
],
|
144 |
outputs=output_message)
|
145 |
+
run_button.click(
|
146 |
+
fn=trainer.run_p2p,
|
147 |
+
inputs=[
|
148 |
+
training_video, training_prompt, output_model_name,
|
149 |
+
delete_existing_repo, validation_prompt, base_model,
|
150 |
+
resolution, num_training_steps, learning_rate,
|
151 |
+
gradient_accumulation, seed, fp16, use_8bit_adam,
|
152 |
+
checkpointing_steps, validation_epochs, upload_to_hub,
|
153 |
+
use_private_repo, delete_existing_repo, upload_to,
|
154 |
+
remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
|
155 |
+
],
|
156 |
+
outputs=output_message)
|
157 |
return demo
|
158 |
|
159 |
|
trainer.py
CHANGED
@@ -104,6 +104,136 @@ class Trainer:
|
|
104 |
shutil.rmtree(output_dir, ignore_errors=True)
|
105 |
output_dir.mkdir(parents=True)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
if upload_to_hub:
|
108 |
self.join_model_library_org(
|
109 |
self.hf_token if self.hf_token else input_token)
|
@@ -147,7 +277,6 @@ class Trainer:
|
|
147 |
|
148 |
# command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
|
149 |
# subprocess.run(shlex.split(command))
|
150 |
-
# torch.cuda.empty_cache()
|
151 |
command = f'python Video-P2P/run_videop2p.py --config {config_path}'
|
152 |
subprocess.run(shlex.split(command))
|
153 |
save_model_card(save_dir=output_dir,
|
|
|
104 |
shutil.rmtree(output_dir, ignore_errors=True)
|
105 |
output_dir.mkdir(parents=True)
|
106 |
|
107 |
+
if upload_to_hub:
|
108 |
+
self.join_model_library_org(
|
109 |
+
self.hf_token if self.hf_token else input_token)
|
110 |
+
|
111 |
+
config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
|
112 |
+
config.pretrained_model_path = self.download_base_model(base_model)
|
113 |
+
config.output_dir = output_dir.as_posix()
|
114 |
+
config.train_data.video_path = training_video.name # type: ignore
|
115 |
+
config.train_data.prompt = training_prompt
|
116 |
+
config.train_data.n_sample_frames = 8
|
117 |
+
config.train_data.width = resolution
|
118 |
+
config.train_data.height = resolution
|
119 |
+
config.train_data.sample_start_idx = 0
|
120 |
+
config.train_data.sample_frame_rate = 1
|
121 |
+
config.validation_data.prompts = [validation_prompt]
|
122 |
+
config.validation_data.video_length = 8
|
123 |
+
config.validation_data.width = resolution
|
124 |
+
config.validation_data.height = resolution
|
125 |
+
config.validation_data.num_inference_steps = 50
|
126 |
+
config.validation_data.guidance_scale = 7.5
|
127 |
+
config.learning_rate = learning_rate
|
128 |
+
config.gradient_accumulation_steps = gradient_accumulation
|
129 |
+
config.train_batch_size = 1
|
130 |
+
config.max_train_steps = n_steps
|
131 |
+
config.checkpointing_steps = checkpointing_steps
|
132 |
+
config.validation_steps = validation_epochs
|
133 |
+
config.seed = seed
|
134 |
+
config.mixed_precision = 'fp16' if fp16 else ''
|
135 |
+
config.use_8bit_adam = use_8bit_adam
|
136 |
+
config.prompts = [training_prompt, validation_prompt]
|
137 |
+
config.blend_word = [blend_word_1, blend_word_2]
|
138 |
+
config.eq_params = {"words":[eq_params_1], "values":[int(eq_params_2)]}
|
139 |
+
if len(validation_prompt) == len(training_prompt):
|
140 |
+
config.is_word_swap = True
|
141 |
+
else:
|
142 |
+
config.is_word_swap = False
|
143 |
+
|
144 |
+
config_path = output_dir / 'config.yaml'
|
145 |
+
with open(config_path, 'w') as f:
|
146 |
+
OmegaConf.save(config, f)
|
147 |
+
|
148 |
+
command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
|
149 |
+
subprocess.run(shlex.split(command))
|
150 |
+
# command = f'python Video-P2P/run_videop2p.py --config {config_path}'
|
151 |
+
# subprocess.run(shlex.split(command))
|
152 |
+
save_model_card(save_dir=output_dir,
|
153 |
+
base_model=base_model,
|
154 |
+
training_prompt=training_prompt,
|
155 |
+
test_prompt=validation_prompt,
|
156 |
+
test_image_dir='results')
|
157 |
+
|
158 |
+
message = 'Training completed!'
|
159 |
+
print(message)
|
160 |
+
|
161 |
+
if upload_to_hub:
|
162 |
+
upload_message = self.model_uploader.upload_model(
|
163 |
+
folder_path=output_dir.as_posix(),
|
164 |
+
repo_name=output_model_name,
|
165 |
+
upload_to=upload_to,
|
166 |
+
private=use_private_repo,
|
167 |
+
delete_existing_repo=delete_existing_repo,
|
168 |
+
input_token=input_token)
|
169 |
+
print(upload_message)
|
170 |
+
message = message + '\n' + upload_message
|
171 |
+
|
172 |
+
if remove_gpu_after_training:
|
173 |
+
space_id = os.getenv('SPACE_ID')
|
174 |
+
if space_id:
|
175 |
+
api = HfApi(
|
176 |
+
token=self.hf_token if self.hf_token else input_token)
|
177 |
+
api.request_space_hardware(repo_id=space_id,
|
178 |
+
hardware='cpu-basic')
|
179 |
+
|
180 |
+
return message
|
181 |
+
|
182 |
+
|
183 |
+
def run_p2p(
|
184 |
+
self,
|
185 |
+
training_video: str,
|
186 |
+
training_prompt: str,
|
187 |
+
output_model_name: str,
|
188 |
+
overwrite_existing_model: bool,
|
189 |
+
validation_prompt: str,
|
190 |
+
base_model: str,
|
191 |
+
resolution_s: str,
|
192 |
+
n_steps: int,
|
193 |
+
learning_rate: float,
|
194 |
+
gradient_accumulation: int,
|
195 |
+
seed: int,
|
196 |
+
fp16: bool,
|
197 |
+
use_8bit_adam: bool,
|
198 |
+
checkpointing_steps: int,
|
199 |
+
validation_epochs: int,
|
200 |
+
upload_to_hub: bool,
|
201 |
+
use_private_repo: bool,
|
202 |
+
delete_existing_repo: bool,
|
203 |
+
upload_to: str,
|
204 |
+
remove_gpu_after_training: bool,
|
205 |
+
input_token: str,
|
206 |
+
blend_word_1: str,
|
207 |
+
blend_word_2: str,
|
208 |
+
eq_params_1: str,
|
209 |
+
eq_params_2: str,
|
210 |
+
) -> str:
|
211 |
+
# if SPACE_ID == ORIGINAL_SPACE_ID:
|
212 |
+
# raise gr.Error(
|
213 |
+
# 'This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU'
|
214 |
+
# )
|
215 |
+
if not torch.cuda.is_available():
|
216 |
+
raise gr.Error('CUDA is not available.')
|
217 |
+
if training_video is None:
|
218 |
+
raise gr.Error('You need to upload a video.')
|
219 |
+
if not training_prompt:
|
220 |
+
raise gr.Error('The training prompt is missing.')
|
221 |
+
if not validation_prompt:
|
222 |
+
raise gr.Error('The validation prompt is missing.')
|
223 |
+
|
224 |
+
resolution = int(resolution_s)
|
225 |
+
|
226 |
+
if not output_model_name:
|
227 |
+
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
228 |
+
output_model_name = f'video-p2p-{timestamp}'
|
229 |
+
output_model_name = slugify.slugify(output_model_name)
|
230 |
+
|
231 |
+
repo_dir = pathlib.Path(__file__).parent
|
232 |
+
output_dir = repo_dir / 'experiments' / output_model_name
|
233 |
+
if overwrite_existing_model or upload_to_hub:
|
234 |
+
shutil.rmtree(output_dir, ignore_errors=True)
|
235 |
+
output_dir.mkdir(parents=True)
|
236 |
+
|
237 |
if upload_to_hub:
|
238 |
self.join_model_library_org(
|
239 |
self.hf_token if self.hf_token else input_token)
|
|
|
277 |
|
278 |
# command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
|
279 |
# subprocess.run(shlex.split(command))
|
|
|
280 |
command = f'python Video-P2P/run_videop2p.py --config {config_path}'
|
281 |
subprocess.run(shlex.split(command))
|
282 |
save_model_card(save_dir=output_dir,
|