gustproof commited on
Commit
fb3e84a
·
1 Parent(s): 131e57e
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.ipynb
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.get_model import get_model, device
2
+ from lib.sampling import edm_sampler
3
+ from lib.embedding import extract_features
4
+ from lib.encoders import StabilityVAEEncoder
5
+ from lib.cond_gen import get_vae_decoder
6
+ from safetensors.torch import load_file
7
+ from torchvision.transforms import ToPILImage
8
+ import torch
9
+ import gradio as gr
10
+ import json
11
+
12
+
13
+ torch.set_grad_enabled(False)
14
+ net = get_model()
15
+
16
+ net.load_state_dict(load_file("model_weights/1girl-edm-xs-test-1.safetensors"))
17
+
18
+
19
+ cond_gen = get_vae_decoder().to(device)
20
+ cond_gen.load_state_dict(load_file("model_weights/condgen_vae_decoder.safetensors"))
21
+
22
+ stability_encoder = StabilityVAEEncoder()
23
+
24
+
25
+ def guided(net, scale=1):
26
+ def f(x, t, label):
27
+ if scale == 1:
28
+ return net(x, t, label)
29
+ return torch.lerp(net(x, t, net.uncond_emb), net(x, t, label), float(scale))
30
+
31
+ return f
32
+
33
+
34
+ @torch.no_grad()
35
+ def generate_image(label, guidance_scale, n_steps, seed):
36
+ label = torch.tensor(label)[None].to(device)
37
+ gen = torch.Generator(device).manual_seed(seed)
38
+ x = torch.randn((1, 4, 88, 64), device=device, generator=gen)
39
+ randn_like = lambda *a, **ka: torch.zeros_like(*a, **ka).normal_(generator=gen)
40
+ im = edm_sampler(
41
+ guided(net, guidance_scale), x, label, num_steps=n_steps, randn_like=randn_like
42
+ )
43
+ im = stability_encoder.decode(im)
44
+ return ToPILImage()(im[0])
45
+
46
+
47
+ with gr.Blocks() as demo:
48
+ selected = [0]
49
+ with gr.Row():
50
+ gr.Markdown(
51
+ """# 1girl-EDM2-XS-test-1 Demo
52
+ Demo of a 125M param model trained in 1 GPU-day for generating `1girl solo` images.
53
+ """
54
+ )
55
+ with gr.Row():
56
+ with gr.Column():
57
+ with gr.Group():
58
+ btn = gr.Button("Generate", variant="primary")
59
+ guidance = gr.Slider(1, 15, 5, step=0.1, label="Guidance scale")
60
+ n_steps = gr.Slider(2, 35, 24, step=1, label="Inference steps")
61
+ seed = gr.Slider(
62
+ -1, 2147483647, -1, step=1, label="Random seed (-1: randomize)"
63
+ )
64
+ with gr.Tab("Condition: auto") as auto_tab:
65
+ gr.Markdown("Conditioning is generated with an external model")
66
+ with gr.Tab("Condition: from image") as img_tab:
67
+ gr.Markdown(
68
+ "Conditioning is extracted from the image with a [tagger](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3)"
69
+ )
70
+ ref_im = gr.Image(label="Reference image", type="pil")
71
+ with gr.Tab("Condition: precomputed") as txt_tab:
72
+ gr.Markdown("Use a precomputed 1024D vector a the condition")
73
+ ref_txt = gr.TextArea(
74
+ label="Precomputed conditioning",
75
+ placeholder="Copy & Paste from the output",
76
+ )
77
+ with gr.Column():
78
+ out_im = gr.Image(label="Generated Image", show_download_button=True)
79
+ out_seed = gr.Textbox(label="Seed", show_copy_button=True)
80
+ out_emb = gr.TextArea(label="Condition vector", show_copy_button=True)
81
+
82
+ @torch.no_grad()
83
+ def get_label(tab_index, cond_img=None, cond_txt=None):
84
+ if tab_index == 0:
85
+ return cond_gen(torch.randn((1, 512), device=device))[0].detach().cpu()
86
+ if tab_index == 1:
87
+ return extract_features(cond_img, device)
88
+ return torch.tensor(json.loads(cond_txt))
89
+
90
+ def on_select(e: gr.SelectData):
91
+ selected[0] = e.index
92
+
93
+ for t in [auto_tab, img_tab, txt_tab]:
94
+ t.select(on_select)
95
+
96
+ def main(guidance, n_steps, seed, cond_img=None, cond_txt=None):
97
+ if seed < 0:
98
+ seed = torch.randint(0, 2147483647, ()).item()
99
+ label = get_label(selected[0], cond_img, cond_txt)
100
+ im = generate_image(label, guidance, n_steps, seed)
101
+ label_txt = json.dumps(label.numpy().astype(float).round(3).tolist())
102
+ return im, seed, label_txt
103
+
104
+ btn.click(
105
+ main, [guidance, n_steps, seed, ref_im, ref_txt], [out_im, out_seed, out_emb]
106
+ )
107
+
108
+ demo.launch()
lib/cond_gen.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ def get_vae_decoder():
4
+ return nn.Sequential(
5
+ nn.Linear(512, 512),
6
+ nn.SiLU(),
7
+ nn.Linear(512, 768),
8
+ nn.SiLU(),
9
+ nn.Linear(768, 1024),
10
+ )
lib/embedding.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://huggingface.co/spaces/SmilingWolf/wd-tagger/blob/main/app.py
2
+
3
+ import os
4
+ from PIL import Image
5
+ import timm
6
+ import torch
7
+
8
+ if torch.cuda.is_available():
9
+ os.environ["ONNX_MODE"] = "cuda"
10
+ EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
11
+
12
+
13
+ class Predictor:
14
+ def __init__(self):
15
+ self.model_target_size = None
16
+ self.last_loaded_repo = None
17
+
18
+ def load_model(self, model_repo):
19
+ if model_repo == self.last_loaded_repo:
20
+ return
21
+
22
+ model = timm.create_model("hf-hub:" + model_repo).eval()
23
+ state_dict = timm.models.load_state_dict_from_hf(model_repo)
24
+ model.load_state_dict(state_dict)
25
+ self.transform = timm.data.create_transform(
26
+ **timm.data.resolve_data_config(model.pretrained_cfg, model=model)
27
+ )
28
+
29
+ self.model_target_size = self.transform.transforms[0].size
30
+ self.last_loaded_repo = model_repo
31
+ self.model = model
32
+
33
+ def prepare_image(self, image):
34
+ target_size = self.model_target_size
35
+
36
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
37
+ canvas.alpha_composite(image)
38
+ image = canvas.convert("RGB")
39
+
40
+ # Pad image to square
41
+ image_shape = image.size
42
+ max_dim = max(image_shape)
43
+ pad_left = (max_dim - image_shape[0]) // 2
44
+ pad_top = (max_dim - image_shape[1]) // 2
45
+
46
+ padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
47
+ padded_image.paste(image, (pad_left, pad_top))
48
+
49
+ # Resize
50
+ if max_dim != target_size:
51
+ padded_image = padded_image.resize(
52
+ (target_size, target_size),
53
+ Image.BICUBIC,
54
+ )
55
+ return self.transform(padded_image)[[2, 1, 0]].clone()
56
+
57
+ @torch.no_grad()
58
+ def predict(
59
+ self,
60
+ images,
61
+ model_repo,
62
+ ):
63
+ self.load_model(model_repo)
64
+ feat = self.model.forward_features(images)
65
+ feat_pooled = feat[:, self.model.num_prefix_tokens :].mean(dim=1)
66
+ return feat_pooled
67
+
68
+
69
+ predictor = Predictor()
70
+ predictor.load_model(EVA02_LARGE_MODEL_DSV3_REPO)
71
+
72
+
73
+ def extract_features(im: Image.Image, device):
74
+ predictor.model.to(device)
75
+ ims = predictor.prepare_image(im.convert("RGBA"))[None]
76
+ feat = predictor.predict(ims.to(device), EVA02_LARGE_MODEL_DSV3_REPO)
77
+ return feat[0].cpu()
lib/encoders.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Converting between pixel and latent representations of image data."""
9
+
10
+ import warnings
11
+ import numpy as np
12
+ import torch
13
+
14
+ warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
15
+ warnings.filterwarnings('ignore', '`resume_download` is deprecated')
16
+
17
+
18
+
19
+ _constant_cache = dict()
20
+
21
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22
+ value = np.asarray(value)
23
+ if shape is not None:
24
+ shape = tuple(shape)
25
+ if dtype is None:
26
+ dtype = torch.get_default_dtype()
27
+ if device is None:
28
+ device = torch.device('cpu')
29
+ if memory_format is None:
30
+ memory_format = torch.contiguous_format
31
+
32
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33
+ tensor = _constant_cache.get(key, None)
34
+ if tensor is None:
35
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36
+ if shape is not None:
37
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38
+ tensor = tensor.contiguous(memory_format=memory_format)
39
+ _constant_cache[key] = tensor
40
+ return tensor
41
+
42
+ #----------------------------------------------------------------------------
43
+ # Variant of constant() that inherits dtype and device from the given
44
+ # reference tensor by default.
45
+
46
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
47
+ if dtype is None:
48
+ dtype = ref.dtype
49
+ if device is None:
50
+ device = ref.device
51
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
52
+
53
+ #----------------------------------------------------------------------------
54
+ # Abstract base class for encoders/decoders that convert back and forth
55
+ # between pixel and latent representations of image data.
56
+ #
57
+ # Logically, "raw pixels" are first encoded into "raw latents" that are
58
+ # then further encoded into "final latents". Decoding, on the other hand,
59
+ # goes directly from the final latents to raw pixels. The final latents are
60
+ # used as inputs and outputs of the model, whereas the raw latents are
61
+ # stored in the dataset. This separation provides added flexibility in terms
62
+ # of performing just-in-time adjustments, such as data whitening, without
63
+ # having to construct a new dataset.
64
+ #
65
+ # All image data is represented as PyTorch tensors in NCHW order.
66
+ # Raw pixels are represented as 3-channel uint8.
67
+
68
+ class Encoder:
69
+ def __init__(self):
70
+ pass
71
+
72
+ def init(self, device): # force lazy init to happen now
73
+ pass
74
+
75
+ def __getstate__(self):
76
+ return self.__dict__
77
+
78
+ def encode(self, x): # raw pixels => final latents
79
+ return self.encode_latents(self.encode_pixels(x))
80
+
81
+ def encode_pixels(self, x): # raw pixels => raw latents
82
+ raise NotImplementedError # to be overridden by subclass
83
+
84
+ def encode_latents(self, x): # raw latents => final latents
85
+ raise NotImplementedError # to be overridden by subclass
86
+
87
+ def decode(self, x): # final latents => raw pixels
88
+ raise NotImplementedError # to be overridden by subclass
89
+
90
+ #----------------------------------------------------------------------------
91
+ # Standard RGB encoder that scales the pixel data into [-1, +1].
92
+
93
+ class StandardRGBEncoder(Encoder):
94
+ def __init__(self):
95
+ super().__init__()
96
+
97
+ def encode_pixels(self, x): # raw pixels => raw latents
98
+ return x
99
+
100
+ def encode_latents(self, x): # raw latents => final latents
101
+ return x.to(torch.float32) / 127.5 - 1
102
+
103
+ def decode(self, x): # final latents => raw pixels
104
+ return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8)
105
+
106
+ #----------------------------------------------------------------------------
107
+ # Pre-trained VAE encoder from Stability AI.
108
+
109
+ class StabilityVAEEncoder(Encoder):
110
+ def __init__(self,
111
+ vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
112
+ raw_mean = [5.81, 3.25, 0.12, -2.15], # Assumed mean of the raw latents.
113
+ raw_std = [4.17, 4.62, 3.71, 3.28], # Assumed standard deviation of the raw latents.
114
+ final_mean = 0, # Desired mean of the final latents.
115
+ final_std = 0.5, # Desired standard deviation of the final latents.
116
+ batch_size = 8, # Batch size to use when running the VAE.
117
+ ):
118
+ super().__init__()
119
+ self.vae_name = vae_name
120
+ self.scale = np.float32(final_std) / np.float32(raw_std)
121
+ self.bias = np.float32(final_mean) - np.float32(raw_mean) * self.scale
122
+ self.batch_size = int(batch_size)
123
+ self._vae = None
124
+
125
+ def init(self, device): # force lazy init to happen now
126
+ super().init(device)
127
+ if self._vae is None:
128
+ self._vae = load_stability_vae(self.vae_name, device=device)
129
+ else:
130
+ self._vae.to(device)
131
+
132
+ def __getstate__(self):
133
+ return dict(super().__getstate__(), _vae=None) # do not pickle the vae
134
+
135
+ def _run_vae_encoder(self, x):
136
+ d = self._vae.encode(x)['latent_dist']
137
+ return torch.cat([d.mean, d.std], dim=1)
138
+
139
+ def _run_vae_decoder(self, x):
140
+ return self._vae.decode(x)['sample']
141
+
142
+ def encode_pixels(self, x): # raw pixels => raw latents
143
+ self.init(x.device)
144
+ x = x.to(torch.float32) / 255
145
+ x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
146
+ return x
147
+
148
+ def encode_latents(self, x): # raw latents => final latents
149
+ mean, std = x.to(torch.float32).chunk(2, dim=1)
150
+ x = mean + torch.randn_like(mean) * std
151
+ x = x * const_like(x, self.scale).reshape(1, -1, 1, 1)
152
+ x = x + const_like(x, self.bias).reshape(1, -1, 1, 1)
153
+ return x
154
+
155
+ def decode(self, x): # final latents => raw pixels
156
+ self.init(x.device)
157
+ x = x.to(torch.float32)
158
+ x = x - const_like(x, self.bias).reshape(1, -1, 1, 1)
159
+ x = x / const_like(x, self.scale).reshape(1, -1, 1, 1)
160
+ x = torch.cat([self._run_vae_decoder(batch) for batch in x.split(self.batch_size)])
161
+ x = x.clamp(0, 1).mul(255).to(torch.uint8)
162
+ return x
163
+
164
+ #----------------------------------------------------------------------------
165
+
166
+ def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
167
+ import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
168
+ vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name)
169
+ return vae.eval().requires_grad_(False).to(device)
170
+
171
+ #----------------------------------------------------------------------------
lib/get_model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.networks_edm2 import Precond
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ def get_model():
7
+ return Precond(
8
+ img_resolution=64, # actually 88x64
9
+ img_channels=4,
10
+ label_dim=1024,
11
+ use_fp16=False,
12
+ model_channels=128,
13
+ ).to(device)
lib/networks_edm2.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Improved diffusion model architecture proposed in the paper
9
+ "Analyzing and Improving the Training Dynamics of Diffusion Models"."""
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+
15
+ #----------------------------------------------------------------------------
16
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17
+ # same constant is used multiple times.
18
+
19
+ _constant_cache = dict()
20
+
21
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22
+ value = np.asarray(value)
23
+ if shape is not None:
24
+ shape = tuple(shape)
25
+ if dtype is None:
26
+ dtype = torch.get_default_dtype()
27
+ if device is None:
28
+ device = torch.device('cpu')
29
+ if memory_format is None:
30
+ memory_format = torch.contiguous_format
31
+
32
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33
+ tensor = _constant_cache.get(key, None)
34
+ if tensor is None:
35
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36
+ if shape is not None:
37
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38
+ tensor = tensor.contiguous(memory_format=memory_format)
39
+ _constant_cache[key] = tensor
40
+ return tensor
41
+
42
+ #----------------------------------------------------------------------------
43
+ # Variant of constant() that inherits dtype and device from the given
44
+ # reference tensor by default.
45
+
46
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
47
+ if dtype is None:
48
+ dtype = ref.dtype
49
+ if device is None:
50
+ device = ref.device
51
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
52
+
53
+ #----------------------------------------------------------------------------
54
+ # Normalize given tensor to unit magnitude with respect to the given
55
+ # dimensions. Default = all dimensions except the first.
56
+
57
+ def normalize(x, dim=None, eps=1e-4):
58
+ if dim is None:
59
+ dim = list(range(1, x.ndim))
60
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
61
+ norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
62
+ return x / norm.to(x.dtype)
63
+
64
+ #----------------------------------------------------------------------------
65
+ # Upsample or downsample the given tensor with the given filter,
66
+ # or keep it as is.
67
+
68
+ def resample(x, f=[1,1], mode='keep'):
69
+ if mode == 'keep':
70
+ return x
71
+ f = np.float32(f)
72
+ assert f.ndim == 1 and len(f) % 2 == 0
73
+ pad = (len(f) - 1) // 2
74
+ f = f / f.sum()
75
+ f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
76
+ f = const_like(x, f)
77
+ c = x.shape[1]
78
+ if mode == 'down':
79
+ return torch.nn.functional.conv2d(x, f.tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
80
+ assert mode == 'up'
81
+ return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]), groups=c, stride=2, padding=(pad,))
82
+
83
+ #----------------------------------------------------------------------------
84
+ # Magnitude-preserving SiLU (Equation 81).
85
+
86
+ def mp_silu(x):
87
+ return torch.nn.functional.silu(x) / 0.596
88
+
89
+ #----------------------------------------------------------------------------
90
+ # Magnitude-preserving sum (Equation 88).
91
+
92
+ def mp_sum(a, b, t=0.5):
93
+ return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
94
+
95
+ #----------------------------------------------------------------------------
96
+ # Magnitude-preserving concatenation (Equation 103).
97
+
98
+ def mp_cat(a, b, dim=1, t=0.5):
99
+ Na = a.shape[dim]
100
+ Nb = b.shape[dim]
101
+ C = np.sqrt((Na + Nb) / ((1 - t) ** 2 + t ** 2))
102
+ wa = C / np.sqrt(Na) * (1 - t)
103
+ wb = C / np.sqrt(Nb) * t
104
+ return torch.cat([wa * a , wb * b], dim=dim)
105
+
106
+ #----------------------------------------------------------------------------
107
+ # Magnitude-preserving Fourier features (Equation 75).
108
+
109
+ class MPFourier(torch.nn.Module):
110
+ def __init__(self, num_channels, bandwidth=1):
111
+ super().__init__()
112
+ self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
113
+ self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
114
+
115
+ def forward(self, x):
116
+ y = x.to(torch.float32)
117
+ y = y.ger(self.freqs.to(torch.float32))
118
+ y = y + self.phases.to(torch.float32)
119
+ y = y.cos() * np.sqrt(2)
120
+ return y.to(x.dtype)
121
+
122
+ #----------------------------------------------------------------------------
123
+ # Magnitude-preserving convolution or fully-connected layer (Equation 47)
124
+ # with force weight normalization (Equation 66).
125
+
126
+ class MPConv(torch.nn.Module):
127
+ def __init__(self, in_channels, out_channels, kernel):
128
+ super().__init__()
129
+ self.out_channels = out_channels
130
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
131
+
132
+ def forward(self, x, gain=1):
133
+ w = self.weight.to(torch.float32)
134
+ if self.training:
135
+ with torch.no_grad():
136
+ self.weight.copy_(normalize(w)) # forced weight normalization
137
+ w = normalize(w) # traditional weight normalization
138
+ w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
139
+ w = w.to(x.dtype)
140
+ if w.ndim == 2:
141
+ return x @ w.t()
142
+ assert w.ndim == 4
143
+ return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
144
+
145
+ #----------------------------------------------------------------------------
146
+ # U-Net encoder/decoder block with optional self-attention (Figure 21).
147
+
148
+ class Block(torch.nn.Module):
149
+ def __init__(self,
150
+ in_channels, # Number of input channels.
151
+ out_channels, # Number of output channels.
152
+ emb_channels, # Number of embedding channels.
153
+ flavor = 'enc', # Flavor: 'enc' or 'dec'.
154
+ resample_mode = 'keep', # Resampling: 'keep', 'up', or 'down'.
155
+ resample_filter = [1,1], # Resampling filter.
156
+ attention = False, # Include self-attention?
157
+ channels_per_head = 64, # Number of channels per attention head.
158
+ dropout = 0, # Dropout probability.
159
+ res_balance = 0.3, # Balance between main branch (0) and residual branch (1).
160
+ attn_balance = 0.3, # Balance between main branch (0) and self-attention (1).
161
+ clip_act = 256, # Clip output activations. None = do not clip.
162
+ ):
163
+ super().__init__()
164
+ self.out_channels = out_channels
165
+ self.flavor = flavor
166
+ self.resample_filter = resample_filter
167
+ self.resample_mode = resample_mode
168
+ self.num_heads = out_channels // channels_per_head if attention else 0
169
+ self.dropout = dropout
170
+ self.res_balance = res_balance
171
+ self.attn_balance = attn_balance
172
+ self.clip_act = clip_act
173
+ self.emb_gain = torch.nn.Parameter(torch.zeros([]))
174
+ self.conv_res0 = MPConv(out_channels if flavor == 'enc' else in_channels, out_channels, kernel=[3,3])
175
+ self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
176
+ self.conv_res1 = MPConv(out_channels, out_channels, kernel=[3,3])
177
+ self.conv_skip = MPConv(in_channels, out_channels, kernel=[1,1]) if in_channels != out_channels else None
178
+ self.attn_qkv = MPConv(out_channels, out_channels * 3, kernel=[1,1]) if self.num_heads != 0 else None
179
+ self.attn_proj = MPConv(out_channels, out_channels, kernel=[1,1]) if self.num_heads != 0 else None
180
+
181
+ def forward(self, x, emb):
182
+ # Main branch.
183
+ x = resample(x, f=self.resample_filter, mode=self.resample_mode)
184
+ if self.flavor == 'enc':
185
+ if self.conv_skip is not None:
186
+ x = self.conv_skip(x)
187
+ x = normalize(x, dim=1) # pixel norm
188
+
189
+ # Residual branch.
190
+ y = self.conv_res0(mp_silu(x))
191
+ c = self.emb_linear(emb, gain=self.emb_gain) + 1
192
+ y = mp_silu(y * c.unsqueeze(2).unsqueeze(3).to(y.dtype))
193
+ if self.training and self.dropout != 0:
194
+ y = torch.nn.functional.dropout(y, p=self.dropout)
195
+ y = self.conv_res1(y)
196
+
197
+ # Connect the branches.
198
+ if self.flavor == 'dec' and self.conv_skip is not None:
199
+ x = self.conv_skip(x)
200
+ x = mp_sum(x, y, t=self.res_balance)
201
+
202
+ # Self-attention.
203
+ # Note: torch.nn.functional.scaled_dot_product_attention() could be used here,
204
+ # but we haven't done sufficient testing to verify that it produces identical results.
205
+ if self.num_heads != 0:
206
+ y = self.attn_qkv(x)
207
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[2] * y.shape[3])
208
+ q, k, v = normalize(y, dim=2).unbind(3) # pixel norm & split
209
+ w = torch.einsum('nhcq,nhck->nhqk', q, k / np.sqrt(q.shape[2])).softmax(dim=3)
210
+ y = torch.einsum('nhqk,nhck->nhcq', w, v)
211
+ y = self.attn_proj(y.reshape(*x.shape))
212
+ x = mp_sum(x, y, t=self.attn_balance)
213
+
214
+ # Clip activations.
215
+ if self.clip_act is not None:
216
+ x = x.clip_(-self.clip_act, self.clip_act)
217
+ return x
218
+
219
+
220
+ class TagEncoder(torch.nn.Module):
221
+ def __init__(self, din, dout):
222
+ super().__init__()
223
+ self.din = din
224
+ self.linear1 = MPConv(din, dout, [])
225
+ self.linear2 = MPConv(dout, dout, [])
226
+ self.out_gain = torch.nn.Parameter(torch.tensor(0.0))
227
+
228
+ def forward(self, x):
229
+ x = mp_silu(self.linear1(x))
230
+ return self.din**-0.5 * self.linear2(x, gain=self.out_gain)
231
+
232
+
233
+ #----------------------------------------------------------------------------
234
+ # EDM2 U-Net model (Figure 21).
235
+
236
+ class UNet(torch.nn.Module):
237
+ def __init__(self,
238
+ img_resolution, # Image resolution.
239
+ img_channels, # Image channels.
240
+ label_dim, # Class label dimensionality. 0 = unconditional.
241
+ model_channels = 192, # Base multiplier for the number of channels.
242
+ channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels.
243
+ channel_mult_noise = None, # Multiplier for noise embedding dimensionality. None = select based on channel_mult.
244
+ channel_mult_emb = None, # Multiplier for final embedding dimensionality. None = select based on channel_mult.
245
+ num_blocks = 3, # Number of residual blocks per resolution.
246
+ attn_resolutions = [16,8], # List of resolutions with self-attention.
247
+ label_balance = 0.5, # Balance between noise embedding (0) and class embedding (1).
248
+ concat_balance = 0.5, # Balance between skip connections (0) and main path (1).
249
+ **block_kwargs, # Arguments for Block.
250
+ ):
251
+ super().__init__()
252
+ cblock = [model_channels * x for x in channel_mult]
253
+ cnoise = model_channels * channel_mult_noise if channel_mult_noise is not None else cblock[0]
254
+ cemb = model_channels * channel_mult_emb if channel_mult_emb is not None else max(cblock)
255
+ self.label_balance = label_balance
256
+ self.concat_balance = concat_balance
257
+ self.out_gain = torch.nn.Parameter(torch.zeros([]))
258
+
259
+ # Embedding.
260
+ self.emb_fourier = MPFourier(cnoise)
261
+ self.emb_noise = MPConv(cnoise, cemb, kernel=[])
262
+ self.emb_label = TagEncoder(label_dim, cemb)
263
+ if type(num_blocks) is int:
264
+ num_blocks = [num_blocks for _ in cblock]
265
+ # Encoder.
266
+ self.enc = torch.nn.ModuleDict()
267
+ cout = img_channels + 1
268
+ for level, channels in enumerate(cblock):
269
+ res = img_resolution >> level
270
+ if level == 0:
271
+ cin = cout
272
+ cout = channels
273
+ self.enc[f'{res}x{res}_conv'] = MPConv(cin, cout, kernel=[3,3])
274
+ else:
275
+ self.enc[f'{res}x{res}_down'] = Block(cout, cout, cemb, flavor='enc', resample_mode='down', **block_kwargs)
276
+ for idx in range(num_blocks[level]):
277
+ cin = cout
278
+ cout = channels
279
+ self.enc[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='enc', attention=(res in attn_resolutions), **block_kwargs)
280
+
281
+ # Decoder.
282
+ self.dec = torch.nn.ModuleDict()
283
+ skips = [block.out_channels for block in self.enc.values()]
284
+ for level, channels in reversed(list(enumerate(cblock))):
285
+ res = img_resolution >> level
286
+ if level == len(cblock) - 1:
287
+ self.dec[f'{res}x{res}_in0'] = Block(cout, cout, cemb, flavor='dec', attention=True, **block_kwargs)
288
+ self.dec[f'{res}x{res}_in1'] = Block(cout, cout, cemb, flavor='dec', **block_kwargs)
289
+ else:
290
+ self.dec[f'{res}x{res}_up'] = Block(cout, cout, cemb, flavor='dec', resample_mode='up', **block_kwargs)
291
+ for idx in range(num_blocks[level] + 1):
292
+ cin = cout + skips.pop()
293
+ cout = channels
294
+ self.dec[f'{res}x{res}_block{idx}'] = Block(cin, cout, cemb, flavor='dec', attention=(res in attn_resolutions), **block_kwargs)
295
+ self.out_conv = MPConv(cout, img_channels, kernel=[3,3])
296
+
297
+ def forward(self, x, noise_labels, class_labels):
298
+ # Embedding.
299
+ emb = self.emb_noise(self.emb_fourier(noise_labels))
300
+ if self.emb_label is not None:
301
+ emb = mp_sum(emb, self.emb_label(class_labels * np.sqrt(class_labels.shape[1])), t=self.label_balance)
302
+ emb = mp_silu(emb)
303
+
304
+ # Encoder.
305
+ x = torch.cat([x, torch.ones_like(x[:, :1])], dim=1)
306
+ skips = []
307
+ for name, block in self.enc.items():
308
+ x = block(x) if 'conv' in name else block(x, emb)
309
+ skips.append(x)
310
+
311
+ # Decoder.
312
+ for name, block in self.dec.items():
313
+ if 'block' in name:
314
+ x = mp_cat(x, skips.pop(), t=self.concat_balance)
315
+ x = block(x, emb)
316
+ x = self.out_conv(x, gain=self.out_gain)
317
+ return x
318
+
319
+ #----------------------------------------------------------------------------
320
+ # Preconditioning and uncertainty estimation.
321
+
322
+ class Precond(torch.nn.Module):
323
+ def __init__(self,
324
+ img_resolution, # Image resolution.
325
+ img_channels, # Image channels.
326
+ label_dim, # Class label dimensionality. 0 = unconditional.
327
+ use_fp16 = True, # Run the model at FP16 precision?
328
+ sigma_data = 0.5, # Expected standard deviation of the training data.
329
+ **unet_kwargs, # Keyword arguments for UNet.
330
+ ):
331
+ super().__init__()
332
+ self.img_resolution = img_resolution
333
+ self.img_channels = img_channels
334
+ self.label_dim = label_dim
335
+ self.use_fp16 = use_fp16
336
+ self.sigma_data = sigma_data
337
+ self.unet = UNet(img_resolution=img_resolution, img_channels=img_channels, label_dim=label_dim, **unet_kwargs)
338
+ self.uncond_emb = torch.nn.Parameter(torch.randn((1024,)))
339
+
340
+
341
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, return_logvar=False, **unet_kwargs):
342
+ x = x.to(torch.float32)
343
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
344
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
345
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
346
+
347
+ # Preconditioning weights.
348
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
349
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
350
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
351
+ c_noise = sigma.flatten().log() / 4
352
+
353
+ # Run the model.
354
+ x_in = (c_in * x).to(dtype)
355
+ F_x = self.unet(x_in, c_noise, class_labels, **unet_kwargs)
356
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
357
+
358
+ return D_x
359
+
360
+ #----------------------------------------------------------------------------
lib/sampling.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def edm_sampler(
6
+ net,
7
+ noise,
8
+ labels=None,
9
+ gnet=None,
10
+ num_steps=32,
11
+ sigma_min=0.002,
12
+ sigma_max=80,
13
+ rho=7,
14
+ guidance=1,
15
+ S_churn=0,
16
+ S_min=0,
17
+ S_max=float("inf"),
18
+ S_noise=1,
19
+ dtype=torch.float32,
20
+ randn_like=torch.randn_like,
21
+ ):
22
+ # Guided denoiser.
23
+ def denoise(x, t):
24
+ Dx = net(x, t, labels).to(dtype)
25
+ if guidance == 1:
26
+ return Dx
27
+ ref_Dx = gnet(x, t).to(dtype)
28
+ return ref_Dx.lerp(Dx, guidance)
29
+
30
+ # Time step discretization.
31
+ step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
32
+ t_steps = (
33
+ sigma_max ** (1 / rho)
34
+ + step_indices
35
+ / (num_steps - 1)
36
+ * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
37
+ ) ** rho
38
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
39
+
40
+ # Main sampling loop.
41
+ x_next = noise.to(dtype) * t_steps[0]
42
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
43
+ x_cur = x_next
44
+
45
+ # Increase noise temporarily.
46
+ if S_churn > 0 and S_min <= t_cur <= S_max:
47
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
48
+ t_hat = t_cur + gamma * t_cur
49
+ x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
50
+ else:
51
+ t_hat = t_cur
52
+ x_hat = x_cur
53
+
54
+ # Euler step.
55
+ d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat
56
+ x_next = x_hat + (t_next - t_hat) * d_cur
57
+
58
+ # Apply 2nd order correction.
59
+ if i < num_steps - 1:
60
+ d_prime = (x_next - denoise(x_next, t_next)) / t_next
61
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
62
+
63
+ return x_next
model_weights/1girl-edm-xs-test-1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:745cdbc08a10dea8a2bffa51559f741f7010b81faa194cfcc6ea62b98ef329bf
3
+ size 499977592
model_weights/condgen_vae_decoder.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fc2e77e8584fdb207b768398fdc96693772cfbb378f4b7c6adc58fa08116cef
3
+ size 5776840