Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- README.md +1 -13
- app.py +1581 -0
- diffusion/__init__.py +46 -0
- diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
- diffusion/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
- diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc +0 -0
- diffusion/__pycache__/respace.cpython-38.pyc +0 -0
- diffusion/__pycache__/scheduler.cpython-38.pyc +0 -0
- diffusion/diffusion_utils.py +88 -0
- diffusion/gaussian_diffusion.py +1118 -0
- diffusion/respace.py +129 -0
- diffusion/scheduler.py +224 -0
- diffusion/timestep_sampler.py +150 -0
- packages.txt +1 -0
- requirements.txt +223 -0
- segment_hoi.py +111 -0
- utils.py +289 -0
- vit.py +323 -0
- vqvae.py +507 -0
README.md
CHANGED
@@ -1,13 +1 @@
|
|
1 |
-
|
2 |
-
title: FoundHand
|
3 |
-
emoji: 🏆
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.9.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
short_description: FoundHand
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,1581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import cv2
|
7 |
+
import mediapipe as mp
|
8 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
9 |
+
import vqvae
|
10 |
+
import vit
|
11 |
+
from typing import Literal
|
12 |
+
from diffusion import create_diffusion
|
13 |
+
from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
|
14 |
+
from segment_hoi import init_sam
|
15 |
+
from io import BytesIO
|
16 |
+
from PIL import Image
|
17 |
+
import random
|
18 |
+
from copy import deepcopy
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
MAX_N = 6
|
22 |
+
FIX_MAX_N = 6
|
23 |
+
|
24 |
+
placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
|
25 |
+
NEW_MODEL = True
|
26 |
+
MODEL_EPOCH = 6
|
27 |
+
REF_POSE_MASK = True
|
28 |
+
|
29 |
+
def set_seed(seed):
|
30 |
+
seed = int(seed)
|
31 |
+
torch.manual_seed(seed)
|
32 |
+
np.random.seed(seed)
|
33 |
+
torch.cuda.manual_seed_all(seed)
|
34 |
+
random.seed(seed)
|
35 |
+
|
36 |
+
|
37 |
+
def remove_prefix(text, prefix):
|
38 |
+
if text.startswith(prefix):
|
39 |
+
return text[len(prefix) :]
|
40 |
+
return text
|
41 |
+
|
42 |
+
|
43 |
+
def unnormalize(x):
|
44 |
+
return (((x + 1) / 2) * 255).astype(np.uint8)
|
45 |
+
|
46 |
+
|
47 |
+
def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
|
48 |
+
# Define the connections between joints for drawing lines and their corresponding colors
|
49 |
+
connections = [
|
50 |
+
((0, 1), "red"),
|
51 |
+
((1, 2), "green"),
|
52 |
+
((2, 3), "blue"),
|
53 |
+
((3, 4), "purple"),
|
54 |
+
((0, 5), "orange"),
|
55 |
+
((5, 6), "pink"),
|
56 |
+
((6, 7), "brown"),
|
57 |
+
((7, 8), "cyan"),
|
58 |
+
((0, 9), "yellow"),
|
59 |
+
((9, 10), "magenta"),
|
60 |
+
((10, 11), "lime"),
|
61 |
+
((11, 12), "indigo"),
|
62 |
+
((0, 13), "olive"),
|
63 |
+
((13, 14), "teal"),
|
64 |
+
((14, 15), "navy"),
|
65 |
+
((15, 16), "gray"),
|
66 |
+
((0, 17), "lavender"),
|
67 |
+
((17, 18), "silver"),
|
68 |
+
((18, 19), "maroon"),
|
69 |
+
((19, 20), "fuchsia"),
|
70 |
+
]
|
71 |
+
H, W, C = img.shape
|
72 |
+
|
73 |
+
# Create a figure and axis
|
74 |
+
plt.figure()
|
75 |
+
ax = plt.gca()
|
76 |
+
# Plot joints as points
|
77 |
+
ax.imshow(img)
|
78 |
+
start_is = []
|
79 |
+
if "right" in side:
|
80 |
+
start_is.append(0)
|
81 |
+
if "left" in side:
|
82 |
+
start_is.append(21)
|
83 |
+
for start_i in start_is:
|
84 |
+
joints = all_joints[start_i : start_i + n_avail_joints]
|
85 |
+
if len(joints) == 1:
|
86 |
+
ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
|
87 |
+
else:
|
88 |
+
for connection, color in connections[: len(joints) - 1]:
|
89 |
+
joint1 = joints[connection[0]]
|
90 |
+
joint2 = joints[connection[1]]
|
91 |
+
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
|
92 |
+
|
93 |
+
ax.set_xlim([0, W])
|
94 |
+
ax.set_ylim([0, H])
|
95 |
+
ax.grid(False)
|
96 |
+
ax.set_axis_off()
|
97 |
+
ax.invert_yaxis()
|
98 |
+
# plt.subplots_adjust(wspace=0.01)
|
99 |
+
# plt.show()
|
100 |
+
buf = BytesIO()
|
101 |
+
plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
102 |
+
plt.close()
|
103 |
+
|
104 |
+
# Convert BytesIO object to numpy array
|
105 |
+
buf.seek(0)
|
106 |
+
img_pil = Image.open(buf)
|
107 |
+
img_pil = img_pil.resize((H, W))
|
108 |
+
numpy_img = np.array(img_pil)
|
109 |
+
|
110 |
+
return numpy_img
|
111 |
+
|
112 |
+
|
113 |
+
def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
|
114 |
+
"""Overlay mask on image for visualization purpose.
|
115 |
+
Args:
|
116 |
+
image (H, W, 3) or (H, W): input image
|
117 |
+
mask (H, W): mask to be overlaid
|
118 |
+
color: the color of overlaid mask
|
119 |
+
alpha: the transparency of the mask
|
120 |
+
"""
|
121 |
+
out = deepcopy(image)
|
122 |
+
img = deepcopy(image)
|
123 |
+
img[mask == 1] = color
|
124 |
+
if transparent:
|
125 |
+
out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
|
126 |
+
else:
|
127 |
+
out = img
|
128 |
+
return out
|
129 |
+
|
130 |
+
|
131 |
+
def scale_keypoint(keypoint, original_size, target_size):
|
132 |
+
"""Scale a keypoint based on the resizing of the image."""
|
133 |
+
keypoint_copy = keypoint.copy()
|
134 |
+
keypoint_copy[:, 0] *= target_size[0] / original_size[0]
|
135 |
+
keypoint_copy[:, 1] *= target_size[1] / original_size[1]
|
136 |
+
return keypoint_copy
|
137 |
+
|
138 |
+
|
139 |
+
print("Configure...")
|
140 |
+
|
141 |
+
|
142 |
+
@dataclass
|
143 |
+
class HandDiffOpts:
|
144 |
+
run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
|
145 |
+
sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
|
146 |
+
log_dir: str = "/users/kchen157/scratch/log"
|
147 |
+
data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
|
148 |
+
image_size: tuple = (256, 256)
|
149 |
+
latent_size: tuple = (32, 32)
|
150 |
+
latent_dim: int = 4
|
151 |
+
mask_bg: bool = False
|
152 |
+
kpts_form: str = "heatmap"
|
153 |
+
n_keypoints: int = 42
|
154 |
+
n_mask: int = 1
|
155 |
+
noise_steps: int = 1000
|
156 |
+
test_sampling_steps: int = 250
|
157 |
+
ddim_steps: int = 100
|
158 |
+
ddim_discretize: str = "uniform"
|
159 |
+
ddim_eta: float = 0.0
|
160 |
+
beta_start: float = 8.5e-4
|
161 |
+
beta_end: float = 0.012
|
162 |
+
latent_scaling_factor: float = 0.18215
|
163 |
+
cfg_pose: float = 5.0
|
164 |
+
cfg_appearance: float = 3.5
|
165 |
+
batch_size: int = 25
|
166 |
+
lr: float = 1e-5
|
167 |
+
max_epochs: int = 500
|
168 |
+
log_every_n_steps: int = 100
|
169 |
+
limit_val_batches: int = 1
|
170 |
+
n_gpu: int = 8
|
171 |
+
num_nodes: int = 1
|
172 |
+
precision: str = "16-mixed"
|
173 |
+
profiler: str = "simple"
|
174 |
+
swa_epoch_start: int = 10
|
175 |
+
swa_lrs: float = 1e-3
|
176 |
+
num_workers: int = 10
|
177 |
+
n_val_samples: int = 4
|
178 |
+
|
179 |
+
if not torch.cuda.is_available():
|
180 |
+
raise ValueError("No GPU")
|
181 |
+
|
182 |
+
# load models
|
183 |
+
if NEW_MODEL:
|
184 |
+
opts = HandDiffOpts()
|
185 |
+
if MODEL_EPOCH == 7:
|
186 |
+
model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
|
187 |
+
elif MODEL_EPOCH == 6:
|
188 |
+
# model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
|
189 |
+
model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt")
|
190 |
+
elif MODEL_EPOCH == 4:
|
191 |
+
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
|
192 |
+
elif MODEL_EPOCH == 10:
|
193 |
+
model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
|
194 |
+
else:
|
195 |
+
raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
|
196 |
+
vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
|
197 |
+
# sd_path = './sd-v1-4.ckpt'
|
198 |
+
print('Load diffusion model...')
|
199 |
+
diffusion = create_diffusion(str(opts.test_sampling_steps))
|
200 |
+
model = vit.DiT_XL_2(
|
201 |
+
input_size=opts.latent_size[0],
|
202 |
+
latent_dim=opts.latent_dim,
|
203 |
+
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
204 |
+
learn_sigma=True,
|
205 |
+
).cuda()
|
206 |
+
# ckpt_state_dict = torch.load(model_path)['model_state_dict']
|
207 |
+
ckpt_state_dict = torch.load(model_path, map_location=torch.device('cuda'))['ema_state_dict']
|
208 |
+
missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
|
209 |
+
model.eval()
|
210 |
+
print(missing_keys, extra_keys)
|
211 |
+
assert len(missing_keys) == 0
|
212 |
+
vae_state_dict = torch.load(vae_path)['state_dict']
|
213 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
|
214 |
+
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
215 |
+
autoencoder.eval()
|
216 |
+
assert len(missing_keys) == 0
|
217 |
+
else:
|
218 |
+
opts = HandDiffOpts()
|
219 |
+
model_path = './finetune_epoch=5-step=130000.ckpt'
|
220 |
+
sd_path = './sd-v1-4.ckpt'
|
221 |
+
print('Load diffusion model...')
|
222 |
+
diffusion = create_diffusion(str(opts.test_sampling_steps))
|
223 |
+
model = vit.DiT_XL_2(
|
224 |
+
input_size=opts.latent_size[0],
|
225 |
+
latent_dim=opts.latent_dim,
|
226 |
+
in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
|
227 |
+
learn_sigma=True,
|
228 |
+
).cuda()
|
229 |
+
ckpt_state_dict = torch.load(model_path)['state_dict']
|
230 |
+
dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
|
231 |
+
vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
|
232 |
+
missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
|
233 |
+
model.eval()
|
234 |
+
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
235 |
+
autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).cuda()
|
236 |
+
missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
|
237 |
+
autoencoder.eval()
|
238 |
+
assert len(missing_keys) == 0 and len(extra_keys) == 0
|
239 |
+
sam_predictor = init_sam(ckpt_path="./sam_vit_h_4b8939.pth")
|
240 |
+
|
241 |
+
|
242 |
+
print("Mediapipe hand detector and SAM ready...")
|
243 |
+
mp_hands = mp.solutions.hands
|
244 |
+
hands = mp_hands.Hands(
|
245 |
+
static_image_mode=True, # Use False if image is part of a video stream
|
246 |
+
max_num_hands=2, # Maximum number of hands to detect
|
247 |
+
min_detection_confidence=0.1,
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
def get_ref_anno(ref):
|
252 |
+
if ref is None:
|
253 |
+
return (
|
254 |
+
None,
|
255 |
+
None,
|
256 |
+
None,
|
257 |
+
None,
|
258 |
+
None,
|
259 |
+
)
|
260 |
+
img = ref["composite"][..., :3]
|
261 |
+
img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
|
262 |
+
keypts = np.zeros((42, 2))
|
263 |
+
if REF_POSE_MASK:
|
264 |
+
mp_pose = hands.process(img)
|
265 |
+
detected = np.array([0, 0])
|
266 |
+
start_idx = 0
|
267 |
+
if mp_pose.multi_hand_landmarks:
|
268 |
+
# handedness is flipped assuming the input image is mirrored in MediaPipe
|
269 |
+
for hand_landmarks, handedness in zip(
|
270 |
+
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
|
271 |
+
):
|
272 |
+
# actually right hand
|
273 |
+
if handedness.classification[0].label == "Left":
|
274 |
+
start_idx = 0
|
275 |
+
detected[0] = 1
|
276 |
+
# actually left hand
|
277 |
+
elif handedness.classification[0].label == "Right":
|
278 |
+
start_idx = 21
|
279 |
+
detected[1] = 1
|
280 |
+
for i, landmark in enumerate(hand_landmarks.landmark):
|
281 |
+
keypts[start_idx + i] = [
|
282 |
+
landmark.x * opts.image_size[1],
|
283 |
+
landmark.y * opts.image_size[0],
|
284 |
+
]
|
285 |
+
|
286 |
+
sam_predictor.set_image(img)
|
287 |
+
l = keypts[:21].shape[0]
|
288 |
+
if keypts[0].sum() != 0 and keypts[21].sum() != 0:
|
289 |
+
input_point = np.array([keypts[0], keypts[21]])
|
290 |
+
input_label = np.array([1, 1])
|
291 |
+
elif keypts[0].sum() != 0:
|
292 |
+
input_point = np.array(keypts[:1])
|
293 |
+
input_label = np.array([1])
|
294 |
+
elif keypts[21].sum() != 0:
|
295 |
+
input_point = np.array(keypts[21:22])
|
296 |
+
input_label = np.array([1])
|
297 |
+
masks, _, _ = sam_predictor.predict(
|
298 |
+
point_coords=input_point,
|
299 |
+
point_labels=input_label,
|
300 |
+
multimask_output=False,
|
301 |
+
)
|
302 |
+
hand_mask = masks[0]
|
303 |
+
masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
|
304 |
+
ref_pose = visualize_hand(keypts, masked_img)
|
305 |
+
else:
|
306 |
+
raise gr.Error("No hands detected in the reference image.")
|
307 |
+
else:
|
308 |
+
hand_mask = np.zeros_like(img[:,:, 0])
|
309 |
+
ref_pose = np.zeros_like(img)
|
310 |
+
|
311 |
+
def make_ref_cond(
|
312 |
+
img,
|
313 |
+
keypts,
|
314 |
+
hand_mask,
|
315 |
+
device="cuda",
|
316 |
+
target_size=(256, 256),
|
317 |
+
latent_size=(32, 32),
|
318 |
+
):
|
319 |
+
image_transform = Compose(
|
320 |
+
[
|
321 |
+
ToTensor(),
|
322 |
+
Resize(target_size),
|
323 |
+
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
324 |
+
]
|
325 |
+
)
|
326 |
+
image = image_transform(img).to(device)
|
327 |
+
kpts_valid = check_keypoints_validity(keypts, target_size)
|
328 |
+
heatmaps = torch.tensor(
|
329 |
+
keypoint_heatmap(
|
330 |
+
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
|
331 |
+
)
|
332 |
+
* kpts_valid[:, None, None],
|
333 |
+
dtype=torch.float,
|
334 |
+
device=device,
|
335 |
+
)[None, ...]
|
336 |
+
mask = torch.tensor(
|
337 |
+
cv2.resize(
|
338 |
+
hand_mask.astype(int),
|
339 |
+
dsize=latent_size,
|
340 |
+
interpolation=cv2.INTER_NEAREST,
|
341 |
+
),
|
342 |
+
dtype=torch.float,
|
343 |
+
device=device,
|
344 |
+
).unsqueeze(0)[None, ...]
|
345 |
+
return image[None, ...], heatmaps, mask
|
346 |
+
|
347 |
+
image, heatmaps, mask = make_ref_cond(
|
348 |
+
img,
|
349 |
+
keypts,
|
350 |
+
hand_mask,
|
351 |
+
device="cuda",
|
352 |
+
target_size=opts.image_size,
|
353 |
+
latent_size=opts.latent_size,
|
354 |
+
)
|
355 |
+
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
|
356 |
+
if not REF_POSE_MASK:
|
357 |
+
heatmaps = torch.zeros_like(heatmaps)
|
358 |
+
mask = torch.zeros_like(mask)
|
359 |
+
ref_cond = torch.cat([latent, heatmaps, mask], 1)
|
360 |
+
|
361 |
+
return img, ref_pose, ref_cond
|
362 |
+
|
363 |
+
|
364 |
+
def get_target_anno(target):
|
365 |
+
if target is None:
|
366 |
+
return (
|
367 |
+
gr.State.update(value=None),
|
368 |
+
gr.Image.update(value=None),
|
369 |
+
gr.State.update(value=None),
|
370 |
+
gr.State.update(value=None),
|
371 |
+
)
|
372 |
+
pose_img = target["composite"][..., :3]
|
373 |
+
pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
|
374 |
+
# detect keypoints
|
375 |
+
mp_pose = hands.process(pose_img)
|
376 |
+
target_keypts = np.zeros((42, 2))
|
377 |
+
detected = np.array([0, 0])
|
378 |
+
start_idx = 0
|
379 |
+
if mp_pose.multi_hand_landmarks:
|
380 |
+
# handedness is flipped assuming the input image is mirrored in MediaPipe
|
381 |
+
for hand_landmarks, handedness in zip(
|
382 |
+
mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
|
383 |
+
):
|
384 |
+
# actually right hand
|
385 |
+
if handedness.classification[0].label == "Left":
|
386 |
+
start_idx = 0
|
387 |
+
detected[0] = 1
|
388 |
+
# actually left hand
|
389 |
+
elif handedness.classification[0].label == "Right":
|
390 |
+
start_idx = 21
|
391 |
+
detected[1] = 1
|
392 |
+
for i, landmark in enumerate(hand_landmarks.landmark):
|
393 |
+
target_keypts[start_idx + i] = [
|
394 |
+
landmark.x * opts.image_size[1],
|
395 |
+
landmark.y * opts.image_size[0],
|
396 |
+
]
|
397 |
+
|
398 |
+
target_pose = visualize_hand(target_keypts, pose_img)
|
399 |
+
kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
|
400 |
+
target_heatmaps = torch.tensor(
|
401 |
+
keypoint_heatmap(
|
402 |
+
scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
|
403 |
+
opts.latent_size,
|
404 |
+
var=1.0,
|
405 |
+
)
|
406 |
+
* kpts_valid[:, None, None],
|
407 |
+
dtype=torch.float,
|
408 |
+
device="cuda",
|
409 |
+
)[None, ...]
|
410 |
+
target_cond = torch.cat(
|
411 |
+
[target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
|
412 |
+
)
|
413 |
+
else:
|
414 |
+
raise gr.Error("No hands detected in the target image.")
|
415 |
+
|
416 |
+
return pose_img, target_pose, target_cond, target_keypts
|
417 |
+
|
418 |
+
|
419 |
+
# def draw_grid(ref):
|
420 |
+
# if ref is None or ref["composite"] is None: # or len(ref["layers"])==0:
|
421 |
+
# return ref
|
422 |
+
|
423 |
+
# # if len(ref["layers"]) == 1:
|
424 |
+
# # need_draw = True
|
425 |
+
# # # elif ref["composite"].shape[0] != size_memory[0] or ref["composite"].shape[1] != size_memory[1]:
|
426 |
+
# # # need_draw = True
|
427 |
+
# # else:
|
428 |
+
# # need_draw = False
|
429 |
+
|
430 |
+
# # size_memory = ref["composite"].shape[0], ref["composite"].shape[1]
|
431 |
+
# # if not need_draw:
|
432 |
+
# # return size_memory, ref
|
433 |
+
|
434 |
+
# h, w = ref["composite"].shape[:2]
|
435 |
+
# grid_h, grid_w = h // 32, w // 32
|
436 |
+
# # grid = np.zeros((h, w, 4), dtype=np.uint8)
|
437 |
+
# for i in range(1, grid_h):
|
438 |
+
# ref["composite"][i * 32, :, :3] = 255 # 0.5 * ref["composite"][i * 32, :, :3] +
|
439 |
+
# for i in range(1, grid_w):
|
440 |
+
# ref["composite"][:, i * 32, :3] = 255 # 0.5 * ref["composite"][:, i * 32, :3] +
|
441 |
+
# # if len(ref["layers"]) == 1:
|
442 |
+
# # ref["layers"].append(grid)
|
443 |
+
# # else:
|
444 |
+
# # ref["layers"][1] = grid
|
445 |
+
# return ref["composite"]
|
446 |
+
|
447 |
+
|
448 |
+
def get_mask_inpaint(ref):
|
449 |
+
inpaint_mask = np.array(ref["layers"][0])[..., -1]
|
450 |
+
inpaint_mask = cv2.resize(
|
451 |
+
inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
|
452 |
+
)
|
453 |
+
inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
|
454 |
+
return inpaint_mask
|
455 |
+
|
456 |
+
|
457 |
+
def visualize_ref(crop, brush):
|
458 |
+
if crop is None or brush is None:
|
459 |
+
return None
|
460 |
+
inpainted = brush["layers"][0][..., -1]
|
461 |
+
img = crop["background"][..., :3]
|
462 |
+
img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
|
463 |
+
mask = inpainted < 128
|
464 |
+
# img = img.astype(np.int32)
|
465 |
+
# img[mask, :] = img[mask, :] - 50
|
466 |
+
# img[np.any(img<0, axis=-1)]=0
|
467 |
+
# img = img.astype(np.uint8)
|
468 |
+
img = mask_image(img, mask)
|
469 |
+
return img
|
470 |
+
|
471 |
+
|
472 |
+
def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
|
473 |
+
if keypoints is None:
|
474 |
+
keypoints = [[], []]
|
475 |
+
kps = np.zeros((42, 2))
|
476 |
+
if side == "right":
|
477 |
+
if len(keypoints[0]) == 21:
|
478 |
+
gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
|
479 |
+
else:
|
480 |
+
keypoints[0].append(list(evt.index))
|
481 |
+
len_kps = len(keypoints[0])
|
482 |
+
kps[:len_kps] = np.array(keypoints[0])
|
483 |
+
elif side == "left":
|
484 |
+
if len(keypoints[1]) == 21:
|
485 |
+
gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
|
486 |
+
else:
|
487 |
+
keypoints[1].append(list(evt.index))
|
488 |
+
len_kps = len(keypoints[1])
|
489 |
+
kps[21 : 21 + len_kps] = np.array(keypoints[1])
|
490 |
+
vis_hand = visualize_hand(kps, img, side, len_kps)
|
491 |
+
return vis_hand, keypoints
|
492 |
+
|
493 |
+
|
494 |
+
def undo_kps(img, keypoints, side: Literal["right", "left"]):
|
495 |
+
if keypoints is None:
|
496 |
+
return img, None
|
497 |
+
kps = np.zeros((42, 2))
|
498 |
+
if side == "right":
|
499 |
+
if len(keypoints[0]) == 0:
|
500 |
+
return img, keypoints
|
501 |
+
keypoints[0].pop()
|
502 |
+
len_kps = len(keypoints[0])
|
503 |
+
kps[:len_kps] = np.array(keypoints[0])
|
504 |
+
elif side == "left":
|
505 |
+
if len(keypoints[1]) == 0:
|
506 |
+
return img, keypoints
|
507 |
+
keypoints[1].pop()
|
508 |
+
len_kps = len(keypoints[1])
|
509 |
+
kps[21 : 21 + len_kps] = np.array(keypoints[1])
|
510 |
+
vis_hand = visualize_hand(kps, img, side, len_kps)
|
511 |
+
return vis_hand, keypoints
|
512 |
+
|
513 |
+
|
514 |
+
def reset_kps(img, keypoints, side: Literal["right", "left"]):
|
515 |
+
if keypoints is None:
|
516 |
+
return img, None
|
517 |
+
if side == "right":
|
518 |
+
keypoints[0] = []
|
519 |
+
elif side == "left":
|
520 |
+
keypoints[1] = []
|
521 |
+
return img, keypoints
|
522 |
+
|
523 |
+
|
524 |
+
def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
|
525 |
+
set_seed(seed)
|
526 |
+
z = torch.randn(
|
527 |
+
(num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
|
528 |
+
device="cuda",
|
529 |
+
)
|
530 |
+
target_cond = target_cond.repeat(num_gen, 1, 1, 1)
|
531 |
+
ref_cond = ref_cond.repeat(num_gen, 1, 1, 1)
|
532 |
+
# novel view synthesis mode = off
|
533 |
+
nvs = torch.zeros(num_gen, dtype=torch.int, device="cuda")
|
534 |
+
z = torch.cat([z, z], 0)
|
535 |
+
model_kwargs = dict(
|
536 |
+
target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
|
537 |
+
ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
|
538 |
+
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
|
539 |
+
cfg_scale=cfg,
|
540 |
+
)
|
541 |
+
|
542 |
+
samples, _ = diffusion.p_sample_loop(
|
543 |
+
model.forward_with_cfg,
|
544 |
+
z.shape,
|
545 |
+
z,
|
546 |
+
clip_denoised=False,
|
547 |
+
model_kwargs=model_kwargs,
|
548 |
+
progress=True,
|
549 |
+
device="cuda",
|
550 |
+
).chunk(2)
|
551 |
+
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
|
552 |
+
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
|
553 |
+
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
|
554 |
+
|
555 |
+
results = []
|
556 |
+
results_pose = []
|
557 |
+
for i in range(MAX_N):
|
558 |
+
if i < num_gen:
|
559 |
+
results.append(sampled_images[i])
|
560 |
+
results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
|
561 |
+
else:
|
562 |
+
results.append(placeholder)
|
563 |
+
results_pose.append(placeholder)
|
564 |
+
return results, results_pose
|
565 |
+
|
566 |
+
|
567 |
+
def ready_sample(img_ori, inpaint_mask, keypts):
|
568 |
+
img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
|
569 |
+
sam_predictor.set_image(img)
|
570 |
+
if len(keypts[0]) == 0:
|
571 |
+
keypts[0] = np.zeros((21, 2))
|
572 |
+
elif len(keypts[0]) == 21:
|
573 |
+
keypts[0] = np.array(keypts[0], dtype=np.float32)
|
574 |
+
else:
|
575 |
+
gr.Info("Number of right hand keypoints should be either 0 or 21.")
|
576 |
+
return None, None
|
577 |
+
|
578 |
+
if len(keypts[1]) == 0:
|
579 |
+
keypts[1] = np.zeros((21, 2))
|
580 |
+
elif len(keypts[1]) == 21:
|
581 |
+
keypts[1] = np.array(keypts[1], dtype=np.float32)
|
582 |
+
else:
|
583 |
+
gr.Info("Number of left hand keypoints should be either 0 or 21.")
|
584 |
+
return None, None
|
585 |
+
|
586 |
+
keypts = np.concatenate(keypts, axis=0)
|
587 |
+
keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
|
588 |
+
# if keypts[0].sum() != 0 and keypts[21].sum() != 0:
|
589 |
+
# input_point = np.array([keypts[0], keypts[21]])
|
590 |
+
# # input_point = keypts
|
591 |
+
# input_label = np.array([1, 1])
|
592 |
+
# # input_label = np.ones_like(input_point[:, 0])
|
593 |
+
# elif keypts[0].sum() != 0:
|
594 |
+
# input_point = np.array(keypts[:1])
|
595 |
+
# # input_point = keypts[:21]
|
596 |
+
# input_label = np.array([1])
|
597 |
+
# # input_label = np.ones_like(input_point[:21, 0])
|
598 |
+
# elif keypts[21].sum() != 0:
|
599 |
+
# input_point = np.array(keypts[21:22])
|
600 |
+
# # input_point = keypts[21:]
|
601 |
+
# input_label = np.array([1])
|
602 |
+
# # input_label = np.ones_like(input_point[21:, 0])
|
603 |
+
|
604 |
+
box_shift_ratio = 0.5
|
605 |
+
box_size_factor = 1.2
|
606 |
+
|
607 |
+
if keypts[0].sum() != 0 and keypts[21].sum() != 0:
|
608 |
+
input_point = np.array(keypts)
|
609 |
+
input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
|
610 |
+
elif keypts[0].sum() != 0:
|
611 |
+
input_point = np.array(keypts[:21])
|
612 |
+
input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
|
613 |
+
elif keypts[21].sum() != 0:
|
614 |
+
input_point = np.array(keypts[21:])
|
615 |
+
input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
|
616 |
+
else:
|
617 |
+
raise ValueError(
|
618 |
+
"Something wrong. If no hand detected, it should not reach here."
|
619 |
+
)
|
620 |
+
|
621 |
+
input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
|
622 |
+
box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
|
623 |
+
input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
|
624 |
+
|
625 |
+
masks, _, _ = sam_predictor.predict(
|
626 |
+
point_coords=input_point,
|
627 |
+
point_labels=input_label,
|
628 |
+
box=input_box[None, :],
|
629 |
+
multimask_output=False,
|
630 |
+
)
|
631 |
+
hand_mask = masks[0]
|
632 |
+
|
633 |
+
inpaint_latent_mask = torch.tensor(
|
634 |
+
cv2.resize(
|
635 |
+
inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
|
636 |
+
),
|
637 |
+
dtype=torch.float,
|
638 |
+
device="cuda",
|
639 |
+
).unsqueeze(0)[None, ...]
|
640 |
+
|
641 |
+
def make_ref_cond(
|
642 |
+
img,
|
643 |
+
keypts,
|
644 |
+
hand_mask,
|
645 |
+
device="cuda",
|
646 |
+
target_size=(256, 256),
|
647 |
+
latent_size=(32, 32),
|
648 |
+
):
|
649 |
+
image_transform = Compose(
|
650 |
+
[
|
651 |
+
ToTensor(),
|
652 |
+
Resize(target_size),
|
653 |
+
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
654 |
+
]
|
655 |
+
)
|
656 |
+
image = image_transform(img).to(device)
|
657 |
+
kpts_valid = check_keypoints_validity(keypts, target_size)
|
658 |
+
heatmaps = torch.tensor(
|
659 |
+
keypoint_heatmap(
|
660 |
+
scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
|
661 |
+
)
|
662 |
+
* kpts_valid[:, None, None],
|
663 |
+
dtype=torch.float,
|
664 |
+
device=device,
|
665 |
+
)[None, ...]
|
666 |
+
mask = torch.tensor(
|
667 |
+
cv2.resize(
|
668 |
+
hand_mask.astype(int),
|
669 |
+
dsize=latent_size,
|
670 |
+
interpolation=cv2.INTER_NEAREST,
|
671 |
+
),
|
672 |
+
dtype=torch.float,
|
673 |
+
device=device,
|
674 |
+
).unsqueeze(0)[None, ...]
|
675 |
+
return image[None, ...], heatmaps, mask
|
676 |
+
|
677 |
+
image, heatmaps, mask = make_ref_cond(
|
678 |
+
img,
|
679 |
+
keypts,
|
680 |
+
hand_mask * (1 - inpaint_mask),
|
681 |
+
device="cuda",
|
682 |
+
target_size=opts.image_size,
|
683 |
+
latent_size=opts.latent_size,
|
684 |
+
)
|
685 |
+
latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
|
686 |
+
target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
|
687 |
+
ref_cond = torch.cat([latent, heatmaps, mask], 1)
|
688 |
+
ref_cond = torch.zeros_like(ref_cond)
|
689 |
+
|
690 |
+
img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
|
691 |
+
assert mask.max() == 1
|
692 |
+
vis_mask32 = mask_image(
|
693 |
+
img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
|
694 |
+
).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()
|
695 |
+
|
696 |
+
assert np.unique(inpaint_mask).shape[0] <= 2
|
697 |
+
assert hand_mask.dtype == bool
|
698 |
+
mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
|
699 |
+
vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
|
700 |
+
np.uint8
|
701 |
+
) # 1 - mask256
|
702 |
+
|
703 |
+
return (
|
704 |
+
ref_cond,
|
705 |
+
target_cond,
|
706 |
+
latent,
|
707 |
+
inpaint_latent_mask,
|
708 |
+
keypts,
|
709 |
+
vis_mask32,
|
710 |
+
vis_mask256,
|
711 |
+
)
|
712 |
+
|
713 |
+
|
714 |
+
def switch_mask_size(radio):
|
715 |
+
if radio == "256x256":
|
716 |
+
out = (gr.update(visible=False), gr.update(visible=True))
|
717 |
+
elif radio == "latent size (32x32)":
|
718 |
+
out = (gr.update(visible=True), gr.update(visible=False))
|
719 |
+
return out
|
720 |
+
|
721 |
+
|
722 |
+
def sample_inpaint(
|
723 |
+
ref_cond,
|
724 |
+
target_cond,
|
725 |
+
latent,
|
726 |
+
inpaint_latent_mask,
|
727 |
+
keypts,
|
728 |
+
num_gen,
|
729 |
+
seed,
|
730 |
+
cfg,
|
731 |
+
quality,
|
732 |
+
):
|
733 |
+
set_seed(seed)
|
734 |
+
N = num_gen
|
735 |
+
jump_length = 10
|
736 |
+
jump_n_sample = quality
|
737 |
+
cfg_scale = cfg
|
738 |
+
z = torch.randn(
|
739 |
+
(N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device="cuda"
|
740 |
+
)
|
741 |
+
target_cond_N = target_cond.repeat(N, 1, 1, 1)
|
742 |
+
ref_cond_N = ref_cond.repeat(N, 1, 1, 1)
|
743 |
+
# novel view synthesis mode = off
|
744 |
+
nvs = torch.zeros(N, dtype=torch.int, device="cuda")
|
745 |
+
z = torch.cat([z, z], 0)
|
746 |
+
model_kwargs = dict(
|
747 |
+
target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
|
748 |
+
ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
|
749 |
+
nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
|
750 |
+
cfg_scale=cfg_scale,
|
751 |
+
)
|
752 |
+
|
753 |
+
samples, _ = diffusion.inpaint_p_sample_loop(
|
754 |
+
model.forward_with_cfg,
|
755 |
+
z.shape,
|
756 |
+
latent,
|
757 |
+
inpaint_latent_mask,
|
758 |
+
z,
|
759 |
+
clip_denoised=False,
|
760 |
+
model_kwargs=model_kwargs,
|
761 |
+
progress=True,
|
762 |
+
device="cuda",
|
763 |
+
jump_length=jump_length,
|
764 |
+
jump_n_sample=jump_n_sample,
|
765 |
+
).chunk(2)
|
766 |
+
sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
|
767 |
+
sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
|
768 |
+
sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
|
769 |
+
|
770 |
+
# visualize
|
771 |
+
results = []
|
772 |
+
results_pose = []
|
773 |
+
for i in range(FIX_MAX_N):
|
774 |
+
if i < num_gen:
|
775 |
+
results.append(sampled_images[i])
|
776 |
+
results_pose.append(visualize_hand(keypts, sampled_images[i]))
|
777 |
+
else:
|
778 |
+
results.append(placeholder)
|
779 |
+
results_pose.append(placeholder)
|
780 |
+
return results, results_pose
|
781 |
+
|
782 |
+
|
783 |
+
def flip_hand(
|
784 |
+
img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None
|
785 |
+
):
|
786 |
+
if cond is None: # clear clicked
|
787 |
+
return None, None, None, None
|
788 |
+
img["composite"] = img["composite"][:, ::-1, :]
|
789 |
+
img["background"] = img["background"][:, ::-1, :]
|
790 |
+
img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
|
791 |
+
pose_img = pose_img[:, ::-1, :]
|
792 |
+
cond = cond.flip(-1)
|
793 |
+
if keypts is not None: # cond is target_cond
|
794 |
+
if keypts[:21, :].sum() != 0:
|
795 |
+
keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
|
796 |
+
# keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
|
797 |
+
if keypts[21:, :].sum() != 0:
|
798 |
+
keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
|
799 |
+
# keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
|
800 |
+
return img, pose_img, cond, keypts
|
801 |
+
|
802 |
+
|
803 |
+
def resize_to_full(img):
|
804 |
+
img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
|
805 |
+
img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
|
806 |
+
img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
|
807 |
+
return img
|
808 |
+
|
809 |
+
|
810 |
+
def clear_all():
|
811 |
+
return (
|
812 |
+
None,
|
813 |
+
None,
|
814 |
+
False,
|
815 |
+
None,
|
816 |
+
None,
|
817 |
+
False,
|
818 |
+
None,
|
819 |
+
None,
|
820 |
+
None,
|
821 |
+
None,
|
822 |
+
None,
|
823 |
+
None,
|
824 |
+
None,
|
825 |
+
1,
|
826 |
+
42,
|
827 |
+
3.0,
|
828 |
+
)
|
829 |
+
|
830 |
+
|
831 |
+
def fix_clear_all():
|
832 |
+
return (
|
833 |
+
None,
|
834 |
+
None,
|
835 |
+
None,
|
836 |
+
None,
|
837 |
+
None,
|
838 |
+
None,
|
839 |
+
None,
|
840 |
+
None,
|
841 |
+
None,
|
842 |
+
None,
|
843 |
+
None,
|
844 |
+
None,
|
845 |
+
None,
|
846 |
+
None,
|
847 |
+
None,
|
848 |
+
None,
|
849 |
+
None,
|
850 |
+
1,
|
851 |
+
# (0,0),
|
852 |
+
42,
|
853 |
+
3.0,
|
854 |
+
10,
|
855 |
+
)
|
856 |
+
|
857 |
+
|
858 |
+
def enable_component(image1, image2):
|
859 |
+
if image1 is None or image2 is None:
|
860 |
+
return gr.update(interactive=False)
|
861 |
+
if "background" in image1 and "layers" in image1 and "composite" in image1:
|
862 |
+
if (
|
863 |
+
image1["background"].sum() == 0
|
864 |
+
and (sum([im.sum() for im in image1["layers"]]) == 0)
|
865 |
+
and image1["composite"].sum() == 0
|
866 |
+
):
|
867 |
+
return gr.update(interactive=False)
|
868 |
+
if "background" in image2 and "layers" in image2 and "composite" in image2:
|
869 |
+
if (
|
870 |
+
image2["background"].sum() == 0
|
871 |
+
and (sum([im.sum() for im in image2["layers"]]) == 0)
|
872 |
+
and image2["composite"].sum() == 0
|
873 |
+
):
|
874 |
+
return gr.update(interactive=False)
|
875 |
+
return gr.update(interactive=True)
|
876 |
+
|
877 |
+
|
878 |
+
def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left):
|
879 |
+
if kpts is None:
|
880 |
+
kpts = [[], []]
|
881 |
+
if "Right hand" not in checkbox:
|
882 |
+
kpts[0] = []
|
883 |
+
vis_right = img_clean
|
884 |
+
update_right = gr.update(visible=False)
|
885 |
+
update_r_info = gr.update(visible=False)
|
886 |
+
else:
|
887 |
+
vis_right = img_pose_right
|
888 |
+
update_right = gr.update(visible=True)
|
889 |
+
update_r_info = gr.update(visible=True)
|
890 |
+
|
891 |
+
if "Left hand" not in checkbox:
|
892 |
+
kpts[1] = []
|
893 |
+
vis_left = img_clean
|
894 |
+
update_left = gr.update(visible=False)
|
895 |
+
update_l_info = gr.update(visible=False)
|
896 |
+
else:
|
897 |
+
vis_left = img_pose_left
|
898 |
+
update_left = gr.update(visible=True)
|
899 |
+
update_l_info = gr.update(visible=True)
|
900 |
+
|
901 |
+
return (
|
902 |
+
kpts,
|
903 |
+
vis_right,
|
904 |
+
vis_left,
|
905 |
+
update_right,
|
906 |
+
update_right,
|
907 |
+
update_right,
|
908 |
+
update_left,
|
909 |
+
update_left,
|
910 |
+
update_left,
|
911 |
+
update_r_info,
|
912 |
+
update_l_info,
|
913 |
+
)
|
914 |
+
|
915 |
+
|
916 |
+
# def parse_fix_example(ex_img, ex_masked):
|
917 |
+
# original_img = ex_img
|
918 |
+
# # ex_img = cv2.resize(ex_img, (LENGTH, LENGTH), interpolation=cv2.INTER_AREA)
|
919 |
+
# # ex_masked = cv2.resize(ex_masked, (LENGTH, LENGTH), interpolation=cv2.INTER_AREA)
|
920 |
+
# inpaint_mask = np.all(ex_masked > 250, axis=-1).astype(np.uint8)
|
921 |
+
# layer = np.ones_like(ex_img) * 255
|
922 |
+
# layer = np.concatenate([layer, np.zeros_like(ex_img[..., 0:1])], axis=-1)
|
923 |
+
# layer[inpaint_mask == 1, 3] = 255
|
924 |
+
# ref_value = {
|
925 |
+
# "composite": ex_masked,
|
926 |
+
# "background": ex_img,
|
927 |
+
# "layers": [layer],
|
928 |
+
# }
|
929 |
+
# inpaint_mask = cv2.resize(
|
930 |
+
# inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
|
931 |
+
# )
|
932 |
+
# kp_img = visualize_ref(ref_value)
|
933 |
+
# return (
|
934 |
+
# original_img,
|
935 |
+
# gr.update(value=ref_value),
|
936 |
+
# kp_img,
|
937 |
+
# inpaint_mask,
|
938 |
+
# )
|
939 |
+
|
940 |
+
|
941 |
+
LENGTH = 480
|
942 |
+
|
943 |
+
example_imgs = [
|
944 |
+
[
|
945 |
+
"sample_images/sample1.jpg",
|
946 |
+
],
|
947 |
+
[
|
948 |
+
"sample_images/sample2.jpg",
|
949 |
+
],
|
950 |
+
[
|
951 |
+
"sample_images/sample3.jpg",
|
952 |
+
],
|
953 |
+
[
|
954 |
+
"sample_images/sample4.jpg",
|
955 |
+
],
|
956 |
+
[
|
957 |
+
"sample_images/sample5.jpg",
|
958 |
+
],
|
959 |
+
[
|
960 |
+
"sample_images/sample6.jpg",
|
961 |
+
],
|
962 |
+
[
|
963 |
+
"sample_images/sample7.jpg",
|
964 |
+
],
|
965 |
+
[
|
966 |
+
"sample_images/sample8.jpg",
|
967 |
+
],
|
968 |
+
[
|
969 |
+
"sample_images/sample9.jpg",
|
970 |
+
],
|
971 |
+
[
|
972 |
+
"sample_images/sample10.jpg",
|
973 |
+
],
|
974 |
+
[
|
975 |
+
"sample_images/sample11.jpg",
|
976 |
+
],
|
977 |
+
["pose_images/pose1.jpg"],
|
978 |
+
["pose_images/pose2.jpg"],
|
979 |
+
["pose_images/pose3.jpg"],
|
980 |
+
["pose_images/pose4.jpg"],
|
981 |
+
["pose_images/pose5.jpg"],
|
982 |
+
["pose_images/pose6.jpg"],
|
983 |
+
["pose_images/pose7.jpg"],
|
984 |
+
["pose_images/pose8.jpg"],
|
985 |
+
]
|
986 |
+
|
987 |
+
fix_example_imgs = [
|
988 |
+
["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"],
|
989 |
+
["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"],
|
990 |
+
["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"],
|
991 |
+
["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
|
992 |
+
["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
|
993 |
+
["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
|
994 |
+
["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
|
995 |
+
["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
|
996 |
+
["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
|
997 |
+
["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
|
998 |
+
["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"],
|
999 |
+
["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"],
|
1000 |
+
["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"],
|
1001 |
+
]
|
1002 |
+
custom_css = """
|
1003 |
+
.gradio-container .examples img {
|
1004 |
+
width: 240px !important;
|
1005 |
+
height: 240px !important;
|
1006 |
+
}
|
1007 |
+
"""
|
1008 |
+
|
1009 |
+
|
1010 |
+
with gr.Blocks(css=custom_css) as demo:
|
1011 |
+
with gr.Tab("Edit Hand Poses"):
|
1012 |
+
ref_img = gr.State(value=None)
|
1013 |
+
ref_cond = gr.State(value=None)
|
1014 |
+
keypts = gr.State(value=None)
|
1015 |
+
target_img = gr.State(value=None)
|
1016 |
+
target_cond = gr.State(value=None)
|
1017 |
+
target_keypts = gr.State(value=None)
|
1018 |
+
dump = gr.State(value=None)
|
1019 |
+
with gr.Row():
|
1020 |
+
with gr.Column():
|
1021 |
+
gr.Markdown(
|
1022 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Reference</p>"""
|
1023 |
+
)
|
1024 |
+
gr.Markdown("""<p style="text-align: center;"><br></p>""")
|
1025 |
+
ref = gr.ImageEditor(
|
1026 |
+
type="numpy",
|
1027 |
+
label="Reference",
|
1028 |
+
show_label=True,
|
1029 |
+
height=LENGTH,
|
1030 |
+
width=LENGTH,
|
1031 |
+
brush=False,
|
1032 |
+
layers=False,
|
1033 |
+
crop_size="1:1",
|
1034 |
+
)
|
1035 |
+
ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
|
1036 |
+
ref_pose = gr.Image(
|
1037 |
+
type="numpy",
|
1038 |
+
label="Reference Pose",
|
1039 |
+
show_label=True,
|
1040 |
+
height=LENGTH,
|
1041 |
+
width=LENGTH,
|
1042 |
+
interactive=False,
|
1043 |
+
)
|
1044 |
+
ref_flip = gr.Checkbox(
|
1045 |
+
value=False, label="Flip Handedness (Reference)", interactive=False
|
1046 |
+
)
|
1047 |
+
with gr.Column():
|
1048 |
+
gr.Markdown(
|
1049 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">2. Target</p>"""
|
1050 |
+
)
|
1051 |
+
target = gr.ImageEditor(
|
1052 |
+
type="numpy",
|
1053 |
+
label="Target",
|
1054 |
+
show_label=True,
|
1055 |
+
height=LENGTH,
|
1056 |
+
width=LENGTH,
|
1057 |
+
brush=False,
|
1058 |
+
layers=False,
|
1059 |
+
crop_size="1:1",
|
1060 |
+
)
|
1061 |
+
target_finish_crop = gr.Button(
|
1062 |
+
value="Finish Cropping", interactive=False
|
1063 |
+
)
|
1064 |
+
target_pose = gr.Image(
|
1065 |
+
type="numpy",
|
1066 |
+
label="Target Pose",
|
1067 |
+
show_label=True,
|
1068 |
+
height=LENGTH,
|
1069 |
+
width=LENGTH,
|
1070 |
+
interactive=False,
|
1071 |
+
)
|
1072 |
+
target_flip = gr.Checkbox(
|
1073 |
+
value=False, label="Flip Handedness (Target)", interactive=False
|
1074 |
+
)
|
1075 |
+
with gr.Column():
|
1076 |
+
gr.Markdown(
|
1077 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold;">3. Result</p>"""
|
1078 |
+
)
|
1079 |
+
gr.Markdown(
|
1080 |
+
"""<p style="text-align: center;">Run is enabled after the images have been processed</p>"""
|
1081 |
+
)
|
1082 |
+
run = gr.Button(value="Run", interactive=False)
|
1083 |
+
gr.Markdown(
|
1084 |
+
"""<p style="text-align: center;">~20s per generation. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
|
1085 |
+
)
|
1086 |
+
results = gr.Gallery(
|
1087 |
+
type="numpy",
|
1088 |
+
label="Results",
|
1089 |
+
show_label=True,
|
1090 |
+
height=LENGTH,
|
1091 |
+
min_width=LENGTH,
|
1092 |
+
columns=MAX_N,
|
1093 |
+
interactive=False,
|
1094 |
+
preview=True,
|
1095 |
+
)
|
1096 |
+
results_pose = gr.Gallery(
|
1097 |
+
type="numpy",
|
1098 |
+
label="Results Pose",
|
1099 |
+
show_label=True,
|
1100 |
+
height=LENGTH,
|
1101 |
+
min_width=LENGTH,
|
1102 |
+
columns=MAX_N,
|
1103 |
+
interactive=False,
|
1104 |
+
preview=True,
|
1105 |
+
)
|
1106 |
+
clear = gr.ClearButton()
|
1107 |
+
|
1108 |
+
with gr.Row():
|
1109 |
+
n_generation = gr.Slider(
|
1110 |
+
label="Number of generations",
|
1111 |
+
value=1,
|
1112 |
+
minimum=1,
|
1113 |
+
maximum=MAX_N,
|
1114 |
+
step=1,
|
1115 |
+
randomize=False,
|
1116 |
+
interactive=True,
|
1117 |
+
)
|
1118 |
+
seed = gr.Slider(
|
1119 |
+
label="Seed",
|
1120 |
+
value=42,
|
1121 |
+
minimum=0,
|
1122 |
+
maximum=10000,
|
1123 |
+
step=1,
|
1124 |
+
randomize=False,
|
1125 |
+
interactive=True,
|
1126 |
+
)
|
1127 |
+
cfg = gr.Slider(
|
1128 |
+
label="Classifier free guidance scale",
|
1129 |
+
value=2.5,
|
1130 |
+
minimum=0.0,
|
1131 |
+
maximum=10.0,
|
1132 |
+
step=0.1,
|
1133 |
+
randomize=False,
|
1134 |
+
interactive=True,
|
1135 |
+
)
|
1136 |
+
|
1137 |
+
ref.change(enable_component, [ref, ref], ref_finish_crop)
|
1138 |
+
ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond])
|
1139 |
+
ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
|
1140 |
+
ref_flip.select(
|
1141 |
+
flip_hand, [ref, ref_pose, ref_cond], [ref, ref_pose, ref_cond, dump]
|
1142 |
+
)
|
1143 |
+
target.change(enable_component, [target, target], target_finish_crop)
|
1144 |
+
target_finish_crop.click(
|
1145 |
+
get_target_anno,
|
1146 |
+
[target],
|
1147 |
+
[target_img, target_pose, target_cond, target_keypts],
|
1148 |
+
)
|
1149 |
+
target_pose.change(enable_component, [target_img, target_pose], target_flip)
|
1150 |
+
target_flip.select(
|
1151 |
+
flip_hand,
|
1152 |
+
[target, target_pose, target_cond, target_keypts],
|
1153 |
+
[target, target_pose, target_cond, target_keypts],
|
1154 |
+
)
|
1155 |
+
ref_pose.change(enable_component, [ref_pose, target_pose], run)
|
1156 |
+
target_pose.change(enable_component, [ref_pose, target_pose], run)
|
1157 |
+
run.click(
|
1158 |
+
sample_diff,
|
1159 |
+
[ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
|
1160 |
+
[results, results_pose],
|
1161 |
+
)
|
1162 |
+
clear.click(
|
1163 |
+
clear_all,
|
1164 |
+
[],
|
1165 |
+
[
|
1166 |
+
ref,
|
1167 |
+
ref_pose,
|
1168 |
+
ref_flip,
|
1169 |
+
target,
|
1170 |
+
target_pose,
|
1171 |
+
target_flip,
|
1172 |
+
results,
|
1173 |
+
results_pose,
|
1174 |
+
ref_img,
|
1175 |
+
ref_cond,
|
1176 |
+
# mask,
|
1177 |
+
target_img,
|
1178 |
+
target_cond,
|
1179 |
+
target_keypts,
|
1180 |
+
n_generation,
|
1181 |
+
seed,
|
1182 |
+
cfg,
|
1183 |
+
],
|
1184 |
+
)
|
1185 |
+
|
1186 |
+
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
|
1187 |
+
with gr.Tab("Reference"):
|
1188 |
+
with gr.Row():
|
1189 |
+
gr.Examples(example_imgs, [ref], examples_per_page=20)
|
1190 |
+
with gr.Tab("Target"):
|
1191 |
+
with gr.Row():
|
1192 |
+
gr.Examples(example_imgs, [target], examples_per_page=20)
|
1193 |
+
with gr.Tab("Fix Hands"):
|
1194 |
+
fix_inpaint_mask = gr.State(value=None)
|
1195 |
+
fix_original = gr.State(value=None)
|
1196 |
+
fix_img = gr.State(value=None)
|
1197 |
+
fix_kpts = gr.State(value=None)
|
1198 |
+
fix_kpts_np = gr.State(value=None)
|
1199 |
+
fix_ref_cond = gr.State(value=None)
|
1200 |
+
fix_target_cond = gr.State(value=None)
|
1201 |
+
fix_latent = gr.State(value=None)
|
1202 |
+
fix_inpaint_latent = gr.State(value=None)
|
1203 |
+
# fix_size_memory = gr.State(value=(0, 0))
|
1204 |
+
with gr.Row():
|
1205 |
+
with gr.Column():
|
1206 |
+
gr.Markdown(
|
1207 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>"""
|
1208 |
+
)
|
1209 |
+
gr.Markdown(
|
1210 |
+
"""<p style="text-align: center;">Crop the image around the hand.<br>Then, brush area (e.g., wrong finger) that needs to be fixed.</p>"""
|
1211 |
+
)
|
1212 |
+
gr.Markdown(
|
1213 |
+
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>"""
|
1214 |
+
)
|
1215 |
+
fix_crop = gr.ImageEditor(
|
1216 |
+
type="numpy",
|
1217 |
+
sources=["upload", "webcam", "clipboard"],
|
1218 |
+
label="Image crop",
|
1219 |
+
show_label=True,
|
1220 |
+
height=LENGTH,
|
1221 |
+
width=LENGTH,
|
1222 |
+
layers=False,
|
1223 |
+
crop_size="1:1",
|
1224 |
+
brush=False,
|
1225 |
+
image_mode="RGBA",
|
1226 |
+
container=False,
|
1227 |
+
)
|
1228 |
+
gr.Markdown(
|
1229 |
+
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>"""
|
1230 |
+
)
|
1231 |
+
fix_ref = gr.ImageEditor(
|
1232 |
+
type="numpy",
|
1233 |
+
label="Image brush",
|
1234 |
+
sources=(),
|
1235 |
+
show_label=True,
|
1236 |
+
height=LENGTH,
|
1237 |
+
width=LENGTH,
|
1238 |
+
layers=False,
|
1239 |
+
transforms=("brush"),
|
1240 |
+
brush=gr.Brush(
|
1241 |
+
colors=["rgb(255, 255, 255)"], default_size=20
|
1242 |
+
), # 204, 50, 50
|
1243 |
+
image_mode="RGBA",
|
1244 |
+
container=False,
|
1245 |
+
interactive=False,
|
1246 |
+
)
|
1247 |
+
fix_finish_crop = gr.Button(
|
1248 |
+
value="Finish Croping & Brushing", interactive=False
|
1249 |
+
)
|
1250 |
+
gr.Markdown(
|
1251 |
+
"""<p style="text-align: left; font-size: 20px; font-weight: bold; ">OpenPose keypoints convention</p>"""
|
1252 |
+
)
|
1253 |
+
fix_openpose = gr.Image(
|
1254 |
+
value="openpose.png",
|
1255 |
+
type="numpy",
|
1256 |
+
label="OpenPose keypoints convention",
|
1257 |
+
show_label=True,
|
1258 |
+
height=LENGTH // 3 * 2,
|
1259 |
+
width=LENGTH // 3 * 2,
|
1260 |
+
interactive=False,
|
1261 |
+
)
|
1262 |
+
with gr.Column():
|
1263 |
+
gr.Markdown(
|
1264 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>"""
|
1265 |
+
)
|
1266 |
+
gr.Markdown(
|
1267 |
+
"""<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\" on the bottom left.</p>"""
|
1268 |
+
)
|
1269 |
+
fix_checkbox = gr.CheckboxGroup(
|
1270 |
+
["Right hand", "Left hand"],
|
1271 |
+
# value=["Right hand", "Left hand"],
|
1272 |
+
label="Hand side",
|
1273 |
+
info="Which side this hand is? Could be both.",
|
1274 |
+
interactive=False,
|
1275 |
+
)
|
1276 |
+
fix_kp_r_info = gr.Markdown(
|
1277 |
+
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
|
1278 |
+
visible=False,
|
1279 |
+
)
|
1280 |
+
fix_kp_right = gr.Image(
|
1281 |
+
type="numpy",
|
1282 |
+
label="Keypoint Selection (right hand)",
|
1283 |
+
show_label=True,
|
1284 |
+
height=LENGTH,
|
1285 |
+
width=LENGTH,
|
1286 |
+
interactive=False,
|
1287 |
+
visible=False,
|
1288 |
+
sources=[],
|
1289 |
+
)
|
1290 |
+
with gr.Row():
|
1291 |
+
fix_undo_right = gr.Button(
|
1292 |
+
value="Undo", interactive=False, visible=False
|
1293 |
+
)
|
1294 |
+
fix_reset_right = gr.Button(
|
1295 |
+
value="Reset", interactive=False, visible=False
|
1296 |
+
)
|
1297 |
+
fix_kp_l_info = gr.Markdown(
|
1298 |
+
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
|
1299 |
+
visible=False
|
1300 |
+
)
|
1301 |
+
fix_kp_left = gr.Image(
|
1302 |
+
type="numpy",
|
1303 |
+
label="Keypoint Selection (left hand)",
|
1304 |
+
show_label=True,
|
1305 |
+
height=LENGTH,
|
1306 |
+
width=LENGTH,
|
1307 |
+
interactive=False,
|
1308 |
+
visible=False,
|
1309 |
+
sources=[],
|
1310 |
+
)
|
1311 |
+
with gr.Row():
|
1312 |
+
fix_undo_left = gr.Button(
|
1313 |
+
value="Undo", interactive=False, visible=False
|
1314 |
+
)
|
1315 |
+
fix_reset_left = gr.Button(
|
1316 |
+
value="Reset", interactive=False, visible=False
|
1317 |
+
)
|
1318 |
+
with gr.Column():
|
1319 |
+
gr.Markdown(
|
1320 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>"""
|
1321 |
+
)
|
1322 |
+
gr.Markdown(
|
1323 |
+
"""<p style="text-align: center;">In Fix Hands, not segmentation mask, but only inpaint mask is used.</p>"""
|
1324 |
+
)
|
1325 |
+
fix_ready = gr.Button(value="Ready", interactive=False)
|
1326 |
+
fix_mask_size = gr.Radio(
|
1327 |
+
["256x256", "latent size (32x32)"],
|
1328 |
+
label="Visualized inpaint mask size",
|
1329 |
+
interactive=False,
|
1330 |
+
value="256x256",
|
1331 |
+
)
|
1332 |
+
gr.Markdown(
|
1333 |
+
"""<p style="text-align: center; font-size: 20px; font-weight: bold; ">Visualized inpaint masks</p>"""
|
1334 |
+
)
|
1335 |
+
fix_vis_mask32 = gr.Image(
|
1336 |
+
type="numpy",
|
1337 |
+
label=f"Visualized {opts.latent_size} Inpaint Mask",
|
1338 |
+
show_label=True,
|
1339 |
+
height=opts.latent_size,
|
1340 |
+
width=opts.latent_size,
|
1341 |
+
interactive=False,
|
1342 |
+
visible=False,
|
1343 |
+
)
|
1344 |
+
fix_vis_mask256 = gr.Image(
|
1345 |
+
type="numpy",
|
1346 |
+
label=f"Visualized {opts.image_size} Inpaint Mask",
|
1347 |
+
visible=True,
|
1348 |
+
show_label=True,
|
1349 |
+
height=opts.image_size,
|
1350 |
+
width=opts.image_size,
|
1351 |
+
interactive=False,
|
1352 |
+
)
|
1353 |
+
with gr.Column():
|
1354 |
+
gr.Markdown(
|
1355 |
+
"""<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>"""
|
1356 |
+
)
|
1357 |
+
fix_run = gr.Button(value="Run", interactive=False)
|
1358 |
+
gr.Markdown(
|
1359 |
+
"""<p style="text-align: center;">>3min and ~24GB per generation</p>"""
|
1360 |
+
)
|
1361 |
+
fix_result = gr.Gallery(
|
1362 |
+
type="numpy",
|
1363 |
+
label="Results",
|
1364 |
+
show_label=True,
|
1365 |
+
height=LENGTH,
|
1366 |
+
min_width=LENGTH,
|
1367 |
+
columns=FIX_MAX_N,
|
1368 |
+
interactive=False,
|
1369 |
+
preview=True,
|
1370 |
+
)
|
1371 |
+
fix_result_pose = gr.Gallery(
|
1372 |
+
type="numpy",
|
1373 |
+
label="Results Pose",
|
1374 |
+
show_label=True,
|
1375 |
+
height=LENGTH,
|
1376 |
+
min_width=LENGTH,
|
1377 |
+
columns=FIX_MAX_N,
|
1378 |
+
interactive=False,
|
1379 |
+
preview=True,
|
1380 |
+
)
|
1381 |
+
fix_clear = gr.ClearButton()
|
1382 |
+
gr.Markdown(
|
1383 |
+
"[NOTE] Currently, Number of generation > 1 could lead to out-of-memory"
|
1384 |
+
)
|
1385 |
+
with gr.Row():
|
1386 |
+
fix_n_generation = gr.Slider(
|
1387 |
+
label="Number of generations",
|
1388 |
+
value=1,
|
1389 |
+
minimum=1,
|
1390 |
+
maximum=FIX_MAX_N,
|
1391 |
+
step=1,
|
1392 |
+
randomize=False,
|
1393 |
+
interactive=True,
|
1394 |
+
)
|
1395 |
+
fix_seed = gr.Slider(
|
1396 |
+
label="Seed",
|
1397 |
+
value=42,
|
1398 |
+
minimum=0,
|
1399 |
+
maximum=10000,
|
1400 |
+
step=1,
|
1401 |
+
randomize=False,
|
1402 |
+
interactive=True,
|
1403 |
+
)
|
1404 |
+
fix_cfg = gr.Slider(
|
1405 |
+
label="Classifier free guidance scale",
|
1406 |
+
value=3.0,
|
1407 |
+
minimum=0.0,
|
1408 |
+
maximum=10.0,
|
1409 |
+
step=0.1,
|
1410 |
+
randomize=False,
|
1411 |
+
interactive=True,
|
1412 |
+
)
|
1413 |
+
fix_quality = gr.Slider(
|
1414 |
+
label="Quality",
|
1415 |
+
value=10,
|
1416 |
+
minimum=1,
|
1417 |
+
maximum=10,
|
1418 |
+
step=1,
|
1419 |
+
randomize=False,
|
1420 |
+
interactive=True,
|
1421 |
+
)
|
1422 |
+
fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
|
1423 |
+
fix_crop.change(resize_to_full, fix_crop, fix_ref)
|
1424 |
+
fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
|
1425 |
+
fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
|
1426 |
+
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right])
|
1427 |
+
# fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left])
|
1428 |
+
fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
|
1429 |
+
fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
|
1430 |
+
fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
|
1431 |
+
fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
|
1432 |
+
fix_inpaint_mask.change(
|
1433 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
|
1434 |
+
)
|
1435 |
+
fix_inpaint_mask.change(
|
1436 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
|
1437 |
+
)
|
1438 |
+
fix_inpaint_mask.change(
|
1439 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
|
1440 |
+
)
|
1441 |
+
fix_inpaint_mask.change(
|
1442 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
|
1443 |
+
)
|
1444 |
+
fix_inpaint_mask.change(
|
1445 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
|
1446 |
+
)
|
1447 |
+
fix_inpaint_mask.change(
|
1448 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
|
1449 |
+
)
|
1450 |
+
fix_inpaint_mask.change(
|
1451 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
|
1452 |
+
)
|
1453 |
+
fix_inpaint_mask.change(
|
1454 |
+
enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready
|
1455 |
+
)
|
1456 |
+
# fix_inpaint_mask.change(
|
1457 |
+
# enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run
|
1458 |
+
# )
|
1459 |
+
fix_checkbox.select(
|
1460 |
+
set_visible,
|
1461 |
+
[fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
|
1462 |
+
[
|
1463 |
+
fix_kpts,
|
1464 |
+
fix_kp_right,
|
1465 |
+
fix_kp_left,
|
1466 |
+
fix_kp_right,
|
1467 |
+
fix_undo_right,
|
1468 |
+
fix_reset_right,
|
1469 |
+
fix_kp_left,
|
1470 |
+
fix_undo_left,
|
1471 |
+
fix_reset_left,
|
1472 |
+
fix_kp_r_info,
|
1473 |
+
fix_kp_l_info,
|
1474 |
+
],
|
1475 |
+
)
|
1476 |
+
fix_kp_right.select(
|
1477 |
+
get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
|
1478 |
+
)
|
1479 |
+
fix_undo_right.click(
|
1480 |
+
undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
|
1481 |
+
)
|
1482 |
+
fix_reset_right.click(
|
1483 |
+
reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
|
1484 |
+
)
|
1485 |
+
fix_kp_left.select(
|
1486 |
+
get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
|
1487 |
+
)
|
1488 |
+
fix_undo_left.click(
|
1489 |
+
undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
|
1490 |
+
)
|
1491 |
+
fix_reset_left.click(
|
1492 |
+
reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
|
1493 |
+
)
|
1494 |
+
# fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run])
|
1495 |
+
# fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose])
|
1496 |
+
fix_vis_mask32.change(
|
1497 |
+
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
|
1498 |
+
)
|
1499 |
+
fix_vis_mask32.change(
|
1500 |
+
enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size
|
1501 |
+
)
|
1502 |
+
fix_ready.click(
|
1503 |
+
ready_sample,
|
1504 |
+
[fix_original, fix_inpaint_mask, fix_kpts],
|
1505 |
+
[
|
1506 |
+
fix_ref_cond,
|
1507 |
+
fix_target_cond,
|
1508 |
+
fix_latent,
|
1509 |
+
fix_inpaint_latent,
|
1510 |
+
fix_kpts_np,
|
1511 |
+
fix_vis_mask32,
|
1512 |
+
fix_vis_mask256,
|
1513 |
+
],
|
1514 |
+
)
|
1515 |
+
fix_mask_size.select(
|
1516 |
+
switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256]
|
1517 |
+
)
|
1518 |
+
fix_run.click(
|
1519 |
+
sample_inpaint,
|
1520 |
+
[
|
1521 |
+
fix_ref_cond,
|
1522 |
+
fix_target_cond,
|
1523 |
+
fix_latent,
|
1524 |
+
fix_inpaint_latent,
|
1525 |
+
fix_kpts_np,
|
1526 |
+
fix_n_generation,
|
1527 |
+
fix_seed,
|
1528 |
+
fix_cfg,
|
1529 |
+
fix_quality,
|
1530 |
+
],
|
1531 |
+
[fix_result, fix_result_pose],
|
1532 |
+
)
|
1533 |
+
fix_clear.click(
|
1534 |
+
fix_clear_all,
|
1535 |
+
[],
|
1536 |
+
[
|
1537 |
+
fix_crop,
|
1538 |
+
fix_ref,
|
1539 |
+
fix_kp_right,
|
1540 |
+
fix_kp_left,
|
1541 |
+
fix_result,
|
1542 |
+
fix_result_pose,
|
1543 |
+
fix_inpaint_mask,
|
1544 |
+
fix_original,
|
1545 |
+
fix_img,
|
1546 |
+
fix_vis_mask32,
|
1547 |
+
fix_vis_mask256,
|
1548 |
+
fix_kpts,
|
1549 |
+
fix_kpts_np,
|
1550 |
+
fix_ref_cond,
|
1551 |
+
fix_target_cond,
|
1552 |
+
fix_latent,
|
1553 |
+
fix_inpaint_latent,
|
1554 |
+
fix_n_generation,
|
1555 |
+
# fix_size_memory,
|
1556 |
+
fix_seed,
|
1557 |
+
fix_cfg,
|
1558 |
+
fix_quality,
|
1559 |
+
],
|
1560 |
+
)
|
1561 |
+
|
1562 |
+
gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
|
1563 |
+
fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False)
|
1564 |
+
fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False)
|
1565 |
+
with gr.Column():
|
1566 |
+
fix_example = gr.Examples(
|
1567 |
+
fix_example_imgs,
|
1568 |
+
# run_on_click=True,
|
1569 |
+
# fn=parse_fix_example,
|
1570 |
+
# inputs=[fix_dump_ex, fix_dump_ex_masked],
|
1571 |
+
# outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask],
|
1572 |
+
inputs=[fix_crop],
|
1573 |
+
examples_per_page=20,
|
1574 |
+
)
|
1575 |
+
|
1576 |
+
|
1577 |
+
print("Ready to launch..")
|
1578 |
+
_, _, shared_url = demo.queue().launch(
|
1579 |
+
share=True, server_name="0.0.0.0", server_port=7739
|
1580 |
+
)
|
1581 |
+
demo.block()
|
diffusion/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from . import gaussian_diffusion as gd
|
7 |
+
from .respace import SpacedDiffusion, space_timesteps
|
8 |
+
|
9 |
+
|
10 |
+
def create_diffusion(
|
11 |
+
timestep_respacing,
|
12 |
+
noise_schedule="linear",
|
13 |
+
use_kl=False,
|
14 |
+
sigma_small=False,
|
15 |
+
predict_xstart=False,
|
16 |
+
learn_sigma=True,
|
17 |
+
rescale_learned_sigmas=False,
|
18 |
+
diffusion_steps=1000
|
19 |
+
):
|
20 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
21 |
+
if use_kl:
|
22 |
+
loss_type = gd.LossType.RESCALED_KL
|
23 |
+
elif rescale_learned_sigmas:
|
24 |
+
loss_type = gd.LossType.RESCALED_MSE
|
25 |
+
else:
|
26 |
+
loss_type = gd.LossType.MSE
|
27 |
+
if timestep_respacing is None or timestep_respacing == "":
|
28 |
+
timestep_respacing = [diffusion_steps]
|
29 |
+
return SpacedDiffusion(
|
30 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
31 |
+
betas=betas,
|
32 |
+
model_mean_type=(
|
33 |
+
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
|
34 |
+
),
|
35 |
+
model_var_type=(
|
36 |
+
(
|
37 |
+
gd.ModelVarType.FIXED_LARGE
|
38 |
+
if not sigma_small
|
39 |
+
else gd.ModelVarType.FIXED_SMALL
|
40 |
+
)
|
41 |
+
if not learn_sigma
|
42 |
+
else gd.ModelVarType.LEARNED_RANGE
|
43 |
+
),
|
44 |
+
loss_type=loss_type
|
45 |
+
# rescale_timesteps=rescale_timesteps,
|
46 |
+
)
|
diffusion/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (986 Bytes). View file
|
|
diffusion/__pycache__/diffusion_utils.cpython-38.pyc
ADDED
Binary file (2.86 kB). View file
|
|
diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc
ADDED
Binary file (27.6 kB). View file
|
|
diffusion/__pycache__/respace.cpython-38.pyc
ADDED
Binary file (5.04 kB). View file
|
|
diffusion/__pycache__/scheduler.cpython-38.pyc
ADDED
Binary file (3.99 kB). View file
|
|
diffusion/diffusion_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import torch as th
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [
|
26 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
27 |
+
for x in (logvar1, logvar2)
|
28 |
+
]
|
29 |
+
|
30 |
+
return 0.5 * (
|
31 |
+
-1.0
|
32 |
+
+ logvar2
|
33 |
+
- logvar1
|
34 |
+
+ th.exp(logvar1 - logvar2)
|
35 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def approx_standard_normal_cdf(x):
|
40 |
+
"""
|
41 |
+
A fast approximation of the cumulative distribution function of the
|
42 |
+
standard normal.
|
43 |
+
"""
|
44 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
45 |
+
|
46 |
+
|
47 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
48 |
+
"""
|
49 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
50 |
+
:param x: the targets
|
51 |
+
:param means: the Gaussian mean Tensor.
|
52 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
53 |
+
:return: a tensor like x of log probabilities (in nats).
|
54 |
+
"""
|
55 |
+
centered_x = x - means
|
56 |
+
inv_stdv = th.exp(-log_scales)
|
57 |
+
normalized_x = centered_x * inv_stdv
|
58 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
59 |
+
return log_probs
|
60 |
+
|
61 |
+
|
62 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
63 |
+
"""
|
64 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
65 |
+
given image.
|
66 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
67 |
+
rescaled to the range [-1, 1].
|
68 |
+
:param means: the Gaussian mean Tensor.
|
69 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
70 |
+
:return: a tensor like x of log probabilities (in nats).
|
71 |
+
"""
|
72 |
+
assert x.shape == means.shape == log_scales.shape
|
73 |
+
centered_x = x - means
|
74 |
+
inv_stdv = th.exp(-log_scales)
|
75 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
76 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
77 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
78 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
79 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
80 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
81 |
+
cdf_delta = cdf_plus - cdf_min
|
82 |
+
log_probs = th.where(
|
83 |
+
x < -0.999,
|
84 |
+
log_cdf_plus,
|
85 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
86 |
+
)
|
87 |
+
assert log_probs.shape == x.shape
|
88 |
+
return log_probs
|
diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
import enum
|
12 |
+
from collections import defaultdict
|
13 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
+
from .scheduler import get_schedule_jump
|
15 |
+
|
16 |
+
|
17 |
+
def mean_flat(tensor):
|
18 |
+
"""
|
19 |
+
Take the mean over all non-batch dimensions.
|
20 |
+
"""
|
21 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
22 |
+
|
23 |
+
|
24 |
+
class ModelMeanType(enum.Enum):
|
25 |
+
"""
|
26 |
+
Which type of output the model predicts.
|
27 |
+
"""
|
28 |
+
|
29 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
30 |
+
START_X = enum.auto() # the model predicts x_0
|
31 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
32 |
+
|
33 |
+
|
34 |
+
class ModelVarType(enum.Enum):
|
35 |
+
"""
|
36 |
+
What is used as the model's output variance.
|
37 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
38 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
39 |
+
"""
|
40 |
+
|
41 |
+
LEARNED = enum.auto()
|
42 |
+
FIXED_SMALL = enum.auto()
|
43 |
+
FIXED_LARGE = enum.auto()
|
44 |
+
LEARNED_RANGE = enum.auto()
|
45 |
+
|
46 |
+
|
47 |
+
class LossType(enum.Enum):
|
48 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
49 |
+
RESCALED_MSE = (
|
50 |
+
enum.auto()
|
51 |
+
) # use raw MSE loss (with RESCALED_KL when learning variances)
|
52 |
+
KL = enum.auto() # use the variational lower-bound
|
53 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
54 |
+
|
55 |
+
def is_vb(self):
|
56 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
57 |
+
|
58 |
+
|
59 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
60 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
61 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
62 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
63 |
+
return betas
|
64 |
+
|
65 |
+
|
66 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
67 |
+
"""
|
68 |
+
This is the deprecated API for creating beta schedules.
|
69 |
+
See get_named_beta_schedule() for the new library of schedules.
|
70 |
+
"""
|
71 |
+
if beta_schedule == "quad":
|
72 |
+
betas = (
|
73 |
+
np.linspace(
|
74 |
+
beta_start ** 0.5,
|
75 |
+
beta_end ** 0.5,
|
76 |
+
num_diffusion_timesteps,
|
77 |
+
dtype=np.float64,
|
78 |
+
)
|
79 |
+
** 2
|
80 |
+
)
|
81 |
+
elif beta_schedule == "linear":
|
82 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
83 |
+
elif beta_schedule == "warmup10":
|
84 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
85 |
+
elif beta_schedule == "warmup50":
|
86 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
87 |
+
elif beta_schedule == "const":
|
88 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
89 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
90 |
+
betas = 1.0 / np.linspace(
|
91 |
+
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
raise NotImplementedError(beta_schedule)
|
95 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
96 |
+
return betas
|
97 |
+
|
98 |
+
|
99 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
100 |
+
"""
|
101 |
+
Get a pre-defined beta schedule for the given name.
|
102 |
+
The beta schedule library consists of beta schedules which remain similar
|
103 |
+
in the limit of num_diffusion_timesteps.
|
104 |
+
Beta schedules may be added, but should not be removed or changed once
|
105 |
+
they are committed to maintain backwards compatibility.
|
106 |
+
"""
|
107 |
+
if schedule_name == "linear":
|
108 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
109 |
+
# diffusion steps.
|
110 |
+
scale = 1000 / num_diffusion_timesteps
|
111 |
+
return get_beta_schedule(
|
112 |
+
"linear",
|
113 |
+
beta_start=scale * 0.0001,
|
114 |
+
beta_end=scale * 0.02,
|
115 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
116 |
+
)
|
117 |
+
elif schedule_name == "squaredcos_cap_v2":
|
118 |
+
return betas_for_alpha_bar(
|
119 |
+
num_diffusion_timesteps,
|
120 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
124 |
+
|
125 |
+
|
126 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
127 |
+
"""
|
128 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
129 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
130 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
131 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
132 |
+
produces the cumulative product of (1-beta) up to that
|
133 |
+
part of the diffusion process.
|
134 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
135 |
+
prevent singularities.
|
136 |
+
"""
|
137 |
+
betas = []
|
138 |
+
for i in range(num_diffusion_timesteps):
|
139 |
+
t1 = i / num_diffusion_timesteps
|
140 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
141 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
142 |
+
return np.array(betas)
|
143 |
+
|
144 |
+
|
145 |
+
class GaussianDiffusion:
|
146 |
+
"""
|
147 |
+
Utilities for training and sampling diffusion models.
|
148 |
+
Original ported from this codebase:
|
149 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
150 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
151 |
+
starting at T and going to 1.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
*,
|
157 |
+
betas,
|
158 |
+
model_mean_type,
|
159 |
+
model_var_type,
|
160 |
+
loss_type
|
161 |
+
):
|
162 |
+
|
163 |
+
self.model_mean_type = model_mean_type
|
164 |
+
self.model_var_type = model_var_type
|
165 |
+
self.loss_type = loss_type
|
166 |
+
|
167 |
+
# Use float64 for accuracy.
|
168 |
+
betas = np.array(betas, dtype=np.float64)
|
169 |
+
self.betas = betas
|
170 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
171 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
172 |
+
|
173 |
+
self.num_timesteps = int(betas.shape[0])
|
174 |
+
|
175 |
+
alphas = 1.0 - betas
|
176 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
177 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
178 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
179 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
180 |
+
|
181 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
182 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
183 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
184 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
185 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
186 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
187 |
+
|
188 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
189 |
+
self.posterior_variance = (
|
190 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
191 |
+
)
|
192 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
193 |
+
self.posterior_log_variance_clipped = np.log(
|
194 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
195 |
+
) if len(self.posterior_variance) > 1 else np.array([])
|
196 |
+
|
197 |
+
self.posterior_mean_coef1 = (
|
198 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
199 |
+
)
|
200 |
+
self.posterior_mean_coef2 = (
|
201 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
202 |
+
)
|
203 |
+
|
204 |
+
def q_mean_variance(self, x_start, t):
|
205 |
+
"""
|
206 |
+
Get the distribution q(x_t | x_0).
|
207 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
208 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
209 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
210 |
+
"""
|
211 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
212 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
213 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
214 |
+
return mean, variance, log_variance
|
215 |
+
|
216 |
+
def q_sample(self, x_start, t, noise=None):
|
217 |
+
"""
|
218 |
+
Diffuse the data for a given number of diffusion steps.
|
219 |
+
In other words, sample from q(x_t | x_0).
|
220 |
+
:param x_start: the initial data batch.
|
221 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
222 |
+
:param noise: if specified, the split-out normal noise.
|
223 |
+
:return: A noisy version of x_start.
|
224 |
+
"""
|
225 |
+
if noise is None:
|
226 |
+
noise = th.randn_like(x_start)
|
227 |
+
assert noise.shape == x_start.shape
|
228 |
+
return (
|
229 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
230 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
231 |
+
)
|
232 |
+
|
233 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
234 |
+
"""
|
235 |
+
Compute the mean and variance of the diffusion posterior:
|
236 |
+
q(x_{t-1} | x_t, x_0)
|
237 |
+
"""
|
238 |
+
assert x_start.shape == x_t.shape
|
239 |
+
posterior_mean = (
|
240 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
241 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
242 |
+
)
|
243 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
244 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
245 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
246 |
+
)
|
247 |
+
assert (
|
248 |
+
posterior_mean.shape[0]
|
249 |
+
== posterior_variance.shape[0]
|
250 |
+
== posterior_log_variance_clipped.shape[0]
|
251 |
+
== x_start.shape[0]
|
252 |
+
)
|
253 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
254 |
+
|
255 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
256 |
+
"""
|
257 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
258 |
+
the initial x, x_0.
|
259 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
260 |
+
as input.
|
261 |
+
:param x: the [N x C x ...] tensor at time t.
|
262 |
+
:param t: a 1-D Tensor of timesteps.
|
263 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
264 |
+
:param denoised_fn: if not None, a function which applies to the
|
265 |
+
x_start prediction before it is used to sample. Applies before
|
266 |
+
clip_denoised.
|
267 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
268 |
+
pass to the model. This can be used for conditioning.
|
269 |
+
:return: a dict with the following keys:
|
270 |
+
- 'mean': the model mean output.
|
271 |
+
- 'variance': the model variance output.
|
272 |
+
- 'log_variance': the log of 'variance'.
|
273 |
+
- 'pred_xstart': the prediction for x_0.
|
274 |
+
"""
|
275 |
+
if model_kwargs is None:
|
276 |
+
model_kwargs = {}
|
277 |
+
|
278 |
+
B, C = x.shape[:2]
|
279 |
+
assert t.shape == (B,)
|
280 |
+
model_output = model(x, t, **model_kwargs)
|
281 |
+
if isinstance(model_output, tuple):
|
282 |
+
model_output, extra = model_output
|
283 |
+
else:
|
284 |
+
extra = None
|
285 |
+
|
286 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
287 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
288 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
289 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
290 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
291 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
292 |
+
frac = (model_var_values + 1) / 2
|
293 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
294 |
+
model_variance = th.exp(model_log_variance)
|
295 |
+
else:
|
296 |
+
model_variance, model_log_variance = {
|
297 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
298 |
+
# to get a better decoder log likelihood.
|
299 |
+
ModelVarType.FIXED_LARGE: (
|
300 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
301 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
302 |
+
),
|
303 |
+
ModelVarType.FIXED_SMALL: (
|
304 |
+
self.posterior_variance,
|
305 |
+
self.posterior_log_variance_clipped,
|
306 |
+
),
|
307 |
+
}[self.model_var_type]
|
308 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
309 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
310 |
+
|
311 |
+
def process_xstart(x):
|
312 |
+
if denoised_fn is not None:
|
313 |
+
x = denoised_fn(x)
|
314 |
+
if clip_denoised:
|
315 |
+
return x.clamp(-1, 1)
|
316 |
+
return x
|
317 |
+
|
318 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
319 |
+
pred_xstart = process_xstart(model_output)
|
320 |
+
else:
|
321 |
+
pred_xstart = process_xstart(
|
322 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
323 |
+
)
|
324 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
325 |
+
|
326 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
327 |
+
return {
|
328 |
+
"mean": model_mean,
|
329 |
+
"variance": model_variance,
|
330 |
+
"log_variance": model_log_variance,
|
331 |
+
"pred_xstart": pred_xstart,
|
332 |
+
"extra": extra,
|
333 |
+
}
|
334 |
+
|
335 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
336 |
+
assert x_t.shape == eps.shape
|
337 |
+
return (
|
338 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
339 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
340 |
+
)
|
341 |
+
|
342 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
343 |
+
return (
|
344 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
345 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
346 |
+
|
347 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
348 |
+
"""
|
349 |
+
Compute the mean for the previous step, given a function cond_fn that
|
350 |
+
computes the gradient of a conditional log probability with respect to
|
351 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
352 |
+
condition on y.
|
353 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
354 |
+
"""
|
355 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
356 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
357 |
+
return new_mean
|
358 |
+
|
359 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
360 |
+
"""
|
361 |
+
Compute what the p_mean_variance output would have been, should the
|
362 |
+
model's score function be conditioned by cond_fn.
|
363 |
+
See condition_mean() for details on cond_fn.
|
364 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
365 |
+
from Song et al (2020).
|
366 |
+
"""
|
367 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
368 |
+
|
369 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
370 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
371 |
+
|
372 |
+
out = p_mean_var.copy()
|
373 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
374 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
375 |
+
return out
|
376 |
+
|
377 |
+
def p_sample(
|
378 |
+
self,
|
379 |
+
model,
|
380 |
+
x,
|
381 |
+
t,
|
382 |
+
clip_denoised=True,
|
383 |
+
denoised_fn=None,
|
384 |
+
cond_fn=None,
|
385 |
+
model_kwargs=None,
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Sample x_{t-1} from the model at the given timestep.
|
389 |
+
:param model: the model to sample from.
|
390 |
+
:param x: the current tensor at x_{t-1}.
|
391 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
392 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
393 |
+
:param denoised_fn: if not None, a function which applies to the
|
394 |
+
x_start prediction before it is used to sample.
|
395 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
396 |
+
similarly to the model.
|
397 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
398 |
+
pass to the model. This can be used for conditioning.
|
399 |
+
:return: a dict containing the following keys:
|
400 |
+
- 'sample': a random sample from the model.
|
401 |
+
- 'pred_xstart': a prediction of x_0.
|
402 |
+
"""
|
403 |
+
out = self.p_mean_variance(
|
404 |
+
model,
|
405 |
+
x,
|
406 |
+
t,
|
407 |
+
clip_denoised=clip_denoised,
|
408 |
+
denoised_fn=denoised_fn,
|
409 |
+
model_kwargs=model_kwargs,
|
410 |
+
)
|
411 |
+
noise = th.randn_like(x)
|
412 |
+
nonzero_mask = (
|
413 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
414 |
+
) # no noise when t == 0
|
415 |
+
if cond_fn is not None:
|
416 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
417 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
418 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
419 |
+
|
420 |
+
def p_sample_loop(
|
421 |
+
self,
|
422 |
+
model,
|
423 |
+
shape,
|
424 |
+
noise=None,
|
425 |
+
clip_denoised=True,
|
426 |
+
denoised_fn=None,
|
427 |
+
cond_fn=None,
|
428 |
+
model_kwargs=None,
|
429 |
+
device=None,
|
430 |
+
progress=False,
|
431 |
+
):
|
432 |
+
"""
|
433 |
+
Generate samples from the model.
|
434 |
+
:param model: the model module.
|
435 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
436 |
+
:param noise: if specified, the noise from the encoder to sample.
|
437 |
+
Should be of the same shape as `shape`.
|
438 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
439 |
+
:param denoised_fn: if not None, a function which applies to the
|
440 |
+
x_start prediction before it is used to sample.
|
441 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
442 |
+
similarly to the model.
|
443 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
444 |
+
pass to the model. This can be used for conditioning.
|
445 |
+
:param device: if specified, the device to create the samples on.
|
446 |
+
If not specified, use a model parameter's device.
|
447 |
+
:param progress: if True, show a tqdm progress bar.
|
448 |
+
:return: a non-differentiable batch of samples.
|
449 |
+
"""
|
450 |
+
final = None
|
451 |
+
for sample in self.p_sample_loop_progressive(
|
452 |
+
model,
|
453 |
+
shape,
|
454 |
+
noise=noise,
|
455 |
+
clip_denoised=clip_denoised,
|
456 |
+
denoised_fn=denoised_fn,
|
457 |
+
cond_fn=cond_fn,
|
458 |
+
model_kwargs=model_kwargs,
|
459 |
+
device=device,
|
460 |
+
progress=progress,
|
461 |
+
):
|
462 |
+
final = sample
|
463 |
+
return final["sample"]
|
464 |
+
|
465 |
+
def inpaint_p_sample_loop(
|
466 |
+
self,
|
467 |
+
model,
|
468 |
+
shape,
|
469 |
+
x0,
|
470 |
+
mask,
|
471 |
+
noise=None,
|
472 |
+
clip_denoised=True,
|
473 |
+
denoised_fn=None,
|
474 |
+
cond_fn=None,
|
475 |
+
model_kwargs=None,
|
476 |
+
device=None,
|
477 |
+
progress=False,
|
478 |
+
jump_length=10,
|
479 |
+
jump_n_sample=10,
|
480 |
+
):
|
481 |
+
"""
|
482 |
+
Generate samples from the model.
|
483 |
+
:param model: the model module.
|
484 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
485 |
+
:param noise: if specified, the noise from the encoder to sample.
|
486 |
+
Should be of the same shape as `shape`.
|
487 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
488 |
+
:param denoised_fn: if not None, a function which applies to the
|
489 |
+
x_start prediction before it is used to sample.
|
490 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
491 |
+
similarly to the model.
|
492 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
493 |
+
pass to the model. This can be used for conditioning.
|
494 |
+
:param device: if specified, the device to create the samples on.
|
495 |
+
If not specified, use a model parameter's device.
|
496 |
+
:param progress: if True, show a tqdm progress bar.
|
497 |
+
:return: a non-differentiable batch of samples.
|
498 |
+
"""
|
499 |
+
final = None
|
500 |
+
for sample in self.inpaint_p_sample_loop_progressive(
|
501 |
+
model,
|
502 |
+
shape,
|
503 |
+
x0,
|
504 |
+
mask,
|
505 |
+
noise=noise,
|
506 |
+
clip_denoised=clip_denoised,
|
507 |
+
denoised_fn=denoised_fn,
|
508 |
+
cond_fn=cond_fn,
|
509 |
+
model_kwargs=model_kwargs,
|
510 |
+
device=device,
|
511 |
+
progress=progress,
|
512 |
+
jump_length=jump_length,
|
513 |
+
jump_n_sample=jump_n_sample,
|
514 |
+
):
|
515 |
+
final = sample
|
516 |
+
return final["sample"]
|
517 |
+
|
518 |
+
def inpaint_p_sample_loop_progressive(
|
519 |
+
self,
|
520 |
+
model,
|
521 |
+
shape,
|
522 |
+
x0,
|
523 |
+
mask,
|
524 |
+
noise=None,
|
525 |
+
clip_denoised=True,
|
526 |
+
denoised_fn=None,
|
527 |
+
cond_fn=None,
|
528 |
+
model_kwargs=None,
|
529 |
+
device=None,
|
530 |
+
progress=False,
|
531 |
+
jump_length=10,
|
532 |
+
jump_n_sample=10,
|
533 |
+
):
|
534 |
+
"""
|
535 |
+
Generate samples from the model and yield intermediate samples from
|
536 |
+
each timestep of diffusion.
|
537 |
+
|
538 |
+
Arguments are the same as p_sample_loop().
|
539 |
+
Returns a generator over dicts, where each dict is the return value of
|
540 |
+
p_sample().
|
541 |
+
"""
|
542 |
+
# if device is None:
|
543 |
+
# device = next(model.parameters()).device
|
544 |
+
# assert isinstance(shape, (tuple, list))
|
545 |
+
# if noise is not None:
|
546 |
+
# img = noise
|
547 |
+
# else:
|
548 |
+
# img = th.randn(*shape, device=device)
|
549 |
+
# indices = list(range(self.num_timesteps))[::-1]
|
550 |
+
|
551 |
+
# if progress:
|
552 |
+
# # Lazy import so that we don't depend on tqdm.
|
553 |
+
# from tqdm.auto import tqdm
|
554 |
+
|
555 |
+
# indices = tqdm(indices)
|
556 |
+
# pred_xstart = None
|
557 |
+
# for i in indices:
|
558 |
+
# t = th.tensor([i] * shape[0], device=device)
|
559 |
+
# with th.no_grad():
|
560 |
+
# out = self.inpaint_p_sample(
|
561 |
+
# model,
|
562 |
+
# img,
|
563 |
+
# t,
|
564 |
+
# x0,
|
565 |
+
# mask,
|
566 |
+
# clip_denoised=clip_denoised,
|
567 |
+
# denoised_fn=denoised_fn,
|
568 |
+
# cond_fn=cond_fn,
|
569 |
+
# model_kwargs=model_kwargs,
|
570 |
+
# pred_xstart=pred_xstart,
|
571 |
+
# )
|
572 |
+
# yield out
|
573 |
+
# img = out["sample"]
|
574 |
+
# pred_xstart = out["pred_xstart"]
|
575 |
+
|
576 |
+
if device is None:
|
577 |
+
device = next(model.parameters()).device
|
578 |
+
assert isinstance(shape, (tuple, list))
|
579 |
+
if noise is not None:
|
580 |
+
image_after_step = noise
|
581 |
+
else:
|
582 |
+
image_after_step = th.randn(*shape, device=device)
|
583 |
+
|
584 |
+
self.gt_noises = None # reset for next image
|
585 |
+
|
586 |
+
|
587 |
+
pred_xstart = None
|
588 |
+
|
589 |
+
idx_wall = -1
|
590 |
+
sample_idxs = defaultdict(lambda: 0)
|
591 |
+
|
592 |
+
times = get_schedule_jump(t_T=250, n_sample=1, jump_length=jump_length, jump_n_sample=jump_n_sample)
|
593 |
+
time_pairs = list(zip(times[:-1], times[1:]))
|
594 |
+
|
595 |
+
if progress:
|
596 |
+
from tqdm.auto import tqdm
|
597 |
+
time_pairs = tqdm(time_pairs)
|
598 |
+
|
599 |
+
for t_last, t_cur in time_pairs:
|
600 |
+
idx_wall += 1
|
601 |
+
t_last_t = th.tensor([t_last] * shape[0], # pylint: disable=not-callable
|
602 |
+
device=device)
|
603 |
+
|
604 |
+
if t_cur < t_last: # reverse
|
605 |
+
with th.no_grad():
|
606 |
+
image_before_step = image_after_step.clone()
|
607 |
+
out = self.inpaint_p_sample(
|
608 |
+
model,
|
609 |
+
image_after_step,
|
610 |
+
t_last_t,
|
611 |
+
x0,
|
612 |
+
mask,
|
613 |
+
clip_denoised=clip_denoised,
|
614 |
+
denoised_fn=denoised_fn,
|
615 |
+
cond_fn=cond_fn,
|
616 |
+
model_kwargs=model_kwargs,
|
617 |
+
pred_xstart=pred_xstart
|
618 |
+
)
|
619 |
+
image_after_step = out["sample"]
|
620 |
+
pred_xstart = out["pred_xstart"]
|
621 |
+
|
622 |
+
sample_idxs[t_cur] += 1
|
623 |
+
|
624 |
+
yield out
|
625 |
+
|
626 |
+
else:
|
627 |
+
t_shift = 1
|
628 |
+
image_before_step = image_after_step.clone()
|
629 |
+
image_after_step = self.undo(
|
630 |
+
image_before_step, image_after_step,
|
631 |
+
est_x_0=out['pred_xstart'], t=t_last_t+t_shift, debug=False)
|
632 |
+
pred_xstart = out["pred_xstart"]
|
633 |
+
|
634 |
+
def inpaint_p_sample(
|
635 |
+
self,
|
636 |
+
model,
|
637 |
+
x,
|
638 |
+
t,
|
639 |
+
x0,
|
640 |
+
mask,
|
641 |
+
clip_denoised=True,
|
642 |
+
denoised_fn=None,
|
643 |
+
cond_fn=None,
|
644 |
+
model_kwargs=None,
|
645 |
+
pred_xstart=None,
|
646 |
+
):
|
647 |
+
"""
|
648 |
+
Sample x_{t-1} from the model at the given timestep.
|
649 |
+
:param model: the model to sample from.
|
650 |
+
:param x: the current tensor at x_{t-1}.
|
651 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
652 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
653 |
+
:param denoised_fn: if not None, a function which applies to the
|
654 |
+
x_start prediction before it is used to sample.
|
655 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
656 |
+
similarly to the model.
|
657 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
658 |
+
pass to the model. This can be used for conditioning.
|
659 |
+
:return: a dict containing the following keys:
|
660 |
+
- 'sample': a random sample from the model.
|
661 |
+
- 'pred_xstart': a prediction of x_0.
|
662 |
+
"""
|
663 |
+
noise = th.randn_like(x)
|
664 |
+
|
665 |
+
if pred_xstart is not None:
|
666 |
+
alpha_cumprod = _extract_into_tensor(
|
667 |
+
self.alphas_cumprod, t, x.shape)
|
668 |
+
weighed_gt = th.sqrt(alpha_cumprod) * x0 + th.sqrt((1 - alpha_cumprod)) * th.randn_like(x)
|
669 |
+
|
670 |
+
x = (1 - mask) * weighed_gt + mask * x
|
671 |
+
|
672 |
+
out = self.p_mean_variance(
|
673 |
+
model,
|
674 |
+
x,
|
675 |
+
t,
|
676 |
+
clip_denoised=clip_denoised,
|
677 |
+
denoised_fn=denoised_fn,
|
678 |
+
model_kwargs=model_kwargs,
|
679 |
+
)
|
680 |
+
|
681 |
+
nonzero_mask = (
|
682 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
683 |
+
)
|
684 |
+
|
685 |
+
if cond_fn is not None:
|
686 |
+
out["mean"] = self.condition_mean(
|
687 |
+
cond_fn, out, x, t, model_kwargs=model_kwargs
|
688 |
+
)
|
689 |
+
|
690 |
+
sample = out["mean"] + nonzero_mask * \
|
691 |
+
th.exp(0.5 * out["log_variance"]) * noise
|
692 |
+
|
693 |
+
result = {"sample": sample,
|
694 |
+
"pred_xstart": out["pred_xstart"], 'gt': model_kwargs.get('gt')}
|
695 |
+
|
696 |
+
return result
|
697 |
+
|
698 |
+
def undo(self, image_before_step, img_after_model, est_x_0, t, debug=False):
|
699 |
+
return self._undo(img_after_model, t)
|
700 |
+
|
701 |
+
def _undo(self, img_out, t):
|
702 |
+
beta = _extract_into_tensor(self.betas, t, img_out.shape)
|
703 |
+
|
704 |
+
img_in_est = th.sqrt(1 - beta) * img_out + \
|
705 |
+
th.sqrt(beta) * th.randn_like(img_out)
|
706 |
+
|
707 |
+
return img_in_est
|
708 |
+
|
709 |
+
def p_sample_loop_progressive(
|
710 |
+
self,
|
711 |
+
model,
|
712 |
+
shape,
|
713 |
+
noise=None,
|
714 |
+
clip_denoised=True,
|
715 |
+
denoised_fn=None,
|
716 |
+
cond_fn=None,
|
717 |
+
model_kwargs=None,
|
718 |
+
device=None,
|
719 |
+
progress=False,
|
720 |
+
):
|
721 |
+
"""
|
722 |
+
Generate samples from the model and yield intermediate samples from
|
723 |
+
each timestep of diffusion.
|
724 |
+
Arguments are the same as p_sample_loop().
|
725 |
+
Returns a generator over dicts, where each dict is the return value of
|
726 |
+
p_sample().
|
727 |
+
"""
|
728 |
+
if device is None:
|
729 |
+
device = next(model.parameters()).device
|
730 |
+
assert isinstance(shape, (tuple, list))
|
731 |
+
if noise is not None:
|
732 |
+
img = noise
|
733 |
+
else:
|
734 |
+
img = th.randn(*shape, device=device)
|
735 |
+
indices = list(range(self.num_timesteps))[::-1]
|
736 |
+
|
737 |
+
if progress:
|
738 |
+
# Lazy import so that we don't depend on tqdm.
|
739 |
+
from tqdm.auto import tqdm
|
740 |
+
|
741 |
+
indices = tqdm(indices)
|
742 |
+
|
743 |
+
for i in indices:
|
744 |
+
t = th.tensor([i] * shape[0], device=device)
|
745 |
+
with th.no_grad():
|
746 |
+
out = self.p_sample(
|
747 |
+
model,
|
748 |
+
img,
|
749 |
+
t,
|
750 |
+
clip_denoised=clip_denoised,
|
751 |
+
denoised_fn=denoised_fn,
|
752 |
+
cond_fn=cond_fn,
|
753 |
+
model_kwargs=model_kwargs,
|
754 |
+
)
|
755 |
+
yield out
|
756 |
+
img = out["sample"]
|
757 |
+
|
758 |
+
def ddim_sample(
|
759 |
+
self,
|
760 |
+
model,
|
761 |
+
x,
|
762 |
+
t,
|
763 |
+
clip_denoised=True,
|
764 |
+
denoised_fn=None,
|
765 |
+
cond_fn=None,
|
766 |
+
model_kwargs=None,
|
767 |
+
eta=0.0,
|
768 |
+
):
|
769 |
+
"""
|
770 |
+
Sample x_{t-1} from the model using DDIM.
|
771 |
+
Same usage as p_sample().
|
772 |
+
"""
|
773 |
+
out = self.p_mean_variance(
|
774 |
+
model,
|
775 |
+
x,
|
776 |
+
t,
|
777 |
+
clip_denoised=clip_denoised,
|
778 |
+
denoised_fn=denoised_fn,
|
779 |
+
model_kwargs=model_kwargs,
|
780 |
+
)
|
781 |
+
if cond_fn is not None:
|
782 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
783 |
+
|
784 |
+
# Usually our model outputs epsilon, but we re-derive it
|
785 |
+
# in case we used x_start or x_prev prediction.
|
786 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
787 |
+
|
788 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
789 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
790 |
+
sigma = (
|
791 |
+
eta
|
792 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
793 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
794 |
+
)
|
795 |
+
# Equation 12.
|
796 |
+
noise = th.randn_like(x)
|
797 |
+
mean_pred = (
|
798 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
799 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
800 |
+
)
|
801 |
+
nonzero_mask = (
|
802 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
803 |
+
) # no noise when t == 0
|
804 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
805 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
806 |
+
|
807 |
+
def ddim_reverse_sample(
|
808 |
+
self,
|
809 |
+
model,
|
810 |
+
x,
|
811 |
+
t,
|
812 |
+
clip_denoised=True,
|
813 |
+
denoised_fn=None,
|
814 |
+
cond_fn=None,
|
815 |
+
model_kwargs=None,
|
816 |
+
eta=0.0,
|
817 |
+
):
|
818 |
+
"""
|
819 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
820 |
+
"""
|
821 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
822 |
+
out = self.p_mean_variance(
|
823 |
+
model,
|
824 |
+
x,
|
825 |
+
t,
|
826 |
+
clip_denoised=clip_denoised,
|
827 |
+
denoised_fn=denoised_fn,
|
828 |
+
model_kwargs=model_kwargs,
|
829 |
+
)
|
830 |
+
if cond_fn is not None:
|
831 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
832 |
+
# Usually our model outputs epsilon, but we re-derive it
|
833 |
+
# in case we used x_start or x_prev prediction.
|
834 |
+
eps = (
|
835 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
836 |
+
- out["pred_xstart"]
|
837 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
838 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
839 |
+
|
840 |
+
# Equation 12. reversed
|
841 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
842 |
+
|
843 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
844 |
+
|
845 |
+
def ddim_sample_loop(
|
846 |
+
self,
|
847 |
+
model,
|
848 |
+
shape,
|
849 |
+
noise=None,
|
850 |
+
clip_denoised=True,
|
851 |
+
denoised_fn=None,
|
852 |
+
cond_fn=None,
|
853 |
+
model_kwargs=None,
|
854 |
+
device=None,
|
855 |
+
progress=False,
|
856 |
+
eta=0.0,
|
857 |
+
):
|
858 |
+
"""
|
859 |
+
Generate samples from the model using DDIM.
|
860 |
+
Same usage as p_sample_loop().
|
861 |
+
"""
|
862 |
+
final = None
|
863 |
+
for sample in self.ddim_sample_loop_progressive(
|
864 |
+
model,
|
865 |
+
shape,
|
866 |
+
noise=noise,
|
867 |
+
clip_denoised=clip_denoised,
|
868 |
+
denoised_fn=denoised_fn,
|
869 |
+
cond_fn=cond_fn,
|
870 |
+
model_kwargs=model_kwargs,
|
871 |
+
device=device,
|
872 |
+
progress=progress,
|
873 |
+
eta=eta,
|
874 |
+
):
|
875 |
+
final = sample
|
876 |
+
return final["sample"]
|
877 |
+
|
878 |
+
def ddim_sample_loop_progressive(
|
879 |
+
self,
|
880 |
+
model,
|
881 |
+
shape,
|
882 |
+
noise=None,
|
883 |
+
clip_denoised=True,
|
884 |
+
denoised_fn=None,
|
885 |
+
cond_fn=None,
|
886 |
+
model_kwargs=None,
|
887 |
+
device=None,
|
888 |
+
progress=False,
|
889 |
+
eta=0.0,
|
890 |
+
):
|
891 |
+
"""
|
892 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
893 |
+
each timestep of DDIM.
|
894 |
+
Same usage as p_sample_loop_progressive().
|
895 |
+
"""
|
896 |
+
if device is None:
|
897 |
+
device = next(model.parameters()).device
|
898 |
+
assert isinstance(shape, (tuple, list))
|
899 |
+
if noise is not None:
|
900 |
+
img = noise
|
901 |
+
else:
|
902 |
+
img = th.randn(*shape, device=device)
|
903 |
+
indices = list(range(self.num_timesteps))[::-1]
|
904 |
+
|
905 |
+
if progress:
|
906 |
+
# Lazy import so that we don't depend on tqdm.
|
907 |
+
from tqdm.auto import tqdm
|
908 |
+
|
909 |
+
indices = tqdm(indices)
|
910 |
+
|
911 |
+
for i in indices:
|
912 |
+
t = th.tensor([i] * shape[0], device=device)
|
913 |
+
with th.no_grad():
|
914 |
+
out = self.ddim_sample(
|
915 |
+
model,
|
916 |
+
img,
|
917 |
+
t,
|
918 |
+
clip_denoised=clip_denoised,
|
919 |
+
denoised_fn=denoised_fn,
|
920 |
+
cond_fn=cond_fn,
|
921 |
+
model_kwargs=model_kwargs,
|
922 |
+
eta=eta,
|
923 |
+
)
|
924 |
+
yield out
|
925 |
+
img = out["sample"]
|
926 |
+
|
927 |
+
def _vb_terms_bpd(
|
928 |
+
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
929 |
+
):
|
930 |
+
"""
|
931 |
+
Get a term for the variational lower-bound.
|
932 |
+
The resulting units are bits (rather than nats, as one might expect).
|
933 |
+
This allows for comparison to other papers.
|
934 |
+
:return: a dict with the following keys:
|
935 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
936 |
+
- 'pred_xstart': the x_0 predictions.
|
937 |
+
"""
|
938 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
939 |
+
x_start=x_start, x_t=x_t, t=t
|
940 |
+
)
|
941 |
+
out = self.p_mean_variance(
|
942 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
943 |
+
)
|
944 |
+
kl = normal_kl(
|
945 |
+
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
946 |
+
)
|
947 |
+
kl = mean_flat(kl) / np.log(2.0)
|
948 |
+
|
949 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
950 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
951 |
+
)
|
952 |
+
assert decoder_nll.shape == x_start.shape
|
953 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
954 |
+
|
955 |
+
# At the first timestep return the decoder NLL,
|
956 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
957 |
+
output = th.where((t == 0), decoder_nll, kl)
|
958 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
959 |
+
|
960 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
961 |
+
"""
|
962 |
+
Compute training losses for a single timestep.
|
963 |
+
:param model: the model to evaluate loss on.
|
964 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
965 |
+
:param t: a batch of timestep indices.
|
966 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
967 |
+
pass to the model. This can be used for conditioning.
|
968 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
969 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
970 |
+
Some mean or variance settings may also have other keys.
|
971 |
+
"""
|
972 |
+
if model_kwargs is None:
|
973 |
+
model_kwargs = {}
|
974 |
+
if noise is None:
|
975 |
+
noise = th.randn_like(x_start)
|
976 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
977 |
+
|
978 |
+
terms = {}
|
979 |
+
|
980 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
981 |
+
terms["loss"] = self._vb_terms_bpd(
|
982 |
+
model=model,
|
983 |
+
x_start=x_start,
|
984 |
+
x_t=x_t,
|
985 |
+
t=t,
|
986 |
+
clip_denoised=False,
|
987 |
+
model_kwargs=model_kwargs,
|
988 |
+
)["output"]
|
989 |
+
if self.loss_type == LossType.RESCALED_KL:
|
990 |
+
terms["loss"] *= self.num_timesteps
|
991 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
992 |
+
model_output = model(x_t, t, **model_kwargs)
|
993 |
+
|
994 |
+
if self.model_var_type in [
|
995 |
+
ModelVarType.LEARNED,
|
996 |
+
ModelVarType.LEARNED_RANGE,
|
997 |
+
]:
|
998 |
+
B, C = x_t.shape[:2]
|
999 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
1000 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
1001 |
+
# Learn the variance using the variational bound, but don't let
|
1002 |
+
# it affect our mean prediction.
|
1003 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
1004 |
+
terms["vb"] = self._vb_terms_bpd(
|
1005 |
+
model=lambda *args, r=frozen_out: r,
|
1006 |
+
x_start=x_start,
|
1007 |
+
x_t=x_t,
|
1008 |
+
t=t,
|
1009 |
+
clip_denoised=False,
|
1010 |
+
)["output"]
|
1011 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
1012 |
+
# Divide by 1000 for equivalence with initial implementation.
|
1013 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
1014 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
1015 |
+
|
1016 |
+
target = {
|
1017 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
|
1018 |
+
x_start=x_start, x_t=x_t, t=t
|
1019 |
+
)[0],
|
1020 |
+
ModelMeanType.START_X: x_start,
|
1021 |
+
ModelMeanType.EPSILON: noise,
|
1022 |
+
}[self.model_mean_type]
|
1023 |
+
assert model_output.shape == target.shape == x_start.shape
|
1024 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
1025 |
+
if "vb" in terms:
|
1026 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
1027 |
+
else:
|
1028 |
+
terms["loss"] = terms["mse"]
|
1029 |
+
else:
|
1030 |
+
raise NotImplementedError(self.loss_type)
|
1031 |
+
|
1032 |
+
return terms
|
1033 |
+
|
1034 |
+
def _prior_bpd(self, x_start):
|
1035 |
+
"""
|
1036 |
+
Get the prior KL term for the variational lower-bound, measured in
|
1037 |
+
bits-per-dim.
|
1038 |
+
This term can't be optimized, as it only depends on the encoder.
|
1039 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1040 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
1041 |
+
"""
|
1042 |
+
batch_size = x_start.shape[0]
|
1043 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
1044 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
1045 |
+
kl_prior = normal_kl(
|
1046 |
+
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
1047 |
+
)
|
1048 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
1049 |
+
|
1050 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
1051 |
+
"""
|
1052 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
1053 |
+
as well as other related quantities.
|
1054 |
+
:param model: the model to evaluate loss on.
|
1055 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
1056 |
+
:param clip_denoised: if True, clip denoised samples.
|
1057 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
1058 |
+
pass to the model. This can be used for conditioning.
|
1059 |
+
:return: a dict containing the following keys:
|
1060 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
1061 |
+
- prior_bpd: the prior term in the lower-bound.
|
1062 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
1063 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
1064 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
1065 |
+
"""
|
1066 |
+
device = x_start.device
|
1067 |
+
batch_size = x_start.shape[0]
|
1068 |
+
|
1069 |
+
vb = []
|
1070 |
+
xstart_mse = []
|
1071 |
+
mse = []
|
1072 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
1073 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
1074 |
+
noise = th.randn_like(x_start)
|
1075 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
1076 |
+
# Calculate VLB term at the current timestep
|
1077 |
+
with th.no_grad():
|
1078 |
+
out = self._vb_terms_bpd(
|
1079 |
+
model,
|
1080 |
+
x_start=x_start,
|
1081 |
+
x_t=x_t,
|
1082 |
+
t=t_batch,
|
1083 |
+
clip_denoised=clip_denoised,
|
1084 |
+
model_kwargs=model_kwargs,
|
1085 |
+
)
|
1086 |
+
vb.append(out["output"])
|
1087 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
1088 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
1089 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
1090 |
+
|
1091 |
+
vb = th.stack(vb, dim=1)
|
1092 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
1093 |
+
mse = th.stack(mse, dim=1)
|
1094 |
+
|
1095 |
+
prior_bpd = self._prior_bpd(x_start)
|
1096 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
1097 |
+
return {
|
1098 |
+
"total_bpd": total_bpd,
|
1099 |
+
"prior_bpd": prior_bpd,
|
1100 |
+
"vb": vb,
|
1101 |
+
"xstart_mse": xstart_mse,
|
1102 |
+
"mse": mse,
|
1103 |
+
}
|
1104 |
+
|
1105 |
+
|
1106 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
1107 |
+
"""
|
1108 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
1109 |
+
:param arr: the 1-D numpy array.
|
1110 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1111 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1112 |
+
dimension equal to the length of timesteps.
|
1113 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1114 |
+
"""
|
1115 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1116 |
+
while len(res.shape) < len(broadcast_shape):
|
1117 |
+
res = res[..., None]
|
1118 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
diffusion/respace.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(
|
38 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
39 |
+
)
|
40 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
41 |
+
size_per = num_timesteps // len(section_counts)
|
42 |
+
extra = num_timesteps % len(section_counts)
|
43 |
+
start_idx = 0
|
44 |
+
all_steps = []
|
45 |
+
for i, section_count in enumerate(section_counts):
|
46 |
+
size = size_per + (1 if i < extra else 0)
|
47 |
+
if size < section_count:
|
48 |
+
raise ValueError(
|
49 |
+
f"cannot divide section of {size} steps into {section_count}"
|
50 |
+
)
|
51 |
+
if section_count <= 1:
|
52 |
+
frac_stride = 1
|
53 |
+
else:
|
54 |
+
frac_stride = (size - 1) / (section_count - 1)
|
55 |
+
cur_idx = 0.0
|
56 |
+
taken_steps = []
|
57 |
+
for _ in range(section_count):
|
58 |
+
taken_steps.append(start_idx + round(cur_idx))
|
59 |
+
cur_idx += frac_stride
|
60 |
+
all_steps += taken_steps
|
61 |
+
start_idx += size
|
62 |
+
return set(all_steps)
|
63 |
+
|
64 |
+
|
65 |
+
class SpacedDiffusion(GaussianDiffusion):
|
66 |
+
"""
|
67 |
+
A diffusion process which can skip steps in a base diffusion process.
|
68 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
69 |
+
original diffusion process to retain.
|
70 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, use_timesteps, **kwargs):
|
74 |
+
self.use_timesteps = set(use_timesteps)
|
75 |
+
self.timestep_map = []
|
76 |
+
self.original_num_steps = len(kwargs["betas"])
|
77 |
+
|
78 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
79 |
+
last_alpha_cumprod = 1.0
|
80 |
+
new_betas = []
|
81 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
82 |
+
if i in self.use_timesteps:
|
83 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
84 |
+
last_alpha_cumprod = alpha_cumprod
|
85 |
+
self.timestep_map.append(i)
|
86 |
+
kwargs["betas"] = np.array(new_betas)
|
87 |
+
super().__init__(**kwargs)
|
88 |
+
|
89 |
+
def p_mean_variance(
|
90 |
+
self, model, *args, **kwargs
|
91 |
+
): # pylint: disable=signature-differs
|
92 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
93 |
+
|
94 |
+
def training_losses(
|
95 |
+
self, model, *args, **kwargs
|
96 |
+
): # pylint: disable=signature-differs
|
97 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
98 |
+
|
99 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
100 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
101 |
+
|
102 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
103 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
104 |
+
|
105 |
+
def _wrap_model(self, model):
|
106 |
+
if isinstance(model, _WrappedModel):
|
107 |
+
return model
|
108 |
+
return _WrappedModel(
|
109 |
+
model, self.timestep_map, self.original_num_steps
|
110 |
+
)
|
111 |
+
|
112 |
+
def _scale_timesteps(self, t):
|
113 |
+
# Scaling is done by the wrapped model.
|
114 |
+
return t
|
115 |
+
|
116 |
+
|
117 |
+
class _WrappedModel:
|
118 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
119 |
+
self.model = model
|
120 |
+
self.timestep_map = timestep_map
|
121 |
+
# self.rescale_timesteps = rescale_timesteps
|
122 |
+
self.original_num_steps = original_num_steps
|
123 |
+
|
124 |
+
def __call__(self, x, ts, **kwargs):
|
125 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
126 |
+
new_ts = map_tensor[ts]
|
127 |
+
# if self.rescale_timesteps:
|
128 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
129 |
+
return self.model(x, new_ts, **kwargs)
|
diffusion/scheduler.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022 Huawei Technologies Co., Ltd.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
#
|
6 |
+
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
+
#
|
8 |
+
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license
|
16 |
+
|
17 |
+
def get_schedule(t_T, t_0, n_sample, n_steplength, debug=0):
|
18 |
+
if n_steplength > 1:
|
19 |
+
if not n_sample > 1:
|
20 |
+
raise RuntimeError('n_steplength has no effect if n_sample=1')
|
21 |
+
|
22 |
+
t = t_T
|
23 |
+
times = [t]
|
24 |
+
while t >= 0:
|
25 |
+
t = t - 1
|
26 |
+
times.append(t)
|
27 |
+
n_steplength_cur = min(n_steplength, t_T - t)
|
28 |
+
|
29 |
+
for _ in range(n_sample - 1):
|
30 |
+
|
31 |
+
for _ in range(n_steplength_cur):
|
32 |
+
t = t + 1
|
33 |
+
times.append(t)
|
34 |
+
for _ in range(n_steplength_cur):
|
35 |
+
t = t - 1
|
36 |
+
times.append(t)
|
37 |
+
|
38 |
+
_check_times(times, t_0, t_T)
|
39 |
+
|
40 |
+
if debug == 2:
|
41 |
+
for x in [list(range(0, 50)), list(range(-1, -50, -1))]:
|
42 |
+
_plot_times(x=x, times=[times[i] for i in x])
|
43 |
+
|
44 |
+
return times
|
45 |
+
|
46 |
+
|
47 |
+
def _check_times(times, t_0, t_T):
|
48 |
+
# Check end
|
49 |
+
assert times[0] > times[1], (times[0], times[1])
|
50 |
+
|
51 |
+
# Check beginning
|
52 |
+
assert times[-1] == -1, times[-1]
|
53 |
+
|
54 |
+
# Steplength = 1
|
55 |
+
for t_last, t_cur in zip(times[:-1], times[1:]):
|
56 |
+
assert abs(t_last - t_cur) == 1, (t_last, t_cur)
|
57 |
+
|
58 |
+
# Value range
|
59 |
+
for t in times:
|
60 |
+
assert t >= t_0, (t, t_0)
|
61 |
+
assert t <= t_T, (t, t_T)
|
62 |
+
|
63 |
+
|
64 |
+
def _plot_times(x, times):
|
65 |
+
import matplotlib.pyplot as plt
|
66 |
+
plt.plot(x, times)
|
67 |
+
plt.show()
|
68 |
+
|
69 |
+
|
70 |
+
def get_schedule_jump(t_T, n_sample, jump_length, jump_n_sample,
|
71 |
+
jump2_length=1, jump2_n_sample=1,
|
72 |
+
jump3_length=1, jump3_n_sample=1,
|
73 |
+
start_resampling=100000000):
|
74 |
+
|
75 |
+
jumps = {}
|
76 |
+
for j in range(0, t_T - jump_length, jump_length):
|
77 |
+
jumps[j] = jump_n_sample - 1
|
78 |
+
|
79 |
+
jumps2 = {}
|
80 |
+
for j in range(0, t_T - jump2_length, jump2_length):
|
81 |
+
jumps2[j] = jump2_n_sample - 1
|
82 |
+
|
83 |
+
jumps3 = {}
|
84 |
+
for j in range(0, t_T - jump3_length, jump3_length):
|
85 |
+
jumps3[j] = jump3_n_sample - 1
|
86 |
+
|
87 |
+
t = t_T
|
88 |
+
ts = []
|
89 |
+
|
90 |
+
while t >= 1:
|
91 |
+
t = t-1
|
92 |
+
ts.append(t)
|
93 |
+
|
94 |
+
if (
|
95 |
+
t + 1 < t_T - 1 and
|
96 |
+
t <= start_resampling
|
97 |
+
):
|
98 |
+
for _ in range(n_sample - 1):
|
99 |
+
t = t + 1
|
100 |
+
ts.append(t)
|
101 |
+
|
102 |
+
if t >= 0:
|
103 |
+
t = t - 1
|
104 |
+
ts.append(t)
|
105 |
+
|
106 |
+
if (
|
107 |
+
jumps3.get(t, 0) > 0 and
|
108 |
+
t <= start_resampling - jump3_length
|
109 |
+
):
|
110 |
+
jumps3[t] = jumps3[t] - 1
|
111 |
+
for _ in range(jump3_length):
|
112 |
+
t = t + 1
|
113 |
+
ts.append(t)
|
114 |
+
|
115 |
+
if (
|
116 |
+
jumps2.get(t, 0) > 0 and
|
117 |
+
t <= start_resampling - jump2_length
|
118 |
+
):
|
119 |
+
jumps2[t] = jumps2[t] - 1
|
120 |
+
for _ in range(jump2_length):
|
121 |
+
t = t + 1
|
122 |
+
ts.append(t)
|
123 |
+
jumps3 = {}
|
124 |
+
for j in range(0, t_T - jump3_length, jump3_length):
|
125 |
+
jumps3[j] = jump3_n_sample - 1
|
126 |
+
|
127 |
+
if (
|
128 |
+
jumps.get(t, 0) > 0 and
|
129 |
+
t <= start_resampling - jump_length
|
130 |
+
):
|
131 |
+
jumps[t] = jumps[t] - 1
|
132 |
+
for _ in range(jump_length):
|
133 |
+
t = t + 1
|
134 |
+
ts.append(t)
|
135 |
+
jumps2 = {}
|
136 |
+
for j in range(0, t_T - jump2_length, jump2_length):
|
137 |
+
jumps2[j] = jump2_n_sample - 1
|
138 |
+
|
139 |
+
jumps3 = {}
|
140 |
+
for j in range(0, t_T - jump3_length, jump3_length):
|
141 |
+
jumps3[j] = jump3_n_sample - 1
|
142 |
+
|
143 |
+
ts.append(-1)
|
144 |
+
|
145 |
+
_check_times(ts, -1, t_T)
|
146 |
+
|
147 |
+
return ts
|
148 |
+
|
149 |
+
|
150 |
+
def get_schedule_jump_paper():
|
151 |
+
t_T = 250
|
152 |
+
jump_length = 10
|
153 |
+
jump_n_sample = 10
|
154 |
+
|
155 |
+
jumps = {}
|
156 |
+
for j in range(0, t_T - jump_length, jump_length):
|
157 |
+
jumps[j] = jump_n_sample - 1
|
158 |
+
|
159 |
+
t = t_T
|
160 |
+
ts = []
|
161 |
+
|
162 |
+
while t >= 1:
|
163 |
+
t = t-1
|
164 |
+
ts.append(t)
|
165 |
+
|
166 |
+
if jumps.get(t, 0) > 0:
|
167 |
+
jumps[t] = jumps[t] - 1
|
168 |
+
for _ in range(jump_length):
|
169 |
+
t = t + 1
|
170 |
+
ts.append(t)
|
171 |
+
|
172 |
+
ts.append(-1)
|
173 |
+
|
174 |
+
_check_times(ts, -1, t_T)
|
175 |
+
|
176 |
+
return ts
|
177 |
+
|
178 |
+
|
179 |
+
def get_schedule_jump_test(to_supplement=False):
|
180 |
+
ts = get_schedule_jump(t_T=250, n_sample=1,
|
181 |
+
jump_length=10, jump_n_sample=10,
|
182 |
+
jump2_length=1, jump2_n_sample=1,
|
183 |
+
jump3_length=1, jump3_n_sample=1,
|
184 |
+
start_resampling=250)
|
185 |
+
|
186 |
+
import matplotlib.pyplot as plt
|
187 |
+
SMALL_SIZE = 8*3
|
188 |
+
MEDIUM_SIZE = 10*3
|
189 |
+
BIGGER_SIZE = 12*3
|
190 |
+
|
191 |
+
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
|
192 |
+
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
|
193 |
+
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
|
194 |
+
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
|
195 |
+
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
|
196 |
+
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
|
197 |
+
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
|
198 |
+
|
199 |
+
plt.plot(ts)
|
200 |
+
|
201 |
+
fig = plt.gcf()
|
202 |
+
fig.set_size_inches(20, 10)
|
203 |
+
|
204 |
+
ax = plt.gca()
|
205 |
+
ax.set_xlabel('Number of Transitions')
|
206 |
+
ax.set_ylabel('Diffusion time $t$')
|
207 |
+
|
208 |
+
fig.tight_layout()
|
209 |
+
|
210 |
+
if to_supplement:
|
211 |
+
out_path = "/cluster/home/alugmayr/gdiff/paper/supplement/figures/jump_sched.pdf"
|
212 |
+
plt.savefig(out_path)
|
213 |
+
|
214 |
+
out_path = "./schedule.png"
|
215 |
+
plt.savefig(out_path)
|
216 |
+
print(out_path)
|
217 |
+
|
218 |
+
|
219 |
+
def main():
|
220 |
+
get_schedule_jump_test()
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
main()
|
diffusion/timestep_sampler.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as th
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def create_named_schedule_sampler(name, diffusion):
|
14 |
+
"""
|
15 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
+
:param name: the name of the sampler.
|
17 |
+
:param diffusion: the diffusion object to sample for.
|
18 |
+
"""
|
19 |
+
if name == "uniform":
|
20 |
+
return UniformSampler(diffusion)
|
21 |
+
elif name == "loss-second-moment":
|
22 |
+
return LossSecondMomentResampler(diffusion)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleSampler(ABC):
|
28 |
+
"""
|
29 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
+
variance of the objective.
|
31 |
+
By default, samplers perform unbiased importance sampling, in which the
|
32 |
+
objective's mean is unchanged.
|
33 |
+
However, subclasses may override sample() to change how the resampled
|
34 |
+
terms are reweighted, allowing for actual changes in the objective.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def weights(self):
|
39 |
+
"""
|
40 |
+
Get a numpy array of weights, one per diffusion step.
|
41 |
+
The weights needn't be normalized, but must be positive.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def sample(self, batch_size, device):
|
45 |
+
"""
|
46 |
+
Importance-sample timesteps for a batch.
|
47 |
+
:param batch_size: the number of timesteps.
|
48 |
+
:param device: the torch device to save to.
|
49 |
+
:return: a tuple (timesteps, weights):
|
50 |
+
- timesteps: a tensor of timestep indices.
|
51 |
+
- weights: a tensor of weights to scale the resulting losses.
|
52 |
+
"""
|
53 |
+
w = self.weights()
|
54 |
+
p = w / np.sum(w)
|
55 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
+
return indices, weights
|
60 |
+
|
61 |
+
|
62 |
+
class UniformSampler(ScheduleSampler):
|
63 |
+
def __init__(self, diffusion):
|
64 |
+
self.diffusion = diffusion
|
65 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
+
|
67 |
+
def weights(self):
|
68 |
+
return self._weights
|
69 |
+
|
70 |
+
|
71 |
+
class LossAwareSampler(ScheduleSampler):
|
72 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
+
"""
|
74 |
+
Update the reweighting using losses from a model.
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
:param local_ts: an integer Tensor of timesteps.
|
80 |
+
:param local_losses: a 1D Tensor of losses.
|
81 |
+
"""
|
82 |
+
batch_sizes = [
|
83 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
84 |
+
for _ in range(dist.get_world_size())
|
85 |
+
]
|
86 |
+
dist.all_gather(
|
87 |
+
batch_sizes,
|
88 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
89 |
+
)
|
90 |
+
|
91 |
+
# Pad all_gather batches to be the maximum batch size.
|
92 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
93 |
+
max_bs = max(batch_sizes)
|
94 |
+
|
95 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
96 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
97 |
+
dist.all_gather(timestep_batches, local_ts)
|
98 |
+
dist.all_gather(loss_batches, local_losses)
|
99 |
+
timesteps = [
|
100 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
101 |
+
]
|
102 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
103 |
+
self.update_with_all_losses(timesteps, losses)
|
104 |
+
|
105 |
+
@abstractmethod
|
106 |
+
def update_with_all_losses(self, ts, losses):
|
107 |
+
"""
|
108 |
+
Update the reweighting using losses from a model.
|
109 |
+
Sub-classes should override this method to update the reweighting
|
110 |
+
using losses from the model.
|
111 |
+
This method directly updates the reweighting without synchronizing
|
112 |
+
between workers. It is called by update_with_local_losses from all
|
113 |
+
ranks with identical arguments. Thus, it should have deterministic
|
114 |
+
behavior to maintain state across workers.
|
115 |
+
:param ts: a list of int timesteps.
|
116 |
+
:param losses: a list of float losses, one per timestep.
|
117 |
+
"""
|
118 |
+
|
119 |
+
|
120 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
121 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
122 |
+
self.diffusion = diffusion
|
123 |
+
self.history_per_term = history_per_term
|
124 |
+
self.uniform_prob = uniform_prob
|
125 |
+
self._loss_history = np.zeros(
|
126 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
127 |
+
)
|
128 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
129 |
+
|
130 |
+
def weights(self):
|
131 |
+
if not self._warmed_up():
|
132 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
133 |
+
weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
|
134 |
+
weights /= np.sum(weights)
|
135 |
+
weights *= 1 - self.uniform_prob
|
136 |
+
weights += self.uniform_prob / len(weights)
|
137 |
+
return weights
|
138 |
+
|
139 |
+
def update_with_all_losses(self, ts, losses):
|
140 |
+
for t, loss in zip(ts, losses):
|
141 |
+
if self._loss_counts[t] == self.history_per_term:
|
142 |
+
# Shift out the oldest loss term.
|
143 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
144 |
+
self._loss_history[t, -1] = loss
|
145 |
+
else:
|
146 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
147 |
+
self._loss_counts[t] += 1
|
148 |
+
|
149 |
+
def _warmed_up(self):
|
150 |
+
return (self._loss_counts == self.history_per_term).all()
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==0.34.2
|
3 |
+
aiofiles==23.2.1
|
4 |
+
aiohappyeyeballs==2.4.3
|
5 |
+
aiohttp==3.10.10
|
6 |
+
aiosignal==1.3.1
|
7 |
+
albumentations==0.5.2
|
8 |
+
annotated-types==0.7.0
|
9 |
+
antlr4-python3-runtime==4.9.3
|
10 |
+
anyio==4.4.0
|
11 |
+
astunparse==1.6.3
|
12 |
+
async-timeout==4.0.3
|
13 |
+
attrs==23.2.0
|
14 |
+
beautifulsoup4==4.12.3
|
15 |
+
bitsandbytes==0.44.1
|
16 |
+
boto==2.49.0
|
17 |
+
boto3==1.28.57
|
18 |
+
botocore==1.34.131
|
19 |
+
cachetools==5.5.0
|
20 |
+
certifi==2022.12.7
|
21 |
+
cffi==1.16.0
|
22 |
+
chardet==5.2.0
|
23 |
+
charset-normalizer==2.1.1
|
24 |
+
click==8.1.7
|
25 |
+
click-default-group==1.2.4
|
26 |
+
clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
|
27 |
+
cmake==3.30.3
|
28 |
+
colorlog==6.8.2
|
29 |
+
commonmark==0.9.1
|
30 |
+
contourpy==1.1.1
|
31 |
+
cycler==0.12.1
|
32 |
+
decord==0.6.0
|
33 |
+
deepspeed==0.15.1
|
34 |
+
diffusers==0.25.0
|
35 |
+
docker-pycreds==0.4.0
|
36 |
+
ego4d==1.3.2
|
37 |
+
einops==0.8.0
|
38 |
+
embreex==2.17.7.post5
|
39 |
+
envlight @ git+https://github.com/ashawkey/envlight.git@05b5851e854429d72ecaf5b206ed64ce55fae677
|
40 |
+
exceptiongroup==1.2.2
|
41 |
+
fastapi==0.112.0
|
42 |
+
ffmpy==0.4.0
|
43 |
+
filelock==3.13.1
|
44 |
+
flatbuffers==24.3.25
|
45 |
+
fonttools==4.53.1
|
46 |
+
frozenlist==1.4.1
|
47 |
+
fsspec==2024.2.0
|
48 |
+
ftfy==6.2.3
|
49 |
+
gast==0.4.0
|
50 |
+
gdown==5.2.0
|
51 |
+
gevent==23.9.1
|
52 |
+
gevent-websocket==0.10.1
|
53 |
+
gitdb==4.0.11
|
54 |
+
GitPython==3.1.43
|
55 |
+
google-auth==2.35.0
|
56 |
+
google-auth-oauthlib==1.0.0
|
57 |
+
google-pasta==0.2.0
|
58 |
+
gradio==4.40.0
|
59 |
+
gradio_client==1.2.0
|
60 |
+
greenlet==2.0.2
|
61 |
+
grpcio==1.66.1
|
62 |
+
h11==0.14.0
|
63 |
+
h5py==3.11.0
|
64 |
+
hjson==3.1.0
|
65 |
+
httpcore==1.0.5
|
66 |
+
httpx==0.27.0
|
67 |
+
huggingface-hub==0.24.5
|
68 |
+
idna==3.4
|
69 |
+
imageio==2.34.2
|
70 |
+
imageio-ffmpeg==0.5.1
|
71 |
+
imgaug==0.4.0
|
72 |
+
importlib_metadata==8.2.0
|
73 |
+
importlib_resources==6.4.0
|
74 |
+
jax==0.4.13
|
75 |
+
jaxlib==0.4.13
|
76 |
+
jaxtyping==0.2.19
|
77 |
+
Jinja2==3.1.3
|
78 |
+
jmespath==1.0.1
|
79 |
+
jsonschema==4.23.0
|
80 |
+
jsonschema-specifications==2023.12.1
|
81 |
+
keras==2.13.1
|
82 |
+
kiwisolver==1.4.5
|
83 |
+
kornia==0.7.3
|
84 |
+
kornia_rs==0.1.5
|
85 |
+
lazy_loader==0.4
|
86 |
+
libclang==18.1.1
|
87 |
+
libigl==2.5.1
|
88 |
+
lightning-utilities==0.11.8
|
89 |
+
lit==18.1.8
|
90 |
+
lxml==5.3.0
|
91 |
+
manifold3d==2.5.1
|
92 |
+
Markdown==3.7
|
93 |
+
markdown-it-py==3.0.0
|
94 |
+
MarkupSafe==2.1.5
|
95 |
+
matplotlib==3.7.5
|
96 |
+
mdurl==0.1.2
|
97 |
+
mediapipe==0.10.11
|
98 |
+
ml-dtypes==0.2.0
|
99 |
+
mpmath==1.3.0
|
100 |
+
multidict==6.1.0
|
101 |
+
mypy-extensions==1.0.0
|
102 |
+
nerfacc @ git+https://github.com/KAIR-BAIR/nerfacc.git@d84cdf3afd7dcfc42150e0f0506db58a5ce62812
|
103 |
+
networkx==3.0
|
104 |
+
ninja==1.11.1.1
|
105 |
+
numpy==1.24.1
|
106 |
+
nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@729261dc64c4241ea36efda84fbf532cc8b425b8
|
107 |
+
nvidia-cublas-cu11==11.10.3.66
|
108 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
109 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
110 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
111 |
+
nvidia-cudnn-cu11==8.5.0.96
|
112 |
+
nvidia-cufft-cu11==10.9.0.58
|
113 |
+
nvidia-curand-cu11==10.2.10.91
|
114 |
+
nvidia-cusolver-cu11==11.4.0.1
|
115 |
+
nvidia-cusparse-cu11==11.7.4.91
|
116 |
+
nvidia-nccl-cu11==2.14.3
|
117 |
+
nvidia-nvtx-cu11==11.7.91
|
118 |
+
oauthlib==3.2.2
|
119 |
+
omegaconf==2.3.0
|
120 |
+
open-clip-torch==2.7.0
|
121 |
+
opencv-contrib-python==4.10.0.84
|
122 |
+
opencv-python==4.10.0.84
|
123 |
+
opencv-python-headless==4.10.0.84
|
124 |
+
opt-einsum==3.3.0
|
125 |
+
orjson==3.10.6
|
126 |
+
packaging==24.1
|
127 |
+
pandas==2.0.3
|
128 |
+
pillow==10.2.0
|
129 |
+
pkgutil_resolve_name==1.3.10
|
130 |
+
platformdirs==4.3.6
|
131 |
+
prometheus-client==0.13.1
|
132 |
+
propcache==0.2.0
|
133 |
+
protobuf==3.20.3
|
134 |
+
psutil==6.0.0
|
135 |
+
py-cpuinfo==9.0.0
|
136 |
+
pyasn1==0.6.1
|
137 |
+
pyasn1_modules==0.4.1
|
138 |
+
pycollada==0.8
|
139 |
+
pycparser==2.22
|
140 |
+
pydantic==2.8.2
|
141 |
+
pydantic_core==2.20.1
|
142 |
+
pydub==0.25.1
|
143 |
+
Pygments==2.18.0
|
144 |
+
pyparsing==3.1.2
|
145 |
+
pyre-extensions==0.0.29
|
146 |
+
pysdf==0.1.9
|
147 |
+
PySocks==1.7.1
|
148 |
+
python-dateutil==2.9.0.post0
|
149 |
+
python-multipart==0.0.9
|
150 |
+
pytorch-lightning==2.1.0
|
151 |
+
pytz==2024.1
|
152 |
+
PyWavelets==1.4.1
|
153 |
+
PyYAML==6.0.1
|
154 |
+
referencing==0.35.1
|
155 |
+
regex==2024.7.24
|
156 |
+
requests==2.32.3
|
157 |
+
requests-oauthlib==2.0.0
|
158 |
+
rich==13.7.1
|
159 |
+
rich-click==1.6.1
|
160 |
+
rpds-py==0.20.0
|
161 |
+
rsa==4.9
|
162 |
+
Rtree==1.3.0
|
163 |
+
ruff==0.5.6
|
164 |
+
s3transfer==0.7.0
|
165 |
+
safetensors==0.4.3
|
166 |
+
scikit-image==0.21.0
|
167 |
+
scipy==1.10.1
|
168 |
+
segment-anything==1.0
|
169 |
+
semantic-version==2.10.0
|
170 |
+
sentencepiece==0.1.99
|
171 |
+
sentry-sdk==2.17.0
|
172 |
+
setproctitle==1.3.3
|
173 |
+
sh==1.14.3
|
174 |
+
shapely==2.0.6
|
175 |
+
shellingham==1.5.4
|
176 |
+
six==1.16.0
|
177 |
+
smmap==5.0.1
|
178 |
+
sniffio==1.3.1
|
179 |
+
sounddevice==0.4.7
|
180 |
+
soupsieve==2.6
|
181 |
+
starlette==0.37.2
|
182 |
+
svg.path==6.3
|
183 |
+
sympy==1.12
|
184 |
+
taming-transformers-rom1504==0.0.6
|
185 |
+
tensorboard==2.13.0
|
186 |
+
tensorboard-data-server==0.7.2
|
187 |
+
tensorflow==2.13.1
|
188 |
+
tensorflow-estimator==2.13.0
|
189 |
+
tensorflow-io-gcs-filesystem==0.34.0
|
190 |
+
termcolor==2.4.0
|
191 |
+
tifffile==2023.7.10
|
192 |
+
timm==0.9.12
|
193 |
+
tinycudann @ git+https://github.com/NVlabs/tiny-cuda-nn/@c91138bcd4c6877c8d5e60e483c0581aafc70cce#subdirectory=bindings/torch
|
194 |
+
tokenizers==0.20.0
|
195 |
+
tomlkit==0.12.0
|
196 |
+
torch==2.0.1+cu118
|
197 |
+
torchaudio==2.0.2+cu118
|
198 |
+
torchmetrics==1.5.0
|
199 |
+
torchvision==0.15.2+cu118
|
200 |
+
tqdm==4.66.4
|
201 |
+
transformers==4.45.1
|
202 |
+
trimesh==4.5.0
|
203 |
+
triton==2.0.0
|
204 |
+
typeguard==4.3.0
|
205 |
+
typer==0.12.3
|
206 |
+
typing-inspect==0.9.0
|
207 |
+
typing_extensions==4.12.2
|
208 |
+
tzdata==2024.1
|
209 |
+
urllib3==2.2.3
|
210 |
+
uvicorn==0.30.5
|
211 |
+
wandb==0.18.5
|
212 |
+
wcwidth==0.2.13
|
213 |
+
websockets==10.4
|
214 |
+
Werkzeug==3.0.4
|
215 |
+
wrapt==1.16.0
|
216 |
+
xatlas==0.0.9
|
217 |
+
xformers==0.0.20
|
218 |
+
xmltodict==0.12.0
|
219 |
+
xxhash==3.5.0
|
220 |
+
yarl==1.15.2
|
221 |
+
zipp==3.19.2
|
222 |
+
zope.event==5.0
|
223 |
+
zope.interface==6.0
|
segment_hoi.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
|
4 |
+
|
5 |
+
|
6 |
+
def show_mask(mask, ax, random_color=False):
|
7 |
+
if random_color:
|
8 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
9 |
+
else:
|
10 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
11 |
+
h, w = mask.shape[-2:]
|
12 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
13 |
+
ax.imshow(mask_image)
|
14 |
+
|
15 |
+
|
16 |
+
def show_points(coords, labels, ax, marker_size=375):
|
17 |
+
pos_points = coords[labels==1]
|
18 |
+
neg_points = coords[labels==0]
|
19 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
20 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
21 |
+
|
22 |
+
|
23 |
+
def show_box(box, ax):
|
24 |
+
x0, y0 = box[0], box[1]
|
25 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
26 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
27 |
+
|
28 |
+
|
29 |
+
def merge_bounding_boxes(bbox1, bbox2):
|
30 |
+
xmin1, ymin1, xmax1, ymax1 = bbox1
|
31 |
+
xmin2, ymin2, xmax2, ymax2 = bbox2
|
32 |
+
|
33 |
+
xmin_merged = min(xmin1, xmin2)
|
34 |
+
ymin_merged = min(ymin1, ymin2)
|
35 |
+
xmax_merged = max(xmax1, xmax2)
|
36 |
+
ymax_merged = max(ymax1, ymax2)
|
37 |
+
|
38 |
+
return np.array([xmin_merged, ymin_merged, xmax_merged, ymax_merged])
|
39 |
+
|
40 |
+
|
41 |
+
def init_sam(
|
42 |
+
device="cuda",
|
43 |
+
ckpt_path='/users/kchen157/scratch/weights/SAM/sam_vit_h_4b8939.pth'
|
44 |
+
):
|
45 |
+
sam = sam_model_registry['vit_h'](checkpoint=ckpt_path)
|
46 |
+
sam.to(device=device)
|
47 |
+
predictor = SamPredictor(sam)
|
48 |
+
return predictor
|
49 |
+
|
50 |
+
|
51 |
+
def segment_hand_and_object(
|
52 |
+
predictor,
|
53 |
+
image,
|
54 |
+
hand_kpts,
|
55 |
+
hand_mask=None,
|
56 |
+
box_shift_ratio = 0.3,
|
57 |
+
box_size_factor = 2.,
|
58 |
+
area_threshold = 0.2,
|
59 |
+
overlap_threshold = 200):
|
60 |
+
# Find bounding box for HOI
|
61 |
+
input_box = {}
|
62 |
+
for hand_type in ['right', 'left']:
|
63 |
+
if hand_type not in hand_kpts:
|
64 |
+
continue
|
65 |
+
input_box[hand_type] = np.stack([hand_kpts[hand_type].min(axis=0), hand_kpts[hand_type].max(axis=0)])
|
66 |
+
box_trans = input_box[hand_type][0] * box_shift_ratio + input_box[hand_type][1] * (1 - box_shift_ratio)
|
67 |
+
input_box[hand_type] = ((input_box[hand_type] - box_trans) * box_size_factor + box_trans).reshape(-1)
|
68 |
+
|
69 |
+
if len(input_box) == 2:
|
70 |
+
input_box = merge_bounding_boxes(input_box['right'], input_box['left'])
|
71 |
+
input_point = np.array([hand_kpts['right'][0], hand_kpts['left'][0]])
|
72 |
+
input_label = np.array([1, 1])
|
73 |
+
elif 'right' in input_box:
|
74 |
+
input_box = input_box['right']
|
75 |
+
input_point = np.array([hand_kpts['right'][0]])
|
76 |
+
input_label = np.array([1])
|
77 |
+
elif 'left' in input_box:
|
78 |
+
input_box = input_box['left']
|
79 |
+
input_point = np.array([hand_kpts['left'][0]])
|
80 |
+
input_label = np.array([1])
|
81 |
+
|
82 |
+
box_area = (input_box[2] - input_box[0]) * (input_box[3] - input_box[1])
|
83 |
+
|
84 |
+
# segment hand using the wrist point
|
85 |
+
predictor.set_image(image)
|
86 |
+
if hand_mask is None:
|
87 |
+
masks, scores, logits = predictor.predict(
|
88 |
+
point_coords=input_point,
|
89 |
+
point_labels=input_label,
|
90 |
+
multimask_output=False,
|
91 |
+
)
|
92 |
+
hand_mask = masks[0]
|
93 |
+
|
94 |
+
# segment object in hand
|
95 |
+
input_label = np.zeros_like(input_label)
|
96 |
+
masks, scores, _ = predictor.predict(
|
97 |
+
point_coords=input_point,
|
98 |
+
point_labels=input_label,
|
99 |
+
box=input_box[None, :],
|
100 |
+
multimask_output=False,
|
101 |
+
)
|
102 |
+
object_mask = masks[0]
|
103 |
+
|
104 |
+
if (masks[0].astype(int) * hand_mask).sum() > overlap_threshold:
|
105 |
+
# print('False positive: The mask overlaps the hand.')
|
106 |
+
object_mask = np.zeros_like(object_mask)
|
107 |
+
elif object_mask.astype(int).sum() / box_area > area_threshold:
|
108 |
+
# print('False positive: The area is very big, probably the background')
|
109 |
+
object_mask = np.zeros_like(object_mask)
|
110 |
+
|
111 |
+
return object_mask, hand_mask
|
utils.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from skimage.transform import resize
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from mpl_toolkits.mplot3d import Axes3D
|
8 |
+
|
9 |
+
|
10 |
+
def draw_hand3d(keypoints):
|
11 |
+
# Define the connections between keypoints as tuples (start, end)
|
12 |
+
bones = [
|
13 |
+
((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
|
14 |
+
((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
|
15 |
+
((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'blueviolet'),
|
16 |
+
((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'crimson'), ((15, 16), 'cornsilk'),
|
17 |
+
((0, 17), 'aqua'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
|
18 |
+
]
|
19 |
+
|
20 |
+
fig = plt.figure()
|
21 |
+
ax = fig.add_subplot(111, projection='3d')
|
22 |
+
|
23 |
+
# Plot the bones
|
24 |
+
for bone, color in bones:
|
25 |
+
start_point = keypoints[bone[0], :]
|
26 |
+
end_point = keypoints[bone[1], :]
|
27 |
+
|
28 |
+
ax.plot([start_point[0], end_point[0]],
|
29 |
+
[start_point[1], end_point[1]],
|
30 |
+
[start_point[2], end_point[2]], color=color)
|
31 |
+
|
32 |
+
ax.scatter(keypoints[:, 0], keypoints[:, 1], keypoints[:, 2], color='gray', s=15)
|
33 |
+
|
34 |
+
# Set the aspect ratio to be equal
|
35 |
+
max_range = np.array([keypoints[:,0].max()-keypoints[:,0].min(),
|
36 |
+
keypoints[:,1].max()-keypoints[:,1].min(),
|
37 |
+
keypoints[:,2].max()-keypoints[:,2].min()]).max() / 2.0
|
38 |
+
|
39 |
+
mid_x = (keypoints[:,0].max()+keypoints[:,0].min()) * 0.5
|
40 |
+
mid_y = (keypoints[:,1].max()+keypoints[:,1].min()) * 0.5
|
41 |
+
mid_z = (keypoints[:,2].max()+keypoints[:,2].min()) * 0.5
|
42 |
+
|
43 |
+
ax.set_xlim(mid_x - max_range, mid_x + max_range)
|
44 |
+
ax.set_ylim(mid_y - max_range, mid_y + max_range)
|
45 |
+
ax.set_zlim(mid_z - max_range, mid_z + max_range)
|
46 |
+
|
47 |
+
# Set labels for axes
|
48 |
+
ax.set_xlabel('X')
|
49 |
+
ax.set_ylabel('Y')
|
50 |
+
ax.set_zlabel('Z')
|
51 |
+
|
52 |
+
plt.show()
|
53 |
+
|
54 |
+
|
55 |
+
def visualize_hand(joints, img):
|
56 |
+
# Define the connections between joints for drawing lines and their corresponding colors
|
57 |
+
connections = [
|
58 |
+
((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
|
59 |
+
((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
|
60 |
+
((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'indigo'),
|
61 |
+
((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'navy'), ((15, 16), 'gray'),
|
62 |
+
((0, 17), 'lavender'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
|
63 |
+
]
|
64 |
+
H, W, C = img.shape
|
65 |
+
|
66 |
+
# Create a figure and axis
|
67 |
+
plt.figure()
|
68 |
+
ax = plt.gca()
|
69 |
+
# Plot joints as points
|
70 |
+
ax.imshow(img)
|
71 |
+
ax.scatter(joints[:, 0], joints[:, 1], color='white', s=15)
|
72 |
+
# Plot lines connecting joints with different colors for each bone
|
73 |
+
for connection, color in connections:
|
74 |
+
joint1 = joints[connection[0]]
|
75 |
+
joint2 = joints[connection[1]]
|
76 |
+
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
|
77 |
+
|
78 |
+
ax.set_xlim([0, W])
|
79 |
+
ax.set_ylim([0, H])
|
80 |
+
ax.grid(False)
|
81 |
+
ax.set_axis_off()
|
82 |
+
ax.invert_yaxis()
|
83 |
+
plt.subplots_adjust(wspace=0.01)
|
84 |
+
plt.show()
|
85 |
+
|
86 |
+
|
87 |
+
def draw_hand_skeleton(joints, image_size, thickness=5):
|
88 |
+
# Create a blank white image
|
89 |
+
image = np.zeros((image_size[0], image_size[1]), dtype=np.uint8)
|
90 |
+
|
91 |
+
# Define the connections between joints
|
92 |
+
connections = [
|
93 |
+
(0, 1),
|
94 |
+
(1, 2),
|
95 |
+
(2, 3),
|
96 |
+
(3, 4),
|
97 |
+
(0, 5),
|
98 |
+
(5, 6),
|
99 |
+
(6, 7),
|
100 |
+
(7, 8),
|
101 |
+
(0, 9),
|
102 |
+
(9, 10),
|
103 |
+
(10, 11),
|
104 |
+
(11, 12),
|
105 |
+
(0, 13),
|
106 |
+
(13, 14),
|
107 |
+
(14, 15),
|
108 |
+
(15, 16),
|
109 |
+
(0, 17),
|
110 |
+
(17, 18),
|
111 |
+
(18, 19),
|
112 |
+
(19, 20),
|
113 |
+
]
|
114 |
+
|
115 |
+
# Draw lines connecting joints
|
116 |
+
for connection in connections:
|
117 |
+
joint1 = joints[connection[0]].astype("int")
|
118 |
+
joint2 = joints[connection[1]].astype("int")
|
119 |
+
cv2.line(image, tuple(joint1), tuple(joint2), color=1, thickness=thickness)
|
120 |
+
|
121 |
+
return image
|
122 |
+
|
123 |
+
|
124 |
+
def draw_hand(joints, img):
|
125 |
+
# Define the connections between joints for drawing lines and their corresponding colors
|
126 |
+
connections = [
|
127 |
+
((0, 1), 'red'), ((1, 2), 'green'), ((2, 3), 'blue'), ((3, 4), 'purple'),
|
128 |
+
((0, 5), 'orange'), ((5, 6), 'pink'), ((6, 7), 'brown'), ((7, 8), 'cyan'),
|
129 |
+
((0, 9), 'yellow'), ((9, 10), 'magenta'), ((10, 11), 'lime'), ((11, 12), 'indigo'),
|
130 |
+
((0, 13), 'olive'), ((13, 14), 'teal'), ((14, 15), 'navy'), ((15, 16), 'gray'),
|
131 |
+
((0, 17), 'lavender'), ((17, 18), 'silver'), ((18, 19), 'maroon'), ((19, 20), 'fuchsia')
|
132 |
+
]
|
133 |
+
H, W, C = img.shape
|
134 |
+
|
135 |
+
# Create a figure and axis with the same size as the input image
|
136 |
+
fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100)
|
137 |
+
# Plot joints as points
|
138 |
+
ax.imshow(img)
|
139 |
+
ax.scatter(joints[:, 0], joints[:, 1], color='white', s=15)
|
140 |
+
# Plot lines connecting joints with different colors for each bone
|
141 |
+
for connection, color in connections:
|
142 |
+
joint1 = joints[connection[0]]
|
143 |
+
joint2 = joints[connection[1]]
|
144 |
+
ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
|
145 |
+
|
146 |
+
ax.set_xlim([0, W])
|
147 |
+
ax.set_ylim([0, H])
|
148 |
+
ax.grid(False)
|
149 |
+
ax.set_axis_off()
|
150 |
+
ax.invert_yaxis()
|
151 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.01, hspace=0.01)
|
152 |
+
|
153 |
+
# Save the plot to a buffer
|
154 |
+
buf = io.BytesIO()
|
155 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
156 |
+
plt.close(fig) # Close the figure to free memory
|
157 |
+
|
158 |
+
# Load the image from the buffer into a PIL image and then into a numpy array
|
159 |
+
buf.seek(0)
|
160 |
+
img_arr = np.array(Image.open(buf))
|
161 |
+
|
162 |
+
return img_arr[..., :3]
|
163 |
+
|
164 |
+
|
165 |
+
def keypoint_heatmap(pts, size, var=1.0):
|
166 |
+
H, W = size
|
167 |
+
x = np.linspace(0, W - 1, W)
|
168 |
+
y = np.linspace(0, H - 1, H)
|
169 |
+
xv, yv = np.meshgrid(x, y)
|
170 |
+
grid = np.stack((xv, yv), axis=-1)
|
171 |
+
|
172 |
+
# Expanding dims for broadcasting subtraction between pts and every grid position
|
173 |
+
modes_exp = np.expand_dims(np.expand_dims(pts, axis=1), axis=1)
|
174 |
+
|
175 |
+
# Calculating squared difference
|
176 |
+
diff = grid - modes_exp
|
177 |
+
normal = np.exp(-np.sum(diff**2, axis=-1) / (2 * var)) / (
|
178 |
+
2.0 * np.pi * var
|
179 |
+
)
|
180 |
+
return normal
|
181 |
+
|
182 |
+
|
183 |
+
def check_keypoints_validity(keypoints, image_size):
|
184 |
+
H, W = image_size
|
185 |
+
# Check if x coordinates are valid: 0 < x < W
|
186 |
+
valid_x = (keypoints[:, 0] > 0) & (keypoints[:, 0] < W)
|
187 |
+
|
188 |
+
# Check if y coordinates are valid: 0 < y < H
|
189 |
+
valid_y = (keypoints[:, 1] > 0) & (keypoints[:, 1] < H)
|
190 |
+
|
191 |
+
# Combine the validity checks for both x and y
|
192 |
+
valid_keypoints = valid_x & valid_y
|
193 |
+
|
194 |
+
# Convert boolean array to integer (1 for True, 0 for False)
|
195 |
+
return valid_keypoints.astype(int)
|
196 |
+
|
197 |
+
|
198 |
+
def find_bounding_box(mask, margin=30):
|
199 |
+
"""Find the bounding box of a binary mask. Return None if the mask is empty."""
|
200 |
+
rows = np.any(mask, axis=1)
|
201 |
+
cols = np.any(mask, axis=0)
|
202 |
+
if not rows.any() or not cols.any(): # Mask is empty
|
203 |
+
return None
|
204 |
+
ymin, ymax = np.where(rows)[0][[0, -1]]
|
205 |
+
xmin, xmax = np.where(cols)[0][[0, -1]]
|
206 |
+
xmin -= margin
|
207 |
+
xmax += margin
|
208 |
+
ymin -= margin
|
209 |
+
ymax += margin
|
210 |
+
return xmin, ymin, xmax, ymax
|
211 |
+
|
212 |
+
|
213 |
+
def adjust_box_to_image(xmin, ymin, xmax, ymax, image_width, image_height):
|
214 |
+
"""Adjust the bounding box to fit within the image boundaries."""
|
215 |
+
box_width = xmax - xmin
|
216 |
+
box_height = ymax - ymin
|
217 |
+
# Determine the side length of the square (the larger of the two dimensions)
|
218 |
+
side_length = max(box_width, box_height)
|
219 |
+
|
220 |
+
# Adjust to maintain a square by expanding or contracting sides
|
221 |
+
xmin = max(0, xmin - (side_length - box_width) // 2)
|
222 |
+
xmax = xmin + side_length
|
223 |
+
ymin = max(0, ymin - (side_length - box_height) // 2)
|
224 |
+
ymax = ymin + side_length
|
225 |
+
|
226 |
+
# Ensure the box is still within the image boundaries after adjustments
|
227 |
+
if xmax > image_width:
|
228 |
+
shift = xmax - image_width
|
229 |
+
xmin -= shift
|
230 |
+
xmax -= shift
|
231 |
+
if ymax > image_height:
|
232 |
+
shift = ymax - image_height
|
233 |
+
ymin -= shift
|
234 |
+
ymax -= shift
|
235 |
+
|
236 |
+
# After shifting, double-check if any side is out-of-bounds and adjust if necessary
|
237 |
+
xmin = max(0, xmin)
|
238 |
+
ymin = max(0, ymin)
|
239 |
+
xmax = min(image_width, xmax)
|
240 |
+
ymax = min(image_height, ymax)
|
241 |
+
|
242 |
+
# It's possible the adjustments made the box not square (due to boundary constraints),
|
243 |
+
# so we might need to slightly adjust the size to keep it as square as possible
|
244 |
+
# This could involve a final adjustment based on the specific requirements,
|
245 |
+
# like reducing the side length to fit or deciding which dimension to prioritize.
|
246 |
+
|
247 |
+
return xmin, ymin, xmax, ymax
|
248 |
+
|
249 |
+
|
250 |
+
def scale_keypoint(keypoint, original_size, target_size):
|
251 |
+
"""Scale a keypoint based on the resizing of the image."""
|
252 |
+
keypoint_copy = keypoint.copy()
|
253 |
+
keypoint_copy[:, 0] *= target_size[0] / original_size[0]
|
254 |
+
keypoint_copy[:, 1] *= target_size[1] / original_size[1]
|
255 |
+
return keypoint_copy
|
256 |
+
|
257 |
+
|
258 |
+
def crop_and_adjust_image_and_annotations(image, hand_mask, obj_mask, hand_pose, intrinsics, target_size=(512, 512)):
|
259 |
+
# Find bounding boxes for each mask, handling potentially empty masks
|
260 |
+
xmin, ymin, xmax, ymax = find_bounding_box(hand_mask) if np.any(hand_mask) else None
|
261 |
+
|
262 |
+
# Adjust bounding box to fit within the image and be square
|
263 |
+
xmin, ymin, xmax, ymax = adjust_box_to_image(xmin, ymin, xmax, ymax, image.shape[1], image.shape[0])
|
264 |
+
|
265 |
+
# Crop the image and mask
|
266 |
+
# masked_hand_image = (image * np.maximum(hand_mask, obj_mask)[..., None].astype(float)).astype(np.uint8)
|
267 |
+
cropped_hand_image = image[ymin:ymax, xmin:xmax]
|
268 |
+
cropped_hand_mask = hand_mask[ymin:ymax, xmin:xmax].astype(np.uint8)
|
269 |
+
cropped_obj_mask = obj_mask[ymin:ymax, xmin:xmax].astype(np.uint8)
|
270 |
+
|
271 |
+
# Resize the image
|
272 |
+
resized_image = resize(cropped_hand_image, target_size, anti_aliasing=True)
|
273 |
+
resized_hand_mask = cv2.resize(cropped_hand_mask, dsize=target_size, interpolation=cv2.INTER_NEAREST)
|
274 |
+
resized_obj_mask = cv2.resize(cropped_obj_mask, dsize=target_size, interpolation=cv2.INTER_NEAREST)
|
275 |
+
|
276 |
+
# adjust and scale 2d keypoints
|
277 |
+
for hand_type, kps2d in hand_pose.items():
|
278 |
+
kps2d[:, 0] -= xmin
|
279 |
+
kps2d[:, 1] -= ymin
|
280 |
+
hand_pose[hand_type] = scale_keypoint(kps2d, (xmax - xmin, ymax - ymin), target_size)
|
281 |
+
|
282 |
+
# adjust instrinsics
|
283 |
+
resized_intrinsics= np.array(intrinsics, copy=True)
|
284 |
+
resized_intrinsics[0, 2] -= xmin
|
285 |
+
resized_intrinsics[1, 2] -= ymin
|
286 |
+
resized_intrinsics[0, :] *= target_size[0] / (xmax - xmin)
|
287 |
+
resized_intrinsics[1, :] *= target_size[1] / (ymax - ymin)
|
288 |
+
|
289 |
+
return (resized_image, resized_hand_mask, resized_obj_mask, hand_pose, resized_intrinsics)
|
vit.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
6 |
+
|
7 |
+
|
8 |
+
def modulate(x, shift, scale):
|
9 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
10 |
+
|
11 |
+
|
12 |
+
#################################################################################
|
13 |
+
# Embedding Layers for Timesteps and Class Labels #
|
14 |
+
#################################################################################
|
15 |
+
|
16 |
+
class TimestepEmbedder(nn.Module):
|
17 |
+
"""
|
18 |
+
Embeds scalar timesteps into vector representations.
|
19 |
+
"""
|
20 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
21 |
+
super().__init__()
|
22 |
+
self.mlp = nn.Sequential(
|
23 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
24 |
+
nn.SiLU(),
|
25 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
26 |
+
)
|
27 |
+
self.frequency_embedding_size = frequency_embedding_size
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def timestep_embedding(t, dim, max_period=10000):
|
31 |
+
"""
|
32 |
+
Create sinusoidal timestep embeddings.
|
33 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
34 |
+
These may be fractional.
|
35 |
+
:param dim: the dimension of the output.
|
36 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
37 |
+
:return: an (N, D) Tensor of positional embeddings.
|
38 |
+
"""
|
39 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
40 |
+
half = dim // 2
|
41 |
+
freqs = torch.exp(
|
42 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
43 |
+
).to(device=t.device)
|
44 |
+
args = t[:, None].float() * freqs[None]
|
45 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
46 |
+
if dim % 2:
|
47 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
48 |
+
return embedding
|
49 |
+
|
50 |
+
def forward(self, t):
|
51 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
52 |
+
t_emb = self.mlp(t_freq)
|
53 |
+
return t_emb
|
54 |
+
|
55 |
+
|
56 |
+
class LabelEmbedder(nn.Module):
|
57 |
+
"""
|
58 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
59 |
+
"""
|
60 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
61 |
+
super().__init__()
|
62 |
+
use_cfg_embedding = dropout_prob > 0
|
63 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
64 |
+
self.num_classes = num_classes
|
65 |
+
self.dropout_prob = dropout_prob
|
66 |
+
|
67 |
+
def token_drop(self, labels, force_drop_ids=None):
|
68 |
+
"""
|
69 |
+
Drops labels to enable classifier-free guidance.
|
70 |
+
"""
|
71 |
+
if force_drop_ids is None:
|
72 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
73 |
+
else:
|
74 |
+
drop_ids = force_drop_ids == 1
|
75 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
76 |
+
return labels
|
77 |
+
|
78 |
+
def forward(self, labels, train, force_drop_ids=None):
|
79 |
+
use_dropout = self.dropout_prob > 0
|
80 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
81 |
+
labels = self.token_drop(labels, force_drop_ids)
|
82 |
+
embeddings = self.embedding_table(labels)
|
83 |
+
return embeddings
|
84 |
+
|
85 |
+
|
86 |
+
class DiTBlock(nn.Module):
|
87 |
+
"""
|
88 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
89 |
+
"""
|
90 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
91 |
+
super().__init__()
|
92 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
93 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
94 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
95 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
96 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
97 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
98 |
+
self.adaLN_modulation = nn.Sequential(
|
99 |
+
nn.SiLU(),
|
100 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
101 |
+
)
|
102 |
+
|
103 |
+
def forward(self, x, c):
|
104 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
105 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
106 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class FinalLayer(nn.Module):
|
111 |
+
"""
|
112 |
+
The final layer of DiT.
|
113 |
+
"""
|
114 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
115 |
+
super().__init__()
|
116 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
117 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
118 |
+
self.adaLN_modulation = nn.Sequential(
|
119 |
+
nn.SiLU(),
|
120 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
121 |
+
)
|
122 |
+
|
123 |
+
def forward(self, x, c):
|
124 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
125 |
+
x = modulate(self.norm_final(x), shift, scale)
|
126 |
+
x = self.linear(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class DiT(nn.Module):
|
131 |
+
"""
|
132 |
+
Diffusion model with a Transformer backbone.
|
133 |
+
"""
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
input_size=32,
|
137 |
+
patch_size=2,
|
138 |
+
latent_dim=4,
|
139 |
+
in_channels=47,
|
140 |
+
hidden_size=1152,
|
141 |
+
depth=28,
|
142 |
+
num_heads=16,
|
143 |
+
mlp_ratio=4.0,
|
144 |
+
learn_sigma=True,
|
145 |
+
):
|
146 |
+
super().__init__()
|
147 |
+
self.learn_sigma = learn_sigma
|
148 |
+
self.in_channels = in_channels
|
149 |
+
self.out_channels = latent_dim * 2 if learn_sigma else latent_dim
|
150 |
+
self.patch_size = patch_size
|
151 |
+
self.num_heads = num_heads
|
152 |
+
|
153 |
+
#self.x_embedder = PatchEmbed(input_size, patch_size, latent_dim, hidden_size, bias=True)
|
154 |
+
self.feature_aligned_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
155 |
+
|
156 |
+
self.n_patches = self.feature_aligned_embedder.num_patches
|
157 |
+
self.patch_size = self.feature_aligned_embedder.patch_size[0]
|
158 |
+
|
159 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
160 |
+
self.nvs_label_embedder = LabelEmbedder(3, hidden_size, 0.)
|
161 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, 2 * self.n_patches, hidden_size), requires_grad=True)
|
162 |
+
self.y_embedder = LabelEmbedder(num_classes=1000, hidden_size=hidden_size, dropout_prob=0.1)
|
163 |
+
|
164 |
+
self.blocks = nn.ModuleList([
|
165 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
166 |
+
])
|
167 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
168 |
+
self.initialize_weights()
|
169 |
+
|
170 |
+
def initialize_weights(self):
|
171 |
+
# Initialize transformer layers:
|
172 |
+
def _basic_init(module):
|
173 |
+
if isinstance(module, nn.Linear):
|
174 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
175 |
+
if module.bias is not None:
|
176 |
+
nn.init.constant_(module.bias, 0)
|
177 |
+
self.apply(_basic_init)
|
178 |
+
|
179 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
180 |
+
grid_size = int(self.n_patches ** 0.5)
|
181 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (2 * grid_size, grid_size))
|
182 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
183 |
+
|
184 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
185 |
+
#w = self.x_embedder.proj.weight.data
|
186 |
+
#nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
187 |
+
#nn.init.constant_(self.x_embedder.proj.bias, 0)
|
188 |
+
|
189 |
+
w = self.feature_aligned_embedder.proj.weight.data
|
190 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
191 |
+
nn.init.constant_(self.feature_aligned_embedder.proj.bias, 0)
|
192 |
+
|
193 |
+
# Initialize label embedding table:
|
194 |
+
nn.init.normal_(self.nvs_label_embedder.embedding_table.weight, std=0.02)
|
195 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
196 |
+
|
197 |
+
# Initialize timestep embedding MLP:
|
198 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
199 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
200 |
+
|
201 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
202 |
+
for block in self.blocks:
|
203 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
204 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
205 |
+
|
206 |
+
# Zero-out output layers:
|
207 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
208 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
209 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
210 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
211 |
+
|
212 |
+
def unpatchify(self, x):
|
213 |
+
"""
|
214 |
+
x: (N, T, patch_size**2 * C)
|
215 |
+
imgs: (N, H, W, C)
|
216 |
+
"""
|
217 |
+
c = self.out_channels
|
218 |
+
p = self.patch_size
|
219 |
+
h = w = int(x.shape[1] ** 0.5)
|
220 |
+
assert h * w == x.shape[1]
|
221 |
+
|
222 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
223 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
224 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
225 |
+
return imgs
|
226 |
+
|
227 |
+
def forward(self, x_t, t, target_cond, ref_cond, nvs, y=None):
|
228 |
+
"""
|
229 |
+
Forward pass of DiT.
|
230 |
+
x: (N, C1, H, W) denoising latent + target pose control
|
231 |
+
cond: (N, C2, H, W) source latent + source pose control + mask
|
232 |
+
t: (N,) tensor of diffusion timesteps
|
233 |
+
y: (N,) tensor of class labels
|
234 |
+
"""
|
235 |
+
x = self.feature_aligned_embedder(torch.concat([x_t, target_cond], 1)) + self.pos_embed[:, :self.n_patches]
|
236 |
+
cond = self.feature_aligned_embedder(ref_cond) + self.pos_embed[:, self.n_patches:]
|
237 |
+
x = torch.concatenate([x, cond], 1)
|
238 |
+
|
239 |
+
t = self.t_embedder(t) # (N, D)
|
240 |
+
nvs = self.nvs_label_embedder(nvs, False)
|
241 |
+
if y is None:
|
242 |
+
y = torch.tensor([1000] * x.shape[0], device=x.device)
|
243 |
+
y = self.y_embedder(y, False) # (N, D)
|
244 |
+
c = t + y + nvs # (N, D)
|
245 |
+
for block in self.blocks:
|
246 |
+
x = block(x, c) # (N, 2T, D)
|
247 |
+
x = x[:, :x.shape[1]//2]
|
248 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
249 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
250 |
+
return x
|
251 |
+
|
252 |
+
def forward_with_cfg(self, x, t, target_cond, ref_cond, nvs, cfg_scale):
|
253 |
+
half = x[: len(x) // 2]
|
254 |
+
combined = torch.cat([half, half], dim=0)
|
255 |
+
y_null = torch.tensor([1000] * half.shape[0], device=x.device)
|
256 |
+
y = torch.cat([y_null, y_null], 0)
|
257 |
+
model_out = self.forward(combined, t, target_cond, ref_cond, nvs, y)
|
258 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
259 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
260 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
261 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
262 |
+
return torch.cat([eps, rest], dim=1)
|
263 |
+
|
264 |
+
#################################################################################
|
265 |
+
# Sine/Cosine Positional Embedding Functions #
|
266 |
+
#################################################################################
|
267 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
268 |
+
|
269 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
270 |
+
"""
|
271 |
+
grid_size: int of the grid height and width
|
272 |
+
return:
|
273 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
274 |
+
"""
|
275 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32)
|
276 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32)
|
277 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
278 |
+
grid = np.stack(grid, axis=0)
|
279 |
+
|
280 |
+
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
|
281 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
282 |
+
if cls_token and extra_tokens > 0:
|
283 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
284 |
+
return pos_embed
|
285 |
+
|
286 |
+
|
287 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
288 |
+
assert embed_dim % 2 == 0
|
289 |
+
|
290 |
+
# use half of dimensions to encode grid_h
|
291 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
292 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
293 |
+
|
294 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
295 |
+
return emb
|
296 |
+
|
297 |
+
|
298 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
299 |
+
"""
|
300 |
+
embed_dim: output dimension for each position
|
301 |
+
pos: a list of positions to be encoded: size (M,)
|
302 |
+
out: (M, D)
|
303 |
+
"""
|
304 |
+
assert embed_dim % 2 == 0
|
305 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
306 |
+
omega /= embed_dim / 2.
|
307 |
+
omega = 1. / 10000**omega # (D/2,)
|
308 |
+
|
309 |
+
pos = pos.reshape(-1) # (M,)
|
310 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
311 |
+
|
312 |
+
emb_sin = np.sin(out) # (M, D/2)
|
313 |
+
emb_cos = np.cos(out) # (M, D/2)
|
314 |
+
|
315 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
316 |
+
return emb
|
317 |
+
|
318 |
+
|
319 |
+
def DiT_XL_2(**kwargs):
|
320 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
321 |
+
|
322 |
+
def DiT_L_2(**kwargs):
|
323 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
vqvae.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
---
|
3 |
+
title: Autoencoder for Stable Diffusion
|
4 |
+
summary: >
|
5 |
+
Annotated PyTorch implementation/tutorial of the autoencoder
|
6 |
+
for stable diffusion.
|
7 |
+
---
|
8 |
+
|
9 |
+
# Autoencoder for [Stable Diffusion](../index.html)
|
10 |
+
|
11 |
+
This implements the auto-encoder model used to map between image space and latent space.
|
12 |
+
|
13 |
+
We have kept to the model definition and naming unchanged from
|
14 |
+
[CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion)
|
15 |
+
so that we can load the checkpoints directly.
|
16 |
+
"""
|
17 |
+
|
18 |
+
from typing import List
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch import nn
|
23 |
+
|
24 |
+
|
25 |
+
class Autoencoder(nn.Module):
|
26 |
+
"""
|
27 |
+
## Autoencoder
|
28 |
+
|
29 |
+
This consists of the encoder and decoder modules.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self, encoder: "Encoder", decoder: "Decoder", emb_channels: int, z_channels: int
|
34 |
+
):
|
35 |
+
"""
|
36 |
+
:param encoder: is the encoder
|
37 |
+
:param decoder: is the decoder
|
38 |
+
:param emb_channels: is the number of dimensions in the quantized embedding space
|
39 |
+
:param z_channels: is the number of channels in the embedding space
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
self.encoder = encoder
|
43 |
+
self.decoder = decoder
|
44 |
+
# Convolution to map from embedding space to
|
45 |
+
# quantized embedding space moments (mean and log variance)
|
46 |
+
self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1)
|
47 |
+
# Convolution to map from quantized embedding space back to
|
48 |
+
# embedding space
|
49 |
+
self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1)
|
50 |
+
|
51 |
+
def encode(self, img: torch.Tensor) -> "GaussianDistribution":
|
52 |
+
"""
|
53 |
+
### Encode images to latent representation
|
54 |
+
|
55 |
+
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
56 |
+
"""
|
57 |
+
# Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]`
|
58 |
+
z = self.encoder(img)
|
59 |
+
# Get the moments in the quantized embedding space
|
60 |
+
moments = self.quant_conv(z)
|
61 |
+
# Return the distribution
|
62 |
+
return GaussianDistribution(moments)
|
63 |
+
|
64 |
+
def decode(self, z: torch.Tensor):
|
65 |
+
"""
|
66 |
+
### Decode images from latent representation
|
67 |
+
|
68 |
+
:param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]`
|
69 |
+
"""
|
70 |
+
# Map to embedding space from the quantized representation
|
71 |
+
z = self.post_quant_conv(z)
|
72 |
+
# Decode the image of shape `[batch_size, channels, height, width]`
|
73 |
+
return self.decoder(z)
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
posterior = self.encode(x)
|
77 |
+
z = posterior.sample()
|
78 |
+
dec = self.decode(z)
|
79 |
+
return dec, posterior
|
80 |
+
|
81 |
+
|
82 |
+
class Encoder(nn.Module):
|
83 |
+
"""
|
84 |
+
## Encoder module
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
*,
|
90 |
+
channels: int,
|
91 |
+
channel_multipliers: List[int],
|
92 |
+
n_resnet_blocks: int,
|
93 |
+
in_channels: int,
|
94 |
+
z_channels: int
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
:param channels: is the number of channels in the first convolution layer
|
98 |
+
:param channel_multipliers: are the multiplicative factors for the number of channels in the
|
99 |
+
subsequent blocks
|
100 |
+
:param n_resnet_blocks: is the number of resnet layers at each resolution
|
101 |
+
:param in_channels: is the number of channels in the image
|
102 |
+
:param z_channels: is the number of channels in the embedding space
|
103 |
+
"""
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
# Number of blocks of different resolutions.
|
107 |
+
# The resolution is halved at the end each top level block
|
108 |
+
n_resolutions = len(channel_multipliers)
|
109 |
+
|
110 |
+
# Initial $3 \times 3$ convolution layer that maps the image to `channels`
|
111 |
+
self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1)
|
112 |
+
|
113 |
+
# Number of channels in each top level block
|
114 |
+
channels_list = [m * channels for m in [1] + channel_multipliers]
|
115 |
+
|
116 |
+
# List of top-level blocks
|
117 |
+
self.down = nn.ModuleList()
|
118 |
+
# Create top-level blocks
|
119 |
+
for i in range(n_resolutions):
|
120 |
+
# Each top level block consists of multiple ResNet Blocks and down-sampling
|
121 |
+
resnet_blocks = nn.ModuleList()
|
122 |
+
# Add ResNet Blocks
|
123 |
+
for _ in range(n_resnet_blocks):
|
124 |
+
resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1]))
|
125 |
+
channels = channels_list[i + 1]
|
126 |
+
# Top-level block
|
127 |
+
down = nn.Module()
|
128 |
+
down.block = resnet_blocks
|
129 |
+
# Down-sampling at the end of each top level block except the last
|
130 |
+
if i != n_resolutions - 1:
|
131 |
+
down.downsample = DownSample(channels)
|
132 |
+
else:
|
133 |
+
down.downsample = nn.Identity()
|
134 |
+
#
|
135 |
+
self.down.append(down)
|
136 |
+
|
137 |
+
# Final ResNet blocks with attention
|
138 |
+
self.mid = nn.Module()
|
139 |
+
self.mid.block_1 = ResnetBlock(channels, channels)
|
140 |
+
self.mid.attn_1 = AttnBlock(channels)
|
141 |
+
self.mid.block_2 = ResnetBlock(channels, channels)
|
142 |
+
|
143 |
+
# Map to embedding space with a $3 \times 3$ convolution
|
144 |
+
self.norm_out = normalization(channels)
|
145 |
+
self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1)
|
146 |
+
|
147 |
+
def forward(self, img: torch.Tensor):
|
148 |
+
"""
|
149 |
+
:param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]`
|
150 |
+
"""
|
151 |
+
|
152 |
+
# Map to `channels` with the initial convolution
|
153 |
+
x = self.conv_in(img)
|
154 |
+
|
155 |
+
# Top-level blocks
|
156 |
+
for down in self.down:
|
157 |
+
# ResNet Blocks
|
158 |
+
for block in down.block:
|
159 |
+
x = block(x)
|
160 |
+
# Down-sampling
|
161 |
+
x = down.downsample(x)
|
162 |
+
|
163 |
+
# Final ResNet blocks with attention
|
164 |
+
x = self.mid.block_1(x)
|
165 |
+
x = self.mid.attn_1(x)
|
166 |
+
x = self.mid.block_2(x)
|
167 |
+
|
168 |
+
# Normalize and map to embedding space
|
169 |
+
x = self.norm_out(x)
|
170 |
+
x = swish(x)
|
171 |
+
x = self.conv_out(x)
|
172 |
+
|
173 |
+
#
|
174 |
+
return x
|
175 |
+
|
176 |
+
|
177 |
+
class Decoder(nn.Module):
|
178 |
+
"""
|
179 |
+
## Decoder module
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(
|
183 |
+
self,
|
184 |
+
*,
|
185 |
+
channels: int,
|
186 |
+
channel_multipliers: List[int],
|
187 |
+
n_resnet_blocks: int,
|
188 |
+
out_channels: int,
|
189 |
+
z_channels: int
|
190 |
+
):
|
191 |
+
"""
|
192 |
+
:param channels: is the number of channels in the final convolution layer
|
193 |
+
:param channel_multipliers: are the multiplicative factors for the number of channels in the
|
194 |
+
previous blocks, in reverse order
|
195 |
+
:param n_resnet_blocks: is the number of resnet layers at each resolution
|
196 |
+
:param out_channels: is the number of channels in the image
|
197 |
+
:param z_channels: is the number of channels in the embedding space
|
198 |
+
"""
|
199 |
+
super().__init__()
|
200 |
+
|
201 |
+
# Number of blocks of different resolutions.
|
202 |
+
# The resolution is halved at the end each top level block
|
203 |
+
num_resolutions = len(channel_multipliers)
|
204 |
+
|
205 |
+
# Number of channels in each top level block, in the reverse order
|
206 |
+
channels_list = [m * channels for m in channel_multipliers]
|
207 |
+
|
208 |
+
# Number of channels in the top-level block
|
209 |
+
channels = channels_list[-1]
|
210 |
+
|
211 |
+
# Initial $3 \times 3$ convolution layer that maps the embedding space to `channels`
|
212 |
+
self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1)
|
213 |
+
|
214 |
+
# ResNet blocks with attention
|
215 |
+
self.mid = nn.Module()
|
216 |
+
self.mid.block_1 = ResnetBlock(channels, channels)
|
217 |
+
self.mid.attn_1 = AttnBlock(channels)
|
218 |
+
self.mid.block_2 = ResnetBlock(channels, channels)
|
219 |
+
|
220 |
+
# List of top-level blocks
|
221 |
+
self.up = nn.ModuleList()
|
222 |
+
# Create top-level blocks
|
223 |
+
for i in reversed(range(num_resolutions)):
|
224 |
+
# Each top level block consists of multiple ResNet Blocks and up-sampling
|
225 |
+
resnet_blocks = nn.ModuleList()
|
226 |
+
# Add ResNet Blocks
|
227 |
+
for _ in range(n_resnet_blocks + 1):
|
228 |
+
resnet_blocks.append(ResnetBlock(channels, channels_list[i]))
|
229 |
+
channels = channels_list[i]
|
230 |
+
# Top-level block
|
231 |
+
up = nn.Module()
|
232 |
+
up.block = resnet_blocks
|
233 |
+
# Up-sampling at the end of each top level block except the first
|
234 |
+
if i != 0:
|
235 |
+
up.upsample = UpSample(channels)
|
236 |
+
else:
|
237 |
+
up.upsample = nn.Identity()
|
238 |
+
# Prepend to be consistent with the checkpoint
|
239 |
+
self.up.insert(0, up)
|
240 |
+
|
241 |
+
# Map to image space with a $3 \times 3$ convolution
|
242 |
+
self.norm_out = normalization(channels)
|
243 |
+
self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1)
|
244 |
+
|
245 |
+
def forward(self, z: torch.Tensor):
|
246 |
+
"""
|
247 |
+
:param z: is the embedding tensor with shape `[batch_size, z_channels, z_height, z_height]`
|
248 |
+
"""
|
249 |
+
|
250 |
+
# Map to `channels` with the initial convolution
|
251 |
+
h = self.conv_in(z)
|
252 |
+
|
253 |
+
# ResNet blocks with attention
|
254 |
+
h = self.mid.block_1(h)
|
255 |
+
h = self.mid.attn_1(h)
|
256 |
+
h = self.mid.block_2(h)
|
257 |
+
|
258 |
+
# Top-level blocks
|
259 |
+
for up in reversed(self.up):
|
260 |
+
# ResNet Blocks
|
261 |
+
for block in up.block:
|
262 |
+
h = block(h)
|
263 |
+
# Up-sampling
|
264 |
+
h = up.upsample(h)
|
265 |
+
|
266 |
+
# Normalize and map to image space
|
267 |
+
h = self.norm_out(h)
|
268 |
+
h = swish(h)
|
269 |
+
img = self.conv_out(h)
|
270 |
+
|
271 |
+
#
|
272 |
+
return img
|
273 |
+
|
274 |
+
|
275 |
+
class GaussianDistribution:
|
276 |
+
"""
|
277 |
+
## Gaussian Distribution
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, parameters: torch.Tensor):
|
281 |
+
"""
|
282 |
+
:param parameters: are the means and log of variances of the embedding of shape
|
283 |
+
`[batch_size, z_channels * 2, z_height, z_height]`
|
284 |
+
"""
|
285 |
+
# Split mean and log of variance
|
286 |
+
self.mean, log_var = torch.chunk(parameters, 2, dim=1)
|
287 |
+
# Clamp the log of variances
|
288 |
+
self.log_var = torch.clamp(log_var, -30.0, 20.0)
|
289 |
+
# Calculate standard deviation
|
290 |
+
self.std = torch.exp(0.5 * self.log_var)
|
291 |
+
self.var = torch.exp(self.log_var)
|
292 |
+
|
293 |
+
def sample(self):
|
294 |
+
# Sample from the distribution
|
295 |
+
return self.mean + self.std * torch.randn_like(self.std)
|
296 |
+
|
297 |
+
def kl(self):
|
298 |
+
return 0.5 * torch.sum(
|
299 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.log_var, dim=[1, 2, 3]
|
300 |
+
)
|
301 |
+
|
302 |
+
|
303 |
+
class AttnBlock(nn.Module):
|
304 |
+
"""
|
305 |
+
## Attention block
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(self, channels: int):
|
309 |
+
"""
|
310 |
+
:param channels: is the number of channels
|
311 |
+
"""
|
312 |
+
super().__init__()
|
313 |
+
# Group normalization
|
314 |
+
self.norm = normalization(channels)
|
315 |
+
# Query, key and value mappings
|
316 |
+
self.q = nn.Conv2d(channels, channels, 1)
|
317 |
+
self.k = nn.Conv2d(channels, channels, 1)
|
318 |
+
self.v = nn.Conv2d(channels, channels, 1)
|
319 |
+
# Final $1 \times 1$ convolution layer
|
320 |
+
self.proj_out = nn.Conv2d(channels, channels, 1)
|
321 |
+
# Attention scaling factor
|
322 |
+
self.scale = channels**-0.5
|
323 |
+
|
324 |
+
def forward(self, x: torch.Tensor):
|
325 |
+
"""
|
326 |
+
:param x: is the tensor of shape `[batch_size, channels, height, width]`
|
327 |
+
"""
|
328 |
+
# Normalize `x`
|
329 |
+
x_norm = self.norm(x)
|
330 |
+
# Get query, key and vector embeddings
|
331 |
+
q = self.q(x_norm)
|
332 |
+
k = self.k(x_norm)
|
333 |
+
v = self.v(x_norm)
|
334 |
+
|
335 |
+
# Reshape to query, key and vector embeedings from
|
336 |
+
# `[batch_size, channels, height, width]` to
|
337 |
+
# `[batch_size, channels, height * width]`
|
338 |
+
b, c, h, w = q.shape
|
339 |
+
q = q.view(b, c, h * w)
|
340 |
+
k = k.view(b, c, h * w)
|
341 |
+
v = v.view(b, c, h * w)
|
342 |
+
|
343 |
+
# Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$
|
344 |
+
attn = torch.einsum("bci,bcj->bij", q, k) * self.scale
|
345 |
+
attn = F.softmax(attn, dim=2)
|
346 |
+
|
347 |
+
# Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$
|
348 |
+
out = torch.einsum("bij,bcj->bci", attn, v)
|
349 |
+
|
350 |
+
# Reshape back to `[batch_size, channels, height, width]`
|
351 |
+
out = out.view(b, c, h, w)
|
352 |
+
# Final $1 \times 1$ convolution layer
|
353 |
+
out = self.proj_out(out)
|
354 |
+
|
355 |
+
# Add residual connection
|
356 |
+
return x + out
|
357 |
+
|
358 |
+
|
359 |
+
class UpSample(nn.Module):
|
360 |
+
"""
|
361 |
+
## Up-sampling layer
|
362 |
+
"""
|
363 |
+
|
364 |
+
def __init__(self, channels: int):
|
365 |
+
"""
|
366 |
+
:param channels: is the number of channels
|
367 |
+
"""
|
368 |
+
super().__init__()
|
369 |
+
# $3 \times 3$ convolution mapping
|
370 |
+
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
371 |
+
|
372 |
+
def forward(self, x: torch.Tensor):
|
373 |
+
"""
|
374 |
+
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
375 |
+
"""
|
376 |
+
# Up-sample by a factor of $2$
|
377 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
378 |
+
# Apply convolution
|
379 |
+
return self.conv(x)
|
380 |
+
|
381 |
+
|
382 |
+
class DownSample(nn.Module):
|
383 |
+
"""
|
384 |
+
## Down-sampling layer
|
385 |
+
"""
|
386 |
+
|
387 |
+
def __init__(self, channels: int):
|
388 |
+
"""
|
389 |
+
:param channels: is the number of channels
|
390 |
+
"""
|
391 |
+
super().__init__()
|
392 |
+
# $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$
|
393 |
+
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0)
|
394 |
+
|
395 |
+
def forward(self, x: torch.Tensor):
|
396 |
+
"""
|
397 |
+
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
398 |
+
"""
|
399 |
+
# Add padding
|
400 |
+
x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
|
401 |
+
# Apply convolution
|
402 |
+
return self.conv(x)
|
403 |
+
|
404 |
+
|
405 |
+
class ResnetBlock(nn.Module):
|
406 |
+
"""
|
407 |
+
## ResNet Block
|
408 |
+
"""
|
409 |
+
|
410 |
+
def __init__(self, in_channels: int, out_channels: int):
|
411 |
+
"""
|
412 |
+
:param in_channels: is the number of channels in the input
|
413 |
+
:param out_channels: is the number of channels in the output
|
414 |
+
"""
|
415 |
+
super().__init__()
|
416 |
+
# First normalization and convolution layer
|
417 |
+
self.norm1 = normalization(in_channels)
|
418 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
|
419 |
+
# Second normalization and convolution layer
|
420 |
+
self.norm2 = normalization(out_channels)
|
421 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
|
422 |
+
# `in_channels` to `out_channels` mapping layer for residual connection
|
423 |
+
if in_channels != out_channels:
|
424 |
+
self.nin_shortcut = nn.Conv2d(
|
425 |
+
in_channels, out_channels, 1, stride=1, padding=0
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
self.nin_shortcut = nn.Identity()
|
429 |
+
|
430 |
+
def forward(self, x: torch.Tensor):
|
431 |
+
"""
|
432 |
+
:param x: is the input feature map with shape `[batch_size, channels, height, width]`
|
433 |
+
"""
|
434 |
+
|
435 |
+
h = x
|
436 |
+
|
437 |
+
# First normalization and convolution layer
|
438 |
+
h = self.norm1(h)
|
439 |
+
h = swish(h)
|
440 |
+
h = self.conv1(h)
|
441 |
+
|
442 |
+
# Second normalization and convolution layer
|
443 |
+
h = self.norm2(h)
|
444 |
+
h = swish(h)
|
445 |
+
h = self.conv2(h)
|
446 |
+
|
447 |
+
# Map and add residual
|
448 |
+
return self.nin_shortcut(x) + h
|
449 |
+
|
450 |
+
|
451 |
+
def swish(x: torch.Tensor):
|
452 |
+
"""
|
453 |
+
### Swish activation
|
454 |
+
|
455 |
+
"""
|
456 |
+
return x * torch.sigmoid(x)
|
457 |
+
|
458 |
+
|
459 |
+
def normalization(channels: int):
|
460 |
+
"""
|
461 |
+
### Group normalization
|
462 |
+
|
463 |
+
This is a helper function, with fixed number of groups and `eps`.
|
464 |
+
"""
|
465 |
+
return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
466 |
+
|
467 |
+
|
468 |
+
def restore_ae_from_sd(model, path):
|
469 |
+
|
470 |
+
def remove_prefix(text, prefix):
|
471 |
+
if text.startswith(prefix):
|
472 |
+
return text[len(prefix) :]
|
473 |
+
return text
|
474 |
+
|
475 |
+
checkpoint = torch.load(path)
|
476 |
+
# checkpoint = torch.load(path, map_location="cpu")
|
477 |
+
|
478 |
+
ckpt_state_dict = checkpoint["state_dict"]
|
479 |
+
new_ckpt_state_dict = {}
|
480 |
+
for k, v in ckpt_state_dict.items():
|
481 |
+
new_k = remove_prefix(k, "first_stage_model.")
|
482 |
+
new_ckpt_state_dict[new_k] = v
|
483 |
+
missing_keys, extra_keys = model.load_state_dict(new_ckpt_state_dict, strict=False)
|
484 |
+
assert len(missing_keys) == 0
|
485 |
+
|
486 |
+
|
487 |
+
def create_model(in_channels, out_channels, latent_dim=4):
|
488 |
+
encoder = Encoder(
|
489 |
+
z_channels=latent_dim,
|
490 |
+
in_channels=in_channels,
|
491 |
+
channels=128,
|
492 |
+
channel_multipliers=[1, 2, 4, 4],
|
493 |
+
n_resnet_blocks=2,
|
494 |
+
)
|
495 |
+
|
496 |
+
decoder = Decoder(
|
497 |
+
out_channels=out_channels,
|
498 |
+
z_channels=latent_dim,
|
499 |
+
channels=128,
|
500 |
+
channel_multipliers=[1, 2, 4, 4],
|
501 |
+
n_resnet_blocks=2,
|
502 |
+
)
|
503 |
+
|
504 |
+
autoencoder = Autoencoder(
|
505 |
+
emb_channels=latent_dim, encoder=encoder, decoder=decoder, z_channels=latent_dim
|
506 |
+
)
|
507 |
+
return autoencoder
|