Chaerin5 commited on
Commit
702c185
·
1 Parent(s): 7df9bdd
Files changed (1) hide show
  1. app.py +168 -49
app.py CHANGED
@@ -256,19 +256,128 @@ hands = mp_hands.Hands(
256
  min_detection_confidence=0.1,
257
  )
258
 
259
- def make_ref_cond(
260
- image
261
- ):
262
- print("ready to run autoencoder")
263
- # print(f"image.device: {image.device}, type(image): {type(image)}")
264
- # image = image.to("cuda")
265
- print(f"autoencoder device: {next(autoencoder.parameters()).device}")
266
- latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
267
- return image[None, ...], latent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
 
270
  def get_ref_anno(ref):
271
- print("inside get_ref_anno")
272
  if ref is None:
273
  return (
274
  None,
@@ -280,11 +389,8 @@ def get_ref_anno(ref):
280
  img = ref["composite"][..., :3]
281
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
  keypts = np.zeros((42, 2))
283
- print("ready to run mediapipe")
284
  if REF_POSE_MASK:
285
- print(f"type(img): {type(img)}, img.shape: {img.shape}, img.dtype: {img.dtype}")
286
  mp_pose = hands.process(img)
287
- print("processed mediapipe")
288
  detected = np.array([0, 0])
289
  start_idx = 0
290
  if mp_pose.multi_hand_landmarks:
@@ -317,13 +423,11 @@ def get_ref_anno(ref):
317
  elif keypts[21].sum() != 0:
318
  input_point = np.array(keypts[21:22])
319
  input_label = np.array([1])
320
- print("ready to run SAM")
321
  masks, _, _ = sam_predictor.predict(
322
  point_coords=input_point,
323
  point_labels=input_label,
324
  multimask_output=False,
325
  )
326
- print("finished SAM")
327
  hand_mask = masks[0]
328
  masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
329
  ref_pose = visualize_hand(keypts, masked_img)
@@ -332,47 +436,62 @@ def get_ref_anno(ref):
332
  else:
333
  hand_mask = np.zeros_like(img[:,:, 0])
334
  ref_pose = np.zeros_like(img)
 
335
 
336
- image_transform = Compose(
337
- [
338
- ToTensor(),
339
- Resize(opts.image_size),
340
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
341
- ]
342
- )
343
- image = image_transform(img)
344
- kpts_valid = check_keypoints_validity(keypts, opts.image_size)
345
- heatmaps = torch.tensor(
346
- keypoint_heatmap(
347
- scale_keypoint(keypts, opts.image_size, opts.latent_size), opts.latent_size, var=1.0
 
 
348
  )
349
- * kpts_valid[:, None, None],
350
- dtype=torch.float,
351
- # device=device,
352
- )[None, ...]
353
- mask = torch.tensor(
354
- cv2.resize(
355
- hand_mask.astype(int),
356
- dsize=opts.latent_size,
357
- interpolation=cv2.INTER_NEAREST,
358
- ),
359
- dtype=torch.float,
360
- # device=device,
361
- ).unsqueeze(0)[None, ...]
362
- image, latent = make_ref_cond(
363
- image,
364
- # keypts,
365
- # hand_mask,
366
- # device=device,
367
- # target_size=opts.image_size,
368
- # latent_size=opts.latent_size,
369
- )
370
- print("finished autoencoder")
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  if not REF_POSE_MASK:
373
  heatmaps = torch.zeros_like(heatmaps)
374
  mask = torch.zeros_like(mask)
 
 
375
  ref_cond = torch.cat([latent, heatmaps, mask], 1)
 
376
 
377
  return img, ref_pose, ref_cond
378
 
 
256
  min_detection_confidence=0.1,
257
  )
258
 
