FoundHand / app.py
Chaerin5's picture
init
49f816b
raw
history blame
54.9 kB
import torch
from dataclasses import dataclass
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import cv2
import mediapipe as mp
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import vqvae
import vit
from typing import Literal
from diffusion import create_diffusion
from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
from segment_hoi import init_sam
from io import BytesIO
from PIL import Image
import random
from copy import deepcopy
from typing import Optional
MAX_N = 6
FIX_MAX_N = 6
placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
NEW_MODEL = True
MODEL_EPOCH = 6
REF_POSE_MASK = True
def set_seed(seed):
seed = int(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def unnormalize(x):
return (((x + 1) / 2) * 255).astype(np.uint8)
def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
# Define the connections between joints for drawing lines and their corresponding colors
connections = [
((0, 1), "red"),
((1, 2), "green"),
((2, 3), "blue"),
((3, 4), "purple"),
((0, 5), "orange"),
((5, 6), "pink"),
((6, 7), "brown"),
((7, 8), "cyan"),
((0, 9), "yellow"),
((9, 10), "magenta"),
((10, 11), "lime"),
((11, 12), "indigo"),
((0, 13), "olive"),
((13, 14), "teal"),
((14, 15), "navy"),
((15, 16), "gray"),
((0, 17), "lavender"),
((17, 18), "silver"),
((18, 19), "maroon"),
((19, 20), "fuchsia"),
]
H, W, C = img.shape
# Create a figure and axis
plt.figure()
ax = plt.gca()
# Plot joints as points
ax.imshow(img)
start_is = []
if "right" in side:
start_is.append(0)
if "left" in side:
start_is.append(21)
for start_i in start_is:
joints = all_joints[start_i : start_i + n_avail_joints]
if len(joints) == 1:
ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
else:
for connection, color in connections[: len(joints) - 1]:
joint1 = joints[connection[0]]
joint2 = joints[connection[1]]
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
ax.set_xlim([0, W])
ax.set_ylim([0, H])
ax.grid(False)
ax.set_axis_off()
ax.invert_yaxis()
# plt.subplots_adjust(wspace=0.01)
# plt.show()
buf = BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close()
# Convert BytesIO object to numpy array
buf.seek(0)
img_pil = Image.open(buf)
img_pil = img_pil.resize((H, W))
numpy_img = np.array(img_pil)
return numpy_img
def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
"""Overlay mask on image for visualization purpose.
Args:
image (H, W, 3) or (H, W): input image
mask (H, W): mask to be overlaid
color: the color of overlaid mask
alpha: the transparency of the mask
"""
out = deepcopy(image)
img = deepcopy(image)
img[mask == 1] = color
if transparent:
out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
else:
out = img
return out
def scale_keypoint(keypoint, original_size, target_size):
"""Scale a keypoint based on the resizing of the image."""
keypoint_copy = keypoint.copy()
keypoint_copy[:, 0] *= target_size[0] / original_size[0]
keypoint_copy[:, 1] *= target_size[1] / original_size[1]
return keypoint_copy
print("Configure...")
@dataclass
class HandDiffOpts:
run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
log_dir: str = "/users/kchen157/scratch/log"
data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
image_size: tuple = (256, 256)
latent_size: tuple = (32, 32)
latent_dim: int = 4
mask_bg: bool = False
kpts_form: str = "heatmap"
n_keypoints: int = 42
n_mask: int = 1
noise_steps: int = 1000
test_sampling_steps: int = 250
ddim_steps: int = 100
ddim_discretize: str = "uniform"
ddim_eta: float = 0.0
beta_start: float = 8.5e-4
beta_end: float = 0.012
latent_scaling_factor: float = 0.18215
cfg_pose: float = 5.0
cfg_appearance: float = 3.5
batch_size: int = 25
lr: float = 1e-5
max_epochs: int = 500
log_every_n_steps: int = 100
limit_val_batches: int = 1
n_gpu: int = 8
num_nodes: int = 1
precision: str = "16-mixed"
profiler: str = "simple"
swa_epoch_start: int = 10
swa_lrs: float = 1e-3
num_workers: int = 10
n_val_samples: int = 4
if not torch.cuda.is_available():
raise ValueError("No GPU")
# load models
if NEW_MODEL:
opts = HandDiffOpts()
if MODEL_EPOCH == 7:
model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
elif MODEL_EPOCH == 6:
# model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt")
elif MODEL_EPOCH == 4:
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
elif MODEL_EPOCH == 10:
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
else:
raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
# sd_path = './sd-v1-4.ckpt'
print('Load diffusion model...')
diffusion = create_diffusion(str(opts.test_sampling_steps))
model = vit.DiT_XL_2(
input_size=opts.latent_size[0],
latent_dim=opts.latent_dim,
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
learn_sigma=True,
).cuda()
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
ckpt_state_dict = torch.load(model_path, map_location=torch.device('cuda'))['ema_state_dict']
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
model.eval()
print(missing_keys, extra_keys)
assert len(missing_keys) == 0
vae_state_dict = torch.load(vae_path)['state_dict']
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
autoencoder.eval()
assert len(missing_keys) == 0
else:
opts = HandDiffOpts()
model_path = './finetune_epoch=5-step=130000.ckpt'
sd_path = './sd-v1-4.ckpt'
print('Load diffusion model...')
diffusion = create_diffusion(str(opts.test_sampling_steps))
model = vit.DiT_XL_2(
input_size=opts.latent_size[0],
latent_dim=opts.latent_dim,
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
learn_sigma=True,
).cuda()
ckpt_state_dict = torch.load(model_path)['state_dict']
dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
model.eval()
assert len(missing_keys) == 0 and len(extra_keys) == 0
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
autoencoder.eval()
assert len(missing_keys) == 0 and len(extra_keys) == 0
sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth")
print("Mediapipe hand detector and SAM ready...")
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
static_image_mode=True, # Use False if image is part of a video stream
max_num_hands=2, # Maximum number of hands to detect
min_detection_confidence=0.1,
)
def get_ref_anno(ref):
if ref is None:
return (
None,
None,
None,
None,
None,
)
img = ref["composite"][..., :3]
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
keypts = np.zeros((42, 2))
if REF_POSE_MASK:
mp_pose = hands.process(img)
detected = np.array([0, 0])
start_idx = 0
if mp_pose.multi_hand_landmarks:
# handedness is flipped assuming the input image is mirrored in MediaPipe
for hand_landmarks, handedness in zip(
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
):
# actually right hand
if handedness.classification[0].label == "Left":
start_idx = 0
detected[0] = 1
# actually left hand
elif handedness.classification[0].label == "Right":
start_idx = 21
detected[1] = 1
for i, landmark in enumerate(hand_landmarks.landmark):
keypts[start_idx + i] = [
landmark.x * opts.image_size[1],
landmark.y * opts.image_size[0],
]
sam_predictor.set_image(img)
l = keypts[:21].shape[0]
if keypts[0].sum() != 0 and keypts[21].sum() != 0:
input_point = np.array([keypts[0], keypts[21]])
input_label = np.array([1, 1])
elif keypts[0].sum() != 0:
input_point = np.array(keypts[:1])
input_label = np.array([1])
elif keypts[21].sum() != 0:
input_point = np.array(keypts[21:22])
input_label = np.array([1])
masks, _, _ = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
hand_mask = masks[0]
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
ref_pose = visualize_hand(keypts, masked_img)
else:
raise gr.Error("No hands detected in the reference image.")
else:
hand_mask = np.zeros_like(img[:,:, 0])
ref_pose = np.zeros_like(img)
def make_ref_cond(
img,
keypts,
hand_mask,
device="cuda",
target_size=(256, 256),
latent_size=(32, 32),
):
image_transform = Compose(
[
ToTensor(),
Resize(target_size),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
image = image_transform(img).to(device)
kpts_valid = check_keypoints_validity(keypts, target_size)
heatmaps = torch.tensor(
keypoint_heatmap(
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
)
* kpts_valid[:, None, None],
dtype=torch.float,
device=device,
)[None, ...]
mask = torch.tensor(
cv2.resize(
hand_mask.astype(int),
dsize=latent_size,
interpolation=cv2.INTER_NEAREST,
),
dtype=torch.float,
device=device,
).unsqueeze(0)[None, ...]
return image[None, ...], heatmaps, mask
image, heatmaps, mask = make_ref_cond(
img,
keypts,
hand_mask,
device="cuda",
target_size=opts.image_size,
latent_size=opts.latent_size,
)
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
if not REF_POSE_MASK:
heatmaps = torch.zeros_like(heatmaps)
mask = torch.zeros_like(mask)
ref_cond = torch.cat([latent, heatmaps, mask], 1)
return img, ref_pose, ref_cond
def get_target_anno(target):
if target is None:
return (
gr.State.update(value=None),
gr.Image.update(value=None),
gr.State.update(value=None),
gr.State.update(value=None),
)
pose_img = target["composite"][..., :3]
pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
# detect keypoints
mp_pose = hands.process(pose_img)
target_keypts = np.zeros((42, 2))
detected = np.array([0, 0])
start_idx = 0
if mp_pose.multi_hand_landmarks:
# handedness is flipped assuming the input image is mirrored in MediaPipe
for hand_landmarks, handedness in zip(
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
):
# actually right hand
if handedness.classification[0].label == "Left":
start_idx = 0
detected[0] = 1
# actually left hand
elif handedness.classification[0].label == "Right":
start_idx = 21
detected[1] = 1
for i, landmark in enumerate(hand_landmarks.landmark):
target_keypts[start_idx + i] = [
landmark.x * opts.image_size[1],
landmark.y * opts.image_size[0],
]
target_pose = visualize_hand(target_keypts, pose_img)
kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
target_heatmaps = torch.tensor(
keypoint_heatmap(
scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
opts.latent_size,
var=1.0,
)
* kpts_valid[:, None, None],
dtype=torch.float,
device="cuda",
)[None, ...]
target_cond = torch.cat(
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
)
else:
raise gr.Error("No hands detected in the target image.")
return pose_img, target_pose, target_cond, target_keypts
# def draw_grid(ref):
# if ref is None or ref["composite"] is None: # or len(ref["layers"])==0:
# return ref
# # if len(ref["layers"]) == 1:
# # need_draw = True
# # # elif ref["composite"].shape[0] != size_memory[0] or ref["composite"].shape[1] != size_memory[1]:
# # # need_draw = True
# # else:
# # need_draw = False
# # size_memory = ref["composite"].shape[0], ref["composite"].shape[1]
# # if not need_draw:
# # return size_memory, ref
# h, w = ref["composite"].shape[:2]
# grid_h, grid_w = h // 32, w // 32
# # grid = np.zeros((h, w, 4), dtype=np.uint8)
# for i in range(1, grid_h):
# ref["composite"][i * 32, :, :3] = 255 # 0.5 * ref["composite"][i * 32, :, :3] +
# for i in range(1, grid_w):
# ref["composite"][:, i * 32, :3] = 255 # 0.5 * ref["composite"][:, i * 32, :3] +
# # if len(ref["layers"]) == 1:
# # ref["layers"].append(grid)
# # else:
# # ref["layers"][1] = grid
# return ref["composite"]
def get_mask_inpaint(ref):
inpaint_mask = np.array(ref["layers"][0])[..., -1]
inpaint_mask = cv2.resize(
inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
)
inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
return inpaint_mask
def visualize_ref(crop, brush):
if crop is None or brush is None:
return None
inpainted = brush["layers"][0][..., -1]
img = crop["background"][..., :3]
img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
mask = inpainted < 128
# img = img.astype(np.int32)
# img[mask, :] = img[mask, :] - 50
# img[np.any(img<0, axis=-1)]=0
# img = img.astype(np.uint8)
img = mask_image(img, mask)
return img
def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
if keypoints is None:
keypoints = [[], []]
kps = np.zeros((42, 2))
if side == "right":
if len(keypoints[0]) == 21:
gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
else:
keypoints[0].append(list(evt.index))
len_kps = len(keypoints[0])
kps[:len_kps] = np.array(keypoints[0])
elif side == "left":
if len(keypoints[1]) == 21:
gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
else:
keypoints[1].append(list(evt.index))
len_kps = len(keypoints[1])
kps[21 : 21 + len_kps] = np.array(keypoints[1])
vis_hand = visualize_hand(kps, img, side, len_kps)
return vis_hand, keypoints
def undo_kps(img, keypoints, side: Literal["right", "left"]):
if keypoints is None:
return img, None
kps = np.zeros((42, 2))
if side == "right":
if len(keypoints[0]) == 0:
return img, keypoints
keypoints[0].pop()
len_kps = len(keypoints[0])
kps[:len_kps] = np.array(keypoints[0])
elif side == "left":
if len(keypoints[1]) == 0:
return img, keypoints
keypoints[1].pop()
len_kps = len(keypoints[1])
kps[21 : 21 + len_kps] = np.array(keypoints[1])
vis_hand = visualize_hand(kps, img, side, len_kps)
return vis_hand, keypoints
def reset_kps(img, keypoints, side: Literal["right", "left"]):
if keypoints is None:
return img, None
if side == "right":
keypoints[0] = []
elif side == "left":
keypoints[1] = []
return img, keypoints
def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
set_seed(seed)
z = torch.randn(
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
device="cuda",
)
target_cond = target_cond.repeat(num_gen, 1, 1, 1)
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
# novel view synthesis mode = off
nvs = torch.zeros(num_gen, dtype=torch.int, device="cuda")
z = torch.cat([z, z], 0)
model_kwargs = dict(
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
cfg_scale=cfg,
)
samples, _ = diffusion.p_sample_loop(
model.forward_with_cfg,
z.shape,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device="cuda",
).chunk(2)
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
results = []
results_pose = []
for i in range(MAX_N):
if i < num_gen:
results.append(sampled_images[i])
results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
else:
results.append(placeholder)
results_pose.append(placeholder)
return results, results_pose
def ready_sample(img_ori, inpaint_mask, keypts):
img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
sam_predictor.set_image(img)
if len(keypts[0]) == 0:
keypts[0] = np.zeros((21, 2))
elif len(keypts[0]) == 21:
keypts[0] = np.array(keypts[0], dtype=np.float32)
else:
gr.Info("Number of right hand keypoints should be either 0 or 21.")
return None, None
if len(keypts[1]) == 0:
keypts[1] = np.zeros((21, 2))
elif len(keypts[1]) == 21:
keypts[1] = np.array(keypts[1], dtype=np.float32)
else:
gr.Info("Number of left hand keypoints should be either 0 or 21.")
return None, None
keypts = np.concatenate(keypts, axis=0)
keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
# if keypts[0].sum() != 0 and keypts[21].sum() != 0:
# input_point = np.array([keypts[0], keypts[21]])
# # input_point = keypts
# input_label = np.array([1, 1])
# # input_label = np.ones_like(input_point[:, 0])
# elif keypts[0].sum() != 0:
# input_point = np.array(keypts[:1])
# # input_point = keypts[:21]
# input_label = np.array([1])
# # input_label = np.ones_like(input_point[:21, 0])
# elif keypts[21].sum() != 0:
# input_point = np.array(keypts[21:22])
# # input_point = keypts[21:]
# input_label = np.array([1])
# # input_label = np.ones_like(input_point[21:, 0])
box_shift_ratio = 0.5
box_size_factor = 1.2
if keypts[0].sum() != 0 and keypts[21].sum() != 0:
input_point = np.array(keypts)
input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
elif keypts[0].sum() != 0:
input_point = np.array(keypts[:21])
input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
elif keypts[21].sum() != 0:
input_point = np.array(keypts[21:])
input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
else:
raise ValueError(
"Something wrong. If no hand detected, it should not reach here."
)
input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
masks, _, _ = sam_predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box[None, :],
multimask_output=False,
)
hand_mask = masks[0]
inpaint_latent_mask = torch.tensor(
cv2.resize(
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
),
dtype=torch.float,
device="cuda",
).unsqueeze(0)[None, ...]
def make_ref_cond(
img,
keypts,
hand_mask,
device="cuda",
target_size=(256, 256),
latent_size=(32, 32),
):
image_transform = Compose(
[
ToTensor(),
Resize(target_size),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
image = image_transform(img).to(device)
kpts_valid = check_keypoints_validity(keypts, target_size)
heatmaps = torch.tensor(
keypoint_heatmap(
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
)
* kpts_valid[:, None, None],
dtype=torch.float,
device=device,
)[None, ...]
mask = torch.tensor(
cv2.resize(
hand_mask.astype(int),
dsize=latent_size,
interpolation=cv2.INTER_NEAREST,
),
dtype=torch.float,
device=device,
).unsqueeze(0)[None, ...]
return image[None, ...], heatmaps, mask
image, heatmaps, mask = make_ref_cond(
img,
keypts,
hand_mask * (1 - inpaint_mask),
device="cuda",
target_size=opts.image_size,
latent_size=opts.latent_size,
)
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
ref_cond = torch.cat([latent, heatmaps, mask], 1)
ref_cond = torch.zeros_like(ref_cond)
img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
assert mask.max() == 1
vis_mask32 = mask_image(
img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()
assert np.unique(inpaint_mask).shape[0] <= 2
assert hand_mask.dtype == bool
mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
np.uint8
) # 1 - mask256
return (
ref_cond,
target_cond,
latent,
inpaint_latent_mask,
keypts,
vis_mask32,
vis_mask256,
)
def switch_mask_size(radio):
if radio == "256x256":
out = (gr.update(visible=False), gr.update(visible=True))
elif radio == "latent size (32x32)":
out = (gr.update(visible=True), gr.update(visible=False))
return out
def sample_inpaint(
ref_cond,
target_cond,
latent,
inpaint_latent_mask,
keypts,
num_gen,
seed,
cfg,
quality,
):
set_seed(seed)
N = num_gen
jump_length = 10
jump_n_sample = quality
cfg_scale = cfg
z = torch.randn(
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device="cuda"
)
target_cond_N = target_cond.repeat(N, 1, 1, 1)
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
# novel view synthesis mode = off
nvs = torch.zeros(N, dtype=torch.int, device="cuda")
z = torch.cat([z, z], 0)
model_kwargs = dict(
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
cfg_scale=cfg_scale,
)
samples, _ = diffusion.inpaint_p_sample_loop(
model.forward_with_cfg,
z.shape,
latent,
inpaint_latent_mask,
z,
clip_denoised=False,
model_kwargs=model_kwargs,
progress=True,
device="cuda",
jump_length=jump_length,
jump_n_sample=jump_n_sample,
).chunk(2)
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
# visualize
results = []
results_pose = []
for i in range(FIX_MAX_N):
if i < num_gen:
results.append(sampled_images[i])
results_pose.append(visualize_hand(keypts, sampled_images[i]))
else:
results.append(placeholder)
results_pose.append(placeholder)
return results, results_pose
def flip_hand(
img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None
):
if cond is None: # clear clicked
return None, None, None, None
img["composite"] = img["composite"][:, ::-1, :]
img["background"] = img["background"][:, ::-1, :]
img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
pose_img = pose_img[:, ::-1, :]
cond = cond.flip(-1)
if keypts is not None: # cond is target_cond
if keypts[:21, :].sum() != 0:
keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
# keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
if keypts[21:, :].sum() != 0:
keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
# keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
return img, pose_img, cond, keypts
def resize_to_full(img):
img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
return img
def clear_all():
return (
None,
None,
False,
None,
None,
False,
None,
None,
None,
None,
None,
None,
None,
1,
42,
3.0,
)
def fix_clear_all():
return (
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
1,
# (0,0),
42,
3.0,
10,
)
def enable_component(image1, image2):
if image1 is None or image2 is None:
return gr.update(interactive=False)
if "background" in image1 and "layers" in image1 and "composite" in image1:
if (
image1["background"].sum() == 0
and (sum([im.sum() for im in image1["layers"]]) == 0)
and image1["composite"].sum() == 0
):
return gr.update(interactive=False)
if "background" in image2 and "layers" in image2 and "composite" in image2:
if (
image2["background"].sum() == 0
and (sum([im.sum() for im in image2["layers"]]) == 0)
and image2["composite"].sum() == 0
):
return gr.update(interactive=False)
return gr.update(interactive=True)
def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left):
if kpts is None:
kpts = [[], []]
if "Right hand" not in checkbox:
kpts[0] = []
vis_right = img_clean
update_right = gr.update(visible=False)
update_r_info = gr.update(visible=False)
else:
vis_right = img_pose_right
update_right = gr.update(visible=True)
update_r_info = gr.update(visible=True)
if "Left hand" not in checkbox:
kpts[1] = []
vis_left = img_clean
update_left = gr.update(visible=False)
update_l_info = gr.update(visible=False)
else:
vis_left = img_pose_left
update_left = gr.update(visible=True)
update_l_info = gr.update(visible=True)
return (
kpts,
vis_right,
vis_left,
update_right,
update_right,
update_right,
update_left,
update_left,
update_left,
update_r_info,
update_l_info,
)
# def parse_fix_example(ex_img, ex_masked):
# original_img = ex_img
# # ex_img = cv2.resize(ex_img, (LENGTH, LENGTH), interpolation=cv2.INTER_AREA)
# # ex_masked = cv2.resize(ex_masked, (LENGTH, LENGTH), interpolation=cv2.INTER_AREA)
# inpaint_mask = np.all(ex_masked > 250, axis=-1).astype(np.uint8)
# layer = np.ones_like(ex_img) * 255
# layer = np.concatenate([layer, np.zeros_like(ex_img[..., 0:1])], axis=-1)
# layer[inpaint_mask == 1, 3] = 255
# ref_value = {
# "composite": ex_masked,
# "background": ex_img,
# "layers": [layer],
# }
# inpaint_mask = cv2.resize(
# inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
# )
# kp_img = visualize_ref(ref_value)
# return (
# original_img,
# gr.update(value=ref_value),
# kp_img,
# inpaint_mask,
# )
LENGTH = 480
example_imgs = [
[
"sample_images/sample1.jpg",
],
[
"sample_images/sample2.jpg",
],
[
"sample_images/sample3.jpg",
],
[
"sample_images/sample4.jpg",
],
[
"sample_images/sample5.jpg",
],
[
"sample_images/sample6.jpg",
],
[
"sample_images/sample7.jpg",
],
[
"sample_images/sample8.jpg",
],
[
"sample_images/sample9.jpg",
],
[
"sample_images/sample10.jpg",
],
[
"sample_images/sample11.jpg",
],
["pose_images/pose1.jpg"],
["pose_images/pose2.jpg"],
["pose_images/pose3.jpg"],
["pose_images/pose4.jpg"],
["pose_images/pose5.jpg"],
["pose_images/pose6.jpg"],
["pose_images/pose7.jpg"],
["pose_images/pose8.jpg"],
]
fix_example_imgs = [
["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"],
["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"],
["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"],
["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"],
["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"],
["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"],
]
custom_css = """
.gradio-container .examples img {
width: 240px !important;
height: 240px !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
with gr.Tab("Edit Hand Poses"):
ref_img = gr.State(value=None)
ref_cond = gr.State(value=None)
keypts = gr.State(value=None)
target_img = gr.State(value=None)
target_cond = gr.State(value=None)
target_keypts = gr.State(value=None)
dump = gr.State(value=None)
with gr.Row():
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Reference</p>"""
)
gr.Markdown("""<p style="text-align: center;"><br></p>""")
ref = gr.ImageEditor(
type="numpy",
label="Reference",
show_label=True,
height=LENGTH,
width=LENGTH,
brush=False,
layers=False,
crop_size="1:1",
)
ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
ref_pose = gr.Image(
type="numpy",
label="Reference Pose",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
)
ref_flip = gr.Checkbox(
value=False, label="Flip Handedness (Reference)", interactive=False
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">2. Target</p>"""
)
target = gr.ImageEditor(
type="numpy",
label="Target",
show_label=True,
height=LENGTH,
width=LENGTH,
brush=False,
layers=False,
crop_size="1:1",
)
target_finish_crop = gr.Button(
value="Finish Cropping", interactive=False
)
target_pose = gr.Image(
type="numpy",
label="Target Pose",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
)
target_flip = gr.Checkbox(
value=False, label="Flip Handedness (Target)", interactive=False
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">3. Result</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">Run is enabled after the images have been processed</p>"""
)
run = gr.Button(value="Run", interactive=False)
gr.Markdown(
"""<p style="text-align: center;">~20s per generation. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
)
results = gr.Gallery(
type="numpy",
label="Results",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=MAX_N,
interactive=False,
preview=True,
)
results_pose = gr.Gallery(
type="numpy",
label="Results Pose",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=MAX_N,
interactive=False,
preview=True,
)
clear = gr.ClearButton()
with gr.Row():
n_generation = gr.Slider(
label="Number of generations",
value=1,
minimum=1,
maximum=MAX_N,
step=1,
randomize=False,
interactive=True,
)
seed = gr.Slider(
label="Seed",
value=42,
minimum=0,
maximum=10000,
step=1,
randomize=False,
interactive=True,
)
cfg = gr.Slider(
label="Classifier free guidance scale",
value=2.5,
minimum=0.0,
maximum=10.0,
step=0.1,
randomize=False,
interactive=True,
)
ref.change(enable_component, [ref, ref], ref_finish_crop)
ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond])
ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
ref_flip.select(
flip_hand, [ref, ref_pose, ref_cond], [ref, ref_pose, ref_cond, dump]
)
target.change(enable_component, [target, target], target_finish_crop)
target_finish_crop.click(
get_target_anno,
[target],
[target_img, target_pose, target_cond, target_keypts],
)
target_pose.change(enable_component, [target_img, target_pose], target_flip)
target_flip.select(
flip_hand,
[target, target_pose, target_cond, target_keypts],
[target, target_pose, target_cond, target_keypts],
)
ref_pose.change(enable_component, [ref_pose, target_pose], run)
target_pose.change(enable_component, [ref_pose, target_pose], run)
run.click(
sample_diff,
[ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
[results, results_pose],
)
clear.click(
clear_all,
[],
[
ref,
ref_pose,
ref_flip,
target,
target_pose,
target_flip,
results,
results_pose,
ref_img,
ref_cond,
# mask,
target_img,
target_cond,
target_keypts,
n_generation,
seed,
cfg,
],
)
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
with gr.Tab("Reference"):
with gr.Row():
gr.Examples(example_imgs, [ref], examples_per_page=20)
with gr.Tab("Target"):
with gr.Row():
gr.Examples(example_imgs, [target], examples_per_page=20)
with gr.Tab("Fix Hands"):
fix_inpaint_mask = gr.State(value=None)
fix_original = gr.State(value=None)
fix_img = gr.State(value=None)
fix_kpts = gr.State(value=None)
fix_kpts_np = gr.State(value=None)
fix_ref_cond = gr.State(value=None)
fix_target_cond = gr.State(value=None)
fix_latent = gr.State(value=None)
fix_inpaint_latent = gr.State(value=None)
# fix_size_memory = gr.State(value=(0, 0))
with gr.Row():
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">Crop the image around the hand.<br>Then, brush area (e.g., wrong finger) that needs to be fixed.</p>"""
)
gr.Markdown(
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>"""
)
fix_crop = gr.ImageEditor(
type="numpy",
sources=["upload", "webcam", "clipboard"],
label="Image crop",
show_label=True,
height=LENGTH,
width=LENGTH,
layers=False,
crop_size="1:1",
brush=False,
image_mode="RGBA",
container=False,
)
gr.Markdown(
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>"""
)
fix_ref = gr.ImageEditor(
type="numpy",
label="Image brush",
sources=(),
show_label=True,
height=LENGTH,
width=LENGTH,
layers=False,
transforms=("brush"),
brush=gr.Brush(
colors=["rgb(255, 255, 255)"], default_size=20
), # 204, 50, 50
image_mode="RGBA",
container=False,
interactive=False,
)
fix_finish_crop = gr.Button(
value="Finish Croping & Brushing", interactive=False
)
gr.Markdown(
"""<p style="text-align: left; font-size: 20px; font-weight: bold; ">OpenPose keypoints convention</p>"""
)
fix_openpose = gr.Image(
value="openpose.png",
type="numpy",
label="OpenPose keypoints convention",
show_label=True,
height=LENGTH // 3 * 2,
width=LENGTH // 3 * 2,
interactive=False,
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\" on the bottom left.</p>"""
)
fix_checkbox = gr.CheckboxGroup(
["Right hand", "Left hand"],
# value=["Right hand", "Left hand"],
label="Hand side",
info="Which side this hand is? Could be both.",
interactive=False,
)
fix_kp_r_info = gr.Markdown(
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
visible=False,
)
fix_kp_right = gr.Image(
type="numpy",
label="Keypoint Selection (right hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
fix_undo_right = gr.Button(
value="Undo", interactive=False, visible=False
)
fix_reset_right = gr.Button(
value="Reset", interactive=False, visible=False
)
fix_kp_l_info = gr.Markdown(
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
visible=False
)
fix_kp_left = gr.Image(
type="numpy",
label="Keypoint Selection (left hand)",
show_label=True,
height=LENGTH,
width=LENGTH,
interactive=False,
visible=False,
sources=[],
)
with gr.Row():
fix_undo_left = gr.Button(
value="Undo", interactive=False, visible=False
)
fix_reset_left = gr.Button(
value="Reset", interactive=False, visible=False
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>"""
)
gr.Markdown(
"""<p style="text-align: center;">In Fix Hands, not segmentation mask, but only inpaint mask is used.</p>"""
)
fix_ready = gr.Button(value="Ready", interactive=False)
fix_mask_size = gr.Radio(
["256x256", "latent size (32x32)"],
label="Visualized inpaint mask size",
interactive=False,
value="256x256",
)
gr.Markdown(
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Visualized inpaint masks</p>"""
)
fix_vis_mask32 = gr.Image(
type="numpy",
label=f"Visualized {opts.latent_size} Inpaint Mask",
show_label=True,
height=opts.latent_size,
width=opts.latent_size,
interactive=False,
visible=False,
)
fix_vis_mask256 = gr.Image(
type="numpy",
label=f"Visualized {opts.image_size} Inpaint Mask",
visible=True,
show_label=True,
height=opts.image_size,
width=opts.image_size,
interactive=False,
)
with gr.Column():
gr.Markdown(
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>"""
)
fix_run = gr.Button(value="Run", interactive=False)
gr.Markdown(
"""<p style="text-align: center;">>3min and ~24GB per generation</p>"""
)
fix_result = gr.Gallery(
type="numpy",
label="Results",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=FIX_MAX_N,
interactive=False,
preview=True,
)
fix_result_pose = gr.Gallery(
type="numpy",
label="Results Pose",
show_label=True,
height=LENGTH,
min_width=LENGTH,
columns=FIX_MAX_N,
interactive=False,
preview=True,
)
fix_clear = gr.ClearButton()
gr.Markdown(
"[NOTE] Currently, Number of generation > 1 could lead to out-of-memory"
)
with gr.Row():
fix_n_generation = gr.Slider(
label="Number of generations",
value=1,
minimum=1,
maximum=FIX_MAX_N,
step=1,
randomize=False,
interactive=True,
)
fix_seed = gr.Slider(
label="Seed",
value=42,
minimum=0,
maximum=10000,
step=1,
randomize=False,
interactive=True,
)
fix_cfg = gr.Slider(
label="Classifier free guidance scale",
value=3.0,
minimum=0.0,
maximum=10.0,
step=0.1,
randomize=False,
interactive=True,
)
fix_quality = gr.Slider(
label="Quality",
value=10,
minimum=1,
maximum=10,
step=1,
randomize=False,
interactive=True,
)
fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
fix_crop.change(resize_to_full, fix_crop, fix_ref)
fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right])
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left])
fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
)
fix_inpaint_mask.change(
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready
)
# fix_inpaint_mask.change(
# enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run
# )
fix_checkbox.select(
set_visible,
[fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
[
fix_kpts,
fix_kp_right,
fix_kp_left,
fix_kp_right,
fix_undo_right,
fix_reset_right,
fix_kp_left,
fix_undo_left,
fix_reset_left,
fix_kp_r_info,
fix_kp_l_info,
],
)
fix_kp_right.select(
get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
)
fix_undo_right.click(
undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
)
fix_reset_right.click(
reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
)
fix_kp_left.select(
get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
)
fix_undo_left.click(
undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
)
fix_reset_left.click(
reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
)
# fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run])
# fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose])
fix_vis_mask32.change(
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
)
fix_vis_mask32.change(
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size
)
fix_ready.click(
ready_sample,
[fix_original, fix_inpaint_mask, fix_kpts],
[
fix_ref_cond,
fix_target_cond,
fix_latent,
fix_inpaint_latent,
fix_kpts_np,
fix_vis_mask32,
fix_vis_mask256,
],
)
fix_mask_size.select(
switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256]
)
fix_run.click(
sample_inpaint,
[
fix_ref_cond,
fix_target_cond,
fix_latent,
fix_inpaint_latent,
fix_kpts_np,
fix_n_generation,
fix_seed,
fix_cfg,
fix_quality,
],
[fix_result, fix_result_pose],
)
fix_clear.click(
fix_clear_all,
[],
[
fix_crop,
fix_ref,
fix_kp_right,
fix_kp_left,
fix_result,
fix_result_pose,
fix_inpaint_mask,
fix_original,
fix_img,
fix_vis_mask32,
fix_vis_mask256,
fix_kpts,
fix_kpts_np,
fix_ref_cond,
fix_target_cond,
fix_latent,
fix_inpaint_latent,
fix_n_generation,
# fix_size_memory,
fix_seed,
fix_cfg,
fix_quality,
],
)
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False)
fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False)
with gr.Column():
fix_example = gr.Examples(
fix_example_imgs,
# run_on_click=True,
# fn=parse_fix_example,
# inputs=[fix_dump_ex, fix_dump_ex_masked],
# outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask],
inputs=[fix_crop],
examples_per_page=20,
)
print("Ready to launch..")
_, _, shared_url = demo.queue().launch(
share=True, server_name="0.0.0.0", server_port=7739
)
demo.block()