ShaoTengLiu commited on
Commit
bb8b1f0
·
1 Parent(s): f56167c

update two buttons

Browse files
Files changed (3) hide show
  1. Video-P2P/run_videop2p.py +4 -4
  2. app_training.py +12 -0
  3. trainer.py +130 -1
Video-P2P/run_videop2p.py CHANGED
@@ -104,8 +104,8 @@ def main(
104
  mask_th = (.3, .3)
105
 
106
 
107
- pretrained_model_path = pretrained_model_path
108
- # pretrained_model_path = output_dir
109
  image_path = train_data['video_path']
110
  prompt = train_data['prompt']
111
  # prompts = [prompt, ]
@@ -142,8 +142,8 @@ def main(
142
  pretrained_model_path,
143
  subfolder="vae",
144
  ).to(device, dtype=weight_dtype)
145
- unet = UNet3DConditionModel.from_pretrained_2d(
146
- # unet = UNet3DConditionModel.from_pretrained(
147
  pretrained_model_path, subfolder="unet"
148
  ).to(device)
149
  ldm_stable = TuneAVideoPipeline(
 
104
  mask_th = (.3, .3)
105
 
106
 
107
+ # pretrained_model_path = pretrained_model_path
108
+ pretrained_model_path = output_dir
109
  image_path = train_data['video_path']
110
  prompt = train_data['prompt']
111
  # prompts = [prompt, ]
 
142
  pretrained_model_path,
143
  subfolder="vae",
144
  ).to(device, dtype=weight_dtype)
145
+ # unet = UNet3DConditionModel.from_pretrained_2d(
146
+ unet = UNet3DConditionModel.from_pretrained(
147
  pretrained_model_path, subfolder="unet"
148
  ).to(device)
149
  ldm_stable = TuneAVideoPipeline(
app_training.py CHANGED
@@ -142,6 +142,18 @@ def create_training_demo(trainer: Trainer,
142
  remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
143
  ],
144
  outputs=output_message)
 
 
 
 
 
 
 
 
 
 
 
 
145
  return demo
146
 
147
 
 
142
  remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
143
  ],
144
  outputs=output_message)
145
+ run_button.click(
146
+ fn=trainer.run_p2p,
147
+ inputs=[
148
+ training_video, training_prompt, output_model_name,
149
+ delete_existing_repo, validation_prompt, base_model,
150
+ resolution, num_training_steps, learning_rate,
151
+ gradient_accumulation, seed, fp16, use_8bit_adam,
152
+ checkpointing_steps, validation_epochs, upload_to_hub,
153
+ use_private_repo, delete_existing_repo, upload_to,
154
+ remove_gpu_after_training, input_token, blend_word_1, blend_word_2, eq_params_1, eq_params_2
155
+ ],
156
+ outputs=output_message)
157
  return demo
158
 
159
 
trainer.py CHANGED
@@ -104,6 +104,136 @@ class Trainer:
104
  shutil.rmtree(output_dir, ignore_errors=True)
105
  output_dir.mkdir(parents=True)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if upload_to_hub:
108
  self.join_model_library_org(
109
  self.hf_token if self.hf_token else input_token)
@@ -147,7 +277,6 @@ class Trainer:
147
 
148
  # command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
149
  # subprocess.run(shlex.split(command))
150
- # torch.cuda.empty_cache()
151
  command = f'python Video-P2P/run_videop2p.py --config {config_path}'
152
  subprocess.run(shlex.split(command))
