Chaerin5 commited on
Commit
5b1e740
·
1 Parent(s): 0d08ee0

runtime fix

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -33,6 +33,10 @@ def set_seed(seed):
33
  torch.cuda.manual_seed_all(seed)
34
  random.seed(seed)
35
 
 
 
 
 
36
 
37
  def remove_prefix(text, prefix):
38
  if text.startswith(prefix):
@@ -176,9 +180,6 @@ class HandDiffOpts:
176
  num_workers: int = 10
177
  n_val_samples: int = 4
178
 
179
- if not torch.cuda.is_available():
180
- raise ValueError("No GPU")
181
-
182
  # load models
183
  if NEW_MODEL:
184
  opts = HandDiffOpts()
@@ -202,15 +203,15 @@ if NEW_MODEL:
202
  latent_dim=opts.latent_dim,
203
  in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
204
  learn_sigma=True,
205
- ).cuda()
206
  # ckpt_state_dict = torch.load(model_path)['model_state_dict']
207
- ckpt_state_dict = torch.load(model_path, map_location=torch.device('cuda'))['ema_state_dict']
208
  missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
209
  model.eval()
210
  print(missing_keys, extra_keys)
211
  assert len(missing_keys) == 0
212
  vae_state_dict = torch.load(vae_path)['state_dict']
213
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
214
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
215
  autoencoder.eval()
216
  assert len(missing_keys) == 0
@@ -225,18 +226,18 @@ else:
225
  latent_dim=opts.latent_dim,
226
  in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
227
  learn_sigma=True,
228
- ).cuda()
229
  ckpt_state_dict = torch.load(model_path)['state_dict']
230
  dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
231
  vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
232
  missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
233
  model.eval()
234
  assert len(missing_keys) == 0 and len(extra_keys) == 0
235
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
236
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
237
  autoencoder.eval()
238
  assert len(missing_keys) == 0 and len(extra_keys) == 0
239
- sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth")
240
 
241
 
242
  print("Mediapipe hand detector and SAM ready...")
@@ -312,7 +313,7 @@ def get_ref_anno(ref):
312
  img,
313
  keypts,
314
  hand_mask,
315
- device="cuda",
316
  target_size=(256, 256),
317
  latent_size=(32, 32),
318
  ):
@@ -348,7 +349,7 @@ def get_ref_anno(ref):
348
  img,
349
  keypts,
350
  hand_mask,
351
- device="cuda",
352
  target_size=opts.image_size,
353
  latent_size=opts.latent_size,
354
  )
@@ -405,7 +406,7 @@ def get_target_anno(target):
405
  )
406
  * kpts_valid[:, None, None],
407
  dtype=torch.float,
408
- device="cuda",
409
  )[None, ...]
410
  target_cond = torch.cat(
411
  [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
@@ -525,12 +526,12 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
525
  set_seed(seed)
526
  z = torch.randn(
527
  (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
528
- device="cuda",
529
  )
530
  target_cond = target_cond.repeat(num_gen, 1, 1, 1)
531
  ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
532
  # novel view synthesis mode = off
533
- nvs = torch.zeros(num_gen, dtype=torch.int, device="cuda")
534
  z = torch.cat([z, z], 0)
535
  model_kwargs = dict(
536
  target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
@@ -546,7 +547,7 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
546
  clip_denoised=False,
547
  model_kwargs=model_kwargs,
548
  progress=True,
549
- device="cuda",
550
  ).chunk(2)
551
  sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
552
  sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
@@ -635,14 +636,14 @@ def ready_sample(img_ori, inpaint_mask, keypts):
635
  inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
636
  ),
637
  dtype=torch.float,
638
- device="cuda",
639
  ).unsqueeze(0)[None, ...]
640
 
641
  def make_ref_cond(
642
  img,
643
  keypts,
644
  hand_mask,
645
- device="cuda",
646
  target_size=(256, 256),
647
  latent_size=(32, 32),
648
  ):
@@ -678,7 +679,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
678
  img,
679
  keypts,
680
  hand_mask * (1 - inpaint_mask),
681
- device="cuda",
682
  target_size=opts.image_size,
683
  latent_size=opts.latent_size,
684
  )
