Spaces:
Runtime error
Runtime error
runtime fix
Browse files
app.py
CHANGED
@@ -33,6 +33,10 @@ def set_seed(seed):
|
|
33 |
torch.cuda.manual_seed_all(seed)
|
34 |
random.seed(seed)
|
35 |
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def remove_prefix(text, prefix):
|
38 |
if text.startswith(prefix):
|
@@ -176,9 +180,6 @@ class HandDiffOpts:
|
|
176 |
num_workers: int = 10
|
177 |
n_val_samples: int = 4
|
178 |
|
179 |
-
if not torch.cuda.is_available():
|
180 |
-
raise ValueError("No GPU")
|
181 |
-
|
182 |
# load models
|
183 |
if NEW_MODEL:
|
184 |
opts = HandDiffOpts()
|
@@ -202,15 +203,15 @@ if NEW_MODEL:
|
|
202 |
latent_dim=opts.latent_dim,
|
203 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
204 |
learn_sigma=True,
|
205 |
-
).
|
206 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
207 |
-
ckpt_state_dict = torch.load(model_path, map_location=torch.device(
|
208 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
209 |
model.eval()
|
210 |
print(missing_keys, extra_keys)
|
211 |
assert len(missing_keys) == 0
|
212 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
213 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).
|
214 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
215 |
autoencoder.eval()
|
216 |
assert len(missing_keys) == 0
|
@@ -225,18 +226,18 @@ else:
|
|
225 |
latent_dim=opts.latent_dim,
|
226 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
227 |
learn_sigma=True,
|
228 |
-
).
|
229 |
ckpt_state_dict = torch.load(model_path)['state_dict']
|
230 |
dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
|
231 |
vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
|
232 |
missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
|
233 |
model.eval()
|
234 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
235 |
-
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).
|
236 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
237 |
autoencoder.eval()
|
238 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
239 |
-
sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth")
|
240 |
|
241 |
|
242 |
print("Mediapipe hand detector and SAM ready...")
|
@@ -312,7 +313,7 @@ def get_ref_anno(ref):
|
|
312 |
img,
|
313 |
keypts,
|
314 |
hand_mask,
|
315 |
-
device=
|
316 |
target_size=(256, 256),
|
317 |
latent_size=(32, 32),
|
318 |
):
|
@@ -348,7 +349,7 @@ def get_ref_anno(ref):
|
|
348 |
img,
|
349 |
keypts,
|
350 |
hand_mask,
|
351 |
-
device=
|
352 |
target_size=opts.image_size,
|
353 |
latent_size=opts.latent_size,
|
354 |
)
|
@@ -405,7 +406,7 @@ def get_target_anno(target):
|
|
405 |
)
|
406 |
* kpts_valid[:, None, None],
|
407 |
dtype=torch.float,
|
408 |
-
device=
|
409 |
)[None, ...]
|
410 |
target_cond = torch.cat(
|
411 |
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
|
@@ -525,12 +526,12 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
|
|
525 |
set_seed(seed)
|
526 |
z = torch.randn(
|
527 |
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
|
528 |
-
device=
|
529 |
)
|
530 |
target_cond = target_cond.repeat(num_gen, 1, 1, 1)
|
531 |
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
|
532 |
# novel view synthesis mode = off
|
533 |
-
nvs = torch.zeros(num_gen, dtype=torch.int, device=
|
534 |
z = torch.cat([z, z], 0)
|
535 |
model_kwargs = dict(
|
536 |
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
|
@@ -546,7 +547,7 @@ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
|
|
546 |
clip_denoised=False,
|
547 |
model_kwargs=model_kwargs,
|
548 |
progress=True,
|
549 |
-
device=
|
550 |
).chunk(2)
|
551 |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
|
552 |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
|
@@ -635,14 +636,14 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
635 |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
636 |
),
|
637 |
dtype=torch.float,
|
638 |
-
device=
|
639 |
).unsqueeze(0)[None, ...]
|
640 |
|
641 |
def make_ref_cond(
|
642 |
img,
|
643 |
keypts,
|
644 |
hand_mask,
|
645 |
-
device=
|
646 |
target_size=(256, 256),
|
647 |
latent_size=(32, 32),
|
648 |
):
|
@@ -678,7 +679,7 @@ def ready_sample(img_ori, inpaint_mask, keypts):
|
|
678 |
img,
|
679 |
keypts,
|
680 |
hand_mask * (1 - inpaint_mask),
|
681 |
-
device=
|
682 |
target_size=opts.image_size,
|
683 |
latent_size=opts.latent_size,
|
684 |
)
|
@@ -736,12 +737,12 @@ def sample_inpaint(
|
|
736 |
jump_n_sample = quality
|
737 |
cfg_scale = cfg
|
738 |
z = torch.randn(
|
739 |
-
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=
|
740 |
)
|
741 |
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
742 |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
743 |
# novel view synthesis mode = off
|
744 |
-
nvs = torch.zeros(N, dtype=torch.int, device=
|
745 |
z = torch.cat([z, z], 0)
|
746 |
model_kwargs = dict(
|
747 |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
@@ -759,7 +760,7 @@ def sample_inpaint(
|
|
759 |
clip_denoised=False,
|
760 |
model_kwargs=model_kwargs,
|
761 |
progress=True,
|
762 |
-
device=
|
763 |
jump_length=jump_length,
|
764 |
jump_n_sample=jump_n_sample,
|
765 |
).chunk(2)
|
|
|
33 |
torch.cuda.manual_seed_all(seed)
|
34 |
random.seed(seed)
|
35 |
|
36 |
+
if torch.cuda.is_available():
|
37 |
+
device = "cuda"
|
38 |
+
else:
|
39 |
+
device = "cpu"
|
40 |
|
41 |
def remove_prefix(text, prefix):
|
42 |
if text.startswith(prefix):
|
|
|
180 |
num_workers: int = 10
|
181 |
n_val_samples: int = 4
|
182 |
|
|
|
|
|
|
|
183 |
# load models
|
184 |
if NEW_MODEL:
|
185 |
opts = HandDiffOpts()
|
|
|
203 |
latent_dim=opts.latent_dim,
|
204 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
205 |
learn_sigma=True,
|
206 |
+
).to(device)
|
207 |
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
208 |
+
ckpt_state_dict = torch.load(model_path, map_location=torch.device(device))['ema_state_dict']
|
209 |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
210 |
model.eval()
|
211 |
print(missing_keys, extra_keys)
|
212 |
assert len(missing_keys) == 0
|
213 |
vae_state_dict = torch.load(vae_path)['state_dict']
|
214 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
|
215 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
216 |
autoencoder.eval()
|
217 |
assert len(missing_keys) == 0
|
|
|
226 |
latent_dim=opts.latent_dim,
|
227 |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
228 |
learn_sigma=True,
|
229 |
+
).to(device)
|
230 |
ckpt_state_dict = torch.load(model_path)['state_dict']
|
231 |
dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
|
232 |
vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
|
233 |
missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
|
234 |
model.eval()
|
235 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
236 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
|
237 |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
238 |
autoencoder.eval()
|
239 |
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
240 |
+
sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth", device=device)
|
241 |
|
242 |
|
243 |
print("Mediapipe hand detector and SAM ready...")
|
|
|
313 |
img,
|
314 |
keypts,
|
315 |
hand_mask,
|
316 |
+
device=device,
|
317 |
target_size=(256, 256),
|
318 |
latent_size=(32, 32),
|
319 |
):
|
|
|
349 |
img,
|
350 |
keypts,
|
351 |
hand_mask,
|
352 |
+
device=device,
|
353 |
target_size=opts.image_size,
|
354 |
latent_size=opts.latent_size,
|
355 |
)
|
|
|
406 |
)
|
407 |
* kpts_valid[:, None, None],
|
408 |
dtype=torch.float,
|
409 |
+
device=device,
|
410 |
)[None, ...]
|
411 |
target_cond = torch.cat(
|
412 |
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
|
|
|
526 |
set_seed(seed)
|
527 |
z = torch.randn(
|
528 |
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
|
529 |
+
device=device,
|
530 |
)
|
531 |
target_cond = target_cond.repeat(num_gen, 1, 1, 1)
|
532 |
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
|
533 |
# novel view synthesis mode = off
|
534 |
+
nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
|
535 |
z = torch.cat([z, z], 0)
|
536 |
model_kwargs = dict(
|
537 |
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
|
|
|
547 |
clip_denoised=False,
|
548 |
model_kwargs=model_kwargs,
|
549 |
progress=True,
|
550 |
+
device=device,
|
551 |
).chunk(2)
|
552 |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
|
553 |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
|
|
|
636 |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
637 |
),
|
638 |
dtype=torch.float,
|
639 |
+
device=device,
|
640 |
).unsqueeze(0)[None, ...]
|
641 |
|
642 |
def make_ref_cond(
|
643 |
img,
|
644 |
keypts,
|
645 |
hand_mask,
|
646 |
+
device=device,
|
647 |
target_size=(256, 256),
|
648 |
latent_size=(32, 32),
|
649 |
):
|
|
|
679 |
img,
|
680 |
keypts,
|
681 |
hand_mask * (1 - inpaint_mask),
|
682 |
+
device=device,
|
683 |
target_size=opts.image_size,
|
684 |
latent_size=opts.latent_size,
|
685 |
)
|
|
|
737 |
jump_n_sample = quality
|
738 |
cfg_scale = cfg
|
739 |
z = torch.randn(
|
740 |
+
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
|
741 |
)
|
742 |
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
743 |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
744 |
# novel view synthesis mode = off
|
745 |
+
nvs = torch.zeros(N, dtype=torch.int, device=device)
|
746 |
z = torch.cat([z, z], 0)
|
747 |
model_kwargs = dict(
|
748 |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
|
|
760 |
clip_denoised=False,
|
761 |
model_kwargs=model_kwargs,
|
762 |
progress=True,
|
763 |
+
device=device,
|
764 |
jump_length=jump_length,
|
765 |
jump_n_sample=jump_n_sample,
|
766 |
).chunk(2)
|