Chaerin5 commited on
Commit
49f816b
·
1 Parent(s): 76e0b86
README.md CHANGED
@@ -1,13 +1 @@
1
- ---
2
- title: FoundHand
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.9.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: FoundHand
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,1581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ import gradio as gr
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+ import mediapipe as mp
8
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
9
+ import vqvae
10
+ import vit
11
+ from typing import Literal
12
+ from diffusion import create_diffusion
13
+ from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
14
+ from segment_hoi import init_sam
15
+ from io import BytesIO
16
+ from PIL import Image
17
+ import random
18
+ from copy import deepcopy
19
+ from typing import Optional
20
+
21
+ MAX_N = 6
22
+ FIX_MAX_N = 6
23
+
24
+ placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
25
+ NEW_MODEL = True
26
+ MODEL_EPOCH = 6
27
+ REF_POSE_MASK = True
28
+
29
+ def set_seed(seed):
30
+ seed = int(seed)
31
+ torch.manual_seed(seed)
32
+ np.random.seed(seed)
33
+ torch.cuda.manual_seed_all(seed)
34
+ random.seed(seed)
35
+
36
+
37
+ def remove_prefix(text, prefix):
38
+ if text.startswith(prefix):
39
+ return text[len(prefix) :]
40
+ return text
41
+
42
+
43
+ def unnormalize(x):
44
+ return (((x + 1) / 2) * 255).astype(np.uint8)
45
+
46
+
47
+ def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
48
+ # Define the connections between joints for drawing lines and their corresponding colors
49
+ connections = [
50
+ ((0, 1), "red"),
51
+ ((1, 2), "green"),
52
+ ((2, 3), "blue"),
53
+ ((3, 4), "purple"),
54
+ ((0, 5), "orange"),
55
+ ((5, 6), "pink"),
56
+ ((6, 7), "brown"),
57
+ ((7, 8), "cyan"),
58
+ ((0, 9), "yellow"),
59
+ ((9, 10), "magenta"),
60
+ ((10, 11), "lime"),
61
+ ((11, 12), "indigo"),
62
+ ((0, 13), "olive"),
63
+ ((13, 14), "teal"),
64
+ ((14, 15), "navy"),
65
+ ((15, 16), "gray"),
66
+ ((0, 17), "lavender"),
67
+ ((17, 18), "silver"),
68
+ ((18, 19), "maroon"),
69
+ ((19, 20), "fuchsia"),
70
+ ]
71
+ H, W, C = img.shape
72
+
73
+ # Create a figure and axis
74
+ plt.figure()
75
+ ax = plt.gca()
76
+ # Plot joints as points
77
+ ax.imshow(img)
78
+ start_is = []
79
+ if "right" in side:
80
+ start_is.append(0)
81
+ if "left" in side:
82
+ start_is.append(21)
83
+ for start_i in start_is:
84
+ joints = all_joints[start_i : start_i + n_avail_joints]
85
+ if len(joints) == 1:
86
+ ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
87
+ else:
88
+ for connection, color in connections[: len(joints) - 1]:
89
+ joint1 = joints[connection[0]]
90
+ joint2 = joints[connection[1]]
91
+ ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
92
+
93
+ ax.set_xlim([0, W])
94
+ ax.set_ylim([0, H])
95
+ ax.grid(False)
96
+ ax.set_axis_off()
97
+ ax.invert_yaxis()
98
+ # plt.subplots_adjust(wspace=0.01)
99
+ # plt.show()
100
+ buf = BytesIO()
101
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
102
+ plt.close()
103
+
104
+ # Convert BytesIO object to numpy array
105
+ buf.seek(0)
106
+ img_pil = Image.open(buf)
107
+ img_pil = img_pil.resize((H, W))
108
+ numpy_img = np.array(img_pil)
109
+
110
+ return numpy_img
111
+
112
+
113
+ def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
114
+ """Overlay mask on image for visualization purpose.
115
+ Args:
116
+ image (H, W, 3) or (H, W): input image
117
+ mask (H, W): mask to be overlaid
118
+ color: the color of overlaid mask
119
+ alpha: the transparency of the mask
120
+ """
121
+ out = deepcopy(image)
122
+ img = deepcopy(image)
123
+ img[mask == 1] = color
124
+ if transparent:
125
+ out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
126
+ else:
127
+ out = img
128
+ return out
129
+
130
+
131
+ def scale_keypoint(keypoint, original_size, target_size):
132
+ """Scale a keypoint based on the resizing of the image."""
133
+ keypoint_copy = keypoint.copy()
134
+ keypoint_copy[:, 0] *= target_size[0] / original_size[0]
135
+ keypoint_copy[:, 1] *= target_size[1] / original_size[1]
136
+ return keypoint_copy
137
+
138
+
139
+ print("Configure...")
140
+
141
+
142
+ @dataclass
143
+ class HandDiffOpts:
144
+ run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
145
+ sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
146
+ log_dir: str = "/users/kchen157/scratch/log"
147
+ data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
148
+ image_size: tuple = (256, 256)
149
+ latent_size: tuple = (32, 32)
150
+ latent_dim: int = 4
151
+ mask_bg: bool = False
152
+ kpts_form: str = "heatmap"
153
+ n_keypoints: int = 42
154
+ n_mask: int = 1
155
+ noise_steps: int = 1000
156
+ test_sampling_steps: int = 250
157
+ ddim_steps: int = 100
158
+ ddim_discretize: str = "uniform"
159
+ ddim_eta: float = 0.0
160
+ beta_start: float = 8.5e-4
161
+ beta_end: float = 0.012
162
+ latent_scaling_factor: float = 0.18215
163
+ cfg_pose: float = 5.0
164
+ cfg_appearance: float = 3.5
165
+ batch_size: int = 25
166
+ lr: float = 1e-5
167
+ max_epochs: int = 500
168
+ log_every_n_steps: int = 100
169
+ limit_val_batches: int = 1
170
+ n_gpu: int = 8
171
+ num_nodes: int = 1
172
+ precision: str = "16-mixed"
173
+ profiler: str = "simple"
174
+ swa_epoch_start: int = 10
175
+ swa_lrs: float = 1e-3
176
+ num_workers: int = 10
177
+ n_val_samples: int = 4
178
+
179
+ if not torch.cuda.is_available():
180
+ raise ValueError("No GPU")
181
+
182
+ # load models
183
+ if NEW_MODEL:
184
+ opts = HandDiffOpts()
185
+ if MODEL_EPOCH == 7:
186
+ model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
187
+ elif MODEL_EPOCH == 6:
188
+ # model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
189
+ model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt")
190
+ elif MODEL_EPOCH == 4:
191
+ model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
192
+ elif MODEL_EPOCH == 10:
193
+ model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
194
+ else:
195
+ raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
196
+ vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
197
+ # sd_path = './sd-v1-4.ckpt'
198
+ print('Load diffusion model...')
199
+ diffusion = create_diffusion(str(opts.test_sampling_steps))
200
+ model = vit.DiT_XL_2(
201
+ input_size=opts.latent_size[0],
202
+ latent_dim=opts.latent_dim,
203
+ in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
204
+ learn_sigma=True,
205
+ ).cuda()
206
+ # ckpt_state_dict = torch.load(model_path)['model_state_dict']
207
+ ckpt_state_dict = torch.load(model_path, map_location=torch.device('cuda'))['ema_state_dict']
208
+ missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
209
+ model.eval()
210
+ print(missing_keys, extra_keys)
211
+ assert len(missing_keys) == 0
212
+ vae_state_dict = torch.load(vae_path)['state_dict']
213
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
214
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
215
+ autoencoder.eval()
216
+ assert len(missing_keys) == 0
217
+ else:
218
+ opts = HandDiffOpts()
219
+ model_path = './finetune_epoch=5-step=130000.ckpt'
220
+ sd_path = './sd-v1-4.ckpt'
221
+ print('Load diffusion model...')
222
+ diffusion = create_diffusion(str(opts.test_sampling_steps))
223
+ model = vit.DiT_XL_2(
224
+ input_size=opts.latent_size[0],
225
+ latent_dim=opts.latent_dim,
226
+ in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
227
+ learn_sigma=True,
228
+ ).cuda()
229
+ ckpt_state_dict = torch.load(model_path)['state_dict']
230
+ dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
231
+ vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
232
+ missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
233
+ model.eval()
234
+ assert len(missing_keys) == 0 and len(extra_keys) == 0
235
+ autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
236
+ missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
237
+ autoencoder.eval()
238
+ assert len(missing_keys) == 0 and len(extra_keys) == 0
239
+ sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth")
240
+
241
+
242
+ print("Mediapipe hand detector and SAM ready...")
243
+ mp_hands = mp.solutions.hands
244
+ hands = mp_hands.Hands(
245
+ static_image_mode=True, # Use False if image is part of a video stream
246
+ max_num_hands=2, # Maximum number of hands to detect
247
+ min_detection_confidence=0.1,
248
+ )
249
+
250
+
251
+ def get_ref_anno(ref):
252
+ if ref is None:
253
+ return (
254
+ None,
255
+ None,
256
+ None,
257
+ None,
258
+ None,
259
+ )
260
+ img = ref["composite"][..., :3]
261
+ img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
262
+ keypts = np.zeros((42, 2))
263
+ if REF_POSE_MASK:
264
+ mp_pose = hands.process(img)
265
+ detected = np.array([0, 0])
266
+ start_idx = 0
267
+ if mp_pose.multi_hand_landmarks:
268
+ # handedness is flipped assuming the input image is mirrored in MediaPipe
269
+ for hand_landmarks, handedness in zip(
270
+ mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
271
+ ):
272
+ # actually right hand
273
+ if handedness.classification[0].label == "Left":
274
+ start_idx = 0
275
+ detected[0] = 1
276
+ # actually left hand
277
+ elif handedness.classification[0].label == "Right":
278
+ start_idx = 21
279
+ detected[1] = 1
280
+ for i, landmark in enumerate(hand_landmarks.landmark):
281
+ keypts[start_idx + i] = [
282
+ landmark.x * opts.image_size[1],
283
+ landmark.y * opts.image_size[0],
284
+ ]
285
+
286
+ sam_predictor.set_image(img)
287
+ l = keypts[:21].shape[0]
288
+ if keypts[0].sum() != 0 and keypts[21].sum() != 0:
289
+ input_point = np.array([keypts[0], keypts[21]])
290
+ input_label = np.array([1, 1])
291
+ elif keypts[0].sum() != 0:
292
+ input_point = np.array(keypts[:1])
293
+ input_label = np.array([1])
294
+ elif keypts[21].sum() != 0:
295
+ input_point = np.array(keypts[21:22])
296
+ input_label = np.array([1])
297
+ masks, _, _ = sam_predictor.predict(
298
+ point_coords=input_point,
299
+ point_labels=input_label,
300
+ multimask_output=False,
301
+ )
302
+ hand_mask = masks[0]
303
+ masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
304
+ ref_pose = visualize_hand(keypts, masked_img)
305
+ else:
306
+ raise gr.Error("No hands detected in the reference image.")
307
+ else:
308
+ hand_mask = np.zeros_like(img[:,:, 0])
309
+ ref_pose = np.zeros_like(img)
310
+
311
+ def make_ref_cond(
312
+ img,
313
+ keypts,
314
+ hand_mask,
315
+ device="cuda",
316
+ target_size=(256, 256),
317
+ latent_size=(32, 32),
318
+ ):
319
+ image_transform = Compose(
320
+ [
321
+ ToTensor(),
322
+ Resize(target_size),
323
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
324
+ ]
325
+ )
326
+ image = image_transform(img).to(device)
327
+ kpts_valid = check_keypoints_validity(keypts, target_size)
328
+ heatmaps = torch.tensor(
329
+ keypoint_heatmap(
330
+ scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
331
+ )
332
+ * kpts_valid[:, None, None],
333
+ dtype=torch.float,
334
+ device=device,
335
+ )[None, ...]
336
+ mask = torch.tensor(
337
+ cv2.resize(
338
+ hand_mask.astype(int),
339
+ dsize=latent_size,
340
+ interpolation=cv2.INTER_NEAREST,
341
+ ),
342
+ dtype=torch.float,
343
+ device=device,
344
+ ).unsqueeze(0)[None, ...]
345
+ return image[None, ...], heatmaps, mask
346
+
347
+ image, heatmaps, mask = make_ref_cond(
348
+ img,
349
+ keypts,
350
+ hand_mask,
351
+ device="cuda",
352
+ target_size=opts.image_size,
353
+ latent_size=opts.latent_size,
354
+ )
355
+ latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
356
+ if not REF_POSE_MASK:
357
+ heatmaps = torch.zeros_like(heatmaps)
358
+ mask = torch.zeros_like(mask)
359
+ ref_cond = torch.cat([latent, heatmaps, mask], 1)
360
+
361
+ return img, ref_pose, ref_cond
362
+
363
+
364
+ def get_target_anno(target):
365
+ if target is None:
366
+ return (
367
+ gr.State.update(value=None),
368
+ gr.Image.update(value=None),
369
+ gr.State.update(value=None),
370
+ gr.State.update(value=None),
371
+ )
372
+ pose_img = target["composite"][..., :3]
373
+ pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
374
+ # detect keypoints
375
+ mp_pose = hands.process(pose_img)
376
+ target_keypts = np.zeros((42, 2))
377
+ detected = np.array([0, 0])
378
+ start_idx = 0
379
+ if mp_pose.multi_hand_landmarks:
380
+ # handedness is flipped assuming the input image is mirrored in MediaPipe
381
+ for hand_landmarks, handedness in zip(
382
+ mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
383
+ ):
384
+ # actually right hand
385
+ if handedness.classification[0].label == "Left":
386
+ start_idx = 0
387
+ detected[0] = 1
388
+ # actually left hand
389
+ elif handedness.classification[0].label == "Right":
390
+ start_idx = 21
391
+ detected[1] = 1
392
+ for i, landmark in enumerate(hand_landmarks.landmark):
393
+ target_keypts[start_idx + i] = [
394
+ landmark.x * opts.image_size[1],
395
+ landmark.y * opts.image_size[0],
396
+ ]
397
+
398
+ target_pose = visualize_hand(target_keypts, pose_img)
399
+ kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
400
+ target_heatmaps = torch.tensor(
401
+ keypoint_heatmap(
402
+ scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
403
+ opts.latent_size,
404
+ var=1.0,
405
+ )
406
+ * kpts_valid[:, None, None],
407
+ dtype=torch.float,
408
+ device="cuda",
409
+ )[None, ...]
410
+ target_cond = torch.cat(
411
+ [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
412
+ )
413
+ else:
414
+ raise gr.Error("No hands detected in the target image.")
415
+
416
+ return pose_img, target_pose, target_cond, target_keypts
417
+
418
+
419
+ # def draw_grid(ref):
420
+ # if ref is None or ref["composite"] is None: # or len(ref["layers"])==0:
421
+ # return ref
422
+
423
+ # # if len(ref["layers"]) == 1:
424
+ # # need_draw = True
425
+ # # # elif ref["composite"].shape[0] != size_memory[0] or ref["composite"].shape[1] != size_memory[1]:
426
+ # # # need_draw = True
427
+ # # else:
428
+ # # need_draw = False
429
+
430
+ # # size_memory = ref["composite"].shape[0], ref["composite"].shape[1]
431
+ # # if not need_draw:
432
+ # # return size_memory, ref
433
+
434
+ # h, w = ref["composite"].shape[:2]
435
+ # grid_h, grid_w = h // 32, w // 32
436
+ # # grid = np.zeros((h, w, 4), dtype=np.uint8)
437
+ # for i in range(1, grid_h):
438
+ # ref["composite"][i * 32, :, :3] = 255 # 0.5 * ref["composite"][i * 32, :, :3] +
439
+ # for i in range(1, grid_w):
440
+ # ref["composite"][:, i * 32, :3] = 255 # 0.5 * ref["composite"][:, i * 32, :3] +
441
+ # # if len(ref["layers"]) == 1:
442
+ # # ref["layers"].append(grid)
443
+ # # else:
444
+ # # ref["layers"][1] = grid
445
+ # return ref["composite"]
446
+
447
+
448
+ def get_mask_inpaint(ref):
449
+ inpaint_mask = np.array(ref["layers"][0])[..., -1]
450
+ inpaint_mask = cv2.resize(
451
+ inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
452
+ )
453
+ inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
454
+ return inpaint_mask
455
+
456
+
457
+ def visualize_ref(crop, brush):
458
+ if crop is None or brush is None:
459
+ return None
460
+ inpainted = brush["layers"][0][..., -1]
461
+ img = crop["background"][..., :3]
462
+ img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
463
+ mask = inpainted < 128
464
+ # img = img.astype(np.int32)
465
+ # img[mask, :] = img[mask, :] - 50
466
+ # img[np.any(img<0, axis=-1)]=0
467
+ # img = img.astype(np.uint8)
468
+ img = mask_image(img, mask)
469
+ return img
470
+
471
+
472
+ def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
473
+ if keypoints is None:
474
+ keypoints = [[], []]
475
+ kps = np.zeros((42, 2))
476
+ if side == "right":
477
+ if len(keypoints[0]) == 21:
478
+ gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
479
+ else:
480
+ keypoints[0].append(list(evt.index))
481
+ len_kps = len(keypoints[0])
482
+ kps[:len_kps] = np.array(keypoints[0])
483
+ elif side == "left":
484
+ if len(keypoints[1]) == 21:
485
+ gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
486
+ else:
487
+ keypoints[1].append(list(evt.index))
488
+ len_kps = len(keypoints[1])
489
+ kps[21 : 21 + len_kps] = np.array(keypoints[1])
490
+ vis_hand = visualize_hand(kps, img, side, len_kps)
491
+ return vis_hand, keypoints
492
+
493
+
494
+ def undo_kps(img, keypoints, side: Literal["right", "left"]):
495
+ if keypoints is None:
496
+ return img, None
497
+ kps = np.zeros((42, 2))
498
+ if side == "right":
499
+ if len(keypoints[0]) == 0:
500
+ return img, keypoints
501
+ keypoints[0].pop()
502
+ len_kps = len(keypoints[0])
503
+ kps[:len_kps] = np.array(keypoints[0])
504
+ elif side == "left":
505
+ if len(keypoints[1]) == 0:
506
+ return img, keypoints
507
+ keypoints[1].pop()
508
+ len_kps = len(keypoints[1])
509
+ kps[21 : 21 + len_kps] = np.array(keypoints[1])
510
+ vis_hand = visualize_hand(kps, img, side, len_kps)
511
+ return vis_hand, keypoints
512
+
513
+
514
+ def reset_kps(img, keypoints, side: Literal["right", "left"]):
515
+ if keypoints is None:
516
+ return img, None
517
+ if side == "right":
518
+ keypoints[0] = []
519
+ elif side == "left":
520
+ keypoints[1] = []
521
+ return img, keypoints
522
+
523
+
524
+ def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
525
+ set_seed(seed)
526
+ z = torch.randn(
527
+ (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
528
+ device="cuda",
529
+ )
530
+ target_cond = target_cond.repeat(num_gen, 1, 1, 1)
531
+ ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
532
+ # novel view synthesis mode = off
533
+ nvs = torch.zeros(num_gen, dtype=torch.int, device="cuda")
534
+ z = torch.cat([z, z], 0)
535
+ model_kwargs = dict(
536
+ target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
537
+ ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
538
+ nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
539
+ cfg_scale=cfg,
540
+ )
541
+
542
+ samples, _ = diffusion.p_sample_loop(
543
+ model.forward_with_cfg,
544
+ z.shape,
545
+ z,
546
+ clip_denoised=False,
547
+ model_kwargs=model_kwargs,
548
+ progress=True,
549
+ device="cuda",
550
+ ).chunk(2)
551
+ sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
552
+ sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
553
+ sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
554
+
555
+ results = []
556
+ results_pose = []
557
+ for i in range(MAX_N):
558
+ if i < num_gen:
559
+ results.append(sampled_images[i])
560
+ results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
561
+ else:
562
+ results.append(placeholder)
563
+ results_pose.append(placeholder)
564
+ return results, results_pose
565
+
566
+
567
+ def ready_sample(img_ori, inpaint_mask, keypts):
568
+ img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
569
+ sam_predictor.set_image(img)
570
+ if len(keypts[0]) == 0:
571
+ keypts[0] = np.zeros((21, 2))
572
+ elif len(keypts[0]) == 21:
573
+ keypts[0] = np.array(keypts[0], dtype=np.float32)
574
+ else:
575
+ gr.Info("Number of right hand keypoints should be either 0 or 21.")
576
+ return None, None
577
+
578
+ if len(keypts[1]) == 0:
579
+ keypts[1] = np.zeros((21, 2))
580
+ elif len(keypts[1]) == 21:
581
+ keypts[1] = np.array(keypts[1], dtype=np.float32)
582
+ else:
583
+ gr.Info("Number of left hand keypoints should be either 0 or 21.")
584
+ return None, None
585
+
586
+ keypts = np.concatenate(keypts, axis=0)
587
+ keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
588
+ # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
589
+ # input_point = np.array([keypts[0], keypts[21]])
590
+ # # input_point = keypts
591
+ # input_label = np.array([1, 1])
592
+ # # input_label = np.ones_like(input_point[:, 0])
593
+ # elif keypts[0].sum() != 0:
594
+ # input_point = np.array(keypts[:1])
595
+ # # input_point = keypts[:21]
596
+ # input_label = np.array([1])
597
+ # # input_label = np.ones_like(input_point[:21, 0])
598
+ # elif keypts[21].sum() != 0:
599
+ # input_point = np.array(keypts[21:22])
600
+ # # input_point = keypts[21:]
601
+ # input_label = np.array([1])
602
+ # # input_label = np.ones_like(input_point[21:, 0])
603
+
604
+ box_shift_ratio = 0.5
605
+ box_size_factor = 1.2
606
+
607
+ if keypts[0].sum() != 0 and keypts[21].sum() != 0:
608
+ input_point = np.array(keypts)
609
+ input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
610
+ elif keypts[0].sum() != 0:
611
+ input_point = np.array(keypts[:21])
612
+ input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
613
+ elif keypts[21].sum() != 0:
614
+ input_point = np.array(keypts[21:])
615
+ input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
616
+ else:
617
+ raise ValueError(
618
+ "Something wrong. If no hand detected, it should not reach here."
619
+ )
620
+
621
+ input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
622
+ box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
623
+ input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
624
+
625
+ masks, _, _ = sam_predictor.predict(
626
+ point_coords=input_point,
627
+ point_labels=input_label,
628
+ box=input_box[None, :],
629
+ multimask_output=False,
630
+ )
631
+ hand_mask = masks[0]
632
+
633
+ inpaint_latent_mask = torch.tensor(
634
+ cv2.resize(
635
+ inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
636
+ ),
637
+ dtype=torch.float,
638
+ device="cuda",
639
+ ).unsqueeze(0)[None, ...]
640
+
641
+ def make_ref_cond(
642
+ img,
643
+ keypts,
644
+ hand_mask,
645
+ device="cuda",
646
+ target_size=(256, 256),
647
+ latent_size=(32, 32),
648
+ ):
649
+ image_transform = Compose(
650
+ [
651
+ ToTensor(),
652
+ Resize(target_size),
653
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
654
+ ]
655
+ )
656
+ image = image_transform(img).to(device)
657
+ kpts_valid = check_keypoints_validity(keypts, target_size)
658
+ heatmaps = torch.tensor(
659
+ keypoint_heatmap(
660
+ scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
661
+ )
662
+ * kpts_valid[:, None, None],
663
+ dtype=torch.float,
664
+ device=device,
665
+ )[None, ...]
666
+ mask = torch.tensor(
667
+ cv2.resize(
668
+ hand_mask.astype(int),
669
+ dsize=latent_size,
670
+ interpolation=cv2.INTER_NEAREST,
671
+ ),
672
+ dtype=torch.float,
673
+ device=device,
674
+ ).unsqueeze(0)[None, ...]
675
+ return image[None, ...], heatmaps, mask
676
+
677
+ image, heatmaps, mask = make_ref_cond(
678
+ img,
679
+ keypts,
680
+ hand_mask * (1 - inpaint_mask),
681
+ device="cuda",
682
+ target_size=opts.image_size,
683
+ latent_size=opts.latent_size,
684
+ )
685
+ latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
686
+ target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
687
+ ref_cond = torch.cat([latent, heatmaps, mask], 1)
688
+ ref_cond = torch.zeros_like(ref_cond)
689
+
690
+ img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
691
+ assert mask.max() == 1
692
+ vis_mask32 = mask_image(
693
+ img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
694
+ ).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()
695
+
696
+ assert np.unique(inpaint_mask).shape[0] <= 2
697
+ assert hand_mask.dtype == bool
698
+ mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
699
+ vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
700
+ np.uint8
701
+ ) # 1 - mask256
702
+
703
+ return (
704
+ ref_cond,
705
+ target_cond,
706
+ latent,
707
+ inpaint_latent_mask,
708
+ keypts,
709
+ vis_mask32,
710
+ vis_mask256,
711
+ )
712
+
713
+
714
+ def switch_mask_size(radio):
715
+ if radio == "256x256":
716
+ out = (gr.update(visible=False), gr.update(visible=True))
717
+ elif radio == "latent size (32x32)":
718
+ out = (gr.update(visible=True), gr.update(visible=False))
719
+ return out
720
+
721
+
722
+ def sample_inpaint(
723
+ ref_cond,
724
+ target_cond,
725
+ latent,
726
+ inpaint_latent_mask,
727
+ keypts,
728
+ num_gen,
729
+ seed,
730
+ cfg,
731
+ quality,
732
+ ):
733
+ set_seed(seed)
734
+ N = num_gen
735
+ jump_length = 10
736
+ jump_n_sample = quality
737
+ cfg_scale = cfg
738
+ z = torch.randn(
739
+ (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device="cuda"
740
+ )
741
+ target_cond_N = target_cond.repeat(N, 1, 1, 1)
742
+ ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
743
+ # novel view synthesis mode = off
744
+ nvs = torch.zeros(N, dtype=torch.int, device="cuda")
745
+ z = torch.cat([z, z], 0)
746
+ model_kwargs = dict(
747
+ target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
748
+ ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
749
+ nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
750
+ cfg_scale=cfg_scale,
751
+ )
752
+
753
+ samples, _ = diffusion.inpaint_p_sample_loop(
754
+ model.forward_with_cfg,
755
+ z.shape,
756
+ latent,
757
+ inpaint_latent_mask,
758
+ z,
759
+ clip_denoised=False,
760
+ model_kwargs=model_kwargs,
761
+ progress=True,
762
+ device="cuda",
763
+ jump_length=jump_length,
764
+ jump_n_sample=jump_n_sample,
765
+ ).chunk(2)
766
+ sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
767
+ sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
768
+ sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
769
+
770
+ # visualize
771
+ results = []
772
+ results_pose = []
773
+ for i in range(FIX_MAX_N):
774
+ if i < num_gen:
775
+ results.append(sampled_images[i])
776
+ results_pose.append(visualize_hand(keypts, sampled_images[i]))
777
+ else:
778
+ results.append(placeholder)
779
+ results_pose.append(placeholder)
780
+ return results, results_pose
781
+
782
+
783
+ def flip_hand(
784
+ img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None
785
+ ):
786
+ if cond is None: # clear clicked
787
+ return None, None, None, None
788
+ img["composite"] = img["composite"][:, ::-1, :]
789
+ img["background"] = img["background"][:, ::-1, :]
790
+ img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
791
+ pose_img = pose_img[:, ::-1, :]
792
+ cond = cond.flip(-1)
793
+ if keypts is not None: # cond is target_cond
794
+ if keypts[:21, :].sum() != 0:
795
+ keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
796
+ # keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
797
+ if keypts[21:, :].sum() != 0:
798
+ keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
799
+ # keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
800
+ return img, pose_img, cond, keypts
801
+
802
+
803
+ def resize_to_full(img):
804
+ img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
805
+ img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
806
+ img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
807
+ return img
808
+
809
+
810
+ def clear_all():
811
+ return (
812
+ None,
813
+ None,
814
+ False,
815
+ None,
816
+ None,
817
+ False,
818
+ None,
819
+ None,
820
+ None,
821
+ None,
822
+ None,
823
+ None,
824
+ None,
825
+ 1,
826
+ 42,
827
+ 3.0,
828
+ )
829
+
830
+
831
+ def fix_clear_all():
832
+ return (
833
+ None,
834
+ None,
835
+ None,
836
+ None,
837
+ None,
838
+ None,
839
+ None,
840
+ None,
841
+ None,
842
+ None,
843
+ None,
844
+ None,
845
+ None,
846
+ None,
847
+ None,
848
+ None,
849
+ None,
850
+ 1,
851
+ # (0,0),
852
+ 42,
853
+ 3.0,
854
+ 10,
855
+ )
856
+
857
+
858
+ def enable_component(image1, image2):
859
+ if image1 is None or image2 is None:
860
+ return gr.update(interactive=False)
861
+ if "background" in image1 and "layers" in image1 and "composite" in image1:
862
+ if (
863
+ image1["background"].sum() == 0
864
+ and (sum([im.sum() for im in image1["layers"]]) == 0)
865
+ and image1["composite"].sum() == 0
866
+ ):
867
+ return gr.update(interactive=False)
868
+ if "background" in image2 and "layers" in image2 and "composite" in image2:
869
+ if (
870
+ image2["background"].sum() == 0
871
+ and (sum([im.sum() for im in image2["layers"]]) == 0)
872
+ and image2["composite"].sum() == 0
873
+ ):
874
+ return gr.update(interactive=False)
875
+ return gr.update(interactive=True)
876
+
877
+
878
+ def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left):
879
+ if kpts is None:
880
+ kpts = [[], []]
881
+ if "Right hand" not in checkbox:
882
+ kpts[0] = []
883
+ vis_right = img_clean
884
+ update_right = gr.update(visible=False)
885
+ update_r_info = gr.update(visible=False)
886
+ else:
887
+ vis_right = img_pose_right
888
+ update_right = gr.update(visible=True)
889
+ update_r_info = gr.update(visible=True)
890
+
891
+ if "Left hand" not in checkbox:
892
+ kpts[1] = []
893
+ vis_left = img_clean
894
+ update_left = gr.update(visible=False)
895
+ update_l_info = gr.update(visible=False)
896
+ else:
897
+ vis_left = img_pose_left
898
+ update_left = gr.update(visible=True)
899
+ update_l_info = gr.update(visible=True)
900
+
901
+ return (
902
+ kpts,
903
+ vis_right,
904
+ vis_left,
905
+ update_right,
906
+ update_right,
907
+ update_right,
908
+ update_left,
909
+ update_left,
910
+ update_left,
911
+ update_r_info,
912
+ update_l_info,
913
+ )
914
+
915
+
916
+ # def parse_fix_example(ex_img, ex_masked):
917
+ # original_img = ex_img
918
+ # # ex_img = cv2.resize(ex_img, (LENGTH, LENGTH), interpolation=cv2.INTER_AREA)
919
+ # # ex_masked = cv2.resize(ex_masked, (LENGTH, LENGTH), interpolation=cv2.INTER_AREA)
920
+ # inpaint_mask = np.all(ex_masked > 250, axis=-1).astype(np.uint8)
921
+ # layer = np.ones_like(ex_img) * 255
922
+ # layer = np.concatenate([layer, np.zeros_like(ex_img[..., 0:1])], axis=-1)
923
+ # layer[inpaint_mask == 1, 3] = 255
924
+ # ref_value = {
925
+ # "composite": ex_masked,
926
+ # "background": ex_img,
927
+ # "layers": [layer],
928
+ # }
929
+ # inpaint_mask = cv2.resize(
930
+ # inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
931
+ # )
932
+ # kp_img = visualize_ref(ref_value)
933
+ # return (
934
+ # original_img,
935
+ # gr.update(value=ref_value),
936
+ # kp_img,
937
+ # inpaint_mask,
938
+ # )
939
+
940
+
941
+ LENGTH = 480
942
+
943
+ example_imgs = [
944
+ [
945
+ "sample_images/sample1.jpg",
946
+ ],
947
+ [
948
+ "sample_images/sample2.jpg",
949
+ ],
950
+ [
951
+ "sample_images/sample3.jpg",
952
+ ],
953
+ [
954
+ "sample_images/sample4.jpg",
955
+ ],
956
+ [
957
+ "sample_images/sample5.jpg",
958
+ ],
959
+ [
960
+ "sample_images/sample6.jpg",
961
+ ],
962
+ [
963
+ "sample_images/sample7.jpg",
964
+ ],
965
+ [
966
+ "sample_images/sample8.jpg",
967
+ ],
968
+ [
969
+ "sample_images/sample9.jpg",
970
+ ],
971
+ [
972
+ "sample_images/sample10.jpg",
973
+ ],
974
+ [
975
+ "sample_images/sample11.jpg",
976
+ ],
977
+ ["pose_images/pose1.jpg"],
978
+ ["pose_images/pose2.jpg"],
979
+ ["pose_images/pose3.jpg"],
980
+ ["pose_images/pose4.jpg"],
981
+ ["pose_images/pose5.jpg"],
982
+ ["pose_images/pose6.jpg"],
983
+ ["pose_images/pose7.jpg"],
984
+ ["pose_images/pose8.jpg"],
985
+ ]
986
+
987
+ fix_example_imgs = [
988
+ ["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"],
989
+ ["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"],
990
+ ["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"],
991
+ ["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
992
+ ["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
993
+ ["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
994
+ ["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
995
+ ["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
996
+ ["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
997
+ ["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
998
+ ["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"],
999
+ ["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"],
1000
+ ["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"],
1001
+ ]
1002
+ custom_css = """
1003
+ .gradio-container .examples img {
1004
+ width: 240px !important;
1005
+ height: 240px !important;
1006
+ }
1007
+ """
1008
+
1009
+
1010
+ with gr.Blocks(css=custom_css) as demo:
1011
+ with gr.Tab("Edit Hand Poses"):
1012
+ ref_img = gr.State(value=None)
1013
+ ref_cond = gr.State(value=None)
1014
+ keypts = gr.State(value=None)
1015
+ target_img = gr.State(value=None)
1016
+ target_cond = gr.State(value=None)
1017
+ target_keypts = gr.State(value=None)
1018
+ dump = gr.State(value=None)
1019
+ with gr.Row():
1020
+ with gr.Column():
1021
+ gr.Markdown(
1022
+ """<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Reference</p>"""
1023
+ )
1024
+ gr.Markdown("""<p style="text-align: center;"><br></p>""")
1025
+ ref = gr.ImageEditor(
1026
+ type="numpy",
1027
+ label="Reference",
1028
+ show_label=True,
1029
+ height=LENGTH,
1030
+ width=LENGTH,
1031
+ brush=False,
1032
+ layers=False,
1033
+ crop_size="1:1",
1034
+ )
1035
+ ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
1036
+ ref_pose = gr.Image(
1037
+ type="numpy",
1038
+ label="Reference Pose",
1039
+ show_label=True,
1040
+ height=LENGTH,
1041
+ width=LENGTH,
1042
+ interactive=False,
1043
+ )
1044
+ ref_flip = gr.Checkbox(
1045
+ value=False, label="Flip Handedness (Reference)", interactive=False
1046
+ )
1047
+ with gr.Column():
1048
+ gr.Markdown(
1049
+ """<p style="text-align: center; font-size: 25px; font-weight: bold;">2. Target</p>"""
1050
+ )
1051
+ target = gr.ImageEditor(
1052
+ type="numpy",
1053
+ label="Target",
1054
+ show_label=True,
1055
+ height=LENGTH,
1056
+ width=LENGTH,
1057
+ brush=False,
1058
+ layers=False,
1059
+ crop_size="1:1",
1060
+ )
1061
+ target_finish_crop = gr.Button(
1062
+ value="Finish Cropping", interactive=False
1063
+ )
1064
+ target_pose = gr.Image(
1065
+ type="numpy",
1066
+ label="Target Pose",
1067
+ show_label=True,
1068
+ height=LENGTH,
1069
+ width=LENGTH,
1070
+ interactive=False,
1071
+ )
1072
+ target_flip = gr.Checkbox(
1073
+ value=False, label="Flip Handedness (Target)", interactive=False
1074
+ )
1075
+ with gr.Column():
1076
+ gr.Markdown(
1077
+ """<p style="text-align: center; font-size: 25px; font-weight: bold;">3. Result</p>"""
1078
+ )
1079
+ gr.Markdown(
1080
+ """<p style="text-align: center;">Run is enabled after the images have been processed</p>"""
1081
+ )
1082
+ run = gr.Button(value="Run", interactive=False)
1083
+ gr.Markdown(
1084
+ """<p style="text-align: center;">~20s per generation. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
1085
+ )
1086
+ results = gr.Gallery(
1087
+ type="numpy",
1088
+ label="Results",
1089
+ show_label=True,
1090
+ height=LENGTH,
1091
+ min_width=LENGTH,
1092
+ columns=MAX_N,
1093
+ interactive=False,
1094
+ preview=True,
1095
+ )
1096
+ results_pose = gr.Gallery(
1097
+ type="numpy",
1098
+ label="Results Pose",
1099
+ show_label=True,
1100
+ height=LENGTH,
1101
+ min_width=LENGTH,
1102
+ columns=MAX_N,
1103
+ interactive=False,
1104
+ preview=True,
1105
+ )
1106
+ clear = gr.ClearButton()
1107
+
1108
+ with gr.Row():
1109
+ n_generation = gr.Slider(
1110
+ label="Number of generations",
1111
+ value=1,
1112
+ minimum=1,
1113
+ maximum=MAX_N,
1114
+ step=1,
1115
+ randomize=False,
1116
+ interactive=True,
1117
+ )
1118
+ seed = gr.Slider(
1119
+ label="Seed",
1120
+ value=42,
1121
+ minimum=0,
1122
+ maximum=10000,
1123
+ step=1,
1124
+ randomize=False,
1125
+ interactive=True,
1126
+ )
1127
+ cfg = gr.Slider(
1128
+ label="Classifier free guidance scale",
1129
+ value=2.5,
1130
+ minimum=0.0,
1131
+ maximum=10.0,
1132
+ step=0.1,
1133
+ randomize=False,
1134
+ interactive=True,
1135
+ )
1136
+
1137
+ ref.change(enable_component, [ref, ref], ref_finish_crop)
1138
+ ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond])
1139
+ ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
1140
+ ref_flip.select(
1141
+ flip_hand, [ref, ref_pose, ref_cond], [ref, ref_pose, ref_cond, dump]
1142
+ )
1143
+ target.change(enable_component, [target, target], target_finish_crop)
1144
+ target_finish_crop.click(
1145
+ get_target_anno,
1146
+ [target],
1147
+ [target_img, target_pose, target_cond, target_keypts],
1148
+ )
1149
+ target_pose.change(enable_component, [target_img, target_pose], target_flip)
1150
+ target_flip.select(
1151
+ flip_hand,
1152
+ [target, target_pose, target_cond, target_keypts],
1153
+ [target, target_pose, target_cond, target_keypts],
1154
+ )
1155
+ ref_pose.change(enable_component, [ref_pose, target_pose], run)
1156
+ target_pose.change(enable_component, [ref_pose, target_pose], run)
1157
+ run.click(
1158
+ sample_diff,
1159
+ [ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
1160
+ [results, results_pose],
1161
+ )
1162
+ clear.click(
1163
+ clear_all,
1164
+ [],
1165
+ [
1166
+ ref,
1167
+ ref_pose,
1168
+ ref_flip,
1169
+ target,
1170
+ target_pose,
1171
+ target_flip,
1172
+ results,
1173
+ results_pose,
1174
+ ref_img,
1175
+ ref_cond,
1176
+ # mask,
1177
+ target_img,
1178
+ target_cond,
1179
+ target_keypts,
1180
+ n_generation,
1181
+ seed,
1182
+ cfg,
1183
+ ],
1184
+ )
1185
+
1186
+ gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
1187
+ with gr.Tab("Reference"):
1188
+ with gr.Row():
1189
+ gr.Examples(example_imgs, [ref], examples_per_page=20)
1190
+ with gr.Tab("Target"):
1191
+ with gr.Row():
1192
+ gr.Examples(example_imgs, [target], examples_per_page=20)
1193
+ with gr.Tab("Fix Hands"):
1194
+ fix_inpaint_mask = gr.State(value=None)
1195
+ fix_original = gr.State(value=None)
1196
+ fix_img = gr.State(value=None)
1197
+ fix_kpts = gr.State(value=None)
1198
+ fix_kpts_np = gr.State(value=None)
1199
+ fix_ref_cond = gr.State(value=None)
1200
+ fix_target_cond = gr.State(value=None)
1201
+ fix_latent = gr.State(value=None)
1202
+ fix_inpaint_latent = gr.State(value=None)
1203
+ # fix_size_memory = gr.State(value=(0, 0))
1204
+ with gr.Row():
1205
+ with gr.Column():
1206
+ gr.Markdown(
1207
+ """<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>"""
1208
+ )
1209
+ gr.Markdown(
1210
+ """<p style="text-align: center;">Crop the image around the hand.<br>Then, brush area (e.g., wrong finger) that needs to be fixed.</p>"""
1211
+ )
1212
+ gr.Markdown(
1213
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>"""
1214
+ )
1215
+ fix_crop = gr.ImageEditor(
1216
+ type="numpy",
1217
+ sources=["upload", "webcam", "clipboard"],
1218
+ label="Image crop",
1219
+ show_label=True,
1220
+ height=LENGTH,
1221
+ width=LENGTH,
1222
+ layers=False,
1223
+ crop_size="1:1",
1224
+ brush=False,
1225
+ image_mode="RGBA",
1226
+ container=False,
1227
+ )
1228
+ gr.Markdown(
1229
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>"""
1230
+ )
1231
+ fix_ref = gr.ImageEditor(
1232
+ type="numpy",
1233
+ label="Image brush",
1234
+ sources=(),
1235
+ show_label=True,
1236
+ height=LENGTH,
1237
+ width=LENGTH,
1238
+ layers=False,
1239
+ transforms=("brush"),
1240
+ brush=gr.Brush(
1241
+ colors=["rgb(255, 255, 255)"], default_size=20
1242
+ ), # 204, 50, 50
1243
+ image_mode="RGBA",
1244
+ container=False,
1245
+ interactive=False,
1246
+ )
1247
+ fix_finish_crop = gr.Button(
1248
+ value="Finish Croping & Brushing", interactive=False
1249
+ )
1250
+ gr.Markdown(
1251
+ """<p style="text-align: left; font-size: 20px; font-weight: bold; ">OpenPose keypoints convention</p>"""
1252
+ )
1253
+ fix_openpose = gr.Image(
1254
+ value="openpose.png",
1255
+ type="numpy",
1256
+ label="OpenPose keypoints convention",
1257
+ show_label=True,
1258
+ height=LENGTH // 3 * 2,
1259
+ width=LENGTH // 3 * 2,
1260
+ interactive=False,
1261
+ )
1262
+ with gr.Column():
1263
+ gr.Markdown(
1264
+ """<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>"""
1265
+ )
1266
+ gr.Markdown(
1267
+ """<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\" on the bottom left.</p>"""
1268
+ )
1269
+ fix_checkbox = gr.CheckboxGroup(
1270
+ ["Right hand", "Left hand"],
1271
+ # value=["Right hand", "Left hand"],
1272
+ label="Hand side",
1273
+ info="Which side this hand is? Could be both.",
1274
+ interactive=False,
1275
+ )
1276
+ fix_kp_r_info = gr.Markdown(
1277
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
1278
+ visible=False,
1279
+ )
1280
+ fix_kp_right = gr.Image(
1281
+ type="numpy",
1282
+ label="Keypoint Selection (right hand)",
1283
+ show_label=True,
1284
+ height=LENGTH,
1285
+ width=LENGTH,
1286
+ interactive=False,
1287
+ visible=False,
1288
+ sources=[],
1289
+ )
1290
+ with gr.Row():
1291
+ fix_undo_right = gr.Button(
1292
+ value="Undo", interactive=False, visible=False
1293
+ )
1294
+ fix_reset_right = gr.Button(
1295
+ value="Reset", interactive=False, visible=False
1296
+ )
1297
+ fix_kp_l_info = gr.Markdown(
1298
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
1299
+ visible=False
1300
+ )
1301
+ fix_kp_left = gr.Image(
1302
+ type="numpy",
1303
+ label="Keypoint Selection (left hand)",
1304
+ show_label=True,
1305
+ height=LENGTH,
1306
+ width=LENGTH,
1307
+ interactive=False,
1308
+ visible=False,
1309
+ sources=[],
1310
+ )
1311
+ with gr.Row():
1312
+ fix_undo_left = gr.Button(
1313
+ value="Undo", interactive=False, visible=False
1314
+ )
1315
+ fix_reset_left = gr.Button(
1316
+ value="Reset", interactive=False, visible=False
1317
+ )
1318
+ with gr.Column():
1319
+ gr.Markdown(
1320
+ """<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>"""
1321
+ )
1322
+ gr.Markdown(
1323
+ """<p style="text-align: center;">In Fix Hands, not segmentation mask, but only inpaint mask is used.</p>"""
1324
+ )
1325
+ fix_ready = gr.Button(value="Ready", interactive=False)
1326
+ fix_mask_size = gr.Radio(
1327
+ ["256x256", "latent size (32x32)"],
1328
+ label="Visualized inpaint mask size",
1329
+ interactive=False,
1330
+ value="256x256",
1331
+ )
1332
+ gr.Markdown(
1333
+ """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Visualized inpaint masks</p>"""
1334
+ )
1335
+ fix_vis_mask32 = gr.Image(
1336
+ type="numpy",
1337
+ label=f"Visualized {opts.latent_size} Inpaint Mask",
1338
+ show_label=True,
1339
+ height=opts.latent_size,
1340
+ width=opts.latent_size,
1341
+ interactive=False,
1342
+ visible=False,
1343
+ )
1344
+ fix_vis_mask256 = gr.Image(
1345
+ type="numpy",
1346
+ label=f"Visualized {opts.image_size} Inpaint Mask",
1347
+ visible=True,
1348
+ show_label=True,
1349
+ height=opts.image_size,
1350
+ width=opts.image_size,
1351
+ interactive=False,
1352
+ )
1353
+ with gr.Column():
1354
+ gr.Markdown(
1355
+ """<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>"""
1356
+ )
1357
+ fix_run = gr.Button(value="Run", interactive=False)
1358
+ gr.Markdown(
1359
+ """<p style="text-align: center;">>3min and ~24GB per generation</p>"""
1360
+ )
1361
+ fix_result = gr.Gallery(
1362
+ type="numpy",
1363
+ label="Results",
1364
+ show_label=True,
1365
+ height=LENGTH,
1366
+ min_width=LENGTH,
1367
+ columns=FIX_MAX_N,
1368
+ interactive=False,
1369
+ preview=True,
1370
+ )
1371
+ fix_result_pose = gr.Gallery(
1372
+ type="numpy",
1373
+ label="Results Pose",
1374
+ show_label=True,
1375
+ height=LENGTH,
1376
+ min_width=LENGTH,
1377
+ columns=FIX_MAX_N,
1378
+ interactive=False,
1379
+ preview=True,
1380
+ )
1381
+ fix_clear = gr.ClearButton()
1382
+ gr.Markdown(
1383
+ "[NOTE] Currently, Number of generation > 1 could lead to out-of-memory"
1384
+ )
1385
+ with gr.Row():
1386
+ fix_n_generation = gr.Slider(
1387
+ label="Number of generations",
1388
+ value=1,
1389
+ minimum=1,
1390
+ maximum=FIX_MAX_N,
1391
+ step=1,
1392
+ randomize=False,
1393
+ interactive=True,
1394
+ )
1395
+ fix_seed = gr.Slider(
1396
+ label="Seed",
1397
+ value=42,
1398
+ minimum=0,
1399
+ maximum=10000,
1400
+ step=1,
1401
+ randomize=False,
1402
+ interactive=True,
1403
+ )
1404
+ fix_cfg = gr.Slider(
1405
+ label="Classifier free guidance scale",
1406
+ value=3.0,
1407
+ minimum=0.0,
1408
+ maximum=10.0,
1409
+ step=0.1,
1410
+ randomize=False,
1411
+ interactive=True,
1412
+ )
1413
+ fix_quality = gr.Slider(
1414
+ label="Quality",
1415
+ value=10,
1416
+ minimum=1,
1417
+ maximum=10,
1418
+ step=1,
1419
+ randomize=False,
1420
+ interactive=True,
1421
+ )
1422
+ fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
1423
+ fix_crop.change(resize_to_full, fix_crop, fix_ref)
1424
+ fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
1425
+ fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
1426
+ # fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right])
1427
+ # fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left])
1428
+ fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
1429
+ fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
1430
+ fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
1431
+ fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
1432
+ fix_inpaint_mask.change(
1433
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
1434
+ )
1435
+ fix_inpaint_mask.change(
1436
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
1437
+ )
1438
+ fix_inpaint_mask.change(
1439
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
1440
+ )
1441
+ fix_inpaint_mask.change(
1442
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
1443
+ )
1444
+ fix_inpaint_mask.change(
1445
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
1446
+ )
1447
+ fix_inpaint_mask.change(
1448
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
1449
+ )
1450
+ fix_inpaint_mask.change(
1451
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
1452
+ )
1453
+ fix_inpaint_mask.change(
1454
+ enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready
1455
+ )
1456
+ # fix_inpaint_mask.change(
1457
+ # enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run
1458
+ # )
1459
+ fix_checkbox.select(
1460
+ set_visible,
1461
+ [fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
1462
+ [
1463
+ fix_kpts,
1464
+ fix_kp_right,
1465
+ fix_kp_left,
1466
+ fix_kp_right,
1467
+ fix_undo_right,
1468
+ fix_reset_right,
1469
+ fix_kp_left,
1470
+ fix_undo_left,
1471
+ fix_reset_left,
1472
+ fix_kp_r_info,
1473
+ fix_kp_l_info,
1474
+ ],
1475
+ )
1476
+ fix_kp_right.select(
1477
+ get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1478
+ )
1479
+ fix_undo_right.click(
1480
+ undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1481
+ )
1482
+ fix_reset_right.click(
1483
+ reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1484
+ )
1485
+ fix_kp_left.select(
1486
+ get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1487
+ )
1488
+ fix_undo_left.click(
1489
+ undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1490
+ )
1491
+ fix_reset_left.click(
1492
+ reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1493
+ )
1494
+ # fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run])
1495
+ # fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose])
1496
+ fix_vis_mask32.change(
1497
+ enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
1498
+ )
1499
+ fix_vis_mask32.change(
1500
+ enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size
1501
+ )
1502
+ fix_ready.click(
1503
+ ready_sample,
1504
+ [fix_original, fix_inpaint_mask, fix_kpts],
1505
+ [
1506
+ fix_ref_cond,
1507
+ fix_target_cond,
1508
+ fix_latent,
1509
+ fix_inpaint_latent,
1510
+ fix_kpts_np,
1511
+ fix_vis_mask32,
1512
+ fix_vis_mask256,
1513
+ ],
1514
+ )
1515
+ fix_mask_size.select(
1516
+ switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256]
1517
+ )
1518
+ fix_run.click(
1519
+ sample_inpaint,
1520
+ [
1521
+ fix_ref_cond,
1522
+ fix_target_cond,
1523
+ fix_latent,
1524
+ fix_inpaint_latent,
1525
+ fix_kpts_np,
1526
+ fix_n_generation,
1527
+ fix_seed,
1528
+ fix_cfg,
1529
+ fix_quality,
1530
+ ],
1531
+ [fix_result, fix_result_pose],
1532
+ )
1533
+ fix_clear.click(
1534
+ fix_clear_all,
1535
+ [],
1536
+ [
1537
+ fix_crop,
1538
+ fix_ref,
1539
+ fix_kp_right,
1540
+ fix_kp_left,
1541
+ fix_result,
1542
+ fix_result_pose,
1543
+ fix_inpaint_mask,
1544
+ fix_original,
1545
+ fix_img,
1546
+ fix_vis_mask32,
1547
+ fix_vis_mask256,
1548
+ fix_kpts,
1549
+ fix_kpts_np,
1550
+ fix_ref_cond,
1551
+ fix_target_cond,
1552
+ fix_latent,
1553
+ fix_inpaint_latent,
1554
+ fix_n_generation,
1555
+ # fix_size_memory,
1556
+ fix_seed,
1557
+ fix_cfg,
1558
+ fix_quality,
1559
+ ],
1560
+ )
1561
+
1562
+ gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
1563
+ fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False)
1564
+ fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False)
1565
+ with gr.Column():
1566
+ fix_example = gr.Examples(
1567
+ fix_example_imgs,
1568
+ # run_on_click=True,
1569
+ # fn=parse_fix_example,
1570
+ # inputs=[fix_dump_ex, fix_dump_ex_masked],
1571
+ # outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask],
1572
+ inputs=[fix_crop],
1573
+ examples_per_page=20,
1574
+ )
1575
+
1576
+
1577
+ print("Ready to launch..")
1578
+ _, _, shared_url = demo.queue().launch(
1579
+ share=True, server_name="0.0.0.0", server_port=7739
1580
+ )
1581
+ demo.block()
diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (986 Bytes). View file
 
diffusion/__pycache__/diffusion_utils.cpython-38.pyc ADDED
Binary file (2.86 kB). View file
 
diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc ADDED
Binary file (27.6 kB). View file
 
diffusion/__pycache__/respace.cpython-38.pyc ADDED
Binary file (5.04 kB). View file
 
diffusion/__pycache__/scheduler.cpython-38.pyc ADDED
Binary file (3.99 kB). View file
 
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+ from collections import defaultdict
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+ from .scheduler import get_schedule_jump
15
+
16
+
17
+ def mean_flat(tensor):
18
+ """
19
+ Take the mean over all non-batch dimensions.
20
+ """
21
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
22
+
23
+
24
+ class ModelMeanType(enum.Enum):
25
+ """
26
+ Which type of output the model predicts.
27
+ """
28
+
29
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
30
+ START_X = enum.auto() # the model predicts x_0
31
+ EPSILON = enum.auto() # the model predicts epsilon
32
+
33
+
34
+ class ModelVarType(enum.Enum):
35
+ """
36
+ What is used as the model's output variance.
37
+ The LEARNED_RANGE option has been added to allow the model to predict
38
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39
+ """
40
+
41
+ LEARNED = enum.auto()
42
+ FIXED_SMALL = enum.auto()
43
+ FIXED_LARGE = enum.auto()
44
+ LEARNED_RANGE = enum.auto()
45
+
46
+
47
+ class LossType(enum.Enum):
48
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49
+ RESCALED_MSE = (
50
+ enum.auto()
51
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
52
+ KL = enum.auto() # use the variational lower-bound
53
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54
+
55
+ def is_vb(self):
56
+ return self == LossType.KL or self == LossType.RESCALED_KL
57
+
58
+
59
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
60
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
61
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
62
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
63
+ return betas
64
+
65
+
66
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
67
+ """
68
+ This is the deprecated API for creating beta schedules.
69
+ See get_named_beta_schedule() for the new library of schedules.
70
+ """
71
+ if beta_schedule == "quad":
72
+ betas = (
73
+ np.linspace(
74
+ beta_start ** 0.5,
75
+ beta_end ** 0.5,
76
+ num_diffusion_timesteps,
77
+ dtype=np.float64,
78
+ )
79
+ ** 2
80
+ )
81
+ elif beta_schedule == "linear":
82
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
83
+ elif beta_schedule == "warmup10":
84
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
85
+ elif beta_schedule == "warmup50":
86
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
87
+ elif beta_schedule == "const":
88
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
89
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
90
+ betas = 1.0 / np.linspace(
91
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
92
+ )
93
+ else:
94
+ raise NotImplementedError(beta_schedule)
95
+ assert betas.shape == (num_diffusion_timesteps,)
96
+ return betas
97
+
98
+
99
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
100
+ """
101
+ Get a pre-defined beta schedule for the given name.
102
+ The beta schedule library consists of beta schedules which remain similar
103
+ in the limit of num_diffusion_timesteps.
104
+ Beta schedules may be added, but should not be removed or changed once
105
+ they are committed to maintain backwards compatibility.
106
+ """
107
+ if schedule_name == "linear":
108
+ # Linear schedule from Ho et al, extended to work for any number of
109
+ # diffusion steps.
110
+ scale = 1000 / num_diffusion_timesteps
111
+ return get_beta_schedule(
112
+ "linear",
113
+ beta_start=scale * 0.0001,
114
+ beta_end=scale * 0.02,
115
+ num_diffusion_timesteps=num_diffusion_timesteps,
116
+ )
117
+ elif schedule_name == "squaredcos_cap_v2":
118
+ return betas_for_alpha_bar(
119
+ num_diffusion_timesteps,
120
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
121
+ )
122
+ else:
123
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
124
+
125
+
126
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
127
+ """
128
+ Create a beta schedule that discretizes the given alpha_t_bar function,
129
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
130
+ :param num_diffusion_timesteps: the number of betas to produce.
131
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
132
+ produces the cumulative product of (1-beta) up to that
133
+ part of the diffusion process.
134
+ :param max_beta: the maximum beta to use; use values lower than 1 to
135
+ prevent singularities.
136
+ """
137
+ betas = []
138
+ for i in range(num_diffusion_timesteps):
139
+ t1 = i / num_diffusion_timesteps
140
+ t2 = (i + 1) / num_diffusion_timesteps
141
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
142
+ return np.array(betas)
143
+
144
+
145
+ class GaussianDiffusion:
146
+ """
147
+ Utilities for training and sampling diffusion models.
148
+ Original ported from this codebase:
149
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
150
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
151
+ starting at T and going to 1.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ *,
157
+ betas,
158
+ model_mean_type,
159
+ model_var_type,
160
+ loss_type
161
+ ):
162
+
163
+ self.model_mean_type = model_mean_type
164
+ self.model_var_type = model_var_type
165
+ self.loss_type = loss_type
166
+
167
+ # Use float64 for accuracy.
168
+ betas = np.array(betas, dtype=np.float64)
169
+ self.betas = betas
170
+ assert len(betas.shape) == 1, "betas must be 1-D"
171
+ assert (betas > 0).all() and (betas <= 1).all()
172
+
173
+ self.num_timesteps = int(betas.shape[0])
174
+
175
+ alphas = 1.0 - betas
176
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
177
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
178
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
179
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
180
+
181
+ # calculations for diffusion q(x_t | x_{t-1}) and others
182
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
183
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
184
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
185
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
186
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
187
+
188
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
189
+ self.posterior_variance = (
190
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
191
+ )
192
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
193
+ self.posterior_log_variance_clipped = np.log(
194
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
195
+ ) if len(self.posterior_variance) > 1 else np.array([])
196
+
197
+ self.posterior_mean_coef1 = (
198
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
199
+ )
200
+ self.posterior_mean_coef2 = (
201
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
202
+ )
203
+
204
+ def q_mean_variance(self, x_start, t):
205
+ """
206
+ Get the distribution q(x_t | x_0).
207
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
208
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210
+ """
211
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
212
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
214
+ return mean, variance, log_variance
215
+
216
+ def q_sample(self, x_start, t, noise=None):
217
+ """
218
+ Diffuse the data for a given number of diffusion steps.
219
+ In other words, sample from q(x_t | x_0).
220
+ :param x_start: the initial data batch.
221
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
222
+ :param noise: if specified, the split-out normal noise.
223
+ :return: A noisy version of x_start.
224
+ """
225
+ if noise is None:
226
+ noise = th.randn_like(x_start)
227
+ assert noise.shape == x_start.shape
228
+ return (
229
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
230
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
231
+ )
232
+
233
+ def q_posterior_mean_variance(self, x_start, x_t, t):
234
+ """
235
+ Compute the mean and variance of the diffusion posterior:
236
+ q(x_{t-1} | x_t, x_0)
237
+ """
238
+ assert x_start.shape == x_t.shape
239
+ posterior_mean = (
240
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
241
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
242
+ )
243
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
244
+ posterior_log_variance_clipped = _extract_into_tensor(
245
+ self.posterior_log_variance_clipped, t, x_t.shape
246
+ )
247
+ assert (
248
+ posterior_mean.shape[0]
249
+ == posterior_variance.shape[0]
250
+ == posterior_log_variance_clipped.shape[0]
251
+ == x_start.shape[0]
252
+ )
253
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
254
+
255
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
256
+ """
257
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
258
+ the initial x, x_0.
259
+ :param model: the model, which takes a signal and a batch of timesteps
260
+ as input.
261
+ :param x: the [N x C x ...] tensor at time t.
262
+ :param t: a 1-D Tensor of timesteps.
263
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
264
+ :param denoised_fn: if not None, a function which applies to the
265
+ x_start prediction before it is used to sample. Applies before
266
+ clip_denoised.
267
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
268
+ pass to the model. This can be used for conditioning.
269
+ :return: a dict with the following keys:
270
+ - 'mean': the model mean output.
271
+ - 'variance': the model variance output.
272
+ - 'log_variance': the log of 'variance'.
273
+ - 'pred_xstart': the prediction for x_0.
274
+ """
275
+ if model_kwargs is None:
276
+ model_kwargs = {}
277
+
278
+ B, C = x.shape[:2]
279
+ assert t.shape == (B,)
280
+ model_output = model(x, t, **model_kwargs)
281
+ if isinstance(model_output, tuple):
282
+ model_output, extra = model_output
283
+ else:
284
+ extra = None
285
+
286
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
287
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
288
+ model_output, model_var_values = th.split(model_output, C, dim=1)
289
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
290
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
291
+ # The model_var_values is [-1, 1] for [min_var, max_var].
292
+ frac = (model_var_values + 1) / 2
293
+ model_log_variance = frac * max_log + (1 - frac) * min_log
294
+ model_variance = th.exp(model_log_variance)
295
+ else:
296
+ model_variance, model_log_variance = {
297
+ # for fixedlarge, we set the initial (log-)variance like so
298
+ # to get a better decoder log likelihood.
299
+ ModelVarType.FIXED_LARGE: (
300
+ np.append(self.posterior_variance[1], self.betas[1:]),
301
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
302
+ ),
303
+ ModelVarType.FIXED_SMALL: (
304
+ self.posterior_variance,
305
+ self.posterior_log_variance_clipped,
306
+ ),
307
+ }[self.model_var_type]
308
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
309
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
310
+
311
+ def process_xstart(x):
312
+ if denoised_fn is not None:
313
+ x = denoised_fn(x)
314
+ if clip_denoised:
315
+ return x.clamp(-1, 1)
316
+ return x
317
+
318
+ if self.model_mean_type == ModelMeanType.START_X:
319
+ pred_xstart = process_xstart(model_output)
320
+ else:
321
+ pred_xstart = process_xstart(
322
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
323
+ )
324
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
325
+
326
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
327
+ return {
328
+ "mean": model_mean,
329
+ "variance": model_variance,
330
+ "log_variance": model_log_variance,
331
+ "pred_xstart": pred_xstart,
332
+ "extra": extra,
333
+ }
334
+
335
+ def _predict_xstart_from_eps(self, x_t, t, eps):
336
+ assert x_t.shape == eps.shape
337
+ return (
338
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
339
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
340
+ )
341
+
342
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
343
+ return (
344
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
345
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
346
+
347
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
348
+ """
349
+ Compute the mean for the previous step, given a function cond_fn that
350
+ computes the gradient of a conditional log probability with respect to
351
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
352
+ condition on y.
353
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
354
+ """
355
+ gradient = cond_fn(x, t, **model_kwargs)
356
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
357
+ return new_mean
358
+
359
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
360
+ """
361
+ Compute what the p_mean_variance output would have been, should the
362
+ model's score function be conditioned by cond_fn.
363
+ See condition_mean() for details on cond_fn.
364
+ Unlike condition_mean(), this instead uses the conditioning strategy
365
+ from Song et al (2020).
366
+ """
367
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
368
+
369
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
370
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
371
+
372
+ out = p_mean_var.copy()
373
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
374
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
375
+ return out
376
+
377
+ def p_sample(
378
+ self,
379
+ model,
380
+ x,
381
+ t,
382
+ clip_denoised=True,
383
+ denoised_fn=None,
384
+ cond_fn=None,
385
+ model_kwargs=None,
386
+ ):
387
+ """
388
+ Sample x_{t-1} from the model at the given timestep.
389
+ :param model: the model to sample from.
390
+ :param x: the current tensor at x_{t-1}.
391
+ :param t: the value of t, starting at 0 for the first diffusion step.
392
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
393
+ :param denoised_fn: if not None, a function which applies to the
394
+ x_start prediction before it is used to sample.
395
+ :param cond_fn: if not None, this is a gradient function that acts
396
+ similarly to the model.
397
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
398
+ pass to the model. This can be used for conditioning.
399
+ :return: a dict containing the following keys:
400
+ - 'sample': a random sample from the model.
401
+ - 'pred_xstart': a prediction of x_0.
402
+ """
403
+ out = self.p_mean_variance(
404
+ model,
405
+ x,
406
+ t,
407
+ clip_denoised=clip_denoised,
408
+ denoised_fn=denoised_fn,
409
+ model_kwargs=model_kwargs,
410
+ )
411
+ noise = th.randn_like(x)
412
+ nonzero_mask = (
413
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
414
+ ) # no noise when t == 0
415
+ if cond_fn is not None:
416
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
417
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
418
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
419
+
420
+ def p_sample_loop(
421
+ self,
422
+ model,
423
+ shape,
424
+ noise=None,
425
+ clip_denoised=True,
426
+ denoised_fn=None,
427
+ cond_fn=None,
428
+ model_kwargs=None,
429
+ device=None,
430
+ progress=False,
431
+ ):
432
+ """
433
+ Generate samples from the model.
434
+ :param model: the model module.
435
+ :param shape: the shape of the samples, (N, C, H, W).
436
+ :param noise: if specified, the noise from the encoder to sample.
437
+ Should be of the same shape as `shape`.
438
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
439
+ :param denoised_fn: if not None, a function which applies to the
440
+ x_start prediction before it is used to sample.
441
+ :param cond_fn: if not None, this is a gradient function that acts
442
+ similarly to the model.
443
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
444
+ pass to the model. This can be used for conditioning.
445
+ :param device: if specified, the device to create the samples on.
446
+ If not specified, use a model parameter's device.
447
+ :param progress: if True, show a tqdm progress bar.
448
+ :return: a non-differentiable batch of samples.
449
+ """
450
+ final = None
451
+ for sample in self.p_sample_loop_progressive(
452
+ model,
453
+ shape,
454
+ noise=noise,
455
+ clip_denoised=clip_denoised,
456
+ denoised_fn=denoised_fn,
457
+ cond_fn=cond_fn,
458
+ model_kwargs=model_kwargs,
459
+ device=device,
460
+ progress=progress,
461
+ ):
462
+ final = sample
463
+ return final["sample"]
464
+
465
+ def inpaint_p_sample_loop(
466
+ self,
467
+ model,
468
+ shape,
469
+ x0,
470
+ mask,
471
+ noise=None,
472
+ clip_denoised=True,
473
+ denoised_fn=None,
474
+ cond_fn=None,
475
+ model_kwargs=None,
476
+ device=None,
477
+ progress=False,
478
+ jump_length=10,
479
+ jump_n_sample=10,
480
+ ):
481
+ """
482
+ Generate samples from the model.
483
+ :param model: the model module.
484
+ :param shape: the shape of the samples, (N, C, H, W).
485
+ :param noise: if specified, the noise from the encoder to sample.
486
+ Should be of the same shape as `shape`.
487
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
488
+ :param denoised_fn: if not None, a function which applies to the
489
+ x_start prediction before it is used to sample.
490
+ :param cond_fn: if not None, this is a gradient function that acts
491
+ similarly to the model.
492
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
493
+ pass to the model. This can be used for conditioning.
494
+ :param device: if specified, the device to create the samples on.
495
+ If not specified, use a model parameter's device.
496
+ :param progress: if True, show a tqdm progress bar.
497
+ :return: a non-differentiable batch of samples.
498
+ """
499
+ final = None
500
+ for sample in self.inpaint_p_sample_loop_progressive(
501
+ model,
502
+ shape,
503
+ x0,
504
+ mask,
505
+ noise=noise,
506
+ clip_denoised=clip_denoised,
507
+ denoised_fn=denoised_fn,
508
+ cond_fn=cond_fn,
509
+ model_kwargs=model_kwargs,
510
+ device=device,
511
+ progress=progress,
512
+ jump_length=jump_length,
513
+ jump_n_sample=jump_n_sample,
514
+ ):
515
+ final = sample
516
+ return final["sample"]
517
+
518
+ def inpaint_p_sample_loop_progressive(
519
+ self,
520
+ model,
521
+ shape,
522
+ x0,
523
+ mask,
524
+ noise=None,
525
+ clip_denoised=True,
526
+ denoised_fn=None,
527
+ cond_fn=None,
528
+ model_kwargs=None,
529
+ device=None,
530
+ progress=False,
531
+ jump_length=10,
532
+ jump_n_sample=10,
533
+ ):
534
+ """
535
+ Generate samples from the model and yield intermediate samples from
536
+ each timestep of diffusion.
537
+
538
+ Arguments are the same as p_sample_loop().
539
+ Returns a generator over dicts, where each dict is the return value of
540
+ p_sample().
541
+ """
542
+ # if device is None:
543
+ # device = next(model.parameters()).device
544
+ # assert isinstance(shape, (tuple, list))
545
+ # if noise is not None:
546
+ # img = noise
547
+ # else:
548
+ # img = th.randn(*shape, device=device)
549
+ # indices = list(range(self.num_timesteps))[::-1]
550
+
551
+ # if progress:
552
+ # # Lazy import so that we don't depend on tqdm.
553
+ # from tqdm.auto import tqdm
554
+
555
+ # indices = tqdm(indices)
556
+ # pred_xstart = None
557
+ # for i in indices:
558
+ # t = th.tensor([i] * shape[0], device=device)
559
+ # with th.no_grad():
560
+ # out = self.inpaint_p_sample(
561
+ # model,
562
+ # img,
563
+ # t,
564
+ # x0,
565
+ # mask,
566
+ # clip_denoised=clip_denoised,
567
+ # denoised_fn=denoised_fn,
568
+ # cond_fn=cond_fn,
569
+ # model_kwargs=model_kwargs,
570
+ # pred_xstart=pred_xstart,
571
+ # )
572
+ # yield out
573
+ # img = out["sample"]
574
+ # pred_xstart = out["pred_xstart"]
575
+
576
+ if device is None:
577
+ device = next(model.parameters()).device
578
+ assert isinstance(shape, (tuple, list))
579
+ if noise is not None:
580
+ image_after_step = noise
581
+ else:
582
+ image_after_step = th.randn(*shape, device=device)
583
+
584
+ self.gt_noises = None # reset for next image
585
+
586
+
587
+ pred_xstart = None
588
+
589
+ idx_wall = -1
590
+ sample_idxs = defaultdict(lambda: 0)
591
+
592
+ times = get_schedule_jump(t_T=250, n_sample=1, jump_length=jump_length, jump_n_sample=jump_n_sample)
593
+ time_pairs = list(zip(times[:-1], times[1:]))
594
+
595
+ if progress:
596
+ from tqdm.auto import tqdm
597
+ time_pairs = tqdm(time_pairs)
598
+
599
+ for t_last, t_cur in time_pairs:
600
+ idx_wall += 1
601
+ t_last_t = th.tensor([t_last] * shape[0], # pylint: disable=not-callable
602
+ device=device)
603
+
604
+ if t_cur < t_last: # reverse
605
+ with th.no_grad():
606
+ image_before_step = image_after_step.clone()
607
+ out = self.inpaint_p_sample(
608
+ model,
609
+ image_after_step,
610
+ t_last_t,
611
+ x0,
612
+ mask,
613
+ clip_denoised=clip_denoised,
614
+ denoised_fn=denoised_fn,
615
+ cond_fn=cond_fn,
616
+ model_kwargs=model_kwargs,
617
+ pred_xstart=pred_xstart
618
+ )
619
+ image_after_step = out["sample"]
620
+ pred_xstart = out["pred_xstart"]
621
+
622
+ sample_idxs[t_cur] += 1
623
+
624
+ yield out
625
+
626
+ else:
627
+ t_shift = 1
628
+ image_before_step = image_after_step.clone()
629
+ image_after_step = self.undo(
630
+ image_before_step, image_after_step,
631
+ est_x_0=out['pred_xstart'], t=t_last_t+t_shift, debug=False)
632
+ pred_xstart = out["pred_xstart"]
633
+
634
+ def inpaint_p_sample(
635
+ self,
636
+ model,
637
+ x,
638
+ t,
639
+ x0,
640
+ mask,
641
+ clip_denoised=True,
642
+ denoised_fn=None,
643
+ cond_fn=None,
644
+ model_kwargs=None,
645
+ pred_xstart=None,
646
+ ):
647
+ """
648
+ Sample x_{t-1} from the model at the given timestep.
649
+ :param model: the model to sample from.
650
+ :param x: the current tensor at x_{t-1}.
651
+ :param t: the value of t, starting at 0 for the first diffusion step.
652
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
653
+ :param denoised_fn: if not None, a function which applies to the
654
+ x_start prediction before it is used to sample.
655
+ :param cond_fn: if not None, this is a gradient function that acts
656
+ similarly to the model.
657
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
658
+ pass to the model. This can be used for conditioning.
659
+ :return: a dict containing the following keys:
660
+ - 'sample': a random sample from the model.
661
+ - 'pred_xstart': a prediction of x_0.
662
+ """
663
+ noise = th.randn_like(x)
664
+
665
+ if pred_xstart is not None:
666
+ alpha_cumprod = _extract_into_tensor(
667
+ self.alphas_cumprod, t, x.shape)
668
+ weighed_gt = th.sqrt(alpha_cumprod) * x0 + th.sqrt((1 - alpha_cumprod)) * th.randn_like(x)
669
+
670
+ x = (1 - mask) * weighed_gt + mask * x
671
+
672
+ out = self.p_mean_variance(
673
+ model,
674
+ x,
675
+ t,
676
+ clip_denoised=clip_denoised,
677
+ denoised_fn=denoised_fn,
678
+ model_kwargs=model_kwargs,
679
+ )
680
+
681
+ nonzero_mask = (
682
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
683
+ )
684
+
685
+ if cond_fn is not None:
686
+ out["mean"] = self.condition_mean(
687
+ cond_fn, out, x, t, model_kwargs=model_kwargs
688
+ )
689
+
690
+ sample = out["mean"] + nonzero_mask * \
691
+ th.exp(0.5 * out["log_variance"]) * noise
692
+
693
+ result = {"sample": sample,
694
+ "pred_xstart": out["pred_xstart"], 'gt': model_kwargs.get('gt')}
695
+
696
+ return result
697
+
698
+ def undo(self, image_before_step, img_after_model, est_x_0, t, debug=False):
699
+ return self._undo(img_after_model, t)
700
+
701
+ def _undo(self, img_out, t):
702
+ beta = _extract_into_tensor(self.betas, t, img_out.shape)
703
+
704
+ img_in_est = th.sqrt(1 - beta) * img_out + \
705
+ th.sqrt(beta) * th.randn_like(img_out)
706
+
707
+ return img_in_est
708
+
709
+ def p_sample_loop_progressive(
710
+ self,
711
+ model,
712
+ shape,
713
+ noise=None,
714
+ clip_denoised=True,
715
+ denoised_fn=None,
716
+ cond_fn=None,
717
+ model_kwargs=None,
718
+ device=None,
719
+ progress=False,
720
+ ):
721
+ """
722
+ Generate samples from the model and yield intermediate samples from
723
+ each timestep of diffusion.
724
+ Arguments are the same as p_sample_loop().
725
+ Returns a generator over dicts, where each dict is the return value of
726
+ p_sample().
727
+ """
728
+ if device is None:
729
+ device = next(model.parameters()).device
730
+ assert isinstance(shape, (tuple, list))
731
+ if noise is not None:
732
+ img = noise
733
+ else:
734
+ img = th.randn(*shape, device=device)
735
+ indices = list(range(self.num_timesteps))[::-1]
736
+
737
+ if progress:
738
+ # Lazy import so that we don't depend on tqdm.
739
+ from tqdm.auto import tqdm
740
+
741
+ indices = tqdm(indices)
742
+
743
+ for i in indices:
744
+ t = th.tensor([i] * shape[0], device=device)
745
+ with th.no_grad():
746
+ out = self.p_sample(
747
+ model,
748
+ img,
749
+ t,
750
+ clip_denoised=clip_denoised,
751
+ denoised_fn=denoised_fn,
752
+ cond_fn=cond_fn,
753
+ model_kwargs=model_kwargs,
754
+ )
755
+ yield out
756
+ img = out["sample"]
757
+
758
+ def ddim_sample(
759
+ self,
760
+ model,
761
+ x,
762
+ t,
763
+ clip_denoised=True,
764
+ denoised_fn=None,
765
+ cond_fn=None,
766
+ model_kwargs=None,
767
+ eta=0.0,
768
+ ):
769
+ """
770
+ Sample x_{t-1} from the model using DDIM.
771
+ Same usage as p_sample().
772
+ """
773
+ out = self.p_mean_variance(
774
+ model,
775
+ x,
776
+ t,
777
+ clip_denoised=clip_denoised,
778
+ denoised_fn=denoised_fn,
779
+ model_kwargs=model_kwargs,
780
+ )
781
+ if cond_fn is not None:
782
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
783
+
784
+ # Usually our model outputs epsilon, but we re-derive it
785
+ # in case we used x_start or x_prev prediction.
786
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
787
+
788
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
789
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
790
+ sigma = (
791
+ eta
792
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
793
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
794
+ )
795
+ # Equation 12.
796
+ noise = th.randn_like(x)
797
+ mean_pred = (
798
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
799
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
800
+ )
801
+ nonzero_mask = (
802
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
803
+ ) # no noise when t == 0
804
+ sample = mean_pred + nonzero_mask * sigma * noise
805
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
806
+
807
+ def ddim_reverse_sample(
808
+ self,
809
+ model,
810
+ x,
811
+ t,
812
+ clip_denoised=True,
813
+ denoised_fn=None,
814
+ cond_fn=None,
815
+ model_kwargs=None,
816
+ eta=0.0,
817
+ ):
818
+ """
819
+ Sample x_{t+1} from the model using DDIM reverse ODE.
820
+ """
821
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
822
+ out = self.p_mean_variance(
823
+ model,
824
+ x,
825
+ t,
826
+ clip_denoised=clip_denoised,
827
+ denoised_fn=denoised_fn,
828
+ model_kwargs=model_kwargs,
829
+ )
830
+ if cond_fn is not None:
831
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
832
+ # Usually our model outputs epsilon, but we re-derive it
833
+ # in case we used x_start or x_prev prediction.
834
+ eps = (
835
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
836
+ - out["pred_xstart"]
837
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
838
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
839
+
840
+ # Equation 12. reversed
841
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
842
+
843
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
844
+
845
+ def ddim_sample_loop(
846
+ self,
847
+ model,
848
+ shape,
849
+ noise=None,
850
+ clip_denoised=True,
851
+ denoised_fn=None,
852
+ cond_fn=None,
853
+ model_kwargs=None,
854
+ device=None,
855
+ progress=False,
856
+ eta=0.0,
857
+ ):
858
+ """
859
+ Generate samples from the model using DDIM.
860
+ Same usage as p_sample_loop().
861
+ """
862
+ final = None
863
+ for sample in self.ddim_sample_loop_progressive(
864
+ model,
865
+ shape,
866
+ noise=noise,
867
+ clip_denoised=clip_denoised,
868
+ denoised_fn=denoised_fn,
869
+ cond_fn=cond_fn,
870
+ model_kwargs=model_kwargs,
871
+ device=device,
872
+ progress=progress,
873
+ eta=eta,
874
+ ):
875
+ final = sample
876
+ return final["sample"]
877
+
878
+ def ddim_sample_loop_progressive(
879
+ self,
880
+ model,
881
+ shape,
882
+ noise=None,
883
+ clip_denoised=True,
884
+ denoised_fn=None,
885
+ cond_fn=None,
886
+ model_kwargs=None,
887
+ device=None,
888
+ progress=False,
889
+ eta=0.0,
890
+ ):
891
+ """
892
+ Use DDIM to sample from the model and yield intermediate samples from
893
+ each timestep of DDIM.
894
+ Same usage as p_sample_loop_progressive().
895
+ """
896
+ if device is None:
897
+ device = next(model.parameters()).device
898
+ assert isinstance(shape, (tuple, list))
899
+ if noise is not None:
900
+ img = noise
901
+ else:
902
+ img = th.randn(*shape, device=device)
903
+ indices = list(range(self.num_timesteps))[::-1]
904
+
905
+ if progress:
906
+ # Lazy import so that we don't depend on tqdm.
907
+ from tqdm.auto import tqdm
908
+
909
+ indices = tqdm(indices)
910
+
911
+ for i in indices:
912
+ t = th.tensor([i] * shape[0], device=device)
913
+ with th.no_grad():
914
+ out = self.ddim_sample(
915
+ model,
916
+ img,
917
+ t,
918
+ clip_denoised=clip_denoised,
919
+ denoised_fn=denoised_fn,
920
+ cond_fn=cond_fn,
921
+ model_kwargs=model_kwargs,
922
+ eta=eta,
923
+ )
924
+ yield out
925
+ img = out["sample"]
926
+
927
+ def _vb_terms_bpd(
928
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
929
+ ):
930
+ """
931
+ Get a term for the variational lower-bound.
932
+ The resulting units are bits (rather than nats, as one might expect).
933
+ This allows for comparison to other papers.
934
+ :return: a dict with the following keys:
935
+ - 'output': a shape [N] tensor of NLLs or KLs.
936
+ - 'pred_xstart': the x_0 predictions.
937
+ """
938
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
939
+ x_start=x_start, x_t=x_t, t=t
940
+ )
941
+ out = self.p_mean_variance(
942
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
943
+ )
944
+ kl = normal_kl(
945
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
946
+ )
947
+ kl = mean_flat(kl) / np.log(2.0)
948
+
949
+ decoder_nll = -discretized_gaussian_log_likelihood(
950
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
951
+ )
952
+ assert decoder_nll.shape == x_start.shape
953
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
954
+
955
+ # At the first timestep return the decoder NLL,
956
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
957
+ output = th.where((t == 0), decoder_nll, kl)
958
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
959
+
960
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
961
+ """
962
+ Compute training losses for a single timestep.
963
+ :param model: the model to evaluate loss on.
964
+ :param x_start: the [N x C x ...] tensor of inputs.
965
+ :param t: a batch of timestep indices.
966
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
967
+ pass to the model. This can be used for conditioning.
968
+ :param noise: if specified, the specific Gaussian noise to try to remove.
969
+ :return: a dict with the key "loss" containing a tensor of shape [N].
970
+ Some mean or variance settings may also have other keys.
971
+ """
972
+ if model_kwargs is None:
973
+ model_kwargs = {}
974
+ if noise is None:
975
+ noise = th.randn_like(x_start)
976
+ x_t = self.q_sample(x_start, t, noise=noise)
977
+
978
+ terms = {}
979
+
980
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
981
+ terms["loss"] = self._vb_terms_bpd(
982
+ model=model,
983
+ x_start=x_start,
984
+ x_t=x_t,
985
+ t=t,
986
+ clip_denoised=False,
987
+ model_kwargs=model_kwargs,
988
+ )["output"]
989
+ if self.loss_type == LossType.RESCALED_KL:
990
+ terms["loss"] *= self.num_timesteps
991
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
992
+ model_output = model(x_t, t, **model_kwargs)
993
+
994
+ if self.model_var_type in [
995
+ ModelVarType.LEARNED,
996
+ ModelVarType.LEARNED_RANGE,
997
+ ]:
998
+ B, C = x_t.shape[:2]
999
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
1000
+ model_output, model_var_values = th.split(model_output, C, dim=1)
1001
+ # Learn the variance using the variational bound, but don't let
1002
+ # it affect our mean prediction.
1003
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
1004
+ terms["vb"] = self._vb_terms_bpd(
1005
+ model=lambda *args, r=frozen_out: r,
1006
+ x_start=x_start,
1007
+ x_t=x_t,
1008
+ t=t,
1009
+ clip_denoised=False,
1010
+ )["output"]
1011
+ if self.loss_type == LossType.RESCALED_MSE:
1012
+ # Divide by 1000 for equivalence with initial implementation.
1013
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
1014
+ terms["vb"] *= self.num_timesteps / 1000.0
1015
+
1016
+ target = {
1017
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1018
+ x_start=x_start, x_t=x_t, t=t
1019
+ )[0],
1020
+ ModelMeanType.START_X: x_start,
1021
+ ModelMeanType.EPSILON: noise,
1022
+ }[self.model_mean_type]
1023
+ assert model_output.shape == target.shape == x_start.shape
1024
+ terms["mse"] = mean_flat((target - model_output) ** 2)
1025
+ if "vb" in terms:
1026
+ terms["loss"] = terms["mse"] + terms["vb"]
1027
+ else:
1028
+ terms["loss"] = terms["mse"]
1029
+ else:
1030
+ raise NotImplementedError(self.loss_type)
1031
+
1032
+ return terms
1033
+
1034
+ def _prior_bpd(self, x_start):
1035
+ """
1036
+ Get the prior KL term for the variational lower-bound, measured in
1037
+ bits-per-dim.
1038
+ This term can't be optimized, as it only depends on the encoder.
1039
+ :param x_start: the [N x C x ...] tensor of inputs.
1040
+ :return: a batch of [N] KL values (in bits), one per batch element.
1041
+ """
1042
+ batch_size = x_start.shape[0]
1043
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1044
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1045
+ kl_prior = normal_kl(
1046
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1047
+ )
1048
+ return mean_flat(kl_prior) / np.log(2.0)
1049
+
1050
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1051
+ """
1052
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1053
+ as well as other related quantities.
1054
+ :param model: the model to evaluate loss on.
1055
+ :param x_start: the [N x C x ...] tensor of inputs.
1056
+ :param clip_denoised: if True, clip denoised samples.
1057
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1058
+ pass to the model. This can be used for conditioning.
1059
+ :return: a dict containing the following keys:
1060
+ - total_bpd: the total variational lower-bound, per batch element.
1061
+ - prior_bpd: the prior term in the lower-bound.
1062
+ - vb: an [N x T] tensor of terms in the lower-bound.
1063
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1064
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1065
+ """
1066
+ device = x_start.device
1067
+ batch_size = x_start.shape[0]
1068
+
1069
+ vb = []
1070
+ xstart_mse = []
1071
+ mse = []
1072
+ for t in list(range(self.num_timesteps))[::-1]:
1073
+ t_batch = th.tensor([t] * batch_size, device=device)
1074
+ noise = th.randn_like(x_start)
1075
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1076
+ # Calculate VLB term at the current timestep
1077
+ with th.no_grad():
1078
+ out = self._vb_terms_bpd(
1079
+ model,
1080
+ x_start=x_start,
1081
+ x_t=x_t,
1082
+ t=t_batch,
1083
+ clip_denoised=clip_denoised,
1084
+ model_kwargs=model_kwargs,
1085
+ )
1086
+ vb.append(out["output"])
1087
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1088
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1089
+ mse.append(mean_flat((eps - noise) ** 2))
1090
+
1091
+ vb = th.stack(vb, dim=1)
1092
+ xstart_mse = th.stack(xstart_mse, dim=1)
1093
+ mse = th.stack(mse, dim=1)
1094
+
1095
+ prior_bpd = self._prior_bpd(x_start)
1096
+ total_bpd = vb.sum(dim=1) + prior_bpd
1097
+ return {
1098
+ "total_bpd": total_bpd,
1099
+ "prior_bpd": prior_bpd,
1100
+ "vb": vb,
1101
+ "xstart_mse": xstart_mse,
1102
+ "mse": mse,
1103
+ }
1104
+
1105
+
1106
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1107
+ """
1108
+ Extract values from a 1-D numpy array for a batch of indices.
1109
+ :param arr: the 1-D numpy array.
1110
+ :param timesteps: a tensor of indices into the array to extract.
1111
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1112
+ dimension equal to the length of timesteps.
1113
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1114
+ """
1115
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1116
+ while len(res.shape) < len(broadcast_shape):
1117
+ res = res[..., None]
1118
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
diffusion/scheduler.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Huawei Technologies Co., Ltd.
2
+ # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7
+ #
8
+ # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
16
+
17
+ def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0):
18
+ if n_steplength > 1:
19
+ if not n_sample > 1:
20
+ raise RuntimeError('n_steplength has no effect if n_sample=1')
21
+
22
+ t = t_T
23
+ times = [t]
24
+ while t >= 0:
25
+ t = t - 1
26
+ times.append(t)
27
+ n_steplength_cur = min(n_steplength, t_T - t)
28
+
29
+ for _ in range(n_sample - 1):
30
+
31
+ for _ in range(n_steplength_cur):
32
+ t = t + 1
33
+ times.append(t)
34
+ for _ in range(n_steplength_cur):
35
+ t = t - 1
36
+ times.append(t)
37
+
38
+ _check_times(times, t_0, t_T)
39
+
40
+ if debug == 2:
41
+ for x in [list(range(0, 50)), list(range(-1, -50, -1))]:
42
+ _plot_times(x=x, times=[times[i] for i in x])
43
+
44
+ return times
45
+
46
+
47
+ def _check_times(times, t_0, t_T):
48
+ # Check end
49
+ assert times[0] > times[1], (times[0], times[1])
50
+
51
+ # Check beginning
52
+ assert times[-1] == -1, times[-1]
53
+
54
+ # Steplength = 1
55
+ for t_last, t_cur in zip(times[:-1], times[1:]):
56
+ assert abs(t_last - t_cur) == 1, (t_last, t_cur)
57
+
58
+ # Value range
59
+ for t in times:
60
+ assert t >= t_0, (t, t_0)
61
+ assert t <= t_T, (t, t_T)
62
+
63
+
64
+ def _plot_times(x, times):
65
+ import matplotlib.pyplot as plt
66
+ plt.plot(x, times)
67
+ plt.show()
68
+
69
+
70
+ def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample,
71
+ jump2_length=1, jump2_n_sample=1,
72
+ jump3_length=1, jump3_n_sample=1,
73
+ start_resampling=100000000):
74
+
75
+ jumps = {}
76
+ for j in range(0, t_T - jump_length, jump_length):
77
+ jumps[j] = jump_n_sample - 1
78
+
79
+ jumps2 = {}
80
+ for j in range(0, t_T - jump2_length, jump2_length):
81
+ jumps2[j] = jump2_n_sample - 1
82
+
83
+ jumps3 = {}
84
+ for j in range(0, t_T - jump3_length, jump3_length):
85
+ jumps3[j] = jump3_n_sample - 1
86
+
87
+ t = t_T
88
+ ts = []
89
+
90
+ while t >= 1:
91
+ t = t-1
92
+ ts.append(t)
93
+
94
+ if (
95
+ t + 1 < t_T - 1 and
96
+ t <= start_resampling
97
+ ):
98
+ for _ in range(n_sample - 1):
99
+ t = t + 1
100
+ ts.append(t)
101
+
102
+ if t >= 0:
103
+ t = t - 1
104
+ ts.append(t)
105
+
106
+ if (
107
+ jumps3.get(t, 0) > 0 and
108
+ t <= start_resampling - jump3_length
109
+ ):
110
+ jumps3[t] = jumps3[t] - 1
111
+ for _ in range(jump3_length):
112
+ t = t + 1
113
+ ts.append(t)
114
+
115
+ if (
116
+ jumps2.get(t, 0) > 0 and
117
+ t <= start_resampling - jump2_length
118
+ ):
119
+ jumps2[t] = jumps2[t] - 1
120
+ for _ in range(jump2_length):
121
+ t = t + 1
122
+ ts.append(t)
123
+ jumps3 = {}
124
+ for j in range(0, t_T - jump3_length, jump3_length):
125
+ jumps3[j] = jump3_n_sample - 1
126
+
127
+ if (
128
+ jumps.get(t, 0) > 0 and
129
+ t <= start_resampling - jump_length
130
+ ):
131
+ jumps[t] = jumps[t] - 1
132
+ for _ in range(jump_length):
133
+ t = t + 1
134
+ ts.append(t)
135
+ jumps2 = {}
136
+ for j in range(0, t_T - jump2_length, jump2_length):
137
+ jumps2[j] = jump2_n_sample - 1
138
+
139
+ jumps3 = {}
140
+ for j in range(0, t_T - jump3_length, jump3_length):
141
+ jumps3[j] = jump3_n_sample - 1
142
+
143
+ ts.append(-1)
144
+
145
+ _check_times(ts, -1, t_T)
146
+
147
+ return ts
148
+
149
+
150
+ def get_schedule_jump_paper():
151
+ t_T = 250
152
+ jump_length = 10
153
+ jump_n_sample = 10
154
+
155
+ jumps = {}
156
+ for j in range(0, t_T - jump_length, jump_length):
157
+ jumps[j] = jump_n_sample - 1
158
+
159
+ t = t_T
160
+ ts = []
161
+
162
+ while t >= 1:
163
+ t = t-1
164
+ ts.append(t)
165
+
166
+ if jumps.get(t, 0) > 0:
167
+ jumps[t] = jumps[t] - 1
168
+ for _ in range(jump_length):
169
+ t = t + 1
170
+ ts.append(t)
171
+
172
+ ts.append(-1)
173
+
174
+ _check_times(ts, -1, t_T)
175
+
176
+ return ts
177
+
178
+
179
+ def get_schedule_jump_test(to_supplement=False):
180
+ ts = get_schedule_jump(t_T=250, n_sample=1,
181
+ jump_length=10, jump_n_sample=10,
182
+ jump2_length=1, jump2_n_sample=1,
183
+ jump3_length=1, jump3_n_sample=1,
184
+ start_resampling=250)
185
+
186
+ import matplotlib.pyplot as plt
187
+ SMALL_SIZE = 8*3
188
+ MEDIUM_SIZE = 10*3
189
+ BIGGER_SIZE = 12*3
190
+
191
+ plt.rc('font', size=SMALL_SIZE) # controls default text sizes
192
+ plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
193
+ plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
194
+ plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
195
+ plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
196
+ plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
197
+ plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
198
+
199
+ plt.plot(ts)
200
+
201
+ fig = plt.gcf()
202
+ fig.set_size_inches(20, 10)
203
+
204
+ ax = plt.gca()
205
+ ax.set_xlabel('Number of Transitions')
206
+ ax.set_ylabel('Diffusion time $t$')
207
+
208
+ fig.tight_layout()
209
+
210
+ if to_supplement:
211
+ out_path = "/cluster/home/alugmayr/gdiff/paper/supplement/figures/jump_sched.pdf"
212
+ plt.savefig(out_path)
213
+
214
+ out_path = "./schedule.png"
215
+ plt.savefig(out_path)
216
+ print(out_path)
217
+
218
+
219
+ def main():
220
+ get_schedule_jump_test()
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.34.2
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.3
5
+ aiohttp==3.10.10
6
+ aiosignal==1.3.1
7
+ albumentations==0.5.2
8
+ annotated-types==0.7.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==4.4.0
11
+ astunparse==1.6.3
12
+ async-timeout==4.0.3
13
+ attrs==23.2.0
14
+ beautifulsoup4==4.12.3
15
+ bitsandbytes==0.44.1
16
+ boto==2.49.0
17
+ boto3==1.28.57
18
+ botocore==1.34.131
19
+ cachetools==5.5.0
20
+ certifi==2022.12.7
21
+ cffi==1.16.0
22
+ chardet==5.2.0
23
+ charset-normalizer==2.1.1
24
+ click==8.1.7
25
+ click-default-group==1.2.4
26
+ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
27
+ cmake==3.30.3
28
+ colorlog==6.8.2
29
+ commonmark==0.9.1
30
+ contourpy==1.1.1
31
+ cycler==0.12.1
32
+ decord==0.6.0
33
+ deepspeed==0.15.1
34
+ diffusers==0.25.0
35
+ docker-pycreds==0.4.0
36
+ ego4d==1.3.2
37
+ einops==0.8.0
38
+ embreex==2.17.7.post5
39
+ envlight @ git+https://github.com/ashawkey/envlight.git@05b5851e854429d72ecaf5b206ed64ce55fae677
40
+ exceptiongroup==1.2.2
41
+ fastapi==0.112.0
42
+ ffmpy==0.4.0
43
+ filelock==3.13.1
44
+ flatbuffers==24.3.25
45
+ fonttools==4.53.1
46
+ frozenlist==1.4.1
47
+ fsspec==2024.2.0
48
+ ftfy==6.2.3
49
+ gast==0.4.0
50
+ gdown==5.2.0
51
+ gevent==23.9.1
52
+ gevent-websocket==0.10.1
53
+ gitdb==4.0.11
54
+ GitPython==3.1.43
55
+ google-auth==2.35.0
56
+ google-auth-oauthlib==1.0.0
57
+ google-pasta==0.2.0
58
+ gradio==4.40.0
59
+ gradio_client==1.2.0
60
+ greenlet==2.0.2
61
+ grpcio==1.66.1
62
+ h11==0.14.0
63
+ h5py==3.11.0
64
+ hjson==3.1.0
65
+ httpcore==1.0.5
66
+ httpx==0.27.0
67
+ huggingface-hub==0.24.5
68
+ idna==3.4
69
+ imageio==2.34.2
70
+ imageio-ffmpeg==0.5.1
71
+ imgaug==0.4.0
72
+ importlib_metadata==8.2.0
73
+ importlib_resources==6.4.0
74
+ jax==0.4.13
75
+ jaxlib==0.4.13
76
+ jaxtyping==0.2.19
77
+ Jinja2==3.1.3
78
+ jmespath==1.0.1
79
+ jsonschema==4.23.0
80
+ jsonschema-specifications==2023.12.1
81
+ keras==2.13.1
82
+ kiwisolver==1.4.5
83
+ kornia==0.7.3
84
+ kornia_rs==0.1.5
85
+ lazy_loader==0.4
86
+ libclang==18.1.1
87
+ libigl==2.5.1
88
+ lightning-utilities==0.11.8
89
+ lit==18.1.8
90
+ lxml==5.3.0
91
+ manifold3d==2.5.1
92
+ Markdown==3.7
93
+ markdown-it-py==3.0.0
94
+ MarkupSafe==2.1.5
95
+ matplotlib==3.7.5
96
+ mdurl==0.1.2
97
+ mediapipe==0.10.11
98
+ ml-dtypes==0.2.0
99
+ mpmath==1.3.0
100
+ multidict==6.1.0
101
+ mypy-extensions==1.0.0
102
+ nerfacc @ git+https://github.com/KAIR-BAIR/nerfacc.git@d84cdf3afd7dcfc42150e0f0506db58a5ce62812
103
+ networkx==3.0
104
+ ninja==1.11.1.1
105
+ numpy==1.24.1
106
+ nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@729261dc64c4241ea36efda84fbf532cc8b425b8
107
+ nvidia-cublas-cu11==11.10.3.66
108
+ nvidia-cuda-cupti-cu11==11.7.101
109
+ nvidia-cuda-nvrtc-cu11==11.7.99
110
+ nvidia-cuda-runtime-cu11==11.7.99
111
+ nvidia-cudnn-cu11==8.5.0.96
112
+ nvidia-cufft-cu11==10.9.0.58
113
+ nvidia-curand-cu11==10.2.10.91
114
+ nvidia-cusolver-cu11==11.4.0.1
115
+ nvidia-cusparse-cu11==11.7.4.91
116
+ nvidia-nccl-cu11==2.14.3
117
+ nvidia-nvtx-cu11==11.7.91
118
+ oauthlib==3.2.2
119
+ omegaconf==2.3.0
120
+ open-clip-torch==2.7.0
121
+ opencv-contrib-python==4.10.0.84
122
+ opencv-python==4.10.0.84
123
+ opencv-python-headless==4.10.0.84
124
+ opt-einsum==3.3.0
125
+ orjson==3.10.6
126
+ packaging==24.1
127
+ pandas==2.0.3
128
+ pillow==10.2.0
129
+ pkgutil_resolve_name==1.3.10
130
+ platformdirs==4.3.6
131
+ prometheus-client==0.13.1
132
+ propcache==0.2.0
133
+ protobuf==3.20.3
134
+ psutil==6.0.0
135
+ py-cpuinfo==9.0.0
136
+ pyasn1==0.6.1
137
+ pyasn1_modules==0.4.1
138
+ pycollada==0.8
139
+ pycparser==2.22
140
+ pydantic==2.8.2
141
+ pydantic_core==2.20.1
142
+ pydub==0.25.1
143
+ Pygments==2.18.0
144
+ pyparsing==3.1.2
145
+ pyre-extensions==0.0.29
146
+ pysdf==0.1.9
147
+ PySocks==1.7.1
148
+ python-dateutil==2.9.0.post0
149
+ python-multipart==0.0.9
150
+ pytorch-lightning==2.1.0
151
+ pytz==2024.1
152
+ PyWavelets==1.4.1
153
+ PyYAML==6.0.1
154
+ referencing==0.35.1
155
+ regex==2024.7.24
156
+ requests==2.32.3
157
+ requests-oauthlib==2.0.0
158
+ rich==13.7.1
159
+ rich-click==1.6.1
160
+ rpds-py==0.20.0
161
+ rsa==4.9
162
+ Rtree==1.3.0
163
+ ruff==0.5.6
164
+ s3transfer==0.7.0
165
+ safetensors==0.4.3
166
+ scikit-image==0.21.0
167
+ scipy==1.10.1
168
+ segment-anything==1.0
169
+ semantic-version==2.10.0
170
+ sentencepiece==0.1.99
171
+ sentry-sdk==2.17.0
172
+ setproctitle==1.3.3
173
+ sh==1.14.3
174
+ shapely==2.0.6
175
+ shellingham==1.5.4
176
+ six==1.16.0
177
+ smmap==5.0.1
178
+ sniffio==1.3.1
179
+ sounddevice==0.4.7
180
+ soupsieve==2.6
181
+ starlette==0.37.2
182
+ svg.path==6.3
183
+ sympy==1.12
184
+ taming-transformers-rom1504==0.0.6
185
+ tensorboard==2.13.0
186
+ tensorboard-data-server==0.7.2
187
+ tensorflow==2.13.1
188
+ tensorflow-estimator==2.13.0
189
+ tensorflow-io-gcs-filesystem==0.34.0
190
+ termcolor==2.4.0
191
+ tifffile==2023.7.10
192
+ timm==0.9.12
193
+ tinycudann @ git+https://github.com/NVlabs/tiny-cuda-nn/@c91138bcd4c6877c8d5e60e483c0581aafc70cce#subdirectory=bindings/torch
194
+ tokenizers==0.20.0
195
+ tomlkit==0.12.0
196
+ torch==2.0.1+cu118
197
+ torchaudio==2.0.2+cu118
198
+ torchmetrics==1.5.0
199
+ torchvision==0.15.2+cu118
200
+ tqdm==4.66.4
201
+ transformers==4.45.1
202
+ trimesh==4.5.0
203
+ triton==2.0.0
204
+ typeguard==4.3.0
205
+ typer==0.12.3
206
+ typing-inspect==0.9.0
207
+ typing_extensions==4.12.2
208
+ tzdata==2024.1
209
+ urllib3==2.2.3
210
+ uvicorn==0.30.5
211
+ wandb==0.18.5
212
+ wcwidth==0.2.13
213
+ websockets==10.4
214
+ Werkzeug==3.0.4
215
+ wrapt==1.16.0
216
+ xatlas==0.0.9
217
+ xformers==0.0.20
218
+ xmltodict==0.12.0
219
+ xxhash==3.5.0
220
+ yarl==1.15.2
221
+ zipp==3.19.2
222
+ zope.event==5.0
223
+ zope.interface==6.0
segment_hoi.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
4
+
5
+
6
+ def show_mask(mask, ax, random_color=False):
7
+ if random_color:
8
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
9
+ else:
10
+ color = np.array([30/255, 144/255, 255/255, 0.6])
11
+ h, w = mask.shape[-2:]
12
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
13
+ ax.imshow(mask_image)
14
+
15
+
16
+ def show_points(coords, labels, ax, marker_size=375):
17
+ pos_points = coords[labels==1]
18
+ neg_points = coords[labels==0]
19
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
20
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
21
+
22
+
23
+ def show_box(box, ax):
24
+ x0, y0 = box[0], box[1]
25
+ w, h = box[2] - box[0], box[3] - box[1]
26
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
27
+
28
+
29
+ def merge_bounding_boxes(bbox1, bbox2):
30
+ xmin1, ymin1, xmax1, ymax1 = bbox1
31
+ xmin2, ymin2, xmax2, ymax2 = bbox2
32
+
33
+ xmin_merged = min(xmin1, xmin2)
34
+ ymin_merged = min(ymin1, ymin2)
35
+ xmax_merged = max(xmax1, xmax2)
36
+ ymax_merged = max(ymax1, ymax2)
37
+
38
+ return np.array([xmin_merged, ymin_merged, xmax_merged, ymax_merged])
39
+
40
+
41
+ def init_sam(
42
+ device="cuda",
43
+ ckpt_path='/users/kchen157/scratch/weights/SAM/sam_vit_h_4b8939.pth'
44
+ ):
45
+ sam = sam_model_registry['vit_h'](checkpoint=ckpt_path)
46
+ sam.to(device=device)
47
+ predictor = SamPredictor(sam)
48
+ return predictor
49
+
50
+
51
+ def segment_hand_and_object(
52
+ predictor,
53
+ image,
54
+ hand_kpts,
55
+ hand_mask=None,
56
+ box_shift_ratio = 0.3,
57
+ box_size_factor = 2.,
58
+ area_threshold = 0.2,
59
+ overlap_threshold = 200):
60
+ # Find bounding box for HOI
61
+ input_box = {}
62
+ for hand_type in ['right', 'left']:
63
+ if hand_type not in hand_kpts:
64
+ continue
65
+ input_box[hand_type] = np.stack([hand_kpts[hand_type].min(axis=0), hand_kpts[hand_type].max(axis=0)])
66
+ box_trans = input_box[hand_type][0] * box_shift_ratio + input_box[hand_type][1] * (1 - box_shift_ratio)
67
+ input_box[hand_type] = ((input_box[hand_type] - box_trans) * box_size_factor + box_trans).reshape(-1)
68
+
69
+ if len(input_box) == 2:
70
+ input_box = merge_bounding_boxes(input_box['right'], input_box['left'])
71
+ input_point = np.array([hand_kpts['right'][0], hand_kpts['left'][0]])
72
+ input_label = np.array([1, 1])
73
+ elif 'right' in input_box:
74
+ input_box = input_box['right']
75
+ input_point = np.array([hand_kpts['right'][0]])
76
+ input_label = np.array([1])
77
+ elif 'left' in input_box:
78
+ input_box = input_box['left']
79
+ input_point = np.array([hand_kpts['left'][0]])
80
+ input_label = np.array([1])
81
+
82
+ box_area = (input_box[2] - input_box[0]) * (input_box[3] - input_box[1])
83
+
84
+ # segment hand using the wrist point
85
+ predictor.set_image(image)
86
+ if hand_mask is None:
87
+ masks, scores, logits = predictor.predict(
88
+ point_coords=input_point,
89
+ point_labels=input_label,
90
+ multimask_output=False,
91
+ )
92
+ hand_mask = masks[0]
93
+
94
+ # segment object in hand
95
+ input_label = np.zeros_like(input_label)
96
+ masks, scores, _ = predictor.predict(
97
+ point_coords=input_point,
98
+ point_labels=input_label,
99
+ box=input_box[None, :],
100
+ multimask_output=False,
101
+ )
102
+ object_mask = masks[0]
103
+
104
+ if (masks[0].astype(int) * hand_mask).sum() > overlap_threshold:
105
+ # print('False positive: The mask overlaps the hand.')
106
+ object_mask = np.zeros_like(object_mask)
107
+ elif object_mask.astype(int).sum() / box_area > area_threshold:
108
+ # print('False positive: The area is very big, probably the background')
109
+ object_mask = np.zeros_like(object_mask)
110
+
111
+ return object_mask, hand_mask
utils.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from skimage.transform import resize
6
+ import matplotlib.pyplot as plt
7
+ from mpl_toolkits.mplot3d import Axes3D
8
+
9
+
10
+ def draw_hand3d(keypoints):
11
+ # Define the connections between keypoints as tuples (start, end)
12
+ bones = [
13
+ ((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
14
+ ((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
15
+ ((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'blueviolet'),
16
+ ((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'crimson'), ((15, 16), 'cornsilk'),
17
+ ((0, 17), 'aqua'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
18
+ ]
19
+
20
+ fig = plt.figure()
21
+ ax = fig.add_subplot(111, projection='3d')
22
+
23
+ # Plot the bones
24
+ for bone, color in bones:
25
+ start_point = keypoints[bone[0], :]
26
+ end_point = keypoints[bone[1], :]
27
+
28
+ ax.plot([start_point[0], end_point[0]],
29
+ [start_point[1], end_point[1]],
30
+ [start_point[2], end_point[2]], color=color)
31
+
32
+ ax.scatter(keypoints[:, 0], keypoints[:, 1], keypoints[:, 2], color='gray', s=15)
33
+
34
+ # Set the aspect ratio to be equal
35
+ max_range = np.array([keypoints[:,0].max()-keypoints[:,0].min(),
36
+ keypoints[:,1].max()-keypoints[:,1].min(),
37
+ keypoints[:,2].max()-keypoints[:,2].min()]).max() / 2.0
38
+
39
+ mid_x = (keypoints[:,0].max()+keypoints[:,0].min()) * 0.5
40
+ mid_y = (keypoints[:,1].max()+keypoints[:,1].min()) * 0.5
41
+ mid_z = (keypoints[:,2].max()+keypoints[:,2].min()) * 0.5
42
+
43
+ ax.set_xlim(mid_x - max_range, mid_x + max_range)
44
+ ax.set_ylim(mid_y - max_range, mid_y + max_range)
45
+ ax.set_zlim(mid_z - max_range, mid_z + max_range)
46
+
47
+ # Set labels for axes
48
+ ax.set_xlabel('X')
49
+ ax.set_ylabel('Y')
50
+ ax.set_zlabel('Z')
51
+
52
+ plt.show()
53
+
54
+
55
+ def visualize_hand(joints, img):
56
+ # Define the connections between joints for drawing lines and their corresponding colors
57
+ connections = [
58
+ ((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
59
+ ((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
60
+ ((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'indigo'),
61
+ ((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'navy'), ((15, 16), 'gray'),
62
+ ((0, 17), 'lavender'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
63
+ ]
64
+ H, W, C = img.shape
65
+
66
+ # Create a figure and axis
67
+ plt.figure()
68
+ ax = plt.gca()
69
+ # Plot joints as points
70
+ ax.imshow(img)
71
+ ax.scatter(joints[:, 0], joints[:, 1], color='white', s=15)
72
+ # Plot lines connecting joints with different colors for each bone
73
+ for connection, color in connections:
74
+ joint1 = joints[connection[0]]
75
+ joint2 = joints[connection[1]]
76
+ ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
77
+
78
+ ax.set_xlim([0, W])
79
+ ax.set_ylim([0, H])
80
+ ax.grid(False)
81
+ ax.set_axis_off()
82
+ ax.invert_yaxis()
83
+ plt.subplots_adjust(wspace=0.01)
84
+ plt.show()
85
+
86
+
87
+ def draw_hand_skeleton(joints, image_size, thickness=5):
88
+ # Create a blank white image
89
+ image = np.zeros((image_size[0], image_size[1]), dtype=np.uint8)
90
+
91
+ # Define the connections between joints
92
+ connections = [
93
+ (0, 1),
94
+ (1, 2),
95
+ (2, 3),
96
+ (3, 4),
97
+ (0, 5),
98
+ (5, 6),
99
+ (6, 7),
100
+ (7, 8),
101
+ (0, 9),
102
+ (9, 10),
103
+ (10, 11),
104
+ (11, 12),
105
+ (0, 13),
106
+ (13, 14),
107
+ (14, 15),
108
+ (15, 16),
109
+ (0, 17),
110
+ (17, 18),
111
+ (18, 19),
112
+ (19, 20),
113
+ ]
114
+
115
+ # Draw lines connecting joints
116
+ for connection in connections:
117
+ joint1 = joints[connection[0]].astype("int")
118
+ joint2 = joints[connection[1]].astype("int")
119
+ cv2.line(image, tuple(joint1), tuple(joint2), color=1, thickness=thickness)
120
+
121
+ return image
122
+
123
+
124
+ def draw_hand(joints, img):
125
+ # Define the connections between joints for drawing lines and their corresponding colors
126
+ connections = [
127
+ ((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
128
+ ((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
129
+ ((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'indigo'),
130
+ ((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'navy'), ((15, 16), 'gray'),
131
+ ((0, 17), 'lavender'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
132
+ ]
133
+ H, W, C = img.shape
134
+
135
+ # Create a figure and axis with the same size as the input image
136
+ fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100)
137
+ # Plot joints as points
138
+ ax.imshow(img)
139
+ ax.scatter(joints[:, 0], joints[:, 1], color='white', s=15)
140
+ # Plot lines connecting joints with different colors for each bone
141
+ for connection, color in connections:
142
+ joint1 = joints[connection[0]]
143
+ joint2 = joints[connection[1]]
144
+ ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
145
+
146
+ ax.set_xlim([0, W])
147
+ ax.set_ylim([0, H])
148
+ ax.grid(False)
149
+ ax.set_axis_off()
150
+ ax.invert_yaxis()
151
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.01, hspace=0.01)
152
+
153
+ # Save the plot to a buffer
154
+ buf = io.BytesIO()
155
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
156
+ plt.close(fig) # Close the figure to free memory
157
+
158
+ # Load the image from the buffer into a PIL image and then into a numpy array
159
+ buf.seek(0)
160
+ img_arr = np.array(Image.open(buf))
161
+
162
+ return img_arr[..., :3]
163
+
164
+
165
+ def keypoint_heatmap(pts, size, var=1.0):
166
+ H, W = size
167
+ x = np.linspace(0, W - 1, W)
168
+ y = np.linspace(0, H - 1, H)
169
+ xv, yv = np.meshgrid(x, y)
170
+ grid = np.stack((xv, yv), axis=-1)
171
+
172
+ # Expanding dims for broadcasting subtraction between pts and every grid position
173
+ modes_exp = np.expand_dims(np.expand_dims(pts, axis=1), axis=1)
174
+
175
+ # Calculating squared difference
176
+ diff = grid - modes_exp
177
+ normal = np.exp(-np.sum(diff**2, axis=-1) / (2 * var)) / (
178
+ 2.0 * np.pi * var
179
+ )
180
+ return normal
181
+
182
+
183
+ def check_keypoints_validity(keypoints, image_size):
184
+ H, W = image_size
185
+ # Check if x coordinates are valid: 0 < x < W
186
+ valid_x = (keypoints[:, 0] > 0) & (keypoints[:, 0] < W)
187
+
188
+ # Check if y coordinates are valid: 0 < y < H
189
+ valid_y = (keypoints[:, 1] > 0) & (keypoints[:, 1] < H)
190
+
191
+ # Combine the validity checks for both x and y
192
+ valid_keypoints = valid_x & valid_y
193
+
194
+ # Convert boolean array to integer (1 for True, 0 for False)
195
+ return valid_keypoints.astype(int)
196
+
197
+
198
+ def find_bounding_box(mask, margin=30):
199
+ """Find the bounding box of a binary mask. Return None if the mask is empty."""
200
+ rows = np.any(mask, axis=1)
201
+ cols = np.any(mask, axis=0)
202
+ if not rows.any() or not cols.any(): # Mask is empty
203
+ return None
204
+ ymin, ymax = np.where(rows)[0][[0, -1]]
205
+ xmin, xmax = np.where(cols)[0][[0, -1]]
206
+ xmin -= margin
207
+ xmax += margin
208
+ ymin -= margin
209
+ ymax += margin
210
+ return xmin, ymin, xmax, ymax
211
+
212
+
213
+ def adjust_box_to_image(xmin, ymin, xmax, ymax, image_width, image_height):
214
+ """Adjust the bounding box to fit within the image boundaries."""
215
+ box_width = xmax - xmin
216
+ box_height = ymax - ymin
217
+ # Determine the side length of the square (the larger of the two dimensions)
218
+ side_length = max(box_width, box_height)
219
+
220
+ # Adjust to maintain a square by expanding or contracting sides
221
+ xmin = max(0, xmin - (side_length - box_width) // 2)
222
+ xmax = xmin + side_length
223
+ ymin = max(0, ymin - (side_length - box_height) // 2)
224
+ ymax = ymin + side_length
225
+
226
+ # Ensure the box is still within the image boundaries after adjustments
227
+ if xmax > image_width:
228
+ shift = xmax - image_width
229
+ xmin -= shift
230
+ xmax -= shift
231
+ if ymax > image_height:
232
+ shift = ymax - image_height
233
+ ymin -= shift
234
+ ymax -= shift
235
+
236
+ # After shifting, double-check if any side is out-of-bounds and adjust if necessary
237
+ xmin = max(0, xmin)
238
+ ymin = max(0, ymin)
239
+ xmax = min(image_width, xmax)
240
+ ymax = min(image_height, ymax)
241
+
242
+ # It's possible the adjustments made the box not square (due to boundary constraints),
243
+ # so we might need to slightly adjust the size to keep it as square as possible
244
+ # This could involve a final adjustment based on the specific requirements,
245
+ # like reducing the side length to fit or deciding which dimension to prioritize.
246
+
247
+ return xmin, ymin, xmax, ymax
248
+
249
+
250
+ def scale_keypoint(keypoint, original_size, target_size):
251
+ """Scale a keypoint based on the resizing of the image."""
252
+ keypoint_copy = keypoint.copy()
253
+ keypoint_copy[:, 0] *= target_size[0] / original_size[0]
254
+ keypoint_copy[:, 1] *= target_size[1] / original_size[1]
255
+ return keypoint_copy
256
+
257
+
258
+ def crop_and_adjust_image_and_annotations(image, hand_mask, obj_mask, hand_pose, intrinsics, target_size=(512, 512)):
259
+ # Find bounding boxes for each mask, handling potentially empty masks
260
+ xmin, ymin, xmax, ymax = find_bounding_box(hand_mask) if np.any(hand_mask) else None
261
+
262
+ # Adjust bounding box to fit within the image and be square
263
+ xmin, ymin, xmax, ymax = adjust_box_to_image(xmin, ymin, xmax, ymax, image.shape[1], image.shape[0])
264
+
265
+ # Crop the image and mask
266
+ # masked_hand_image = (image * np.maximum(hand_mask, obj_mask)[..., None].astype(float)).astype(np.uint8)
267
+ cropped_hand_image = image[ymin:ymax, xmin:xmax]
268
+ cropped_hand_mask = hand_mask[ymin:ymax, xmin:xmax].astype(np.uint8)
269
+ cropped_obj_mask = obj_mask[ymin:ymax, xmin:xmax].astype(np.uint8)
270
+
271
+ # Resize the image
272
+ resized_image = resize(cropped_hand_image, target_size, anti_aliasing=True)
273
+ resized_hand_mask = cv2.resize(cropped_hand_mask, dsize=target_size, interpolation=cv2.INTER_NEAREST)
274
+ resized_obj_mask = cv2.resize(cropped_obj_mask, dsize=target_size, interpolation=cv2.INTER_NEAREST)
275
+
276
+ # adjust and scale 2d keypoints
277
+ for hand_type, kps2d in hand_pose.items():
278
+ kps2d[:, 0] -= xmin
279
+ kps2d[:, 1] -= ymin
280
+ hand_pose[hand_type] = scale_keypoint(kps2d, (xmax - xmin, ymax - ymin), target_size)
281
+
282
+ # adjust instrinsics
283
+ resized_intrinsics= np.array(intrinsics, copy=True)
284
+ resized_intrinsics[0, 2] -= xmin
285
+ resized_intrinsics[1, 2] -= ymin
286
+ resized_intrinsics[0, :] *= target_size[0] / (xmax - xmin)
287
+ resized_intrinsics[1, :] *= target_size[1] / (ymax - ymin)
288
+
289
+ return (resized_image, resized_hand_mask, resized_obj_mask, hand_pose, resized_intrinsics)
vit.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import math
5
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
6
+
7
+
8
+ def modulate(x, shift, scale):
9
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
10
+
11
+
12
+ #################################################################################
13
+ # Embedding Layers for Timesteps and Class Labels #
14
+ #################################################################################
15
+
16
+ class TimestepEmbedder(nn.Module):
17
+ """
18
+ Embeds scalar timesteps into vector representations.
19
+ """
20
+ def __init__(self, hidden_size, frequency_embedding_size=256):
21
+ super().__init__()
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
24
+ nn.SiLU(),
25
+ nn.Linear(hidden_size, hidden_size, bias=True),
26
+ )
27
+ self.frequency_embedding_size = frequency_embedding_size
28
+
29
+ @staticmethod
30
+ def timestep_embedding(t, dim, max_period=10000):
31
+ """
32
+ Create sinusoidal timestep embeddings.
33
+ :param t: a 1-D Tensor of N indices, one per batch element.
34
+ These may be fractional.
35
+ :param dim: the dimension of the output.
36
+ :param max_period: controls the minimum frequency of the embeddings.
37
+ :return: an (N, D) Tensor of positional embeddings.
38
+ """
39
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
40
+ half = dim // 2
41
+ freqs = torch.exp(
42
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
43
+ ).to(device=t.device)
44
+ args = t[:, None].float() * freqs[None]
45
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
46
+ if dim % 2:
47
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
48
+ return embedding
49
+
50
+ def forward(self, t):
51
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
52
+ t_emb = self.mlp(t_freq)
53
+ return t_emb
54
+
55
+
56
+ class LabelEmbedder(nn.Module):
57
+ """
58
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
59
+ """
60
+ def __init__(self, num_classes, hidden_size, dropout_prob):
61
+ super().__init__()
62
+ use_cfg_embedding = dropout_prob > 0
63
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
64
+ self.num_classes = num_classes
65
+ self.dropout_prob = dropout_prob
66
+
67
+ def token_drop(self, labels, force_drop_ids=None):
68
+ """
69
+ Drops labels to enable classifier-free guidance.
70
+ """
71
+ if force_drop_ids is None:
72
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
73
+ else:
74
+ drop_ids = force_drop_ids == 1
75
+ labels = torch.where(drop_ids, self.num_classes, labels)
76
+ return labels
77
+
78
+ def forward(self, labels, train, force_drop_ids=None):
79
+ use_dropout = self.dropout_prob > 0
80
+ if (train and use_dropout) or (force_drop_ids is not None):
81
+ labels = self.token_drop(labels, force_drop_ids)
82
+ embeddings = self.embedding_table(labels)
83
+ return embeddings
84
+
85
+
86
+ class DiTBlock(nn.Module):
87
+ """
88
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
89
+ """
90
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
91
+ super().__init__()
92
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
93
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
94
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
95
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
96
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
97
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
98
+ self.adaLN_modulation = nn.Sequential(
99
+ nn.SiLU(),
100
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
101
+ )
102
+
103
+ def forward(self, x, c):
104
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
105
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
106
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
107
+ return x
108
+
109
+
110
+ class FinalLayer(nn.Module):
111
+ """
112
+ The final layer of DiT.
113
+ """
114
+ def __init__(self, hidden_size, patch_size, out_channels):
115
+ super().__init__()
116
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
117
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
118
+ self.adaLN_modulation = nn.Sequential(
119
+ nn.SiLU(),
120
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
121
+ )
122
+
123
+ def forward(self, x, c):
124
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
125
+ x = modulate(self.norm_final(x), shift, scale)
126
+ x = self.linear(x)
127
+ return x
128
+
129
+
130
+ class DiT(nn.Module):
131
+ """
132
+ Diffusion model with a Transformer backbone.
133
+ """
134
+ def __init__(
135
+ self,
136
+ input_size=32,
137
+ patch_size=2,
138
+ latent_dim=4,
139
+ in_channels=47,
140
+ hidden_size=1152,
141
+ depth=28,
142
+ num_heads=16,
143
+ mlp_ratio=4.0,
144
+ learn_sigma=True,
145
+ ):
146
+ super().__init__()
147
+ self.learn_sigma = learn_sigma
148
+ self.in_channels = in_channels
149
+ self.out_channels = latent_dim * 2 if learn_sigma else latent_dim
150
+ self.patch_size = patch_size
151
+ self.num_heads = num_heads
152
+
153
+ #self.x_embedder = PatchEmbed(input_size, patch_size, latent_dim, hidden_size, bias=True)
154
+ self.feature_aligned_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
155
+
156
+ self.n_patches = self.feature_aligned_embedder.num_patches
157
+ self.patch_size = self.feature_aligned_embedder.patch_size[0]
158
+
159
+ self.t_embedder = TimestepEmbedder(hidden_size)
160
+ self.nvs_label_embedder = LabelEmbedder(3, hidden_size, 0.)
161
+ self.pos_embed = nn.Parameter(torch.zeros(1, 2 * self.n_patches, hidden_size), requires_grad=True)
162
+ self.y_embedder = LabelEmbedder(num_classes=1000, hidden_size=hidden_size, dropout_prob=0.1)
163
+
164
+ self.blocks = nn.ModuleList([
165
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
166
+ ])
167
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
168
+ self.initialize_weights()
169
+
170
+ def initialize_weights(self):
171
+ # Initialize transformer layers:
172
+ def _basic_init(module):
173
+ if isinstance(module, nn.Linear):
174
+ torch.nn.init.xavier_uniform_(module.weight)
175
+ if module.bias is not None:
176
+ nn.init.constant_(module.bias, 0)
177
+ self.apply(_basic_init)
178
+
179
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
180
+ grid_size = int(self.n_patches ** 0.5)
181
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (2 * grid_size, grid_size))
182
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
183
+
184
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
185
+ #w = self.x_embedder.proj.weight.data
186
+ #nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
187
+ #nn.init.constant_(self.x_embedder.proj.bias, 0)
188
+
189
+ w = self.feature_aligned_embedder.proj.weight.data
190
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
191
+ nn.init.constant_(self.feature_aligned_embedder.proj.bias, 0)
192
+
193
+ # Initialize label embedding table:
194
+ nn.init.normal_(self.nvs_label_embedder.embedding_table.weight, std=0.02)
195
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
196
+
197
+ # Initialize timestep embedding MLP:
198
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
199
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
200
+
201
+ # Zero-out adaLN modulation layers in DiT blocks:
202
+ for block in self.blocks:
203
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
204
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
205
+
206
+ # Zero-out output layers:
207
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
208
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
209
+ nn.init.constant_(self.final_layer.linear.weight, 0)
210
+ nn.init.constant_(self.final_layer.linear.bias, 0)
211
+
212
+ def unpatchify(self, x):
213
+ """
214
+ x: (N, T, patch_size**2 * C)
215
+ imgs: (N, H, W, C)
216
+ """
217
+ c = self.out_channels
218
+ p = self.patch_size
219
+ h = w = int(x.shape[1] ** 0.5)
220
+ assert h * w == x.shape[1]
221
+
222
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
223
+ x = torch.einsum('nhwpqc->nchpwq', x)
224
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
225
+ return imgs
226
+
227
+ def forward(self, x_t, t, target_cond, ref_cond, nvs, y=None):
228
+ """
229
+ Forward pass of DiT.
230
+ x: (N, C1, H, W) denoising latent + target pose control
231
+ cond: (N, C2, H, W) source latent + source pose control + mask
232
+ t: (N,) tensor of diffusion timesteps
233
+ y: (N,) tensor of class labels
234
+ """
235
+ x = self.feature_aligned_embedder(torch.concat([x_t, target_cond], 1)) + self.pos_embed[:, :self.n_patches]
236
+ cond = self.feature_aligned_embedder(ref_cond) + self.pos_embed[:, self.n_patches:]
237
+ x = torch.concatenate([x, cond], 1)
238
+
239
+ t = self.t_embedder(t) # (N, D)
240
+ nvs = self.nvs_label_embedder(nvs, False)
241
+ if y is None:
242
+ y = torch.tensor([1000] * x.shape[0], device=x.device)
243
+ y = self.y_embedder(y, False) # (N, D)
244
+ c = t + y + nvs # (N, D)
245
+ for block in self.blocks:
246
+ x = block(x, c) # (N, 2T, D)
247
+ x = x[:, :x.shape[1]//2]
248
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
249
+ x = self.unpatchify(x) # (N, out_channels, H, W)
250
+ return x
251
+
252
+ def forward_with_cfg(self, x, t, target_cond, ref_cond, nvs, cfg_scale):
253
+ half = x[: len(x) // 2]
254
+ combined = torch.cat([half, half], dim=0)
255
+ y_null = torch.tensor([1000] * half.shape[0], device=x.device)
256
+ y = torch.cat([y_null, y_null], 0)
257
+ model_out = self.forward(combined, t, target_cond, ref_cond, nvs, y)
258
+ eps, rest = model_out[:, :3], model_out[:, 3:]
259
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
260
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
261
+ eps = torch.cat([half_eps, half_eps], dim=0)
262
+ return torch.cat([eps, rest], dim=1)
263
+
264
+ #################################################################################
265
+ # Sine/Cosine Positional Embedding Functions #
266
+ #################################################################################
267
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
268
+
269
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
270
+ """
271
+ grid_size: int of the grid height and width
272
+ return:
273
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
274
+ """
275
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
276
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
277
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
278
+ grid = np.stack(grid, axis=0)
279
+
280
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
281
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
282
+ if cls_token and extra_tokens > 0:
283
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
284
+ return pos_embed
285
+
286
+
287
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
288
+ assert embed_dim % 2 == 0
289
+
290
+ # use half of dimensions to encode grid_h
291
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
292
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
293
+
294
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
295
+ return emb
296
+
297
+
298
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
299
+ """
300
+ embed_dim: output dimension for each position
301
+ pos: a list of positions to be encoded: size (M,)
302
+ out: (M, D)
303
+ """
304
+ assert embed_dim % 2 == 0
305
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
306
+ omega /= embed_dim / 2.
307
+ omega = 1. / 10000**omega # (D/2,)
308
+
309
+ pos = pos.reshape(-1) # (M,)
310
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
311
+
312
+ emb_sin = np.sin(out) # (M, D/2)
313
+ emb_cos = np.cos(out) # (M, D/2)
314
+
315
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
316
+ return emb
317
+
318
+
319
+ def DiT_XL_2(**kwargs):
320
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
321
+
322
+ def DiT_L_2(**kwargs):
323
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
vqvae.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ---
3
+ title: Autoencoder for Stable Diffusion
4
+ summary: >
5
+ Annotated PyTorch implementation/tutorial of the autoencoder
6
+ for stable diffusion.
7
+ ---
8
+
9
+ # Autoencoder for [Stable Diffusion](../index.html)
10
+
11
+ This implements the auto-encoder model used to map between image space and latent space.
12
+
13
+ We have kept to the model definition and naming unchanged from
14
+ [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
15
+ so that we can load the checkpoints directly.
16
+ """
17
+
18
+ from typing import List
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+
25
+ class Autoencoder(nn.Module):
26
+ """
27
+ ## Autoencoder
28
+
29
+ This consists of the encoder and decoder modules.
30
+ """
31
+
32
+ def __init__(
33
+ self, encoder: "Encoder", decoder: "Decoder", emb_channels: int, z_channels: int
34
+ ):
35
+ """
36
+ :param encoder: is the encoder
37
+ :param decoder: is the decoder
38
+ :param emb_channels: is the number of dimensions in the quantized embedding space
39
+ :param z_channels: is the number of channels in the embedding space
40
+ """
41
+ super().__init__()
42
+ self.encoder = encoder
43
+ self.decoder = decoder
44
+ # Convolution to map from embedding space to
45
+ # quantized embedding space moments (mean and log variance)
46
+ self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
47
+ # Convolution to map from quantized embedding space back to
48
+ # embedding space
49
+ self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
50
+
51
+ def encode(self, img: torch.Tensor) -> "GaussianDistribution":
52
+ """
53
+ ### Encode images to latent representation
54
+
55
+ :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
56
+ """
57
+ # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
58
+ z = self.encoder(img)
59
+ # Get the moments in the quantized embedding space
60
+ moments = self.quant_conv(z)
61
+ # Return the distribution
62
+ return GaussianDistribution(moments)
63
+
64
+ def decode(self, z: torch.Tensor):
65
+ """
66
+ ### Decode images from latent representation
67
+
68
+ :param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]`
69
+ """
70
+ # Map to embedding space from the quantized representation
71
+ z = self.post_quant_conv(z)
72
+ # Decode the image of shape `[batch_size, channels, height, width]`
73
+ return self.decoder(z)
74
+
75
+ def forward(self, x):
76
+ posterior = self.encode(x)
77
+ z = posterior.sample()
78
+ dec = self.decode(z)
79
+ return dec, posterior
80
+
81
+
82
+ class Encoder(nn.Module):
83
+ """
84
+ ## Encoder module
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ *,
90
+ channels: int,
91
+ channel_multipliers: List[int],
92
+ n_resnet_blocks: int,
93
+ in_channels: int,
94
+ z_channels: int
95
+ ):
96
+ """
97
+ :param channels: is the number of channels in the first convolution layer
98
+ :param channel_multipliers: are the multiplicative factors for the number of channels in the
99
+ subsequent blocks
100
+ :param n_resnet_blocks: is the number of resnet layers at each resolution
101
+ :param in_channels: is the number of channels in the image
102
+ :param z_channels: is the number of channels in the embedding space
103
+ """
104
+ super().__init__()
105
+
106
+ # Number of blocks of different resolutions.
107
+ # The resolution is halved at the end each top level block
108
+ n_resolutions = len(channel_multipliers)
109
+
110
+ # Initial $3 \times 3$ convolution layer that maps the image to `channels`
111
+ self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
112
+
113
+ # Number of channels in each top level block
114
+ channels_list = [m * channels for m in [1] + channel_multipliers]
115
+
116
+ # List of top-level blocks
117
+ self.down = nn.ModuleList()
118
+ # Create top-level blocks
119
+ for i in range(n_resolutions):
120
+ # Each top level block consists of multiple ResNet Blocks and down-sampling
121
+ resnet_blocks = nn.ModuleList()
122
+ # Add ResNet Blocks
123
+ for _ in range(n_resnet_blocks):
124
+ resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
125
+ channels = channels_list[i + 1]
126
+ # Top-level block
127
+ down = nn.Module()
128
+ down.block = resnet_blocks
129
+ # Down-sampling at the end of each top level block except the last
130
+ if i != n_resolutions - 1:
131
+ down.downsample = DownSample(channels)
132
+ else:
133
+ down.downsample = nn.Identity()
134
+ #
135
+ self.down.append(down)
136
+
137
+ # Final ResNet blocks with attention
138
+ self.mid = nn.Module()
139
+ self.mid.block_1 = ResnetBlock(channels, channels)
140
+ self.mid.attn_1 = AttnBlock(channels)
141
+ self.mid.block_2 = ResnetBlock(channels, channels)
142
+
143
+ # Map to embedding space with a $3 \times 3$ convolution
144
+ self.norm_out = normalization(channels)
145
+ self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
146
+
147
+ def forward(self, img: torch.Tensor):
148
+ """
149
+ :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
150
+ """
151
+
152
+ # Map to `channels` with the initial convolution
153
+ x = self.conv_in(img)
154
+
155
+ # Top-level blocks
156
+ for down in self.down:
157
+ # ResNet Blocks
158
+ for block in down.block:
159
+ x = block(x)
160
+ # Down-sampling
161
+ x = down.downsample(x)
162
+
163
+ # Final ResNet blocks with attention
164
+ x = self.mid.block_1(x)
165
+ x = self.mid.attn_1(x)
166
+ x = self.mid.block_2(x)
167
+
168
+ # Normalize and map to embedding space
169
+ x = self.norm_out(x)
170
+ x = swish(x)
171
+ x = self.conv_out(x)
172
+
173
+ #
174
+ return x
175
+
176
+
177
+ class Decoder(nn.Module):
178
+ """
179
+ ## Decoder module
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ *,
185
+ channels: int,
186
+ channel_multipliers: List[int],
187
+ n_resnet_blocks: int,
188
+ out_channels: int,
189
+ z_channels: int
190
+ ):
191
+ """
192
+ :param channels: is the number of channels in the final convolution layer
193
+ :param channel_multipliers: are the multiplicative factors for the number of channels in the
194
+ previous blocks, in reverse order
195
+ :param n_resnet_blocks: is the number of resnet layers at each resolution
196
+ :param out_channels: is the number of channels in the image
197
+ :param z_channels: is the number of channels in the embedding space
198
+ """
199
+ super().__init__()
200
+
201
+ # Number of blocks of different resolutions.
202
+ # The resolution is halved at the end each top level block
203
+ num_resolutions = len(channel_multipliers)
204
+
205
+ # Number of channels in each top level block, in the reverse order
206
+ channels_list = [m * channels for m in channel_multipliers]
207
+
208
+ # Number of channels in the top-level block
209
+ channels = channels_list[-1]
210
+
211
+ # Initial $3 \times 3$ convolution layer that maps the embedding space to `channels`
212
+ self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
213
+
214
+ # ResNet blocks with attention
215
+ self.mid = nn.Module()
216
+ self.mid.block_1 = ResnetBlock(channels, channels)
217
+ self.mid.attn_1 = AttnBlock(channels)
218
+ self.mid.block_2 = ResnetBlock(channels, channels)
219
+
220
+ # List of top-level blocks
221
+ self.up = nn.ModuleList()
222
+ # Create top-level blocks
223
+ for i in reversed(range(num_resolutions)):
224
+ # Each top level block consists of multiple ResNet Blocks and up-sampling
225
+ resnet_blocks = nn.ModuleList()
226
+ # Add ResNet Blocks
227
+ for _ in range(n_resnet_blocks + 1):
228
+ resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
229
+ channels = channels_list[i]
230
+ # Top-level block
231
+ up = nn.Module()
232
+ up.block = resnet_blocks
233
+ # Up-sampling at the end of each top level block except the first
234
+ if i != 0:
235
+ up.upsample = UpSample(channels)
236
+ else:
237
+ up.upsample = nn.Identity()
238
+ # Prepend to be consistent with the checkpoint
239
+ self.up.insert(0, up)
240
+
241
+ # Map to image space with a $3 \times 3$ convolution
242
+ self.norm_out = normalization(channels)
243
+ self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
244
+
245
+ def forward(self, z: torch.Tensor):
246
+ """
247
+ :param z: is the embedding tensor with shape `[batch_size, z_channels, z_height, z_height]`
248
+ """
249
+
250
+ # Map to `channels` with the initial convolution
251
+ h = self.conv_in(z)
252
+
253
+ # ResNet blocks with attention
254
+ h = self.mid.block_1(h)
255
+ h = self.mid.attn_1(h)
256
+ h = self.mid.block_2(h)
257
+
258
+ # Top-level blocks
259
+ for up in reversed(self.up):
260
+ # ResNet Blocks
261
+ for block in up.block:
262
+ h = block(h)
263
+ # Up-sampling
264
+ h = up.upsample(h)
265
+
266
+ # Normalize and map to image space
267
+ h = self.norm_out(h)
268
+ h = swish(h)
269
+ img = self.conv_out(h)
270
+
271
+ #
272
+ return img
273
+
274
+
275
+ class GaussianDistribution:
276
+ """
277
+ ## Gaussian Distribution
278
+ """
279
+
280
+ def __init__(self, parameters: torch.Tensor):
281
+ """
282
+ :param parameters: are the means and log of variances of the embedding of shape
283
+ `[batch_size, z_channels * 2, z_height, z_height]`
284
+ """
285
+ # Split mean and log of variance
286
+ self.mean, log_var = torch.chunk(parameters, 2, dim=1)
287
+ # Clamp the log of variances
288
+ self.log_var = torch.clamp(log_var, -30.0, 20.0)
289
+ # Calculate standard deviation
290
+ self.std = torch.exp(0.5 * self.log_var)
291
+ self.var = torch.exp(self.log_var)
292
+
293
+ def sample(self):
294
+ # Sample from the distribution
295
+ return self.mean + self.std * torch.randn_like(self.std)
296
+
297
+ def kl(self):
298
+ return 0.5 * torch.sum(
299
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.log_var, dim=[1, 2, 3]
300
+ )
301
+
302
+
303
+ class AttnBlock(nn.Module):
304
+ """
305
+ ## Attention block
306
+ """
307
+
308
+ def __init__(self, channels: int):
309
+ """
310
+ :param channels: is the number of channels
311
+ """
312
+ super().__init__()
313
+ # Group normalization
314
+ self.norm = normalization(channels)
315
+ # Query, key and value mappings
316
+ self.q = nn.Conv2d(channels, channels, 1)
317
+ self.k = nn.Conv2d(channels, channels, 1)
318
+ self.v = nn.Conv2d(channels, channels, 1)
319
+ # Final $1 \times 1$ convolution layer
320
+ self.proj_out = nn.Conv2d(channels, channels, 1)
321
+ # Attention scaling factor
322
+ self.scale = channels**-0.5
323
+
324
+ def forward(self, x: torch.Tensor):
325
+ """
326
+ :param x: is the tensor of shape `[batch_size, channels, height, width]`
327
+ """
328
+ # Normalize `x`
329
+ x_norm = self.norm(x)
330
+ # Get query, key and vector embeddings
331
+ q = self.q(x_norm)
332
+ k = self.k(x_norm)
333
+ v = self.v(x_norm)
334
+
335
+ # Reshape to query, key and vector embeedings from
336
+ # `[batch_size, channels, height, width]` to
337
+ # `[batch_size, channels, height * width]`
338
+ b, c, h, w = q.shape
339
+ q = q.view(b, c, h * w)
340
+ k = k.view(b, c, h * w)
341
+ v = v.view(b, c, h * w)
342
+
343
+ # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$
344
+ attn = torch.einsum("bci,bcj->bij", q, k) * self.scale
345
+ attn = F.softmax(attn, dim=2)
346
+
347
+ # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$
348
+ out = torch.einsum("bij,bcj->bci", attn, v)
349
+
350
+ # Reshape back to `[batch_size, channels, height, width]`
351
+ out = out.view(b, c, h, w)
352
+ # Final $1 \times 1$ convolution layer
353
+ out = self.proj_out(out)
354
+
355
+ # Add residual connection
356
+ return x + out
357
+
358
+
359
+ class UpSample(nn.Module):
360
+ """
361
+ ## Up-sampling layer
362
+ """
363
+
364
+ def __init__(self, channels: int):
365
+ """
366
+ :param channels: is the number of channels
367
+ """
368
+ super().__init__()
369
+ # $3 \times 3$ convolution mapping
370
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
371
+
372
+ def forward(self, x: torch.Tensor):
373
+ """
374
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
375
+ """
376
+ # Up-sample by a factor of $2$
377
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
378
+ # Apply convolution
379
+ return self.conv(x)
380
+
381
+
382
+ class DownSample(nn.Module):
383
+ """
384
+ ## Down-sampling layer
385
+ """
386
+
387
+ def __init__(self, channels: int):
388
+ """
389
+ :param channels: is the number of channels
390
+ """
391
+ super().__init__()
392
+ # $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
393
+ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
394
+
395
+ def forward(self, x: torch.Tensor):
396
+ """
397
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
398
+ """
399
+ # Add padding
400
+ x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
401
+ # Apply convolution
402
+ return self.conv(x)
403
+
404
+
405
+ class ResnetBlock(nn.Module):
406
+ """
407
+ ## ResNet Block
408
+ """
409
+
410
+ def __init__(self, in_channels: int, out_channels: int):
411
+ """
412
+ :param in_channels: is the number of channels in the input
413
+ :param out_channels: is the number of channels in the output
414
+ """
415
+ super().__init__()
416
+ # First normalization and convolution layer
417
+ self.norm1 = normalization(in_channels)
418
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
419
+ # Second normalization and convolution layer
420
+ self.norm2 = normalization(out_channels)
421
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
422
+ # `in_channels` to `out_channels` mapping layer for residual connection
423
+ if in_channels != out_channels:
424
+ self.nin_shortcut = nn.Conv2d(
425
+ in_channels, out_channels, 1, stride=1, padding=0
426
+ )
427
+ else:
428
+ self.nin_shortcut = nn.Identity()
429
+
430
+ def forward(self, x: torch.Tensor):
431
+ """
432
+ :param x: is the input feature map with shape `[batch_size, channels, height, width]`
433
+ """
434
+
435
+ h = x
436
+
437
+ # First normalization and convolution layer
438
+ h = self.norm1(h)
439
+ h = swish(h)
440
+ h = self.conv1(h)
441
+
442
+ # Second normalization and convolution layer
443
+ h = self.norm2(h)
444
+ h = swish(h)
445
+ h = self.conv2(h)
446
+
447
+ # Map and add residual
448
+ return self.nin_shortcut(x) + h
449
+
450
+
451
+ def swish(x: torch.Tensor):
452
+ """
453
+ ### Swish activation
454
+
455
+ """
456
+ return x * torch.sigmoid(x)
457
+
458
+
459
+ def normalization(channels: int):
460
+ """
461
+ ### Group normalization
462
+
463
+ This is a helper function, with fixed number of groups and `eps`.
464
+ """
465
+ return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
466
+
467
+
468
+ def restore_ae_from_sd(model, path):
469
+
470
+ def remove_prefix(text, prefix):
471
+ if text.startswith(prefix):
472
+ return text[len(prefix) :]
473
+ return text
474
+
475
+ checkpoint = torch.load(path)
476
+ # checkpoint = torch.load(path, map_location="cpu")
477
+
478
+ ckpt_state_dict = checkpoint["state_dict"]
479
+ new_ckpt_state_dict = {}
480
+ for k, v in ckpt_state_dict.items():
481
+ new_k = remove_prefix(k, "first_stage_model.")
482
+ new_ckpt_state_dict[new_k] = v
483
+ missing_keys, extra_keys = model.load_state_dict(new_ckpt_state_dict, strict=False)
484
+ assert len(missing_keys) == 0
485
+
486
+
487
+ def create_model(in_channels, out_channels, latent_dim=4):
488
+ encoder = Encoder(
489
+ z_channels=latent_dim,
490
+ in_channels=in_channels,
491
+ channels=128,
492
+ channel_multipliers=[1, 2, 4, 4],
493
+ n_resnet_blocks=2,
494
+ )
495
+
496
+ decoder = Decoder(
497
+ out_channels=out_channels,
498
+ z_channels=latent_dim,
499
+ channels=128,
500
+ channel_multipliers=[1, 2, 4, 4],
501
+ n_resnet_blocks=2,
502
+ )
503
+
504
+ autoencoder = Autoencoder(
505
+ emb_channels=latent_dim, encoder=encoder, decoder=decoder, z_channels=latent_dim
506
+ )
507
+ return autoencoder