153
  save_model_card(save_dir=output_dir,
 
104
  shutil.rmtree(output_dir, ignore_errors=True)
105
  output_dir.mkdir(parents=True)
106
 
107
+ if upload_to_hub:
108
+ self.join_model_library_org(
109
+ self.hf_token if self.hf_token else input_token)
110
+
111
+ config = OmegaConf.load('Video-P2P/configs/man-skiing.yaml')
112
+ config.pretrained_model_path = self.download_base_model(base_model)
113
+ config.output_dir = output_dir.as_posix()
114
+ config.train_data.video_path = training_video.name # type: ignore
115
+ config.train_data.prompt = training_prompt
116
+ config.train_data.n_sample_frames = 8
117
+ config.train_data.width = resolution
118
+ config.train_data.height = resolution
119
+ config.train_data.sample_start_idx = 0
120
+ config.train_data.sample_frame_rate = 1
121
+ config.validation_data.prompts = [validation_prompt]
122
+ config.validation_data.video_length = 8
123
+ config.validation_data.width = resolution
124
+ config.validation_data.height = resolution
125
+ config.validation_data.num_inference_steps = 50
126
+ config.validation_data.guidance_scale = 7.5
127
+ config.learning_rate = learning_rate
128
+ config.gradient_accumulation_steps = gradient_accumulation
129
+ config.train_batch_size = 1
130
+ config.max_train_steps = n_steps
131
+ config.checkpointing_steps = checkpointing_steps
132
+ config.validation_steps = validation_epochs
133
+ config.seed = seed
134
+ config.mixed_precision = 'fp16' if fp16 else ''
135
+ config.use_8bit_adam = use_8bit_adam
136
+ config.prompts = [training_prompt, validation_prompt]
137
+ config.blend_word = [blend_word_1, blend_word_2]
138
+ config.eq_params = {"words":[eq_params_1], "values":[int(eq_params_2)]}
139
+ if len(validation_prompt) == len(training_prompt):
140
+ config.is_word_swap = True
141
+ else:
142
+ config.is_word_swap = False
143
+
144
+ config_path = output_dir / 'config.yaml'
145
+ with open(config_path, 'w') as f:
146
+ OmegaConf.save(config, f)
147
+
148
+ command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
149
+ subprocess.run(shlex.split(command))
150
+ # command = f'python Video-P2P/run_videop2p.py --config {config_path}'
151
+ # subprocess.run(shlex.split(command))
152
+ save_model_card(save_dir=output_dir,
153
+ base_model=base_model,
154
+ training_prompt=training_prompt,
155
+ test_prompt=validation_prompt,
156
+ test_image_dir='results')
157
+
158
+ message = 'Training completed!'
159
+ print(message)
160
+
161
+ if upload_to_hub:
162
+ upload_message = self.model_uploader.upload_model(
163
+ folder_path=output_dir.as_posix(),
164
+ repo_name=output_model_name,
165
+ upload_to=upload_to,
166
+ private=use_private_repo,
167
+ delete_existing_repo=delete_existing_repo,
168
+ input_token=input_token)
169
+ print(upload_message)
170
+ message = message + '\n' + upload_message
171
+
172
+ if remove_gpu_after_training:
173
+ space_id = os.getenv('SPACE_ID')
174
+ if space_id:
175
+ api = HfApi(
176
+ token=self.hf_token if self.hf_token else input_token)
177
+ api.request_space_hardware(repo_id=space_id,
178
+ hardware='cpu-basic')
179
+
180
+ return message
181
+
182
+
183
+ def run_p2p(
184
+ self,
185
+ training_video: str,
186
+ training_prompt: str,
187
+ output_model_name: str,
188
+ overwrite_existing_model: bool,
189
+ validation_prompt: str,
190
+ base_model: str,
191
+ resolution_s: str,
192
+ n_steps: int,
193
+ learning_rate: float,
194
+ gradient_accumulation: int,
195
+ seed: int,
196
+ fp16: bool,
197
+ use_8bit_adam: bool,
198
+ checkpointing_steps: int,
199
+ validation_epochs: int,
200
+ upload_to_hub: bool,
201
+ use_private_repo: bool,
202
+ delete_existing_repo: bool,
203
+ upload_to: str,
204
+ remove_gpu_after_training: bool,
205
+ input_token: str,
206
+ blend_word_1: str,
207
+ blend_word_2: str,
208
+ eq_params_1: str,
209
+ eq_params_2: str,
210
+ ) -> str:
211
+ # if SPACE_ID == ORIGINAL_SPACE_ID:
212
+ # raise gr.Error(
213
+ # 'This Space does not work on this Shared UI. Duplicate the Space and attribute a GPU'
214
+ # )
215
+ if not torch.cuda.is_available():
216
+ raise gr.Error('CUDA is not available.')
217
+ if training_video is None:
218
+ raise gr.Error('You need to upload a video.')
219
+ if not training_prompt:
220
+ raise gr.Error('The training prompt is missing.')
221
+ if not validation_prompt:
222
+ raise gr.Error('The validation prompt is missing.')
223
+
224
+ resolution = int(resolution_s)
225
+
226
+ if not output_model_name:
227
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
228
+ output_model_name = f'video-p2p-{timestamp}'
229
+ output_model_name = slugify.slugify(output_model_name)
230
+
231
+ repo_dir = pathlib.Path(__file__).parent
232
+ output_dir = repo_dir / 'experiments' / output_model_name
233
+ if overwrite_existing_model or upload_to_hub:
234
+ shutil.rmtree(output_dir, ignore_errors=True)
235
+ output_dir.mkdir(parents=True)
236
+
237
  if upload_to_hub:
238
  self.join_model_library_org(
239
  self.hf_token if self.hf_token else input_token)
 
277
 
278
  # command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
279
  # subprocess.run(shlex.split(command))
 
280
  command = f'python Video-P2P/run_videop2p.py --config {config_path}'
281
  subprocess.run(shlex.split(command))
282
  save_model_card(save_dir=output_dir,