Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from dataclasses import dataclass | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
import mediapipe as mp | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
import vqvae | |
import vit | |
from typing import Literal | |
from diffusion import create_diffusion | |
from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity | |
from segment_hoi import init_sam | |
from io import BytesIO | |
from PIL import Image | |
import random | |
from copy import deepcopy | |
from typing import Optional | |
import requests | |
from huggingface_hub import hf_hub_download | |
import spaces | |
MAX_N = 6 | |
FIX_MAX_N = 6 | |
placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB) | |
NEW_MODEL = True | |
MODEL_EPOCH = 6 | |
REF_POSE_MASK = True | |
def set_seed(seed): | |
seed = int(seed) | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
# if torch.cuda.is_available(): | |
device = "cuda" | |
# else: | |
# device = "cpu" | |
def remove_prefix(text, prefix): | |
if text.startswith(prefix): | |
return text[len(prefix) :] | |
return text | |
def unnormalize(x): | |
return (((x + 1) / 2) * 255).astype(np.uint8) | |
def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21): | |
# Define the connections between joints for drawing lines and their corresponding colors | |
connections = [ | |
((0, 1), "red"), | |
((1, 2), "green"), | |
((2, 3), "blue"), | |
((3, 4), "purple"), | |
((0, 5), "orange"), | |
((5, 6), "pink"), | |
((6, 7), "brown"), | |
((7, 8), "cyan"), | |
((0, 9), "yellow"), | |
((9, 10), "magenta"), | |
((10, 11), "lime"), | |
((11, 12), "indigo"), | |
((0, 13), "olive"), | |
((13, 14), "teal"), | |
((14, 15), "navy"), | |
((15, 16), "gray"), | |
((0, 17), "lavender"), | |
((17, 18), "silver"), | |
((18, 19), "maroon"), | |
((19, 20), "fuchsia"), | |
] | |
H, W, C = img.shape | |
# Create a figure and axis | |
plt.figure() | |
ax = plt.gca() | |
# Plot joints as points | |
ax.imshow(img) | |
start_is = [] | |
if "right" in side: | |
start_is.append(0) | |
if "left" in side: | |
start_is.append(21) | |
for start_i in start_is: | |
joints = all_joints[start_i : start_i + n_avail_joints] | |
if len(joints) == 1: | |
ax.scatter(joints[0][0], joints[0][1], color="red", s=10) | |
else: | |
for connection, color in connections[: len(joints) - 1]: | |
joint1 = joints[connection[0]] | |
joint2 = joints[connection[1]] | |
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color) | |
ax.set_xlim([0, W]) | |
ax.set_ylim([0, H]) | |
ax.grid(False) | |
ax.set_axis_off() | |
ax.invert_yaxis() | |
# plt.subplots_adjust(wspace=0.01) | |
# plt.show() | |
buf = BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
plt.close() | |
# Convert BytesIO object to numpy array | |
buf.seek(0) | |
img_pil = Image.open(buf) | |
img_pil = img_pil.resize((H, W)) | |
numpy_img = np.array(img_pil) | |
return numpy_img | |
def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True): | |
"""Overlay mask on image for visualization purpose. | |
Args: | |
image (H, W, 3) or (H, W): input image | |
mask (H, W): mask to be overlaid | |
color: the color of overlaid mask | |
alpha: the transparency of the mask | |
""" | |
out = deepcopy(image) | |
img = deepcopy(image) | |
img[mask == 1] = color | |
if transparent: | |
out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out) | |
else: | |
out = img | |
return out | |
def scale_keypoint(keypoint, original_size, target_size): | |
"""Scale a keypoint based on the resizing of the image.""" | |
keypoint_copy = keypoint.copy() | |
keypoint_copy[:, 0] *= target_size[0] / original_size[0] | |
keypoint_copy[:, 1] *= target_size[1] / original_size[1] | |
return keypoint_copy | |
print("Configure...") | |
class HandDiffOpts: | |
run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5" | |
sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt" | |
log_dir: str = "/users/kchen157/scratch/log" | |
data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff" | |
image_size: tuple = (256, 256) | |
latent_size: tuple = (32, 32) | |
latent_dim: int = 4 | |
mask_bg: bool = False | |
kpts_form: str = "heatmap" | |
n_keypoints: int = 42 | |
n_mask: int = 1 | |
noise_steps: int = 1000 | |
test_sampling_steps: int = 250 | |
ddim_steps: int = 100 | |
ddim_discretize: str = "uniform" | |
ddim_eta: float = 0.0 | |
beta_start: float = 8.5e-4 | |
beta_end: float = 0.012 | |
latent_scaling_factor: float = 0.18215 | |
cfg_pose: float = 5.0 | |
cfg_appearance: float = 3.5 | |
batch_size: int = 25 | |
lr: float = 1e-5 | |
max_epochs: int = 500 | |
log_every_n_steps: int = 100 | |
limit_val_batches: int = 1 | |
n_gpu: int = 8 | |
num_nodes: int = 1 | |
precision: str = "16-mixed" | |
profiler: str = "simple" | |
swa_epoch_start: int = 10 | |
swa_lrs: float = 1e-3 | |
num_workers: int = 10 | |
n_val_samples: int = 4 | |
# load models | |
token = os.getenv("HF_TOKEN") | |
if NEW_MODEL: | |
opts = HandDiffOpts() | |
if MODEL_EPOCH == 7: | |
model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt' | |
elif MODEL_EPOCH == 6: | |
# model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt" | |
model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token) | |
elif MODEL_EPOCH == 4: | |
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt" | |
elif MODEL_EPOCH == 10: | |
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt" | |
else: | |
raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}") | |
# vae_path = './vae-ft-mse-840000-ema-pruned.ckpt' | |
vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token) | |
# sd_path = './sd-v1-4.ckpt' | |
print('Load diffusion model...') | |
diffusion = create_diffusion(str(opts.test_sampling_steps)) | |
model = vit.DiT_XL_2( | |
input_size=opts.latent_size[0], | |
latent_dim=opts.latent_dim, | |
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask, | |
learn_sigma=True, | |
).to(device) | |
# ckpt_state_dict = torch.load(model_path)['model_state_dict'] | |
ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict'] | |
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False) | |
model = model.to(device) | |
model.eval() | |
print(missing_keys, extra_keys) | |
assert len(missing_keys) == 0 | |
vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict'] | |
print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}") | |
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False) | |
print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}") | |
print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) | |
print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
autoencoder = autoencoder.to(device) | |
autoencoder.eval() | |
print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}") | |
assert len(missing_keys) == 0 | |
# else: | |
# opts = HandDiffOpts() | |
# model_path = './finetune_epoch=5-step=130000.ckpt' | |
# sd_path = './sd-v1-4.ckpt' | |
# print('Load diffusion model...') | |
# diffusion = create_diffusion(str(opts.test_sampling_steps)) | |
# model = vit.DiT_XL_2( | |
# input_size=opts.latent_size[0], | |
# latent_dim=opts.latent_dim, | |
# in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask, | |
# learn_sigma=True, | |
# ).to(device) | |
# ckpt_state_dict = torch.load(model_path)['state_dict'] | |
# dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')} | |
# vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')} | |
# missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False) | |
# model.eval() | |
# assert len(missing_keys) == 0 and len(extra_keys) == 0 | |
# autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device) | |
# missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) | |
# autoencoder.eval() | |
# assert len(missing_keys) == 0 and len(extra_keys) == 0 | |
sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token) | |
sam_predictor = init_sam(ckpt_path=sam_path, device='cpu') | |
print("Mediapipe hand detector and SAM ready...") | |
mp_hands = mp.solutions.hands | |
hands = mp_hands.Hands( | |
static_image_mode=True, # Use False if image is part of a video stream | |
max_num_hands=2, # Maximum number of hands to detect | |
min_detection_confidence=0.1, | |
) | |
def get_ref_anno(ref): | |
if ref is None: | |
return ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
) | |
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False) | |
img = ref["composite"][..., :3] | |
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA) | |
keypts = np.zeros((42, 2)) | |
if REF_POSE_MASK: | |
mp_pose = hands.process(img) | |
detected = np.array([0, 0]) | |
start_idx = 0 | |
if mp_pose.multi_hand_landmarks: | |
# handedness is flipped assuming the input image is mirrored in MediaPipe | |
for hand_landmarks, handedness in zip( | |
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness | |
): | |
# actually right hand | |
if handedness.classification[0].label == "Left": | |
start_idx = 0 | |
detected[0] = 1 | |
# actually left hand | |
elif handedness.classification[0].label == "Right": | |
start_idx = 21 | |
detected[1] = 1 | |
for i, landmark in enumerate(hand_landmarks.landmark): | |
keypts[start_idx + i] = [ | |
landmark.x * opts.image_size[1], | |
landmark.y * opts.image_size[0], | |
] | |
sam_predictor.set_image(img) | |
l = keypts[:21].shape[0] | |
if keypts[0].sum() != 0 and keypts[21].sum() != 0: | |
input_point = np.array([keypts[0], keypts[21]]) | |
input_label = np.array([1, 1]) | |
elif keypts[0].sum() != 0: | |
input_point = np.array(keypts[:1]) | |
input_label = np.array([1]) | |
elif keypts[21].sum() != 0: | |
input_point = np.array(keypts[21:22]) | |
input_label = np.array([1]) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
hand_mask = masks[0] | |
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None]) | |
ref_pose = visualize_hand(keypts, masked_img) | |
else: | |
raise gr.Error("No hands detected in the reference image.") | |
else: | |
hand_mask = np.zeros_like(img[:,:, 0]) | |
ref_pose = np.zeros_like(img) | |
print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}") | |
def make_ref_cond( | |
img, | |
keypts, | |
hand_mask, | |
device="cuda", | |
target_size=(256, 256), | |
latent_size=(32, 32), | |
): | |
image_transform = Compose( | |
[ | |
ToTensor(), | |
Resize(target_size), | |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
image = image_transform(img) | |
kpts_valid = check_keypoints_validity(keypts, target_size) | |
heatmaps = torch.tensor( | |
keypoint_heatmap( | |
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0 | |
) | |
* kpts_valid[:, None, None], | |
dtype=torch.float, | |
)[None, ...] | |
mask = torch.tensor( | |
cv2.resize( | |
hand_mask.astype(int), | |
dsize=latent_size, | |
interpolation=cv2.INTER_NEAREST, | |
), | |
dtype=torch.float, | |
).unsqueeze(0)[None, ...] | |
return image[None, ...], heatmaps, mask | |
print(f"img.max(): {img.max()}, img.min(): {img.min()}") | |
image, heatmaps, mask = make_ref_cond( | |
img, | |
keypts, | |
hand_mask, | |
device="cuda", | |
target_size=opts.image_size, | |
latent_size=opts.latent_size, | |
) | |
print(f"image.max(): {image.max()}, image.min(): {image.min()}") | |
print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}") | |
print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}") | |
print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}") | |
print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}") | |
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample() | |
print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}") | |
if not REF_POSE_MASK: | |
heatmaps = torch.zeros_like(heatmaps) | |
mask = torch.zeros_like(mask) | |
print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}") | |
print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}") | |
ref_cond = torch.cat([latent, heatmaps, mask], 1) | |
print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}") | |
return img, ref_pose, ref_cond | |
def get_target_anno(target): | |
if target is None: | |
return ( | |
gr.State.update(value=None), | |
gr.Image.update(value=None), | |
gr.State.update(value=None), | |
gr.State.update(value=None), | |
) | |
pose_img = target["composite"][..., :3] | |
pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA) | |
# detect keypoints | |
mp_pose = hands.process(pose_img) | |
target_keypts = np.zeros((42, 2)) | |
detected = np.array([0, 0]) | |
start_idx = 0 | |
if mp_pose.multi_hand_landmarks: | |
# handedness is flipped assuming the input image is mirrored in MediaPipe | |
for hand_landmarks, handedness in zip( | |
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness | |
): | |
# actually right hand | |
if handedness.classification[0].label == "Left": | |
start_idx = 0 | |
detected[0] = 1 | |
# actually left hand | |
elif handedness.classification[0].label == "Right": | |
start_idx = 21 | |
detected[1] = 1 | |
for i, landmark in enumerate(hand_landmarks.landmark): | |
target_keypts[start_idx + i] = [ | |
landmark.x * opts.image_size[1], | |
landmark.y * opts.image_size[0], | |
] | |
target_pose = visualize_hand(target_keypts, pose_img) | |
kpts_valid = check_keypoints_validity(target_keypts, opts.image_size) | |
target_heatmaps = torch.tensor( | |
keypoint_heatmap( | |
scale_keypoint(target_keypts, opts.image_size, opts.latent_size), | |
opts.latent_size, | |
var=1.0, | |
) | |
* kpts_valid[:, None, None], | |
dtype=torch.float, | |
# device=device, | |
)[None, ...] | |
target_cond = torch.cat( | |
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1 | |
) | |
else: | |
raise gr.Error("No hands detected in the target image.") | |
return pose_img, target_pose, target_cond, target_keypts | |
def get_mask_inpaint(ref): | |
inpaint_mask = np.array(ref["layers"][0])[..., -1] | |
inpaint_mask = cv2.resize( | |
inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA | |
) | |
inpaint_mask = (inpaint_mask >= 128).astype(np.uint8) | |
return inpaint_mask | |
def visualize_ref(crop, brush): | |
if crop is None or brush is None: | |
return None | |
inpainted = brush["layers"][0][..., -1] | |
img = crop["background"][..., :3] | |
img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA) | |
mask = inpainted < 128 | |
# img = img.astype(np.int32) | |
# img[mask, :] = img[mask, :] - 50 | |
# img[np.any(img<0, axis=-1)]=0 | |
# img = img.astype(np.uint8) | |
img = mask_image(img, mask) | |
return img | |
def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData): | |
if keypoints is None: | |
keypoints = [[], []] | |
kps = np.zeros((42, 2)) | |
if side == "right": | |
if len(keypoints[0]) == 21: | |
gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.") | |
else: | |
keypoints[0].append(list(evt.index)) | |
len_kps = len(keypoints[0]) | |
kps[:len_kps] = np.array(keypoints[0]) | |
elif side == "left": | |
if len(keypoints[1]) == 21: | |
gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.") | |
else: | |
keypoints[1].append(list(evt.index)) | |
len_kps = len(keypoints[1]) | |
kps[21 : 21 + len_kps] = np.array(keypoints[1]) | |
vis_hand = visualize_hand(kps, img, side, len_kps) | |
return vis_hand, keypoints | |
def undo_kps(img, keypoints, side: Literal["right", "left"]): | |
if keypoints is None: | |
return img, None | |
kps = np.zeros((42, 2)) | |
if side == "right": | |
if len(keypoints[0]) == 0: | |
return img, keypoints | |
keypoints[0].pop() | |
len_kps = len(keypoints[0]) | |
kps[:len_kps] = np.array(keypoints[0]) | |
elif side == "left": | |
if len(keypoints[1]) == 0: | |
return img, keypoints | |
keypoints[1].pop() | |
len_kps = len(keypoints[1]) | |
kps[21 : 21 + len_kps] = np.array(keypoints[1]) | |
vis_hand = visualize_hand(kps, img, side, len_kps) | |
return vis_hand, keypoints | |
def reset_kps(img, keypoints, side: Literal["right", "left"]): | |
if keypoints is None: | |
return img, None | |
if side == "right": | |
keypoints[0] = [] | |
elif side == "left": | |
keypoints[1] = [] | |
return img, keypoints | |
def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg): | |
set_seed(seed) | |
z = torch.randn( | |
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), | |
device=device, | |
) | |
print(f"z.device: {z.device}") | |
target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device) | |
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device) | |
print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}") | |
print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}") | |
# novel view synthesis mode = off | |
nvs = torch.zeros(num_gen, dtype=torch.int, device=device) | |
z = torch.cat([z, z], 0) | |
model_kwargs = dict( | |
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]), | |
ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]), | |
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), | |
cfg_scale=cfg, | |
) | |
samples, _ = diffusion.p_sample_loop( | |
model.forward_with_cfg, | |
z.shape, | |
z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=device, | |
).chunk(2) | |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) | |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0) | |
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy()) | |
results = [] | |
results_pose = [] | |
for i in range(MAX_N): | |
if i < num_gen: | |
results.append(sampled_images[i]) | |
results_pose.append(visualize_hand(target_keypts, sampled_images[i])) | |
else: | |
results.append(placeholder) | |
results_pose.append(placeholder) | |
print(f"results[0].max(): {results[0].max()}") | |
return results, results_pose | |
# @spaces.GPU(duration=120) | |
def ready_sample(img_ori, inpaint_mask, keypts): | |
img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA) | |
sam_predictor.set_image(img) | |
if len(keypts[0]) == 0: | |
keypts[0] = np.zeros((21, 2)) | |
elif len(keypts[0]) == 21: | |
keypts[0] = np.array(keypts[0], dtype=np.float32) | |
else: | |
gr.Info("Number of right hand keypoints should be either 0 or 21.") | |
return None, None | |
if len(keypts[1]) == 0: | |
keypts[1] = np.zeros((21, 2)) | |
elif len(keypts[1]) == 21: | |
keypts[1] = np.array(keypts[1], dtype=np.float32) | |
else: | |
gr.Info("Number of left hand keypoints should be either 0 or 21.") | |
return None, None | |
keypts = np.concatenate(keypts, axis=0) | |
keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size) | |
# if keypts[0].sum() != 0 and keypts[21].sum() != 0: | |
# input_point = np.array([keypts[0], keypts[21]]) | |
# # input_point = keypts | |
# input_label = np.array([1, 1]) | |
# # input_label = np.ones_like(input_point[:, 0]) | |
# elif keypts[0].sum() != 0: | |
# input_point = np.array(keypts[:1]) | |
# # input_point = keypts[:21] | |
# input_label = np.array([1]) | |
# # input_label = np.ones_like(input_point[:21, 0]) | |
# elif keypts[21].sum() != 0: | |
# input_point = np.array(keypts[21:22]) | |
# # input_point = keypts[21:] | |
# input_label = np.array([1]) | |
# # input_label = np.ones_like(input_point[21:, 0]) | |
box_shift_ratio = 0.5 | |
box_size_factor = 1.2 | |
if keypts[0].sum() != 0 and keypts[21].sum() != 0: | |
input_point = np.array(keypts) | |
input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)]) | |
elif keypts[0].sum() != 0: | |
input_point = np.array(keypts[:21]) | |
input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)]) | |
elif keypts[21].sum() != 0: | |
input_point = np.array(keypts[21:]) | |
input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)]) | |
else: | |
raise ValueError( | |
"Something wrong. If no hand detected, it should not reach here." | |
) | |
input_label = np.ones_like(input_point[:, 0]).astype(np.int32) | |
box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio) | |
input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1) | |
masks, _, _ = sam_predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
box=input_box[None, :], | |
multimask_output=False, | |
) | |
hand_mask = masks[0] | |
inpaint_latent_mask = torch.tensor( | |
cv2.resize( | |
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST | |
), | |
dtype=torch.float, | |
# device=device, | |
).unsqueeze(0)[None, ...] | |
def make_ref_cond( | |
img, | |
keypts, | |
hand_mask, | |
device=device, | |
target_size=(256, 256), | |
latent_size=(32, 32), | |
): | |
image_transform = Compose( | |
[ | |
ToTensor(), | |
Resize(target_size), | |
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
image = image_transform(img) | |
kpts_valid = check_keypoints_validity(keypts, target_size) | |
heatmaps = torch.tensor( | |
keypoint_heatmap( | |
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0 | |
) | |
* kpts_valid[:, None, None], | |
dtype=torch.float, | |
# device=device, | |
)[None, ...] | |
mask = torch.tensor( | |
cv2.resize( | |
hand_mask.astype(int), | |
dsize=latent_size, | |
interpolation=cv2.INTER_NEAREST, | |
), | |
dtype=torch.float, | |
# device=device, | |
).unsqueeze(0)[None, ...] | |
return image[None, ...], heatmaps, mask | |
image, heatmaps, mask = make_ref_cond( | |
img, | |
keypts, | |
hand_mask * (1 - inpaint_mask), | |
device=device, | |
target_size=opts.image_size, | |
latent_size=opts.latent_size, | |
) | |
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample() | |
target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1) | |
ref_cond = torch.cat([latent, heatmaps, mask], 1) | |
ref_cond = torch.zeros_like(ref_cond) | |
img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST) | |
assert mask.max() == 1 | |
vis_mask32 = mask_image( | |
img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False | |
).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy() | |
assert np.unique(inpaint_mask).shape[0] <= 2 | |
assert hand_mask.dtype == bool | |
mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask) | |
vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype( | |
np.uint8 | |
) # 1 - mask256 | |
return ( | |
ref_cond, | |
target_cond, | |
latent, | |
inpaint_latent_mask, | |
keypts, | |
vis_mask32, | |
vis_mask256, | |
) | |
def switch_mask_size(radio): | |
if radio == "256x256": | |
out = (gr.update(visible=False), gr.update(visible=True)) | |
elif radio == "latent size (32x32)": | |
out = (gr.update(visible=True), gr.update(visible=False)) | |
return out | |
def sample_inpaint( | |
ref_cond, | |
target_cond, | |
latent, | |
inpaint_latent_mask, | |
keypts, | |
num_gen, | |
seed, | |
cfg, | |
quality, | |
): | |
set_seed(seed) | |
N = num_gen | |
jump_length = 10 | |
jump_n_sample = quality | |
cfg_scale = cfg | |
z = torch.randn( | |
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device | |
) | |
target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device) | |
ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device) | |
# novel view synthesis mode = off | |
nvs = torch.zeros(N, dtype=torch.int, device=device) | |
z = torch.cat([z, z], 0) | |
model_kwargs = dict( | |
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]), | |
ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]), | |
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]), | |
cfg_scale=cfg_scale, | |
) | |
samples, _ = diffusion.inpaint_p_sample_loop( | |
model.forward_with_cfg, | |
z.shape, | |
latent.to(z.device), | |
inpaint_latent_mask.to(z.device), | |
z, | |
clip_denoised=False, | |
model_kwargs=model_kwargs, | |
progress=True, | |
device=z.device, | |
jump_length=jump_length, | |
jump_n_sample=jump_n_sample, | |
).chunk(2) | |
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor) | |
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0) | |
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy()) | |
# visualize | |
results = [] | |
results_pose = [] | |
for i in range(FIX_MAX_N): | |
if i < num_gen: | |
results.append(sampled_images[i]) | |
results_pose.append(visualize_hand(keypts, sampled_images[i])) | |
else: | |
results.append(placeholder) | |
results_pose.append(placeholder) | |
return results, results_pose | |
def flip_hand( | |
img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None | |
): | |
if cond is None: # clear clicked | |
return None, None, None, None | |
img["composite"] = img["composite"][:, ::-1, :] | |
img["background"] = img["background"][:, ::-1, :] | |
img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]] | |
pose_img = pose_img[:, ::-1, :] | |
cond = cond.flip(-1) | |
if keypts is not None: # cond is target_cond | |
if keypts[:21, :].sum() != 0: | |
keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0] | |
# keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1] | |
if keypts[21:, :].sum() != 0: | |
keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0] | |
# keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1] | |
return img, pose_img, cond, keypts | |
def resize_to_full(img): | |
img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH)) | |
img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH)) | |
img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]] | |
return img | |
def clear_all(): | |
return ( | |
None, | |
None, | |
False, | |
None, | |
None, | |
False, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
1, | |
42, | |
3.0, | |
) | |
def fix_clear_all(): | |
return ( | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
None, | |
1, | |
# (0,0), | |
42, | |
3.0, | |
10, | |
) | |
def enable_component(image1, image2): | |
if image1 is None or image2 is None: | |
return gr.update(interactive=False) | |
if "background" in image1 and "layers" in image1 and "composite" in image1: | |
if ( | |
image1["background"].sum() == 0 | |
and (sum([im.sum() for im in image1["layers"]]) == 0) | |
and image1["composite"].sum() == 0 | |
): | |
return gr.update(interactive=False) | |
if "background" in image2 and "layers" in image2 and "composite" in image2: | |
if ( | |
image2["background"].sum() == 0 | |
and (sum([im.sum() for im in image2["layers"]]) == 0) | |
and image2["composite"].sum() == 0 | |
): | |
return gr.update(interactive=False) | |
return gr.update(interactive=True) | |
def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left): | |
if kpts is None: | |
kpts = [[], []] | |
if "Right hand" not in checkbox: | |
kpts[0] = [] | |
vis_right = img_clean | |
update_right = gr.update(visible=False) | |
update_r_info = gr.update(visible=False) | |
else: | |
vis_right = img_pose_right | |
update_right = gr.update(visible=True) | |
update_r_info = gr.update(visible=True) | |
if "Left hand" not in checkbox: | |
kpts[1] = [] | |
vis_left = img_clean | |
update_left = gr.update(visible=False) | |
update_l_info = gr.update(visible=False) | |
else: | |
vis_left = img_pose_left | |
update_left = gr.update(visible=True) | |
update_l_info = gr.update(visible=True) | |
return ( | |
kpts, | |
vis_right, | |
vis_left, | |
update_right, | |
update_right, | |
update_right, | |
update_left, | |
update_left, | |
update_left, | |
update_r_info, | |
update_l_info, | |
) | |
LENGTH = 480 | |
example_imgs = [ | |
[ | |
"sample_images/sample1.jpg", | |
], | |
[ | |
"sample_images/sample2.jpg", | |
], | |
[ | |
"sample_images/sample3.jpg", | |
], | |
[ | |
"sample_images/sample4.jpg", | |
], | |
[ | |
"sample_images/sample5.jpg", | |
], | |
[ | |
"sample_images/sample6.jpg", | |
], | |
[ | |
"sample_images/sample7.jpg", | |
], | |
[ | |
"sample_images/sample8.jpg", | |
], | |
[ | |
"sample_images/sample9.jpg", | |
], | |
[ | |
"sample_images/sample10.jpg", | |
], | |
[ | |
"sample_images/sample11.jpg", | |
], | |
["pose_images/pose1.jpg"], | |
["pose_images/pose2.jpg"], | |
["pose_images/pose3.jpg"], | |
["pose_images/pose4.jpg"], | |
["pose_images/pose5.jpg"], | |
["pose_images/pose6.jpg"], | |
["pose_images/pose7.jpg"], | |
["pose_images/pose8.jpg"], | |
] | |
fix_example_imgs = [ | |
["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"], | |
["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"], | |
["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"], | |
["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"], | |
["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"], | |
["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"], | |
["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"], | |
["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"], | |
["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"], | |
["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"], | |
["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"], | |
["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"], | |
["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"], | |
] | |
custom_css = """ | |
.gradio-container .examples img { | |
width: 240px !important; | |
height: 240px !important; | |
} | |
""" | |
_HEADER_ = ''' | |
<h1><b>FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation</b></h1> | |
<h2> | |
📝<a href='https://arxiv.org/abs/2412.02690' target='_blank'>Paper</a> | |
📢<a href='https://ivl.cs.brown.edu/research/foundhand.html' target='_blank'>Project</a> | |
</h2> | |
''' | |
_CITE_ = r""" | |
``` | |
@article{chen2024foundhand, | |
title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation}, | |
author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath}, | |
journal={arXiv preprint arXiv:2412.02690}, | |
year={2024} | |
} | |
``` | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
gr.Markdown(_HEADER_) | |
with gr.Tab("Edit Hand Poses"): | |
ref_img = gr.State(value=None) | |
ref_cond = gr.State(value=None) | |
keypts = gr.State(value=None) | |
target_img = gr.State(value=None) | |
target_cond = gr.State(value=None) | |
target_keypts = gr.State(value=None) | |
dump = gr.State(value=None) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Reference</p>""" | |
) | |
gr.Markdown("""<p style="text-align: center;"><br></p>""") | |
ref = gr.ImageEditor( | |
type="numpy", | |
label="Reference", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
brush=False, | |
layers=False, | |
crop_size="1:1", | |
) | |
ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False) | |
ref_pose = gr.Image( | |
type="numpy", | |
label="Reference Pose", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
) | |
ref_flip = gr.Checkbox( | |
value=False, label="Flip Handedness (Reference)", interactive=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">2. Target</p>""" | |
) | |
target = gr.ImageEditor( | |
type="numpy", | |
label="Target", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
brush=False, | |
layers=False, | |
crop_size="1:1", | |
) | |
target_finish_crop = gr.Button( | |
value="Finish Cropping", interactive=False | |
) | |
target_pose = gr.Image( | |
type="numpy", | |
label="Target Pose", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
) | |
target_flip = gr.Checkbox( | |
value=False, label="Flip Handedness (Target)", interactive=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">3. Result</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">Run is enabled after the images have been processed</p>""" | |
) | |
run = gr.Button(value="Run", interactive=False) | |
gr.Markdown( | |
"""<p style="text-align: center;">~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>""" | |
) | |
results = gr.Gallery( | |
type="numpy", | |
label="Results", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
results_pose = gr.Gallery( | |
type="numpy", | |
label="Results Pose", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
clear = gr.ClearButton() | |
with gr.Row(): | |
n_generation = gr.Slider( | |
label="Number of generations", | |
value=1, | |
minimum=1, | |
maximum=MAX_N, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
value=42, | |
minimum=0, | |
maximum=10000, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
cfg = gr.Slider( | |
label="Classifier free guidance scale", | |
value=2.5, | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
randomize=False, | |
interactive=True, | |
) | |
ref.change(enable_component, [ref, ref], ref_finish_crop) | |
ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond]) | |
ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip) | |
ref_flip.select( | |
flip_hand, [ref, ref_pose, ref_cond], [ref, ref_pose, ref_cond, dump] | |
) | |
target.change(enable_component, [target, target], target_finish_crop) | |
target_finish_crop.click( | |
get_target_anno, | |
[target], | |
[target_img, target_pose, target_cond, target_keypts], | |
) | |
target_pose.change(enable_component, [target_img, target_pose], target_flip) | |
target_flip.select( | |
flip_hand, | |
[target, target_pose, target_cond, target_keypts], | |
[target, target_pose, target_cond, target_keypts], | |
) | |
ref_pose.change(enable_component, [ref_pose, target_pose], run) | |
target_pose.change(enable_component, [ref_pose, target_pose], run) | |
run.click( | |
sample_diff, | |
[ref_cond, target_cond, target_keypts, n_generation, seed, cfg], | |
[results, results_pose], | |
) | |
clear.click( | |
clear_all, | |
[], | |
[ | |
ref, | |
ref_pose, | |
ref_flip, | |
target, | |
target_pose, | |
target_flip, | |
results, | |
results_pose, | |
ref_img, | |
ref_cond, | |
# mask, | |
target_img, | |
target_cond, | |
target_keypts, | |
n_generation, | |
seed, | |
cfg, | |
], | |
) | |
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""") | |
with gr.Tab("Reference"): | |
with gr.Row(): | |
gr.Examples(example_imgs, [ref], examples_per_page=20) | |
with gr.Tab("Target"): | |
with gr.Row(): | |
gr.Examples(example_imgs, [target], examples_per_page=20) | |
with gr.Tab("Fix Hands"): | |
fix_inpaint_mask = gr.State(value=None) | |
fix_original = gr.State(value=None) | |
fix_img = gr.State(value=None) | |
fix_kpts = gr.State(value=None) | |
fix_kpts_np = gr.State(value=None) | |
fix_ref_cond = gr.State(value=None) | |
fix_target_cond = gr.State(value=None) | |
fix_latent = gr.State(value=None) | |
fix_inpaint_latent = gr.State(value=None) | |
# fix_size_memory = gr.State(value=(0, 0)) | |
gr.Markdown("""<p style="text-align: center; font-size: 25px; font-weight: bold; ">⚠️ Note</p>""") | |
gr.Markdown("""<p>Fix Hands with A100 needs around 6 mins, which is beyond the ZeroGPU quota (5mins). Please either purchase additional gpus from Hugging Face or kindly use your own gpus with our opensourced code (We will opensource the code upon acceptance).</p>""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">Crop the image around the hand.<br>Then, brush area (e.g., wrong finger) that needs to be fixed.</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>""" | |
) | |
fix_crop = gr.ImageEditor( | |
type="numpy", | |
sources=["upload", "webcam", "clipboard"], | |
label="Image crop", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
layers=False, | |
crop_size="1:1", | |
brush=False, | |
image_mode="RGBA", | |
container=False, | |
) | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>""" | |
) | |
fix_ref = gr.ImageEditor( | |
type="numpy", | |
label="Image brush", | |
sources=(), | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
layers=False, | |
transforms=("brush"), | |
brush=gr.Brush( | |
colors=["rgb(255, 255, 255)"], default_size=20 | |
), # 204, 50, 50 | |
image_mode="RGBA", | |
container=False, | |
interactive=False, | |
) | |
fix_finish_crop = gr.Button( | |
value="Finish Croping & Brushing", interactive=False | |
) | |
gr.Markdown( | |
"""<p style="text-align: left; font-size: 20px; font-weight: bold; ">OpenPose keypoints convention</p>""" | |
) | |
fix_openpose = gr.Image( | |
value="openpose.png", | |
type="numpy", | |
label="OpenPose keypoints convention", | |
show_label=True, | |
height=LENGTH // 3 * 2, | |
width=LENGTH // 3 * 2, | |
interactive=False, | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\" on the bottom left.</p>""" | |
) | |
fix_checkbox = gr.CheckboxGroup( | |
["Right hand", "Left hand"], | |
# value=["Right hand", "Left hand"], | |
label="Hand side", | |
info="Which side this hand is? Could be both.", | |
interactive=False, | |
) | |
fix_kp_r_info = gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""", | |
visible=False, | |
) | |
fix_kp_right = gr.Image( | |
type="numpy", | |
label="Keypoint Selection (right hand)", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
visible=False, | |
sources=[], | |
) | |
with gr.Row(): | |
fix_undo_right = gr.Button( | |
value="Undo", interactive=False, visible=False | |
) | |
fix_reset_right = gr.Button( | |
value="Reset", interactive=False, visible=False | |
) | |
fix_kp_l_info = gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""", | |
visible=False | |
) | |
fix_kp_left = gr.Image( | |
type="numpy", | |
label="Keypoint Selection (left hand)", | |
show_label=True, | |
height=LENGTH, | |
width=LENGTH, | |
interactive=False, | |
visible=False, | |
sources=[], | |
) | |
with gr.Row(): | |
fix_undo_left = gr.Button( | |
value="Undo", interactive=False, visible=False | |
) | |
fix_reset_left = gr.Button( | |
value="Reset", interactive=False, visible=False | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>""" | |
) | |
gr.Markdown( | |
"""<p style="text-align: center;">In Fix Hands, not segmentation mask, but only inpaint mask is used.</p>""" | |
) | |
fix_ready = gr.Button(value="Ready", interactive=False) | |
fix_mask_size = gr.Radio( | |
["256x256", "latent size (32x32)"], | |
label="Visualized inpaint mask size", | |
interactive=False, | |
value="256x256", | |
) | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Visualized inpaint masks</p>""" | |
) | |
fix_vis_mask32 = gr.Image( | |
type="numpy", | |
label=f"Visualized {opts.latent_size} Inpaint Mask", | |
show_label=True, | |
height=opts.latent_size, | |
width=opts.latent_size, | |
interactive=False, | |
visible=False, | |
) | |
fix_vis_mask256 = gr.Image( | |
type="numpy", | |
label=f"Visualized {opts.image_size} Inpaint Mask", | |
visible=True, | |
show_label=True, | |
height=opts.image_size, | |
width=opts.image_size, | |
interactive=False, | |
) | |
with gr.Column(): | |
gr.Markdown( | |
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>""" | |
) | |
fix_run = gr.Button(value="Run", interactive=False) | |
gr.Markdown( | |
"""<p style="text-align: center;">>3min and ~24GB per generation</p>""" | |
) | |
fix_result = gr.Gallery( | |
type="numpy", | |
label="Results", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=FIX_MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
fix_result_pose = gr.Gallery( | |
type="numpy", | |
label="Results Pose", | |
show_label=True, | |
height=LENGTH, | |
min_width=LENGTH, | |
columns=FIX_MAX_N, | |
interactive=False, | |
preview=True, | |
) | |
fix_clear = gr.ClearButton() | |
gr.Markdown( | |
"[NOTE] Currently, Number of generation > 1 could lead to out-of-memory" | |
) | |
with gr.Row(): | |
fix_n_generation = gr.Slider( | |
label="Number of generations", | |
value=1, | |
minimum=1, | |
maximum=FIX_MAX_N, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_seed = gr.Slider( | |
label="Seed", | |
value=42, | |
minimum=0, | |
maximum=10000, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_cfg = gr.Slider( | |
label="Classifier free guidance scale", | |
value=3.0, | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_quality = gr.Slider( | |
label="Quality", | |
value=10, | |
minimum=1, | |
maximum=10, | |
step=1, | |
randomize=False, | |
interactive=True, | |
) | |
fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref) | |
fix_crop.change(resize_to_full, fix_crop, fix_ref) | |
fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop) | |
fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask]) | |
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right]) | |
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left]) | |
fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original]) | |
fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img]) | |
fix_img.change(lambda x: x, [fix_img], [fix_kp_right]) | |
fix_img.change(lambda x: x, [fix_img], [fix_kp_left]) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left | |
) | |
fix_inpaint_mask.change( | |
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready | |
) | |
# fix_inpaint_mask.change( | |
# enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run | |
# ) | |
fix_checkbox.select( | |
set_visible, | |
[fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left], | |
[ | |
fix_kpts, | |
fix_kp_right, | |
fix_kp_left, | |
fix_kp_right, | |
fix_undo_right, | |
fix_reset_right, | |
fix_kp_left, | |
fix_undo_left, | |
fix_reset_left, | |
fix_kp_r_info, | |
fix_kp_l_info, | |
], | |
) | |
fix_kp_right.select( | |
get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] | |
) | |
fix_undo_right.click( | |
undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] | |
) | |
fix_reset_right.click( | |
reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts] | |
) | |
fix_kp_left.select( | |
get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] | |
) | |
fix_undo_left.click( | |
undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] | |
) | |
fix_reset_left.click( | |
reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts] | |
) | |
# fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run]) | |
# fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose]) | |
fix_vis_mask32.change( | |
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run | |
) | |
fix_vis_mask32.change( | |
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size | |
) | |
fix_ready.click( | |
ready_sample, | |
[fix_original, fix_inpaint_mask, fix_kpts], | |
[ | |
fix_ref_cond, | |
fix_target_cond, | |
fix_latent, | |
fix_inpaint_latent, | |
fix_kpts_np, | |
fix_vis_mask32, | |
fix_vis_mask256, | |
], | |
) | |
fix_mask_size.select( | |
switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256] | |
) | |
fix_run.click( | |
sample_inpaint, | |
[ | |
fix_ref_cond, | |
fix_target_cond, | |
fix_latent, | |
fix_inpaint_latent, | |
fix_kpts_np, | |
fix_n_generation, | |
fix_seed, | |
fix_cfg, | |
fix_quality, | |
], | |
[fix_result, fix_result_pose], | |
) | |
fix_clear.click( | |
fix_clear_all, | |
[], | |
[ | |
fix_crop, | |
fix_ref, | |
fix_kp_right, | |
fix_kp_left, | |
fix_result, | |
fix_result_pose, | |
fix_inpaint_mask, | |
fix_original, | |
fix_img, | |
fix_vis_mask32, | |
fix_vis_mask256, | |
fix_kpts, | |
fix_kpts_np, | |
fix_ref_cond, | |
fix_target_cond, | |
fix_latent, | |
fix_inpaint_latent, | |
fix_n_generation, | |
# fix_size_memory, | |
fix_seed, | |
fix_cfg, | |
fix_quality, | |
], | |
) | |
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""") | |
fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False) | |
fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False) | |
with gr.Column(): | |
fix_example = gr.Examples( | |
fix_example_imgs, | |
# run_on_click=True, | |
# fn=parse_fix_example, | |
# inputs=[fix_dump_ex, fix_dump_ex_masked], | |
# outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask], | |
inputs=[fix_crop], | |
examples_per_page=20, | |
) | |
gr.Markdown("<h1>Citation</h1>") | |
gr.Markdown(_CITE_) | |
# print("Ready to launch..") | |
# _, _, shared_url = demo.queue().launch( | |
# share=True, server_name="0.0.0.0", server_port=7739 | |
# ) | |
demo.launch(share=True) | |