@@ -736,12 +737,12 @@ def sample_inpaint(
736
  jump_n_sample = quality
737
  cfg_scale = cfg
738
  z = torch.randn(
739
- (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device="cuda"
740
  )
741
  target_cond_N = target_cond.repeat(N, 1, 1, 1)
742
  ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
743
  # novel view synthesis mode = off
744
- nvs = torch.zeros(N, dtype=torch.int, device="cuda")
745
  z = torch.cat([z, z], 0)
746
  model_kwargs = dict(
747
  target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
@@ -759,7 +760,7 @@ def sample_inpaint(
759
  clip_denoised=False,
760
  model_kwargs=model_kwargs,
761
  progress=True,
762
- device="cuda",
763
  jump_length=jump_length,
764
  jump_n_sample=jump_n_sample,
765
  ).chunk(2)
 
33
  torch.cuda.manual_seed_all(seed)
34
  random.seed(seed)
35
 
36
+ if torch.cuda.is_available():
37
+ device = "cuda"
38
+ else:
39
+ device = "cpu"
40
 
41
  def remove_prefix(text, prefix):
42
  if text.startswith(prefix):
 
180
  num_workers: int = 10
181
  n_val_samples: int = 4
182
 
 
 
 
183
  # load models
184
  if NEW_MODEL:
185
  opts = HandDiffOpts()
 
203
  latent_dim=opts.latent_dim,
204
  in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
205
  learn_sigma=True,
206
+ ).to(device)
207
  # ckpt_state_dict = torch.load(model_path)['model_state_dict']
208
+ ckpt_state_dict = torch.load(model_path, map_location=torch.device(device))['ema_state_dict']
209
  missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
210
  model.eval()
211
  print(missing_keys, extra_keys)
212
  assert len(missing_keys) == 0
213
  vae_state_dict = torch.load(vae_path)['state_dict']
214
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
215
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
216
  autoencoder.eval()
217
  assert len(missing_keys) == 0
 
226
  latent_dim=opts.latent_dim,
227
  in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
228
  learn_sigma=True,
229
+ ).to(device)
230
  ckpt_state_dict = torch.load(model_path)['state_dict']
231
  dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
232
  vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
233
  missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
234
  model.eval()
235
  assert len(missing_keys) == 0 and len(extra_keys) == 0
236
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
237
  missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
238
  autoencoder.eval()
239
  assert len(missing_keys) == 0 and len(extra_keys) == 0
240
+ sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth", device=device)
241
 
242
 
243
  print("Mediapipe hand detector and SAM ready...")
 
313
  img,
314
  keypts,
315
  hand_mask,
316
+ device=device,
317
  target_size=(256, 256),
318
  latent_size=(32, 32),
319
  ):
 
349
  img,
350
  keypts,
351
  hand_mask,
352
+ device=device,
353
  target_size=opts.image_size,
354
  latent_size=opts.latent_size,
355
  )
 
406
  )
407
  * kpts_valid[:, None, None],
408
  dtype=torch.float,
409
+ device=device,
410
  )[None, ...]
411
  target_cond = torch.cat(
412
  [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
 
526
  set_seed(seed)
527
  z = torch.randn(
528
  (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
529
+ device=device,
530
  )
531
  target_cond = target_cond.repeat(num_gen, 1, 1, 1)
532
  ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
533
  # novel view synthesis mode = off
534
+ nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
535
  z = torch.cat([z, z], 0)
536
  model_kwargs = dict(
537
  target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
 
547
  clip_denoised=False,
548
  model_kwargs=model_kwargs,
549
  progress=True,
550
+ device=device,
551
  ).chunk(2)
552
  sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
553
  sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
 
636
  inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
637
  ),
638
  dtype=torch.float,
639
+ device=device,
640
  ).unsqueeze(0)[None, ...]
641
 
642
  def make_ref_cond(
643
  img,
644
  keypts,
645
  hand_mask,
646
+ device=device,
647
  target_size=(256, 256),
648
  latent_size=(32, 32),
649
  ):
 
679
  img,
680
  keypts,
681
  hand_mask * (1 - inpaint_mask),
682
+ device=device,
683
  target_size=opts.image_size,
684
  latent_size=opts.latent_size,
685
  )
 
737
  jump_n_sample = quality
738
  cfg_scale = cfg
739
  z = torch.randn(
740
+ (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
741
  )
742
  target_cond_N = target_cond.repeat(N, 1, 1, 1)
743
  ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
744
  # novel view synthesis mode = off
745
+ nvs = torch.zeros(N, dtype=torch.int, device=device)
746
  z = torch.cat([z, z], 0)
747
  model_kwargs = dict(
748
  target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
 
760
  clip_denoised=False,
761
  model_kwargs=model_kwargs,
762
  progress=True,
763
+ device=device,
764
  jump_length=jump_length,
765
  jump_n_sample=jump_n_sample,
766
  ).chunk(2)