Spaces:
Sleeping
Sleeping
First
Browse files- .gitignore +2 -0
- app.py +108 -0
- lib/cond_gen.py +10 -0
- lib/embedding.py +77 -0
- lib/encoders.py +171 -0
- lib/get_model.py +13 -0
- lib/networks_edm2.py +360 -0
- lib/sampling.py +63 -0
- model_weights/1girl-edm-xs-test-1.safetensors +3 -0
- model_weights/condgen_vae_decoder.safetensors +3 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.ipynb
|
app.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib.get_model import get_model, device
|
2 |
+
from lib.sampling import edm_sampler
|
3 |
+
from lib.embedding import extract_features
|
4 |
+
from lib.encoders import StabilityVAEEncoder
|
5 |
+
from lib.cond_gen import get_vae_decoder
|
6 |
+
from safetensors.torch import load_file
|
7 |
+
from torchvision.transforms import ToPILImage
|
8 |
+
import torch
|
9 |
+
import gradio as gr
|
10 |
+
import json
|
11 |
+
|
12 |
+
|
13 |
+
torch.set_grad_enabled(False)
|
14 |
+
net = get_model()
|
15 |
+
|
16 |
+
net.load_state_dict(load_file("model_weights/1girl-edm-xs-test-1.safetensors"))
|
17 |
+
|
18 |
+
|
19 |
+
cond_gen = get_vae_decoder().to(device)
|
20 |
+
cond_gen.load_state_dict(load_file("model_weights/condgen_vae_decoder.safetensors"))
|
21 |
+
|
22 |
+
stability_encoder = StabilityVAEEncoder()
|
23 |
+
|
24 |
+
|
25 |
+
def guided(net, scale=1):
|
26 |
+
def f(x, t, label):
|
27 |
+
if scale == 1:
|
28 |
+
return net(x, t, label)
|
29 |
+
return torch.lerp(net(x, t, net.uncond_emb), net(x, t, label), float(scale))
|
30 |
+
|
31 |
+
return f
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def generate_image(label, guidance_scale, n_steps, seed):
|
36 |
+
label = torch.tensor(label)[None].to(device)
|
37 |
+
gen = torch.Generator(device).manual_seed(seed)
|
38 |
+
x = torch.randn((1, 4, 88, 64), device=device, generator=gen)
|
39 |
+
randn_like = lambda *a, **ka: torch.zeros_like(*a, **ka).normal_(generator=gen)
|
40 |
+
im = edm_sampler(
|
41 |
+
guided(net, guidance_scale), x, label, num_steps=n_steps, randn_like=randn_like
|
42 |
+
)
|
43 |
+
im = stability_encoder.decode(im)
|
44 |
+
return ToPILImage()(im[0])
|
45 |
+
|
46 |
+
|
47 |
+
with gr.Blocks() as demo:
|
48 |
+
selected = [0]
|
49 |
+
with gr.Row():
|
50 |
+
gr.Markdown(
|
51 |
+
"""# 1girl-EDM2-XS-test-1 Demo
|
52 |
+
Demo of a 125M param model trained in 1 GPU-day for generating `1girl solo` images.
|
53 |
+
"""
|
54 |
+
)
|
55 |
+
with gr.Row():
|
56 |
+
with gr.Column():
|
57 |
+
with gr.Group():
|
58 |
+
btn = gr.Button("Generate", variant="primary")
|
59 |
+
guidance = gr.Slider(1, 15, 5, step=0.1, label="Guidance scale")
|
60 |
+
n_steps = gr.Slider(2, 35, 24, step=1, label="Inference steps")
|
61 |
+
seed = gr.Slider(
|
62 |
+
-1, 2147483647, -1, step=1, label="Random seed (-1: randomize)"
|
63 |
+
)
|
64 |
+
with gr.Tab("Condition: auto") as auto_tab:
|
65 |
+
gr.Markdown("Conditioning is generated with an external model")
|
66 |
+
with gr.Tab("Condition: from image") as img_tab:
|
67 |
+
gr.Markdown(
|
68 |
+
"Conditioning is extracted from the image with a [tagger](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3)"
|
69 |
+
)
|
70 |
+
ref_im = gr.Image(label="Reference image", type="pil")
|
71 |
+
with gr.Tab("Condition: precomputed") as txt_tab:
|
72 |
+
gr.Markdown("Use a precomputed 1024D vector a the condition")
|
73 |
+
ref_txt = gr.TextArea(
|
74 |
+
label="Precomputed conditioning",
|
75 |
+
placeholder="Copy & Paste from the output",
|
76 |
+
)
|
77 |
+
with gr.Column():
|
78 |
+
out_im = gr.Image(label="Generated Image", show_download_button=True)
|
79 |
+
out_seed = gr.Textbox(label="Seed", show_copy_button=True)
|
80 |
+
out_emb = gr.TextArea(label="Condition vector", show_copy_button=True)
|
81 |
+
|
82 |
+
@torch.no_grad()
|
83 |
+
def get_label(tab_index, cond_img=None, cond_txt=None):
|
84 |
+
if tab_index == 0:
|
85 |
+
return cond_gen(torch.randn((1, 512), device=device))[0].detach().cpu()
|
86 |
+
if tab_index == 1:
|
87 |
+
return extract_features(cond_img, device)
|
88 |
+
return torch.tensor(json.loads(cond_txt))
|
89 |
+
|
90 |
+
def on_select(e: gr.SelectData):
|
91 |
+
selected[0] = e.index
|
92 |
+
|
93 |
+
for t in [auto_tab, img_tab, txt_tab]:
|
94 |
+
t.select(on_select)
|
95 |
+
|
96 |
+
def main(guidance, n_steps, seed, cond_img=None, cond_txt=None):
|
97 |
+
if seed < 0:
|
98 |
+
seed = torch.randint(0, 2147483647, ()).item()
|
99 |
+
label = get_label(selected[0], cond_img, cond_txt)
|
100 |
+
im = generate_image(label, guidance, n_steps, seed)
|
101 |
+
label_txt = json.dumps(label.numpy().astype(float).round(3).tolist())
|
102 |
+
return im, seed, label_txt
|
103 |
+
|
104 |
+
btn.click(
|
105 |
+
main, [guidance, n_steps, seed, ref_im, ref_txt], [out_im, out_seed, out_emb]
|
106 |
+
)
|
107 |
+
|
108 |
+
demo.launch()
|
lib/cond_gen.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
def get_vae_decoder():
|
4 |
+
return nn.Sequential(
|
5 |
+
nn.Linear(512, 512),
|
6 |
+
nn.SiLU(),
|
7 |
+
nn.Linear(512, 768),
|
8 |
+
nn.SiLU(),
|
9 |
+
nn.Linear(768, 1024),
|
10 |
+
)
|
lib/embedding.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://huggingface.co/spaces/SmilingWolf/wd-tagger/blob/main/app.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
|
8 |
+
if torch.cuda.is_available():
|
9 |
+
os.environ["ONNX_MODE"] = "cuda"
|
10 |
+
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
11 |
+
|
12 |
+
|
13 |
+
class Predictor:
|
14 |
+
def __init__(self):
|
15 |
+
self.model_target_size = None
|
16 |
+
self.last_loaded_repo = None
|
17 |
+
|
18 |
+
def load_model(self, model_repo):
|
19 |
+
if model_repo == self.last_loaded_repo:
|
20 |
+
return
|
21 |
+
|
22 |
+
model = timm.create_model("hf-hub:" + model_repo).eval()
|
23 |
+
state_dict = timm.models.load_state_dict_from_hf(model_repo)
|
24 |
+
model.load_state_dict(state_dict)
|
25 |
+
self.transform = timm.data.create_transform(
|
26 |
+
**timm.data.resolve_data_config(model.pretrained_cfg, model=model)
|
27 |
+
)
|
28 |
+
|
29 |
+
self.model_target_size = self.transform.transforms[0].size
|
30 |
+
self.last_loaded_repo = model_repo
|
31 |
+
self.model = model
|
32 |
+
|
33 |
+
def prepare_image(self, image):
|
34 |
+
target_size = self.model_target_size
|
35 |
+
|
36 |
+
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
37 |
+
canvas.alpha_composite(image)
|
38 |
+
image = canvas.convert("RGB")
|
39 |
+
|
40 |
+
# Pad image to square
|
41 |
+
image_shape = image.size
|
42 |
+
max_dim = max(image_shape)
|
43 |
+
pad_left = (max_dim - image_shape[0]) // 2
|
44 |
+
pad_top = (max_dim - image_shape[1]) // 2
|
45 |
+
|
46 |
+
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
47 |
+
padded_image.paste(image, (pad_left, pad_top))
|
48 |
+
|
49 |
+
# Resize
|
50 |
+
if max_dim != target_size:
|
51 |
+
padded_image = padded_image.resize(
|
52 |
+
(target_size, target_size),
|
53 |
+
Image.BICUBIC,
|
54 |
+
)
|
55 |
+
return self.transform(padded_image)[[2, 1, 0]].clone()
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def predict(
|
59 |
+
self,
|
60 |
+
images,
|
61 |
+
model_repo,
|
62 |
+
):
|
63 |
+
self.load_model(model_repo)
|
64 |
+
feat = self.model.forward_features(images)
|
65 |
+
feat_pooled = feat[:, self.model.num_prefix_tokens :].mean(dim=1)
|
66 |
+
return feat_pooled
|
67 |
+
|
68 |
+
|
69 |
+
predictor = Predictor()
|
70 |
+
predictor.load_model(EVA02_LARGE_MODEL_DSV3_REPO)
|
71 |
+
|
72 |
+
|
73 |
+
def extract_features(im: Image.Image, device):
|
74 |
+
predictor.model.to(device)
|
75 |
+
ims = predictor.prepare_image(im.convert("RGBA"))[None]
|
76 |
+
feat = predictor.predict(ims.to(device), EVA02_LARGE_MODEL_DSV3_REPO)
|
77 |
+
return feat[0].cpu()
|
lib/encoders.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under a Creative Commons
|
4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
5 |
+
# You should have received a copy of the license along with this
|
6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
7 |
+
|
8 |
+
"""Converting between pixel and latent representations of image data."""
|
9 |
+
|
10 |
+
import warnings
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
|
15 |
+
warnings.filterwarnings('ignore', '`resume_download` is deprecated')
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
_constant_cache = dict()
|
20 |
+
|
21 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
22 |
+
value = np.asarray(value)
|
23 |
+
if shape is not None:
|
24 |
+
shape = tuple(shape)
|
25 |
+
if dtype is None:
|
26 |
+
dtype = torch.get_default_dtype()
|
27 |
+
if device is None:
|
28 |
+
device = torch.device('cpu')
|
29 |
+
if memory_format is None:
|
30 |
+
memory_format = torch.contiguous_format
|
31 |
+
|
32 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
33 |
+
tensor = _constant_cache.get(key, None)
|
34 |
+
if tensor is None:
|
35 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
36 |
+
if shape is not None:
|
37 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
38 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
39 |
+
_constant_cache[key] = tensor
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
# Variant of constant() that inherits dtype and device from the given
|
44 |
+
# reference tensor by default.
|
45 |
+
|
46 |
+
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
47 |
+
if dtype is None:
|
48 |
+
dtype = ref.dtype
|
49 |
+
if device is None:
|
50 |
+
device = ref.device
|
51 |
+
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
# Abstract base class for encoders/decoders that convert back and forth
|
55 |
+
# between pixel and latent representations of image data.
|
56 |
+
#
|
57 |
+
# Logically, "raw pixels" are first encoded into "raw latents" that are
|
58 |
+
# then further encoded into "final latents". Decoding, on the other hand,
|
59 |
+
# goes directly from the final latents to raw pixels. The final latents are
|
60 |
+
# used as inputs and outputs of the model, whereas the raw latents are
|
61 |
+
# stored in the dataset. This separation provides added flexibility in terms
|
62 |
+
# of performing just-in-time adjustments, such as data whitening, without
|
63 |
+
# having to construct a new dataset.
|
64 |
+
#
|
65 |
+
# All image data is represented as PyTorch tensors in NCHW order.
|
66 |
+
# Raw pixels are represented as 3-channel uint8.
|
67 |
+
|
68 |
+
class Encoder:
|
69 |
+
def __init__(self):
|
70 |
+
pass
|
71 |
+
|
72 |
+
def init(self, device): # force lazy init to happen now
|
73 |
+
pass
|
74 |
+
|
75 |
+
def __getstate__(self):
|
76 |
+
return self.__dict__
|
77 |
+
|
78 |
+
def encode(self, x): # raw pixels => final latents
|
79 |
+
return self.encode_latents(self.encode_pixels(x))
|
80 |
+
|
81 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
82 |
+
raise NotImplementedError # to be overridden by subclass
|
83 |
+
|
84 |
+
def encode_latents(self, x): # raw latents => final latents
|
85 |
+
raise NotImplementedError # to be overridden by subclass
|
86 |
+
|
87 |
+
def decode(self, x): # final latents => raw pixels
|
88 |
+
raise NotImplementedError # to be overridden by subclass
|
89 |
+
|
90 |
+
#----------------------------------------------------------------------------
|
91 |
+
# Standard RGB encoder that scales the pixel data into [-1, +1].
|
92 |
+
|
93 |
+
class StandardRGBEncoder(Encoder):
|
94 |
+
def __init__(self):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
98 |
+
return x
|
99 |
+
|
100 |
+
def encode_latents(self, x): # raw latents => final latents
|
101 |
+
return x.to(torch.float32) / 127.5 - 1
|
102 |
+
|
103 |
+
def decode(self, x): # final latents => raw pixels
|
104 |
+
return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8)
|
105 |
+
|
106 |
+
#----------------------------------------------------------------------------
|
107 |
+
# Pre-trained VAE encoder from Stability AI.
|
108 |
+
|
109 |
+
class StabilityVAEEncoder(Encoder):
|
110 |
+
def __init__(self,
|
111 |
+
vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
|
112 |
+
raw_mean = [5.81, 3.25, 0.12, -2.15], # Assumed mean of the raw latents.
|
113 |
+
raw_std = [4.17, 4.62, 3.71, 3.28], # Assumed standard deviation of the raw latents.
|
114 |
+
final_mean = 0, # Desired mean of the final latents.
|
115 |
+
final_std = 0.5, # Desired standard deviation of the final latents.
|
116 |
+
batch_size = 8, # Batch size to use when running the VAE.
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.vae_name = vae_name
|
120 |
+
self.scale = np.float32(final_std) / np.float32(raw_std)
|
121 |
+
self.bias = np.float32(final_mean) - np.float32(raw_mean) * self.scale
|
122 |
+
self.batch_size = int(batch_size)
|
123 |
+
self._vae = None
|
124 |
+
|
125 |
+
def init(self, device): # force lazy init to happen now
|
126 |
+
super().init(device)
|
127 |
+
if self._vae is None:
|
128 |
+
self._vae = load_stability_vae(self.vae_name, device=device)
|
129 |
+
else:
|
130 |
+
self._vae.to(device)
|
131 |
+
|
132 |
+
def __getstate__(self):
|
133 |
+
return dict(super().__getstate__(), _vae=None) # do not pickle the vae
|
134 |
+
|
135 |
+
def _run_vae_encoder(self, x):
|
136 |
+
d = self._vae.encode(x)['latent_dist']
|
137 |
+
return torch.cat([d.mean, d.std], dim=1)
|
138 |
+
|
139 |
+
def _run_vae_decoder(self, x):
|
140 |
+
return self._vae.decode(x)['sample']
|
141 |
+
|
142 |
+
def encode_pixels(self, x): # raw pixels => raw latents
|
143 |
+
self.init(x.device)
|
144 |
+
x = x.to(torch.float32) / 255
|
145 |
+
x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
|
146 |
+
return x
|
147 |
+
|
148 |
+
def encode_latents(self, x): # raw latents => final latents
|
149 |
+
mean, std = x.to(torch.float32).chunk(2, dim=1)
|
150 |
+
x = mean + torch.randn_like(mean) * std
|
151 |
+
x = x * const_like(x, self.scale).reshape(1, -1, 1, 1)
|
152 |
+
x = x + const_like(x, self.bias).reshape(1, -1, 1, 1)
|
153 |
+
return x
|
154 |
+
|
155 |
+
def decode(self, x): # final latents => raw pixels
|
156 |
+
self.init(x.device)
|
157 |
+
x = x.to(torch.float32)
|
158 |
+
x = x - const_like(x, self.bias).reshape(1, -1, 1, 1)
|
159 |
+
x = x / const_like(x, self.scale).reshape(1, -1, 1, 1)
|
160 |
+
x = torch.cat([self._run_vae_decoder(batch) for batch in x.split(self.batch_size)])
|
161 |
+
x = x.clamp(0, 1).mul(255).to(torch.uint8)
|
162 |
+
return x
|
163 |
+
|
164 |
+
#----------------------------------------------------------------------------
|
165 |
+
|
166 |
+
def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
|
167 |
+
import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
|
168 |
+
vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name)
|
169 |
+
return vae.eval().requires_grad_(False).to(device)
|
170 |
+
|
171 |
+
#----------------------------------------------------------------------------
|
lib/get_model.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from lib.networks_edm2 import Precond
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
def get_model():
|
7 |
+
return Precond(
|
8 |
+
img_resolution=64, # actually 88x64
|
9 |
+
img_channels=4,
|
10 |
+
label_dim=1024,
|
11 |
+
use_fp16=False,
|
12 |
+
model_channels=128,
|
13 |
+
).to(device)
|
lib/networks_edm2.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under a Creative Commons
|
4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
5 |
+
# You should have received a copy of the license along with this
|
6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
7 |
+
|
8 |
+
"""Improved diffusion model architecture proposed in the paper
|
9 |
+
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
#----------------------------------------------------------------------------
|
16 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
17 |
+
# same constant is used multiple times.
|
18 |
+
|
19 |
+
_constant_cache = dict()
|
20 |
+
|
21 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
22 |
+
value = np.asarray(value)
|
23 |
+
if shape is not None:
|
24 |
+
shape = tuple(shape)
|
25 |
+
if dtype is None:
|
26 |
+
dtype = torch.get_default_dtype()
|
27 |
+
if device is None:
|
28 |
+
device = torch.device('cpu')
|
29 |
+
if memory_format is None:
|
30 |
+
memory_format = torch.contiguous_format
|
31 |
+
|
32 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
33 |
+
tensor = _constant_cache.get(key, None)
|
34 |
+
if tensor is None:
|
35 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
36 |
+
if shape is not None:
|
37 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
38 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
39 |
+
_constant_cache[key] = tensor
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
# Variant of constant() that inherits dtype and device from the given
|
44 |
+
# reference tensor by default.
|
45 |
+
|
46 |
+
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
47 |
+
if dtype is None:
|
48 |
+
dtype = ref.dtype
|
49 |
+
if device is None:
|
50 |
+
device = ref.device
|
51 |
+
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
# Normalize given tensor to unit magnitude with respect to the given
|
55 |
+
# dimensions. Default = all dimensions except the first.
|
56 |
+
|
57 |
+
def normalize(x, dim=None, eps=1e-4):
|
58 |
+
if dim is None:
|
59 |
+
dim = list(range(1, x.ndim))
|
60 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
61 |
+
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
62 |
+
return x / norm.to(x.dtype)
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
65 |
+
# Upsample or downsample the given tensor with the given filter,
|
66 |
+
# or keep it as is.
|
67 |
+
|
68 |
+
def resample(x, f=[1,1], mode='keep'):
|
69 |
+
if mode == 'keep':
|
70 |
+
return x
|
71 |
+
f = np.float32(f)
|
72 |
+
assert f.ndim == 1 and len(f) % 2 == 0
|
73 |
+
pad = (len(f) - 1) // 2
|
74 |
+
f = f / f.sum()
|
75 |
+
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
76 |
+
f = const_like(x, f)
|
77 |
+
c = x.shape[1]
|
78 |
+
if mode == 'down':
|
79 |
+
return torch.nn.functional.conv2d(x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
80 |
+
assert mode == 'up'
|
81 |
+
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
|
82 |
+
|
83 |
+
#----------------------------------------------------------------------------
|
84 |
+
# Magnitude-preserving SiLU (Equation 81).
|
85 |
+
|
86 |
+
def mp_silu(x):
|
87 |
+
return torch.nn.functional.silu(x) / 0.596
|
88 |
+
|
89 |
+
#----------------------------------------------------------------------------
|
90 |
+
# Magnitude-preserving sum (Equation 88).
|
91 |
+
|
92 |
+
def mp_sum(a, b, t=0.5):
|
93 |
+
return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
|
94 |
+
|
95 |
+
#----------------------------------------------------------------------------
|
96 |
+
# Magnitude-preserving concatenation (Equation 103).
|
97 |
+
|
98 |
+
def mp_cat(a, b, dim=1, t=0.5):
|
99 |
+
Na = a.shape[dim]
|
100 |
+
Nb = b.shape[dim]
|
101 |
+
C = np.sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2))
|
102 |
+
wa = C / np.sqrt(Na) * (1 - t)
|
103 |
+
wb = C / np.sqrt(Nb) * t
|
104 |
+
return torch.cat([wa * a , wb * b], dim=dim)
|
105 |
+
|
106 |
+
#----------------------------------------------------------------------------
|
107 |
+
# Magnitude-preserving Fourier features (Equation 75).
|
108 |
+
|
109 |
+
class MPFourier(torch.nn.Module):
|
110 |
+
def __init__(self, num_channels, bandwidth=1):
|
111 |
+
super().__init__()
|
112 |
+
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
|
113 |
+
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
y = x.to(torch.float32)
|
117 |
+
y = y.ger(self.freqs.to(torch.float32))
|
118 |
+
y = y + self.phases.to(torch.float32)
|
119 |
+
y = y.cos() * np.sqrt(2)
|
120 |
+
return y.to(x.dtype)
|
121 |
+
|
122 |
+
#----------------------------------------------------------------------------
|
123 |
+
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
124 |
+
# with force weight normalization (Equation 66).
|
125 |
+
|
126 |
+
class MPConv(torch.nn.Module):
|
127 |
+
def __init__(self, in_channels, out_channels, kernel):
|
128 |
+
super().__init__()
|
129 |
+
self.out_channels = out_channels
|
130 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
|
131 |
+
|
132 |
+
def forward(self, x, gain=1):
|
133 |
+
w = self.weight.to(torch.float32)
|
134 |
+
if self.training:
|
135 |
+
with torch.no_grad():
|
136 |
+
self.weight.copy_(normalize(w)) # forced weight normalization
|
137 |
+
w = normalize(w) # traditional weight normalization
|
138 |
+
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
|
139 |
+
w = w.to(x.dtype)
|
140 |
+
if w.ndim == 2:
|
141 |
+
return x @ w.t()
|
142 |
+
assert w.ndim == 4
|
143 |
+
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
|
144 |
+
|
145 |
+
#----------------------------------------------------------------------------
|
146 |
+
# U-Net encoder/decoder block with optional self-attention (Figure 21).
|
147 |
+
|
148 |
+
class Block(torch.nn.Module):
|
149 |
+
def __init__(self,
|
150 |
+
in_channels, # Number of input channels.
|
151 |
+
out_channels, # Number of output channels.
|
152 |
+
emb_channels, # Number of embedding channels.
|
153 |
+
flavor = 'enc', # Flavor: 'enc' or 'dec'.
|
154 |
+
resample_mode = 'keep', # Resampling: 'keep', 'up', or 'down'.
|
155 |
+
resample_filter = [1,1], # Resampling filter.
|
156 |
+
attention = False, # Include self-attention?
|
157 |
+
channels_per_head = 64, # Number of channels per attention head.
|
158 |
+
dropout = 0, # Dropout probability.
|
159 |
+
res_balance = 0.3, # Balance between main branch (0) and residual branch (1).
|
160 |
+
attn_balance = 0.3, # Balance between main branch (0) and self-attention (1).
|
161 |
+
clip_act = 256, # Clip output activations. None = do not clip.
|
162 |
+
):
|
163 |
+
super().__init__()
|
164 |
+
self.out_channels = out_channels
|
165 |
+
self.flavor = flavor
|
166 |
+
self.resample_filter = resample_filter
|
167 |
+
self.resample_mode = resample_mode
|
168 |
+
self.num_heads = out_channels // channels_per_head if attention else 0
|
169 |
+
self.dropout = dropout
|
170 |
+
self.res_balance = res_balance
|
171 |
+
self.attn_balance = attn_balance
|
172 |
+
self.clip_act = clip_act
|
173 |
+
self.emb_gain = torch.nn.Parameter(torch.zeros([]))
|
174 |
+
self.conv_res0 = MPConv(out_channels if flavor == 'enc' else in_channels, out_channels, kernel=[3,3])
|
175 |
+
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
|
176 |
+
self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3,3])
|
177 |
+
self.conv_skip = MPConv(in_channels, out_channels, kernel=[1,1]) if in_channels != out_channels else None
|
178 |
+
self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=[1,1]) if self.num_heads != 0 else None
|
179 |
+
self.attn_proj = MPConv(out_channels, out_channels, kernel=[1,1]) if self.num_heads != 0 else None
|
180 |
+
|
181 |
+
def forward(self, x, emb):
|
182 |
+
# Main branch.
|
183 |
+
x = resample(x, f=self.resample_filter, mode=self.resample_mode)
|
184 |
+
if self.flavor == 'enc':
|
185 |
+
if self.conv_skip is not None:
|
186 |
+
x = self.conv_skip(x)
|
187 |
+
x = normalize(x, dim=1) # pixel norm
|
188 |
+
|
189 |
+
# Residual branch.
|
190 |
+
y = self.conv_res0(mp_silu(x))
|
191 |
+
c = self.emb_linear(emb, gain=self.emb_gain) + 1
|
192 |
+
y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
|
193 |
+
if self.training and self.dropout != 0:
|
194 |
+
y = torch.nn.functional.dropout(y, p=self.dropout)
|
195 |
+
y = self.conv_res1(y)
|
196 |
+
|
197 |
+
# Connect the branches.
|
198 |
+
if self.flavor == 'dec' and self.conv_skip is not None:
|
199 |
+
x = self.conv_skip(x)
|
200 |
+
x = mp_sum(x, y, t=self.res_balance)
|
201 |
+
|
202 |
+
# Self-attention.
|
203 |
+
# Note: torch.nn.functional.scaled_dot_product_attention() could be used here,
|
204 |
+
# but we haven't done sufficient testing to verify that it produces identical results.
|
205 |
+
if self.num_heads != 0:
|
206 |
+
y = self.attn_qkv(x)
|
207 |
+
y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
|
208 |
+
q, k, v = normalize(y, dim=2).unbind(3) # pixel norm & split
|
209 |
+
w = torch.einsum('nhcq,nhck->nhqk', q, k / np.sqrt(q.shape[2])).softmax(dim=3)
|
210 |
+
y = torch.einsum('nhqk,nhck->nhcq', w, v)
|
211 |
+
y = self.attn_proj(y.reshape(*x.shape))
|
212 |
+
x = mp_sum(x, y, t=self.attn_balance)
|
213 |
+
|
214 |
+
# Clip activations.
|
215 |
+
if self.clip_act is not None:
|
216 |
+
x = x.clip_(-self.clip_act, self.clip_act)
|
217 |
+
return x
|
218 |
+
|
219 |
+
|
220 |
+
class TagEncoder(torch.nn.Module):
|
221 |
+
def __init__(self, din, dout):
|
222 |
+
super().__init__()
|
223 |
+
self.din = din
|
224 |
+
self.linear1 = MPConv(din, dout, [])
|
225 |
+
self.linear2 = MPConv(dout, dout, [])
|
226 |
+
self.out_gain = torch.nn.Parameter(torch.tensor(0.0))
|
227 |
+
|
228 |
+
def forward(self, x):
|
229 |
+
x = mp_silu(self.linear1(x))
|
230 |
+
return self.din**-0.5 * self.linear2(x, gain=self.out_gain)
|
231 |
+
|
232 |
+
|
233 |
+
#----------------------------------------------------------------------------
|
234 |
+
# EDM2 U-Net model (Figure 21).
|
235 |
+
|
236 |
+
class UNet(torch.nn.Module):
|
237 |
+
def __init__(self,
|
238 |
+
img_resolution, # Image resolution.
|
239 |
+
img_channels, # Image channels.
|
240 |
+
label_dim, # Class label dimensionality. 0 = unconditional.
|
241 |
+
model_channels = 192, # Base multiplier for the number of channels.
|
242 |
+
channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
|
243 |
+
channel_mult_noise = None, # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
|
244 |
+
channel_mult_emb = None, # Multiplier for final embedding dimensionality. None = select based on channel_mult.
|
245 |
+
num_blocks = 3, # Number of residual blocks per resolution.
|
246 |
+
attn_resolutions = [16,8], # List of resolutions with self-attention.
|
247 |
+
label_balance = 0.5, # Balance between noise embedding (0) and class embedding (1).
|
248 |
+
concat_balance = 0.5, # Balance between skip connections (0) and main path (1).
|
249 |
+
**block_kwargs, # Arguments for Block.
|
250 |
+
):
|
251 |
+
super().__init__()
|
252 |
+
cblock = [model_channels * x for x in channel_mult]
|
253 |
+
cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
|
254 |
+
cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
|
255 |
+
self.label_balance = label_balance
|
256 |
+
self.concat_balance = concat_balance
|
257 |
+
self.out_gain = torch.nn.Parameter(torch.zeros([]))
|
258 |
+
|
259 |
+
# Embedding.
|
260 |
+
self.emb_fourier = MPFourier(cnoise)
|
261 |
+
self.emb_noise = MPConv(cnoise, cemb, kernel=[])
|
262 |
+
self.emb_label = TagEncoder(label_dim, cemb)
|
263 |
+
if type(num_blocks) is int:
|
264 |
+
num_blocks = [num_blocks for _ in cblock]
|
265 |
+
# Encoder.
|
266 |
+
self.enc = torch.nn.ModuleDict()
|
267 |
+
cout = img_channels + 1
|
268 |
+
for level, channels in enumerate(cblock):
|
269 |
+
res = img_resolution >> level
|
270 |
+
if level == 0:
|
271 |
+
cin = cout
|
272 |
+
cout = channels
|
273 |
+
self.enc[f'{res}x{res}_conv'] = MPConv(cin, cout, kernel=[3,3])
|
274 |
+
else:
|
275 |
+
self.enc[f'{res}x{res}_down'] = Block(cout, cout, cemb, flavor='enc', resample_mode='down', **block_kwargs)
|
276 |
+
for idx in range(num_blocks[level]):
|
277 |
+
cin = cout
|
278 |
+
cout = channels
|
279 |
+
self.enc[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='enc', attention=(res in attn_resolutions), **block_kwargs)
|
280 |
+
|
281 |
+
# Decoder.
|
282 |
+
self.dec = torch.nn.ModuleDict()
|
283 |
+
skips = [block.out_channels for block in self.enc.values()]
|
284 |
+
for level, channels in reversed(list(enumerate(cblock))):
|
285 |
+
res = img_resolution >> level
|
286 |
+
if level == len(cblock) - 1:
|
287 |
+
self.dec[f'{res}x{res}_in0'] = Block(cout, cout, cemb, flavor='dec', attention=True, **block_kwargs)
|
288 |
+
self.dec[f'{res}x{res}_in1'] = Block(cout, cout, cemb, flavor='dec', **block_kwargs)
|
289 |
+
else:
|
290 |
+
self.dec[f'{res}x{res}_up'] = Block(cout, cout, cemb, flavor='dec', resample_mode='up', **block_kwargs)
|
291 |
+
for idx in range(num_blocks[level] + 1):
|
292 |
+
cin = cout + skips.pop()
|
293 |
+
cout = channels
|
294 |
+
self.dec[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='dec', attention=(res in attn_resolutions), **block_kwargs)
|
295 |
+
self.out_conv = MPConv(cout, img_channels, kernel=[3,3])
|
296 |
+
|
297 |
+
def forward(self, x, noise_labels, class_labels):
|
298 |
+
# Embedding.
|
299 |
+
emb = self.emb_noise(self.emb_fourier(noise_labels))
|
300 |
+
if self.emb_label is not None:
|
301 |
+
emb = mp_sum(emb, self.emb_label(class_labels * np.sqrt(class_labels.shape[1])), t=self.label_balance)
|
302 |
+
emb = mp_silu(emb)
|
303 |
+
|
304 |
+
# Encoder.
|
305 |
+
x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
|
306 |
+
skips = []
|
307 |
+
for name, block in self.enc.items():
|
308 |
+
x = block(x) if 'conv' in name else block(x, emb)
|
309 |
+
skips.append(x)
|
310 |
+
|
311 |
+
# Decoder.
|
312 |
+
for name, block in self.dec.items():
|
313 |
+
if 'block' in name:
|
314 |
+
x = mp_cat(x, skips.pop(), t=self.concat_balance)
|
315 |
+
x = block(x, emb)
|
316 |
+
x = self.out_conv(x, gain=self.out_gain)
|
317 |
+
return x
|
318 |
+
|
319 |
+
#----------------------------------------------------------------------------
|
320 |
+
# Preconditioning and uncertainty estimation.
|
321 |
+
|
322 |
+
class Precond(torch.nn.Module):
|
323 |
+
def __init__(self,
|
324 |
+
img_resolution, # Image resolution.
|
325 |
+
img_channels, # Image channels.
|
326 |
+
label_dim, # Class label dimensionality. 0 = unconditional.
|
327 |
+
use_fp16 = True, # Run the model at FP16 precision?
|
328 |
+
sigma_data = 0.5, # Expected standard deviation of the training data.
|
329 |
+
**unet_kwargs, # Keyword arguments for UNet.
|
330 |
+
):
|
331 |
+
super().__init__()
|
332 |
+
self.img_resolution = img_resolution
|
333 |
+
self.img_channels = img_channels
|
334 |
+
self.label_dim = label_dim
|
335 |
+
self.use_fp16 = use_fp16
|
336 |
+
self.sigma_data = sigma_data
|
337 |
+
self.unet = UNet(img_resolution=img_resolution, img_channels=img_channels, label_dim=label_dim, **unet_kwargs)
|
338 |
+
self.uncond_emb = torch.nn.Parameter(torch.randn((1024,)))
|
339 |
+
|
340 |
+
|
341 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, return_logvar=False, **unet_kwargs):
|
342 |
+
x = x.to(torch.float32)
|
343 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
344 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
345 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
346 |
+
|
347 |
+
# Preconditioning weights.
|
348 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
349 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
350 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
351 |
+
c_noise = sigma.flatten().log() / 4
|
352 |
+
|
353 |
+
# Run the model.
|
354 |
+
x_in = (c_in * x).to(dtype)
|
355 |
+
F_x = self.unet(x_in, c_noise, class_labels, **unet_kwargs)
|
356 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
357 |
+
|
358 |
+
return D_x
|
359 |
+
|
360 |
+
#----------------------------------------------------------------------------
|
lib/sampling.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def edm_sampler(
|
6 |
+
net,
|
7 |
+
noise,
|
8 |
+
labels=None,
|
9 |
+
gnet=None,
|
10 |
+
num_steps=32,
|
11 |
+
sigma_min=0.002,
|
12 |
+
sigma_max=80,
|
13 |
+
rho=7,
|
14 |
+
guidance=1,
|
15 |
+
S_churn=0,
|
16 |
+
S_min=0,
|
17 |
+
S_max=float("inf"),
|
18 |
+
S_noise=1,
|
19 |
+
dtype=torch.float32,
|
20 |
+
randn_like=torch.randn_like,
|
21 |
+
):
|
22 |
+
# Guided denoiser.
|
23 |
+
def denoise(x, t):
|
24 |
+
Dx = net(x, t, labels).to(dtype)
|
25 |
+
if guidance == 1:
|
26 |
+
return Dx
|
27 |
+
ref_Dx = gnet(x, t).to(dtype)
|
28 |
+
return ref_Dx.lerp(Dx, guidance)
|
29 |
+
|
30 |
+
# Time step discretization.
|
31 |
+
step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
|
32 |
+
t_steps = (
|
33 |
+
sigma_max ** (1 / rho)
|
34 |
+
+ step_indices
|
35 |
+
/ (num_steps - 1)
|
36 |
+
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
|
37 |
+
) ** rho
|
38 |
+
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
|
39 |
+
|
40 |
+
# Main sampling loop.
|
41 |
+
x_next = noise.to(dtype) * t_steps[0]
|
42 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
43 |
+
x_cur = x_next
|
44 |
+
|
45 |
+
# Increase noise temporarily.
|
46 |
+
if S_churn > 0 and S_min <= t_cur <= S_max:
|
47 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
|
48 |
+
t_hat = t_cur + gamma * t_cur
|
49 |
+
x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
|
50 |
+
else:
|
51 |
+
t_hat = t_cur
|
52 |
+
x_hat = x_cur
|
53 |
+
|
54 |
+
# Euler step.
|
55 |
+
d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat
|
56 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
57 |
+
|
58 |
+
# Apply 2nd order correction.
|
59 |
+
if i < num_steps - 1:
|
60 |
+
d_prime = (x_next - denoise(x_next, t_next)) / t_next
|
61 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
62 |
+
|
63 |
+
return x_next
|
model_weights/1girl-edm-xs-test-1.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:745cdbc08a10dea8a2bffa51559f741f7010b81faa194cfcc6ea62b98ef329bf
|
3 |
+
size 499977592
|
model_weights/condgen_vae_decoder.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3fc2e77e8584fdb207b768398fdc96693772cfbb378f4b7c6adc58fa08116cef
|
3 |
+
size 5776840
|