259
+ # def make_ref_cond(
260
+ # image
261
+ # ):
262
+ # print("ready to run autoencoder")
263
+ # # print(f"image.device: {image.device}, type(image): {type(image)}")
264
+ # # image = image.to("cuda")
265
+ # print(f"autoencoder device: {next(autoencoder.parameters()).device}")
266
+ # latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
267
+ # return image[None, ...], latent
268
+
269
+
270
+ # def get_ref_anno(ref):
271
+ # print("inside get_ref_anno")
272
+ # if ref is None:
273
+ # return (
274
+ # None,
275
+ # None,
276
+ # None,
277
+ # None,
278
+ # None,
279
+ # )
280
+ # img = ref["composite"][..., :3]
281
+ # img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
+ # keypts = np.zeros((42, 2))
283
+ # print("ready to run mediapipe")
284
+ # if REF_POSE_MASK:
285
+ # print(f"type(img): {type(img)}, img.shape: {img.shape}, img.dtype: {img.dtype}")
286
+ # mp_pose = hands.process(img)
287
+ # print("processed mediapipe")
288
+ # detected = np.array([0, 0])
289
+ # start_idx = 0
290
+ # if mp_pose.multi_hand_landmarks:
291
+ # # handedness is flipped assuming the input image is mirrored in MediaPipe
292
+ # for hand_landmarks, handedness in zip(
293
+ # mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
294
+ # ):
295
+ # # actually right hand
296
+ # if handedness.classification[0].label == "Left":
297
+ # start_idx = 0
298
+ # detected[0] = 1
299
+ # # actually left hand
300
+ # elif handedness.classification[0].label == "Right":
301
+ # start_idx = 21
302
+ # detected[1] = 1
303
+ # for i, landmark in enumerate(hand_landmarks.landmark):
304
+ # keypts[start_idx + i] = [
305
+ # landmark.x * opts.image_size[1],
306
+ # landmark.y * opts.image_size[0],
307
+ # ]
308
+
309
+ # sam_predictor.set_image(img)
310
+ # l = keypts[:21].shape[0]
311
+ # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
312
+ # input_point = np.array([keypts[0], keypts[21]])
313
+ # input_label = np.array([1, 1])
314
+ # elif keypts[0].sum() != 0:
315
+ # input_point = np.array(keypts[:1])
316
+ # input_label = np.array([1])
317
+ # elif keypts[21].sum() != 0:
318
+ # input_point = np.array(keypts[21:22])
319
+ # input_label = np.array([1])
320
+ # print("ready to run SAM")
321
+ # masks, _, _ = sam_predictor.predict(
322
+ # point_coords=input_point,
323
+ # point_labels=input_label,
324
+ # multimask_output=False,
325
+ # )
326
+ # print("finished SAM")
327
+ # hand_mask = masks[0]
328
+ # masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
329
+ # ref_pose = visualize_hand(keypts, masked_img)
330
+ # else:
331
+ # raise gr.Error("No hands detected in the reference image.")
332
+ # else:
333
+ # hand_mask = np.zeros_like(img[:,:, 0])
334
+ # ref_pose = np.zeros_like(img)
335
+
336
+ # image_transform = Compose(
337
+ # [
338
+ # ToTensor(),
339
+ # Resize(opts.image_size),
340
+ # Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
341
+ # ]
342
+ # )
343
+ # image = image_transform(img)
344
+ # kpts_valid = check_keypoints_validity(keypts, opts.image_size)
345
+ # heatmaps = torch.tensor(
346
+ # keypoint_heatmap(
347
+ # scale_keypoint(keypts, opts.image_size, opts.latent_size), opts.latent_size, var=1.0
348
+ # )
349
+ # * kpts_valid[:, None, None],
350
+ # dtype=torch.float,
351
+ # # device=device,
352
+ # )[None, ...]
353
+ # mask = torch.tensor(
354
+ # cv2.resize(
355
+ # hand_mask.astype(int),
356
+ # dsize=opts.latent_size,
357
+ # interpolation=cv2.INTER_NEAREST,
358
+ # ),
359
+ # dtype=torch.float,
360
+ # # device=device,
361
+ # ).unsqueeze(0)[None, ...]
362
+ # image, latent = make_ref_cond(
363
+ # image,
364
+ # # keypts,
365
+ # # hand_mask,
366
+ # # device=device,
367
+ # # target_size=opts.image_size,
368
+ # # latent_size=opts.latent_size,
369
+ # )
370
+ # print("finished autoencoder")
371
+
372
+ # if not REF_POSE_MASK:
373
+ # heatmaps = torch.zeros_like(heatmaps)
374
+ # mask = torch.zeros_like(mask)
375
+ # ref_cond = torch.cat([latent, heatmaps, mask], 1)
376
+
377
+ # return img, ref_pose, ref_cond
378
 
379
 
380
  def get_ref_anno(ref):
 
381
  if ref is None:
382
  return (
383
  None,
 
389
  img = ref["composite"][..., :3]
390
  img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
391
  keypts = np.zeros((42, 2))
 
392
  if REF_POSE_MASK:
 
393
  mp_pose = hands.process(img)
 
394
  detected = np.array([0, 0])
395
  start_idx = 0
396
  if mp_pose.multi_hand_landmarks:
 
423
  elif keypts[21].sum() != 0:
424
  input_point = np.array(keypts[21:22])
425
  input_label = np.array([1])
 
426
  masks, _, _ = sam_predictor.predict(
427
  point_coords=input_point,
428
  point_labels=input_label,
429
  multimask_output=False,
430
  )
 
431
  hand_mask = masks[0]
432
  masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
433
  ref_pose = visualize_hand(keypts, masked_img)
 
436
  else:
437
  hand_mask = np.zeros_like(img[:,:, 0])
438
  ref_pose = np.zeros_like(img)
439
+ print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}")
440
 
441
+ def make_ref_cond(
442
+ img,
443
+ keypts,
444
+ hand_mask,
445
+ device="cuda",
446
+ target_size=(256, 256),
447
+ latent_size=(32, 32),
448
+ ):
449
+ image_transform = Compose(
450
+ [
451
+ ToTensor(),
452
+ Resize(target_size),
453
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
454
+ ]
455
  )
456
+ image = image_transform(img)
457
+ kpts_valid = check_keypoints_validity(keypts, target_size)
458
+ heatmaps = torch.tensor(
459
+ keypoint_heatmap(
460
+ scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
461
+ )
462
+ * kpts_valid[:, None, None],
463
+ dtype=torch.float,
464
+ )[None, ...]
465
+ mask = torch.tensor(
466
+ cv2.resize(
467
+ hand_mask.astype(int),
468
+ dsize=latent_size,
469
+ interpolation=cv2.INTER_NEAREST,
470
+ ),
471
+ dtype=torch.float,
472
+ ).unsqueeze(0)[None, ...]
473
+ return image[None, ...], heatmaps, mask
 
 
 
 
474
 
475
+ print(f"img.max(): {img.max()}, img.min(): {img.min()}")
476
+ image, heatmaps, mask = make_ref_cond(
477
+ img,
478
+ keypts,
479
+ hand_mask,
480
+ device="cuda",
481
+ target_size=opts.image_size,
482
+ latent_size=opts.latent_size,
483
+ )
484
+ print(f"image.max(): {image.max()}, image.min(): {image.min()}")
485
+ print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
486
+ latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
487
+ print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
488
  if not REF_POSE_MASK:
489
  heatmaps = torch.zeros_like(heatmaps)
490
  mask = torch.zeros_like(mask)
491
+ print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}")
492
+ print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}")
493
  ref_cond = torch.cat([latent, heatmaps, mask], 1)
494
+ print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
495
 
496
  return img, ref_pose, ref_cond
497