Fabrice-TIERCELIN commited on
Commit
6d08643
1 Parent(s): dd6a633

Upload 13 files

Browse files
sgm/modules/diffusionmodules/denoiser.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from ...util import append_dims, instantiate_from_config
4
+
5
+
6
+ class Denoiser(nn.Module):
7
+ def __init__(self, weighting_config, scaling_config):
8
+ super().__init__()
9
+
10
+ self.weighting = instantiate_from_config(weighting_config)
11
+ self.scaling = instantiate_from_config(scaling_config)
12
+
13
+ def possibly_quantize_sigma(self, sigma):
14
+ return sigma
15
+
16
+ def possibly_quantize_c_noise(self, c_noise):
17
+ return c_noise
18
+
19
+ def w(self, sigma):
20
+ return self.weighting(sigma)
21
+
22
+ def __call__(self, network, input, sigma, cond):
23
+ sigma = self.possibly_quantize_sigma(sigma)
24
+ sigma_shape = sigma.shape
25
+ sigma = append_dims(sigma, input.ndim)
26
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
27
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
28
+ return network(input * c_in, c_noise, cond) * c_out + input * c_skip
29
+
30
+
31
+ class DiscreteDenoiser(Denoiser):
32
+ def __init__(
33
+ self,
34
+ weighting_config,
35
+ scaling_config,
36
+ num_idx,
37
+ discretization_config,
38
+ do_append_zero=False,
39
+ quantize_c_noise=True,
40
+ flip=True,
41
+ ):
42
+ super().__init__(weighting_config, scaling_config)
43
+ sigmas = instantiate_from_config(discretization_config)(
44
+ num_idx, do_append_zero=do_append_zero, flip=flip
45
+ )
46
+ self.register_buffer("sigmas", sigmas)
47
+ self.quantize_c_noise = quantize_c_noise
48
+
49
+ def sigma_to_idx(self, sigma):
50
+ dists = sigma - self.sigmas[:, None]
51
+ return dists.abs().argmin(dim=0).view(sigma.shape)
52
+
53
+ def idx_to_sigma(self, idx):
54
+ return self.sigmas[idx]
55
+
56
+ def possibly_quantize_sigma(self, sigma):
57
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
58
+
59
+ def possibly_quantize_c_noise(self, c_noise):
60
+ if self.quantize_c_noise:
61
+ return self.sigma_to_idx(c_noise)
62
+ else:
63
+ return c_noise
64
+
65
+
66
+ class DiscreteDenoiserWithControl(DiscreteDenoiser):
67
+ def __call__(self, network, input, sigma, cond, control_scale):
68
+ sigma = self.possibly_quantize_sigma(sigma)
69
+ sigma_shape = sigma.shape
70
+ sigma = append_dims(sigma, input.ndim)
71
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
72
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
73
+ return network(input * c_in, c_noise, cond, control_scale) * c_out + input * c_skip
sgm/modules/diffusionmodules/denoiser_scaling.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class EDMScaling:
5
+ def __init__(self, sigma_data=0.5):
6
+ self.sigma_data = sigma_data
7
+
8
+ def __call__(self, sigma):
9
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
10
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
11
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
12
+ c_noise = 0.25 * sigma.log()
13
+ return c_skip, c_out, c_in, c_noise
14
+
15
+
16
+ class EpsScaling:
17
+ def __call__(self, sigma):
18
+ c_skip = torch.ones_like(sigma, device=sigma.device)
19
+ c_out = -sigma
20
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
21
+ c_noise = sigma.clone()
22
+ return c_skip, c_out, c_in, c_noise
23
+
24
+
25
+ class VScaling:
26
+ def __call__(self, sigma):
27
+ c_skip = 1.0 / (sigma**2 + 1.0)
28
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
29
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
30
+ c_noise = sigma.clone()
31
+ return c_skip, c_out, c_in, c_noise
sgm/modules/diffusionmodules/denoiser_weighting.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class UnitWeighting:
4
+ def __call__(self, sigma):
5
+ return torch.ones_like(sigma, device=sigma.device)
6
+
7
+
8
+ class EDMWeighting:
9
+ def __init__(self, sigma_data=0.5):
10
+ self.sigma_data = sigma_data
11
+
12
+ def __call__(self, sigma):
13
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
14
+
15
+
16
+ class VWeighting(EDMWeighting):
17
+ def __init__(self):
18
+ super().__init__(sigma_data=1.0)
19
+
20
+
21
+ class EpsWeighting:
22
+ def __call__(self, sigma):
23
+ return sigma**-2
24
+
sgm/modules/diffusionmodules/discretizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...modules.diffusionmodules.util import make_beta_schedule
8
+ from ...util import append_zero
9
+
10
+
11
+ def generate_roughly_equally_spaced_steps(
12
+ num_substeps: int, max_step: int
13
+ ) -> np.ndarray:
14
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
15
+
16
+
17
+ class Discretization:
18
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
19
+ sigmas = self.get_sigmas(n, device=device)
20
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
21
+ return sigmas if not flip else torch.flip(sigmas, (0,))
22
+
23
+ @abstractmethod
24
+ def get_sigmas(self, n, device):
25
+ pass
26
+
27
+
28
+ class EDMDiscretization(Discretization):
29
+ def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
30
+ self.sigma_min = sigma_min
31
+ self.sigma_max = sigma_max
32
+ self.rho = rho
33
+
34
+ def get_sigmas(self, n, device="cpu"):
35
+ ramp = torch.linspace(0, 1, n, device=device)
36
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
37
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
38
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
39
+ return sigmas
40
+
41
+
42
+ class LegacyDDPMDiscretization(Discretization):
43
+ def __init__(
44
+ self,
45
+ linear_start=0.00085,
46
+ linear_end=0.0120,
47
+ num_timesteps=1000,
48
+ ):
49
+ super().__init__()
50
+ self.num_timesteps = num_timesteps
51
+ betas = make_beta_schedule(
52
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
53
+ )
54
+ alphas = 1.0 - betas
55
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
56
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
57
+
58
+ def get_sigmas(self, n, device="cpu"):
59
+ if n < self.num_timesteps:
60
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
61
+ alphas_cumprod = self.alphas_cumprod[timesteps]
62
+ elif n == self.num_timesteps:
63
+ alphas_cumprod = self.alphas_cumprod
64
+ else:
65
+ raise ValueError
66
+
67
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
68
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
69
+ return torch.flip(sigmas, (0,))
sgm/modules/diffusionmodules/guiders.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+
5
+ from ...util import default, instantiate_from_config
6
+
7
+
8
+ class VanillaCFG:
9
+ """
10
+ implements parallelized CFG
11
+ """
12
+
13
+ def __init__(self, scale, dyn_thresh_config=None):
14
+ scale_schedule = lambda scale, sigma: scale # independent of step
15
+ self.scale_schedule = partial(scale_schedule, scale)
16
+ self.dyn_thresh = instantiate_from_config(
17
+ default(
18
+ dyn_thresh_config,
19
+ {
20
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
21
+ },
22
+ )
23
+ )
24
+
25
+ def __call__(self, x, sigma):
26
+ x_u, x_c = x.chunk(2)
27
+ scale_value = self.scale_schedule(sigma)
28
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
+ return x_pred
30
+
31
+ def prepare_inputs(self, x, s, c, uc):
32
+ c_out = dict()
33
+
34
+ for k in c:
35
+ if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']:
36
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
37
+ else:
38
+ assert c[k] == uc[k]
39
+ c_out[k] = c[k]
40
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
+
42
+
43
+
44
+ class LinearCFG:
45
+ def __init__(self, scale, scale_min=None, dyn_thresh_config=None):
46
+ if scale_min is None:
47
+ scale_min = scale
48
+ scale_schedule = lambda scale, scale_min, sigma: (scale - scale_min) * sigma / 14.6146 + scale_min
49
+ self.scale_schedule = partial(scale_schedule, scale, scale_min)
50
+ self.dyn_thresh = instantiate_from_config(
51
+ default(
52
+ dyn_thresh_config,
53
+ {
54
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
55
+ },
56
+ )
57
+ )
58
+
59
+ def __call__(self, x, sigma):
60
+ x_u, x_c = x.chunk(2)
61
+ scale_value = self.scale_schedule(sigma)
62
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
63
+ return x_pred
64
+
65
+ def prepare_inputs(self, x, s, c, uc):
66
+ c_out = dict()
67
+
68
+ for k in c:
69
+ if k in ["vector", "crossattn", "concat", "control", 'control_vector', 'mask_x']:
70
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
71
+ else:
72
+ assert c[k] == uc[k]
73
+ c_out[k] = c[k]
74
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
75
+
76
+
77
+
78
+ class IdentityGuider:
79
+ def __call__(self, x, sigma):
80
+ return x
81
+
82
+ def prepare_inputs(self, x, s, c, uc):
83
+ c_out = dict()
84
+
85
+ for k in c:
86
+ c_out[k] = c[k]
87
+
88
+ return x, s, c_out
sgm/modules/diffusionmodules/loss.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import ListConfig
6
+
7
+ from ...util import append_dims, instantiate_from_config
8
+ from ...modules.autoencoding.lpips.loss.lpips import LPIPS
9
+
10
+
11
+ class StandardDiffusionLoss(nn.Module):
12
+ def __init__(
13
+ self,
14
+ sigma_sampler_config,
15
+ type="l2",
16
+ offset_noise_level=0.0,
17
+ batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
18
+ ):
19
+ super().__init__()
20
+
21
+ assert type in ["l2", "l1", "lpips"]
22
+
23
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
24
+
25
+ self.type = type
26
+ self.offset_noise_level = offset_noise_level
27
+
28
+ if type == "lpips":
29
+ self.lpips = LPIPS().eval()
30
+
31
+ if not batch2model_keys:
32
+ batch2model_keys = []
33
+
34
+ if isinstance(batch2model_keys, str):
35
+ batch2model_keys = [batch2model_keys]
36
+
37
+ self.batch2model_keys = set(batch2model_keys)
38
+
39
+ def __call__(self, network, denoiser, conditioner, input, batch):
40
+ cond = conditioner(batch)
41
+ additional_model_inputs = {
42
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
43
+ }
44
+
45
+ sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
46
+ noise = torch.randn_like(input)
47
+ if self.offset_noise_level > 0.0:
48
+ noise = noise + self.offset_noise_level * append_dims(
49
+ torch.randn(input.shape[0], device=input.device), input.ndim
50
+ )
51
+ noised_input = input + noise * append_dims(sigmas, input.ndim)
52
+ model_output = denoiser(
53
+ network, noised_input, sigmas, cond, **additional_model_inputs
54
+ )
55
+ w = append_dims(denoiser.w(sigmas), input.ndim)
56
+ return self.get_loss(model_output, input, w)
57
+
58
+ def get_loss(self, model_output, target, w):
59
+ if self.type == "l2":
60
+ return torch.mean(
61
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
62
+ )
63
+ elif self.type == "l1":
64
+ return torch.mean(
65
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
66
+ )
67
+ elif self.type == "lpips":
68
+ loss = self.lpips(model_output, target).reshape(-1)
69
+ return loss
sgm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ from typing import Any, Callable, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from packaging import version
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+
15
+ XFORMERS_IS_AVAILABLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILABLE = False
18
+ print("no module 'xformers'. Processing without...")
19
+
20
+ from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
21
+
22
+
23
+ def get_timestep_embedding(timesteps, embedding_dim):
24
+ """
25
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
26
+ From Fairseq.
27
+ Build sinusoidal embeddings.
28
+ This matches the implementation in tensor2tensor, but differs slightly
29
+ from the description in Section 3.5 of "Attention Is All You Need".
30
+ """
31
+ assert len(timesteps.shape) == 1
32
+
33
+ half_dim = embedding_dim // 2
34
+ emb = math.log(10000) / (half_dim - 1)
35
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
36
+ emb = emb.to(device=timesteps.device)
37
+ emb = timesteps.float()[:, None] * emb[None, :]
38
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
39
+ if embedding_dim % 2 == 1: # zero pad
40
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
41
+ return emb
42
+
43
+
44
+ def nonlinearity(x):
45
+ # swish
46
+ return x * torch.sigmoid(x)
47
+
48
+
49
+ def Normalize(in_channels, num_groups=32):
50
+ return torch.nn.GroupNorm(
51
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
52
+ )
53
+
54
+
55
+ class Upsample(nn.Module):
56
+ def __init__(self, in_channels, with_conv):
57
+ super().__init__()
58
+ self.with_conv = with_conv
59
+ if self.with_conv:
60
+ self.conv = torch.nn.Conv2d(
61
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
66
+ if self.with_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ def __init__(self, in_channels, with_conv):
73
+ super().__init__()
74
+ self.with_conv = with_conv
75
+ if self.with_conv:
76
+ # no asymmetric padding in torch conv, must do it ourselves
77
+ self.conv = torch.nn.Conv2d(
78
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
79
+ )
80
+
81
+ def forward(self, x):
82
+ if self.with_conv:
83
+ pad = (0, 1, 0, 1)
84
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
85
+ x = self.conv(x)
86
+ else:
87
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
88
+ return x
89
+
90
+
91
+ class ResnetBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ *,
95
+ in_channels,
96
+ out_channels=None,
97
+ conv_shortcut=False,
98
+ dropout,
99
+ temb_channels=512,
100
+ ):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.use_conv_shortcut = conv_shortcut
106
+
107
+ self.norm1 = Normalize(in_channels)
108
+ self.conv1 = torch.nn.Conv2d(
109
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
110
+ )
111
+ if temb_channels > 0:
112
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
113
+ self.norm2 = Normalize(out_channels)
114
+ self.dropout = torch.nn.Dropout(dropout)
115
+ self.conv2 = torch.nn.Conv2d(
116
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
117
+ )
118
+ if self.in_channels != self.out_channels:
119
+ if self.use_conv_shortcut:
120
+ self.conv_shortcut = torch.nn.Conv2d(
121
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
122
+ )
123
+ else:
124
+ self.nin_shortcut = torch.nn.Conv2d(
125
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
126
+ )
127
+
128
+ def forward(self, x, temb):
129
+ h = x
130
+ h = self.norm1(h)
131
+ h = nonlinearity(h)
132
+ h = self.conv1(h)
133
+
134
+ if temb is not None:
135
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
136
+
137
+ h = self.norm2(h)
138
+ h = nonlinearity(h)
139
+ h = self.dropout(h)
140
+ h = self.conv2(h)
141
+
142
+ if self.in_channels != self.out_channels:
143
+ if self.use_conv_shortcut:
144
+ x = self.conv_shortcut(x)
145
+ else:
146
+ x = self.nin_shortcut(x)
147
+
148
+ return x + h
149
+
150
+
151
+ class LinAttnBlock(LinearAttention):
152
+ """to match AttnBlock usage"""
153
+
154
+ def __init__(self, in_channels):
155
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
156
+
157
+
158
+ class AttnBlock(nn.Module):
159
+ def __init__(self, in_channels):
160
+ super().__init__()
161
+ self.in_channels = in_channels
162
+
163
+ self.norm = Normalize(in_channels)
164
+ self.q = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+ self.k = torch.nn.Conv2d(
168
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
169
+ )
170
+ self.v = torch.nn.Conv2d(
171
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
172
+ )
173
+ self.proj_out = torch.nn.Conv2d(
174
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
175
+ )
176
+
177
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
178
+ h_ = self.norm(h_)
179
+ q = self.q(h_)
180
+ k = self.k(h_)
181
+ v = self.v(h_)
182
+
183
+ b, c, h, w = q.shape
184
+ q, k, v = map(
185
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
186
+ )
187
+ h_ = torch.nn.functional.scaled_dot_product_attention(
188
+ q, k, v
189
+ ) # scale is dim ** -0.5 per default
190
+ # compute attention
191
+
192
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
193
+
194
+ def forward(self, x, **kwargs):
195
+ h_ = x
196
+ h_ = self.attention(h_)
197
+ h_ = self.proj_out(h_)
198
+ return x + h_
199
+
200
+
201
+ class MemoryEfficientAttnBlock(nn.Module):
202
+ """
203
+ Uses xformers efficient implementation,
204
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
205
+ Note: this is a single-head self-attention operation
206
+ """
207
+
208
+ #
209
+ def __init__(self, in_channels):
210
+ super().__init__()
211
+ self.in_channels = in_channels
212
+
213
+ self.norm = Normalize(in_channels)
214
+ self.q = torch.nn.Conv2d(
215
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
216
+ )
217
+ self.k = torch.nn.Conv2d(
218
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
219
+ )
220
+ self.v = torch.nn.Conv2d(
221
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
222
+ )
223
+ self.proj_out = torch.nn.Conv2d(
224
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
225
+ )
226
+ self.attention_op: Optional[Any] = None
227
+
228
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
229
+ h_ = self.norm(h_)
230
+ q = self.q(h_)
231
+ k = self.k(h_)
232
+ v = self.v(h_)
233
+
234
+ # compute attention
235
+ B, C, H, W = q.shape
236
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
237
+
238
+ q, k, v = map(
239
+ lambda t: t.unsqueeze(3)
240
+ .reshape(B, t.shape[1], 1, C)
241
+ .permute(0, 2, 1, 3)
242
+ .reshape(B * 1, t.shape[1], C)
243
+ .contiguous(),
244
+ (q, k, v),
245
+ )
246
+ out = xformers.ops.memory_efficient_attention(
247
+ q, k, v, attn_bias=None, op=self.attention_op
248
+ )
249
+
250
+ out = (
251
+ out.unsqueeze(0)
252
+ .reshape(B, 1, out.shape[1], C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B, out.shape[1], C)
255
+ )
256
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
257
+
258
+ def forward(self, x, **kwargs):
259
+ h_ = x
260
+ h_ = self.attention(h_)
261
+ h_ = self.proj_out(h_)
262
+ return x + h_
263
+
264
+
265
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
266
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
267
+ b, c, h, w = x.shape
268
+ x = rearrange(x, "b c h w -> b (h w) c")
269
+ out = super().forward(x, context=context, mask=mask)
270
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
271
+ return x + out
272
+
273
+
274
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
275
+ assert attn_type in [
276
+ "vanilla",
277
+ "vanilla-xformers",
278
+ "memory-efficient-cross-attn",
279
+ "linear",
280
+ "none",
281
+ ], f"attn_type {attn_type} unknown"
282
+ if (
283
+ version.parse(torch.__version__) < version.parse("2.0.0")
284
+ and attn_type != "none"
285
+ ):
286
+ assert XFORMERS_IS_AVAILABLE, (
287
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
288
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
289
+ )
290
+ attn_type = "vanilla-xformers"
291
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
292
+ if attn_type == "vanilla":
293
+ assert attn_kwargs is None
294
+ return AttnBlock(in_channels)
295
+ elif attn_type == "vanilla-xformers":
296
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
297
+ return MemoryEfficientAttnBlock(in_channels)
298
+ elif type == "memory-efficient-cross-attn":
299
+ attn_kwargs["query_dim"] = in_channels
300
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
301
+ elif attn_type == "none":
302
+ return nn.Identity(in_channels)
303
+ else:
304
+ return LinAttnBlock(in_channels)
305
+
306
+
307
+ class Model(nn.Module):
308
+ def __init__(
309
+ self,
310
+ *,
311
+ ch,
312
+ out_ch,
313
+ ch_mult=(1, 2, 4, 8),
314
+ num_res_blocks,
315
+ attn_resolutions,
316
+ dropout=0.0,
317
+ resamp_with_conv=True,
318
+ in_channels,
319
+ resolution,
320
+ use_timestep=True,
321
+ use_linear_attn=False,
322
+ attn_type="vanilla",
323
+ ):
324
+ super().__init__()
325
+ if use_linear_attn:
326
+ attn_type = "linear"
327
+ self.ch = ch
328
+ self.temb_ch = self.ch * 4
329
+ self.num_resolutions = len(ch_mult)
330
+ self.num_res_blocks = num_res_blocks
331
+ self.resolution = resolution
332
+ self.in_channels = in_channels
333
+
334
+ self.use_timestep = use_timestep
335
+ if self.use_timestep:
336
+ # timestep embedding
337
+ self.temb = nn.Module()
338
+ self.temb.dense = nn.ModuleList(
339
+ [
340
+ torch.nn.Linear(self.ch, self.temb_ch),
341
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
342
+ ]
343
+ )
344
+
345
+ # downsampling
346
+ self.conv_in = torch.nn.Conv2d(
347
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
348
+ )
349
+
350
+ curr_res = resolution
351
+ in_ch_mult = (1,) + tuple(ch_mult)
352
+ self.down = nn.ModuleList()
353
+ for i_level in range(self.num_resolutions):
354
+ block = nn.ModuleList()
355
+ attn = nn.ModuleList()
356
+ block_in = ch * in_ch_mult[i_level]
357
+ block_out = ch * ch_mult[i_level]
358
+ for i_block in range(self.num_res_blocks):
359
+ block.append(
360
+ ResnetBlock(
361
+ in_channels=block_in,
362
+ out_channels=block_out,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout,
365
+ )
366
+ )
367
+ block_in = block_out
368
+ if curr_res in attn_resolutions:
369
+ attn.append(make_attn(block_in, attn_type=attn_type))
370
+ down = nn.Module()
371
+ down.block = block
372
+ down.attn = attn
373
+ if i_level != self.num_resolutions - 1:
374
+ down.downsample = Downsample(block_in, resamp_with_conv)
375
+ curr_res = curr_res // 2
376
+ self.down.append(down)
377
+
378
+ # middle
379
+ self.mid = nn.Module()
380
+ self.mid.block_1 = ResnetBlock(
381
+ in_channels=block_in,
382
+ out_channels=block_in,
383
+ temb_channels=self.temb_ch,
384
+ dropout=dropout,
385
+ )
386
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
387
+ self.mid.block_2 = ResnetBlock(
388
+ in_channels=block_in,
389
+ out_channels=block_in,
390
+ temb_channels=self.temb_ch,
391
+ dropout=dropout,
392
+ )
393
+
394
+ # upsampling
395
+ self.up = nn.ModuleList()
396
+ for i_level in reversed(range(self.num_resolutions)):
397
+ block = nn.ModuleList()
398
+ attn = nn.ModuleList()
399
+ block_out = ch * ch_mult[i_level]
400
+ skip_in = ch * ch_mult[i_level]
401
+ for i_block in range(self.num_res_blocks + 1):
402
+ if i_block == self.num_res_blocks:
403
+ skip_in = ch * in_ch_mult[i_level]
404
+ block.append(
405
+ ResnetBlock(
406
+ in_channels=block_in + skip_in,
407
+ out_channels=block_out,
408
+ temb_channels=self.temb_ch,
409
+ dropout=dropout,
410
+ )
411
+ )
412
+ block_in = block_out
413
+ if curr_res in attn_resolutions:
414
+ attn.append(make_attn(block_in, attn_type=attn_type))
415
+ up = nn.Module()
416
+ up.block = block
417
+ up.attn = attn
418
+ if i_level != 0:
419
+ up.upsample = Upsample(block_in, resamp_with_conv)
420
+ curr_res = curr_res * 2
421
+ self.up.insert(0, up) # prepend to get consistent order
422
+
423
+ # end
424
+ self.norm_out = Normalize(block_in)
425
+ self.conv_out = torch.nn.Conv2d(
426
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
427
+ )
428
+
429
+ def forward(self, x, t=None, context=None):
430
+ # assert x.shape[2] == x.shape[3] == self.resolution
431
+ if context is not None:
432
+ # assume aligned context, cat along channel axis
433
+ x = torch.cat((x, context), dim=1)
434
+ if self.use_timestep:
435
+ # timestep embedding
436
+ assert t is not None
437
+ temb = get_timestep_embedding(t, self.ch)
438
+ temb = self.temb.dense[0](temb)
439
+ temb = nonlinearity(temb)
440
+ temb = self.temb.dense[1](temb)
441
+ else:
442
+ temb = None
443
+
444
+ # downsampling
445
+ hs = [self.conv_in(x)]
446
+ for i_level in range(self.num_resolutions):
447
+ for i_block in range(self.num_res_blocks):
448
+ h = self.down[i_level].block[i_block](hs[-1], temb)
449
+ if len(self.down[i_level].attn) > 0:
450
+ h = self.down[i_level].attn[i_block](h)
451
+ hs.append(h)
452
+ if i_level != self.num_resolutions - 1:
453
+ hs.append(self.down[i_level].downsample(hs[-1]))
454
+
455
+ # middle
456
+ h = hs[-1]
457
+ h = self.mid.block_1(h, temb)
458
+ h = self.mid.attn_1(h)
459
+ h = self.mid.block_2(h, temb)
460
+
461
+ # upsampling
462
+ for i_level in reversed(range(self.num_resolutions)):
463
+ for i_block in range(self.num_res_blocks + 1):
464
+ h = self.up[i_level].block[i_block](
465
+ torch.cat([h, hs.pop()], dim=1), temb
466
+ )
467
+ if len(self.up[i_level].attn) > 0:
468
+ h = self.up[i_level].attn[i_block](h)
469
+ if i_level != 0:
470
+ h = self.up[i_level].upsample(h)
471
+
472
+ # end
473
+ h = self.norm_out(h)
474
+ h = nonlinearity(h)
475
+ h = self.conv_out(h)
476
+ return h
477
+
478
+ def get_last_layer(self):
479
+ return self.conv_out.weight
480
+
481
+
482
+ class Encoder(nn.Module):
483
+ def __init__(
484
+ self,
485
+ *,
486
+ ch,
487
+ out_ch,
488
+ ch_mult=(1, 2, 4, 8),
489
+ num_res_blocks,
490
+ attn_resolutions,
491
+ dropout=0.0,
492
+ resamp_with_conv=True,
493
+ in_channels,
494
+ resolution,
495
+ z_channels,
496
+ double_z=True,
497
+ use_linear_attn=False,
498
+ attn_type="vanilla",
499
+ **ignore_kwargs,
500
+ ):
501
+ super().__init__()
502
+ if use_linear_attn:
503
+ attn_type = "linear"
504
+ self.ch = ch
505
+ self.temb_ch = 0
506
+ self.num_resolutions = len(ch_mult)
507
+ self.num_res_blocks = num_res_blocks
508
+ self.resolution = resolution
509
+ self.in_channels = in_channels
510
+
511
+ # downsampling
512
+ self.conv_in = torch.nn.Conv2d(
513
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
514
+ )
515
+
516
+ curr_res = resolution
517
+ in_ch_mult = (1,) + tuple(ch_mult)
518
+ self.in_ch_mult = in_ch_mult
519
+ self.down = nn.ModuleList()
520
+ for i_level in range(self.num_resolutions):
521
+ block = nn.ModuleList()
522
+ attn = nn.ModuleList()
523
+ block_in = ch * in_ch_mult[i_level]
524
+ block_out = ch * ch_mult[i_level]
525
+ for i_block in range(self.num_res_blocks):
526
+ block.append(
527
+ ResnetBlock(
528
+ in_channels=block_in,
529
+ out_channels=block_out,
530
+ temb_channels=self.temb_ch,
531
+ dropout=dropout,
532
+ )
533
+ )
534
+ block_in = block_out
535
+ if curr_res in attn_resolutions:
536
+ attn.append(make_attn(block_in, attn_type=attn_type))
537
+ down = nn.Module()
538
+ down.block = block
539
+ down.attn = attn
540
+ if i_level != self.num_resolutions - 1:
541
+ down.downsample = Downsample(block_in, resamp_with_conv)
542
+ curr_res = curr_res // 2
543
+ self.down.append(down)
544
+
545
+ # middle
546
+ self.mid = nn.Module()
547
+ self.mid.block_1 = ResnetBlock(
548
+ in_channels=block_in,
549
+ out_channels=block_in,
550
+ temb_channels=self.temb_ch,
551
+ dropout=dropout,
552
+ )
553
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
554
+ self.mid.block_2 = ResnetBlock(
555
+ in_channels=block_in,
556
+ out_channels=block_in,
557
+ temb_channels=self.temb_ch,
558
+ dropout=dropout,
559
+ )
560
+
561
+ # end
562
+ self.norm_out = Normalize(block_in)
563
+ self.conv_out = torch.nn.Conv2d(
564
+ block_in,
565
+ 2 * z_channels if double_z else z_channels,
566
+ kernel_size=3,
567
+ stride=1,
568
+ padding=1,
569
+ )
570
+
571
+ def forward(self, x):
572
+ # timestep embedding
573
+ temb = None
574
+
575
+ # downsampling
576
+ hs = [self.conv_in(x)]
577
+ for i_level in range(self.num_resolutions):
578
+ for i_block in range(self.num_res_blocks):
579
+ h = self.down[i_level].block[i_block](hs[-1], temb)
580
+ if len(self.down[i_level].attn) > 0:
581
+ h = self.down[i_level].attn[i_block](h)
582
+ hs.append(h)
583
+ if i_level != self.num_resolutions - 1:
584
+ hs.append(self.down[i_level].downsample(hs[-1]))
585
+
586
+ # middle
587
+ h = hs[-1]
588
+ h = self.mid.block_1(h, temb)
589
+ h = self.mid.attn_1(h)
590
+ h = self.mid.block_2(h, temb)
591
+
592
+ # end
593
+ h = self.norm_out(h)
594
+ h = nonlinearity(h)
595
+ h = self.conv_out(h)
596
+ return h
597
+
598
+
599
+ class Decoder(nn.Module):
600
+ def __init__(
601
+ self,
602
+ *,
603
+ ch,
604
+ out_ch,
605
+ ch_mult=(1, 2, 4, 8),
606
+ num_res_blocks,
607
+ attn_resolutions,
608
+ dropout=0.0,
609
+ resamp_with_conv=True,
610
+ in_channels,
611
+ resolution,
612
+ z_channels,
613
+ give_pre_end=False,
614
+ tanh_out=False,
615
+ use_linear_attn=False,
616
+ attn_type="vanilla",
617
+ **ignorekwargs,
618
+ ):
619
+ super().__init__()
620
+ if use_linear_attn:
621
+ attn_type = "linear"
622
+ self.ch = ch
623
+ self.temb_ch = 0
624
+ self.num_resolutions = len(ch_mult)
625
+ self.num_res_blocks = num_res_blocks
626
+ self.resolution = resolution
627
+ self.in_channels = in_channels
628
+ self.give_pre_end = give_pre_end
629
+ self.tanh_out = tanh_out
630
+
631
+ # compute in_ch_mult, block_in and curr_res at lowest res
632
+ in_ch_mult = (1,) + tuple(ch_mult)
633
+ block_in = ch * ch_mult[self.num_resolutions - 1]
634
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
635
+ self.z_shape = (1, z_channels, curr_res, curr_res)
636
+ print(
637
+ "Working with z of shape {} = {} dimensions.".format(
638
+ self.z_shape, np.prod(self.z_shape)
639
+ )
640
+ )
641
+
642
+ make_attn_cls = self._make_attn()
643
+ make_resblock_cls = self._make_resblock()
644
+ make_conv_cls = self._make_conv()
645
+ # z to block_in
646
+ self.conv_in = torch.nn.Conv2d(
647
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
648
+ )
649
+
650
+ # middle
651
+ self.mid = nn.Module()
652
+ self.mid.block_1 = make_resblock_cls(
653
+ in_channels=block_in,
654
+ out_channels=block_in,
655
+ temb_channels=self.temb_ch,
656
+ dropout=dropout,
657
+ )
658
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
659
+ self.mid.block_2 = make_resblock_cls(
660
+ in_channels=block_in,
661
+ out_channels=block_in,
662
+ temb_channels=self.temb_ch,
663
+ dropout=dropout,
664
+ )
665
+
666
+ # upsampling
667
+ self.up = nn.ModuleList()
668
+ for i_level in reversed(range(self.num_resolutions)):
669
+ block = nn.ModuleList()
670
+ attn = nn.ModuleList()
671
+ block_out = ch * ch_mult[i_level]
672
+ for i_block in range(self.num_res_blocks + 1):
673
+ block.append(
674
+ make_resblock_cls(
675
+ in_channels=block_in,
676
+ out_channels=block_out,
677
+ temb_channels=self.temb_ch,
678
+ dropout=dropout,
679
+ )
680
+ )
681
+ block_in = block_out
682
+ if curr_res in attn_resolutions:
683
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
684
+ up = nn.Module()
685
+ up.block = block
686
+ up.attn = attn
687
+ if i_level != 0:
688
+ up.upsample = Upsample(block_in, resamp_with_conv)
689
+ curr_res = curr_res * 2
690
+ self.up.insert(0, up) # prepend to get consistent order
691
+
692
+ # end
693
+ self.norm_out = Normalize(block_in)
694
+ self.conv_out = make_conv_cls(
695
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
696
+ )
697
+
698
+ def _make_attn(self) -> Callable:
699
+ return make_attn
700
+
701
+ def _make_resblock(self) -> Callable:
702
+ return ResnetBlock
703
+
704
+ def _make_conv(self) -> Callable:
705
+ return torch.nn.Conv2d
706
+
707
+ def get_last_layer(self, **kwargs):
708
+ return self.conv_out.weight
709
+
710
+ def forward(self, z, **kwargs):
711
+ # assert z.shape[1:] == self.z_shape[1:]
712
+ self.last_z_shape = z.shape
713
+
714
+ # timestep embedding
715
+ temb = None
716
+
717
+ # z to block_in
718
+ h = self.conv_in(z)
719
+
720
+ # middle
721
+ h = self.mid.block_1(h, temb, **kwargs)
722
+ h = self.mid.attn_1(h, **kwargs)
723
+ h = self.mid.block_2(h, temb, **kwargs)
724
+
725
+ # upsampling
726
+ for i_level in reversed(range(self.num_resolutions)):
727
+ for i_block in range(self.num_res_blocks + 1):
728
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
729
+ if len(self.up[i_level].attn) > 0:
730
+ h = self.up[i_level].attn[i_block](h, **kwargs)
731
+ if i_level != 0:
732
+ h = self.up[i_level].upsample(h)
733
+
734
+ # end
735
+ if self.give_pre_end:
736
+ return h
737
+
738
+ h = self.norm_out(h)
739
+ h = nonlinearity(h)
740
+ h = self.conv_out(h, **kwargs)
741
+ if self.tanh_out:
742
+ h = torch.tanh(h)
743
+ return h
sgm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+ from functools import partial
4
+ from typing import Iterable
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ # from einops._torch_specific import allow_ops_in_compiled_graph
11
+ # allow_ops_in_compiled_graph()
12
+ from einops import rearrange
13
+
14
+ from ...modules.attention import SpatialTransformer
15
+ from ...modules.diffusionmodules.util import (
16
+ avg_pool_nd,
17
+ checkpoint,
18
+ conv_nd,
19
+ linear,
20
+ normalization,
21
+ timestep_embedding,
22
+ zero_module,
23
+ )
24
+ from ...util import default, exists
25
+
26
+
27
+ # dummy replace
28
+ def convert_module_to_f16(x):
29
+ pass
30
+
31
+
32
+ def convert_module_to_f32(x):
33
+ pass
34
+
35
+
36
+ ## go
37
+ class AttentionPool2d(nn.Module):
38
+ """
39
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ spacial_dim: int,
45
+ embed_dim: int,
46
+ num_heads_channels: int,
47
+ output_dim: int = None,
48
+ ):
49
+ super().__init__()
50
+ self.positional_embedding = nn.Parameter(
51
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
52
+ )
53
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
54
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
55
+ self.num_heads = embed_dim // num_heads_channels
56
+ self.attention = QKVAttention(self.num_heads)
57
+
58
+ def forward(self, x):
59
+ b, c, *_spatial = x.shape
60
+ x = x.reshape(b, c, -1) # NC(HW)
61
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
62
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
63
+ x = self.qkv_proj(x)
64
+ x = self.attention(x)
65
+ x = self.c_proj(x)
66
+ return x[:, :, 0]
67
+
68
+
69
+ class TimestepBlock(nn.Module):
70
+ """
71
+ Any module where forward() takes timestep embeddings as a second argument.
72
+ """
73
+
74
+ @abstractmethod
75
+ def forward(self, x, emb):
76
+ """
77
+ Apply the module to `x` given `emb` timestep embeddings.
78
+ """
79
+
80
+
81
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
82
+ """
83
+ A sequential module that passes timestep embeddings to the children that
84
+ support it as an extra input.
85
+ """
86
+
87
+ def forward(
88
+ self,
89
+ x,
90
+ emb,
91
+ context=None,
92
+ skip_time_mix=False,
93
+ time_context=None,
94
+ num_video_frames=None,
95
+ time_context_cat=None,
96
+ use_crossframe_attention_in_spatial_layers=False,
97
+ ):
98
+ for layer in self:
99
+ if isinstance(layer, TimestepBlock):
100
+ x = layer(x, emb)
101
+ elif isinstance(layer, SpatialTransformer):
102
+ x = layer(x, context)
103
+ else:
104
+ x = layer(x)
105
+ return x
106
+
107
+
108
+ class Upsample(nn.Module):
109
+ """
110
+ An upsampling layer with an optional convolution.
111
+ :param channels: channels in the inputs and outputs.
112
+ :param use_conv: a bool determining if a convolution is applied.
113
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
114
+ upsampling occurs in the inner-two dimensions.
115
+ """
116
+
117
+ def __init__(
118
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_up=False
119
+ ):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.out_channels = out_channels or channels
123
+ self.use_conv = use_conv
124
+ self.dims = dims
125
+ self.third_up = third_up
126
+ if use_conv:
127
+ self.conv = conv_nd(
128
+ dims, self.channels, self.out_channels, 3, padding=padding
129
+ )
130
+
131
+ def forward(self, x):
132
+ # support fp32 only
133
+ _dtype = x.dtype
134
+ x = x.to(th.float32)
135
+
136
+ assert x.shape[1] == self.channels
137
+ if self.dims == 3:
138
+ t_factor = 1 if not self.third_up else 2
139
+ x = F.interpolate(
140
+ x,
141
+ (t_factor * x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
142
+ mode="nearest",
143
+ )
144
+ else:
145
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
146
+
147
+ x = x.to(_dtype) # support fp32 only
148
+
149
+ if self.use_conv:
150
+ x = self.conv(x)
151
+ return x
152
+
153
+
154
+ class TransposedUpsample(nn.Module):
155
+ "Learned 2x upsampling without padding"
156
+
157
+ def __init__(self, channels, out_channels=None, ks=5):
158
+ super().__init__()
159
+ self.channels = channels
160
+ self.out_channels = out_channels or channels
161
+
162
+ self.up = nn.ConvTranspose2d(
163
+ self.channels, self.out_channels, kernel_size=ks, stride=2
164
+ )
165
+
166
+ def forward(self, x):
167
+ return self.up(x)
168
+
169
+
170
+ class Downsample(nn.Module):
171
+ """
172
+ A downsampling layer with an optional convolution.
173
+ :param channels: channels in the inputs and outputs.
174
+ :param use_conv: a bool determining if a convolution is applied.
175
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
176
+ downsampling occurs in the inner-two dimensions.
177
+ """
178
+
179
+ def __init__(
180
+ self, channels, use_conv, dims=2, out_channels=None, padding=1, third_down=False
181
+ ):
182
+ super().__init__()
183
+ self.channels = channels
184
+ self.out_channels = out_channels or channels
185
+ self.use_conv = use_conv
186
+ self.dims = dims
187
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
188
+ if use_conv:
189
+ print(f"Building a Downsample layer with {dims} dims.")
190
+ print(
191
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
192
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
193
+ )
194
+ if dims == 3:
195
+ print(f" --> Downsampling third axis (time): {third_down}")
196
+ self.op = conv_nd(
197
+ dims,
198
+ self.channels,
199
+ self.out_channels,
200
+ 3,
201
+ stride=stride,
202
+ padding=padding,
203
+ )
204
+ else:
205
+ assert self.channels == self.out_channels
206
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
207
+
208
+ def forward(self, x):
209
+ assert x.shape[1] == self.channels
210
+ return self.op(x)
211
+
212
+
213
+ class ResBlock(TimestepBlock):
214
+ """
215
+ A residual block that can optionally change the number of channels.
216
+ :param channels: the number of input channels.
217
+ :param emb_channels: the number of timestep embedding channels.
218
+ :param dropout: the rate of dropout.
219
+ :param out_channels: if specified, the number of out channels.
220
+ :param use_conv: if True and out_channels is specified, use a spatial
221
+ convolution instead of a smaller 1x1 convolution to change the
222
+ channels in the skip connection.
223
+ :param dims: determines if the signal is 1D, 2D, or 3D.
224
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
225
+ :param up: if True, use this block for upsampling.
226
+ :param down: if True, use this block for downsampling.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ channels,
232
+ emb_channels,
233
+ dropout,
234
+ out_channels=None,
235
+ use_conv=False,
236
+ use_scale_shift_norm=False,
237
+ dims=2,
238
+ use_checkpoint=False,
239
+ up=False,
240
+ down=False,
241
+ kernel_size=3,
242
+ exchange_temb_dims=False,
243
+ skip_t_emb=False,
244
+ ):
245
+ super().__init__()
246
+ self.channels = channels
247
+ self.emb_channels = emb_channels
248
+ self.dropout = dropout
249
+ self.out_channels = out_channels or channels
250
+ self.use_conv = use_conv
251
+ self.use_checkpoint = use_checkpoint
252
+ self.use_scale_shift_norm = use_scale_shift_norm
253
+ self.exchange_temb_dims = exchange_temb_dims
254
+
255
+ if isinstance(kernel_size, Iterable):
256
+ padding = [k // 2 for k in kernel_size]
257
+ else:
258
+ padding = kernel_size // 2
259
+
260
+ self.in_layers = nn.Sequential(
261
+ normalization(channels),
262
+ nn.SiLU(),
263
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
264
+ )
265
+
266
+ self.updown = up or down
267
+
268
+ if up:
269
+ self.h_upd = Upsample(channels, False, dims)
270
+ self.x_upd = Upsample(channels, False, dims)
271
+ elif down:
272
+ self.h_upd = Downsample(channels, False, dims)
273
+ self.x_upd = Downsample(channels, False, dims)
274
+ else:
275
+ self.h_upd = self.x_upd = nn.Identity()
276
+
277
+ self.skip_t_emb = skip_t_emb
278
+ self.emb_out_channels = (
279
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
280
+ )
281
+ if self.skip_t_emb:
282
+ print(f"Skipping timestep embedding in {self.__class__.__name__}")
283
+ assert not self.use_scale_shift_norm
284
+ self.emb_layers = None
285
+ self.exchange_temb_dims = False
286
+ else:
287
+ self.emb_layers = nn.Sequential(
288
+ nn.SiLU(),
289
+ linear(
290
+ emb_channels,
291
+ self.emb_out_channels,
292
+ ),
293
+ )
294
+
295
+ self.out_layers = nn.Sequential(
296
+ normalization(self.out_channels),
297
+ nn.SiLU(),
298
+ nn.Dropout(p=dropout),
299
+ zero_module(
300
+ conv_nd(
301
+ dims,
302
+ self.out_channels,
303
+ self.out_channels,
304
+ kernel_size,
305
+ padding=padding,
306
+ )
307
+ ),
308
+ )
309
+
310
+ if self.out_channels == channels:
311
+ self.skip_connection = nn.Identity()
312
+ elif use_conv:
313
+ self.skip_connection = conv_nd(
314
+ dims, channels, self.out_channels, kernel_size, padding=padding
315
+ )
316
+ else:
317
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
318
+
319
+ def forward(self, x, emb):
320
+ """
321
+ Apply the block to a Tensor, conditioned on a timestep embedding.
322
+ :param x: an [N x C x ...] Tensor of features.
323
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
324
+ :return: an [N x C x ...] Tensor of outputs.
325
+ """
326
+ return checkpoint(
327
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
328
+ )
329
+
330
+ def _forward(self, x, emb):
331
+ if self.updown:
332
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
333
+ h = in_rest(x)
334
+ h = self.h_upd(h)
335
+ x = self.x_upd(x)
336
+ h = in_conv(h)
337
+ else:
338
+ h = self.in_layers(x)
339
+
340
+ if self.skip_t_emb:
341
+ emb_out = th.zeros_like(h)
342
+ else:
343
+ emb_out = self.emb_layers(emb).type(h.dtype)
344
+ while len(emb_out.shape) < len(h.shape):
345
+ emb_out = emb_out[..., None]
346
+ if self.use_scale_shift_norm:
347
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
348
+ scale, shift = th.chunk(emb_out, 2, dim=1)
349
+ h = out_norm(h) * (1 + scale) + shift
350
+ h = out_rest(h)
351
+ else:
352
+ if self.exchange_temb_dims:
353
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
354
+ h = h + emb_out
355
+ h = self.out_layers(h)
356
+ return self.skip_connection(x) + h
357
+
358
+
359
+ class AttentionBlock(nn.Module):
360
+ """
361
+ An attention block that allows spatial positions to attend to each other.
362
+ Originally ported from here, but adapted to the N-d case.
363
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
364
+ """
365
+
366
+ def __init__(
367
+ self,
368
+ channels,
369
+ num_heads=1,
370
+ num_head_channels=-1,
371
+ use_checkpoint=False,
372
+ use_new_attention_order=False,
373
+ ):
374
+ super().__init__()
375
+ self.channels = channels
376
+ if num_head_channels == -1:
377
+ self.num_heads = num_heads
378
+ else:
379
+ assert (
380
+ channels % num_head_channels == 0
381
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
382
+ self.num_heads = channels // num_head_channels
383
+ self.use_checkpoint = use_checkpoint
384
+ self.norm = normalization(channels)
385
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
386
+ if use_new_attention_order:
387
+ # split qkv before split heads
388
+ self.attention = QKVAttention(self.num_heads)
389
+ else:
390
+ # split heads before split qkv
391
+ self.attention = QKVAttentionLegacy(self.num_heads)
392
+
393
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
394
+
395
+ def forward(self, x, **kwargs):
396
+ # TODO add crossframe attention and use mixed checkpoint
397
+ return checkpoint(
398
+ self._forward, (x,), self.parameters(), True
399
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
400
+ # return pt_checkpoint(self._forward, x) # pytorch
401
+
402
+ def _forward(self, x):
403
+ b, c, *spatial = x.shape
404
+ x = x.reshape(b, c, -1)
405
+ qkv = self.qkv(self.norm(x))
406
+ h = self.attention(qkv)
407
+ h = self.proj_out(h)
408
+ return (x + h).reshape(b, c, *spatial)
409
+
410
+
411
+ def count_flops_attn(model, _x, y):
412
+ """
413
+ A counter for the `thop` package to count the operations in an
414
+ attention operation.
415
+ Meant to be used like:
416
+ macs, params = thop.profile(
417
+ model,
418
+ inputs=(inputs, timestamps),
419
+ custom_ops={QKVAttention: QKVAttention.count_flops},
420
+ )
421
+ """
422
+ b, c, *spatial = y[0].shape
423
+ num_spatial = int(np.prod(spatial))
424
+ # We perform two matmuls with the same number of ops.
425
+ # The first computes the weight matrix, the second computes
426
+ # the combination of the value vectors.
427
+ matmul_ops = 2 * b * (num_spatial**2) * c
428
+ model.total_ops += th.DoubleTensor([matmul_ops])
429
+
430
+
431
+ class QKVAttentionLegacy(nn.Module):
432
+ """
433
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
434
+ """
435
+
436
+ def __init__(self, n_heads):
437
+ super().__init__()
438
+ self.n_heads = n_heads
439
+
440
+ def forward(self, qkv):
441
+ """
442
+ Apply QKV attention.
443
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
444
+ :return: an [N x (H * C) x T] tensor after attention.
445
+ """
446
+ bs, width, length = qkv.shape
447
+ assert width % (3 * self.n_heads) == 0
448
+ ch = width // (3 * self.n_heads)
449
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
450
+ scale = 1 / math.sqrt(math.sqrt(ch))
451
+ weight = th.einsum(
452
+ "bct,bcs->bts", q * scale, k * scale
453
+ ) # More stable with f16 than dividing afterwards
454
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
455
+ a = th.einsum("bts,bcs->bct", weight, v)
456
+ return a.reshape(bs, -1, length)
457
+
458
+ @staticmethod
459
+ def count_flops(model, _x, y):
460
+ return count_flops_attn(model, _x, y)
461
+
462
+
463
+ class QKVAttention(nn.Module):
464
+ """
465
+ A module which performs QKV attention and splits in a different order.
466
+ """
467
+
468
+ def __init__(self, n_heads):
469
+ super().__init__()
470
+ self.n_heads = n_heads
471
+
472
+ def forward(self, qkv):
473
+ """
474
+ Apply QKV attention.
475
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
476
+ :return: an [N x (H * C) x T] tensor after attention.
477
+ """
478
+ bs, width, length = qkv.shape
479
+ assert width % (3 * self.n_heads) == 0
480
+ ch = width // (3 * self.n_heads)
481
+ q, k, v = qkv.chunk(3, dim=1)
482
+ scale = 1 / math.sqrt(math.sqrt(ch))
483
+ weight = th.einsum(
484
+ "bct,bcs->bts",
485
+ (q * scale).view(bs * self.n_heads, ch, length),
486
+ (k * scale).view(bs * self.n_heads, ch, length),
487
+ ) # More stable with f16 than dividing afterwards
488
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
489
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
490
+ return a.reshape(bs, -1, length)
491
+
492
+ @staticmethod
493
+ def count_flops(model, _x, y):
494
+ return count_flops_attn(model, _x, y)
495
+
496
+
497
+ class Timestep(nn.Module):
498
+ def __init__(self, dim):
499
+ super().__init__()
500
+ self.dim = dim
501
+
502
+ def forward(self, t):
503
+ return timestep_embedding(t, self.dim)
504
+
505
+
506
+ class UNetModel(nn.Module):
507
+ """
508
+ The full UNet model with attention and timestep embedding.
509
+ :param in_channels: channels in the input Tensor.
510
+ :param model_channels: base channel count for the model.
511
+ :param out_channels: channels in the output Tensor.
512
+ :param num_res_blocks: number of residual blocks per downsample.
513
+ :param attention_resolutions: a collection of downsample rates at which
514
+ attention will take place. May be a set, list, or tuple.
515
+ For example, if this contains 4, then at 4x downsampling, attention
516
+ will be used.
517
+ :param dropout: the dropout probability.
518
+ :param channel_mult: channel multiplier for each level of the UNet.
519
+ :param conv_resample: if True, use learned convolutions for upsampling and
520
+ downsampling.
521
+ :param dims: determines if the signal is 1D, 2D, or 3D.
522
+ :param num_classes: if specified (as an int), then this model will be
523
+ class-conditional with `num_classes` classes.
524
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
525
+ :param num_heads: the number of attention heads in each attention layer.
526
+ :param num_heads_channels: if specified, ignore num_heads and instead use
527
+ a fixed channel width per attention head.
528
+ :param num_heads_upsample: works with num_heads to set a different number
529
+ of heads for upsampling. Deprecated.
530
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
531
+ :param resblock_updown: use residual blocks for up/downsampling.
532
+ :param use_new_attention_order: use a different attention pattern for potentially
533
+ increased efficiency.
534
+ """
535
+
536
+ def __init__(
537
+ self,
538
+ in_channels,
539
+ model_channels,
540
+ out_channels,
541
+ num_res_blocks,
542
+ attention_resolutions,
543
+ dropout=0,
544
+ channel_mult=(1, 2, 4, 8),
545
+ conv_resample=True,
546
+ dims=2,
547
+ num_classes=None,
548
+ use_checkpoint=False,
549
+ use_fp16=False,
550
+ num_heads=-1,
551
+ num_head_channels=-1,
552
+ num_heads_upsample=-1,
553
+ use_scale_shift_norm=False,
554
+ resblock_updown=False,
555
+ use_new_attention_order=False,
556
+ use_spatial_transformer=False, # custom transformer support
557
+ transformer_depth=1, # custom transformer support
558
+ context_dim=None, # custom transformer support
559
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
560
+ legacy=True,
561
+ disable_self_attentions=None,
562
+ num_attention_blocks=None,
563
+ disable_middle_self_attn=False,
564
+ use_linear_in_transformer=False,
565
+ spatial_transformer_attn_type="softmax",
566
+ adm_in_channels=None,
567
+ use_fairscale_checkpoint=False,
568
+ offload_to_cpu=False,
569
+ transformer_depth_middle=None,
570
+ ):
571
+ super().__init__()
572
+ from omegaconf.listconfig import ListConfig
573
+
574
+ if use_spatial_transformer:
575
+ assert (
576
+ context_dim is not None
577
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
578
+
579
+ if context_dim is not None:
580
+ assert (
581
+ use_spatial_transformer
582
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
583
+ if type(context_dim) == ListConfig:
584
+ context_dim = list(context_dim)
585
+
586
+ if num_heads_upsample == -1:
587
+ num_heads_upsample = num_heads
588
+
589
+ if num_heads == -1:
590
+ assert (
591
+ num_head_channels != -1
592
+ ), "Either num_heads or num_head_channels has to be set"
593
+
594
+ if num_head_channels == -1:
595
+ assert (
596
+ num_heads != -1
597
+ ), "Either num_heads or num_head_channels has to be set"
598
+
599
+ self.in_channels = in_channels
600
+ self.model_channels = model_channels
601
+ self.out_channels = out_channels
602
+ if isinstance(transformer_depth, int):
603
+ transformer_depth = len(channel_mult) * [transformer_depth]
604
+ elif isinstance(transformer_depth, ListConfig):
605
+ transformer_depth = list(transformer_depth)
606
+ transformer_depth_middle = default(
607
+ transformer_depth_middle, transformer_depth[-1]
608
+ )
609
+
610
+ if isinstance(num_res_blocks, int):
611
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
612
+ else:
613
+ if len(num_res_blocks) != len(channel_mult):
614
+ raise ValueError(
615
+ "provide num_res_blocks either as an int (globally constant) or "
616
+ "as a list/tuple (per-level) with the same length as channel_mult"
617
+ )
618
+ self.num_res_blocks = num_res_blocks
619
+ # self.num_res_blocks = num_res_blocks
620
+ if disable_self_attentions is not None:
621
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
622
+ assert len(disable_self_attentions) == len(channel_mult)
623
+ if num_attention_blocks is not None:
624
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
625
+ assert all(
626
+ map(
627
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
628
+ range(len(num_attention_blocks)),
629
+ )
630
+ )
631
+ print(
632
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
633
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
634
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
635
+ f"attention will still not be set."
636
+ ) # todo: convert to warning
637
+
638
+ self.attention_resolutions = attention_resolutions
639
+ self.dropout = dropout
640
+ self.channel_mult = channel_mult
641
+ self.conv_resample = conv_resample
642
+ self.num_classes = num_classes
643
+ self.use_checkpoint = use_checkpoint
644
+ if use_fp16:
645
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
646
+ # self.dtype = th.float16 if use_fp16 else th.float32
647
+ self.num_heads = num_heads
648
+ self.num_head_channels = num_head_channels
649
+ self.num_heads_upsample = num_heads_upsample
650
+ self.predict_codebook_ids = n_embed is not None
651
+
652
+ assert use_fairscale_checkpoint != use_checkpoint or not (
653
+ use_checkpoint or use_fairscale_checkpoint
654
+ )
655
+
656
+ self.use_fairscale_checkpoint = False
657
+ checkpoint_wrapper_fn = (
658
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
659
+ if self.use_fairscale_checkpoint
660
+ else lambda x: x
661
+ )
662
+
663
+ time_embed_dim = model_channels * 4
664
+ self.time_embed = checkpoint_wrapper_fn(
665
+ nn.Sequential(
666
+ linear(model_channels, time_embed_dim),
667
+ nn.SiLU(),
668
+ linear(time_embed_dim, time_embed_dim),
669
+ )
670
+ )
671
+
672
+ if self.num_classes is not None:
673
+ if isinstance(self.num_classes, int):
674
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
675
+ elif self.num_classes == "continuous":
676
+ print("setting up linear c_adm embedding layer")
677
+ self.label_emb = nn.Linear(1, time_embed_dim)
678
+ elif self.num_classes == "timestep":
679
+ self.label_emb = checkpoint_wrapper_fn(
680
+ nn.Sequential(
681
+ Timestep(model_channels),
682
+ nn.Sequential(
683
+ linear(model_channels, time_embed_dim),
684
+ nn.SiLU(),
685
+ linear(time_embed_dim, time_embed_dim),
686
+ ),
687
+ )
688
+ )
689
+ elif self.num_classes == "sequential":
690
+ assert adm_in_channels is not None
691
+ self.label_emb = nn.Sequential(
692
+ nn.Sequential(
693
+ linear(adm_in_channels, time_embed_dim),
694
+ nn.SiLU(),
695
+ linear(time_embed_dim, time_embed_dim),
696
+ )
697
+ )
698
+ else:
699
+ raise ValueError()
700
+
701
+ self.input_blocks = nn.ModuleList(
702
+ [
703
+ TimestepEmbedSequential(
704
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
705
+ )
706
+ ]
707
+ )
708
+ self._feature_size = model_channels
709
+ input_block_chans = [model_channels]
710
+ ch = model_channels
711
+ ds = 1
712
+ for level, mult in enumerate(channel_mult):
713
+ for nr in range(self.num_res_blocks[level]):
714
+ layers = [
715
+ checkpoint_wrapper_fn(
716
+ ResBlock(
717
+ ch,
718
+ time_embed_dim,
719
+ dropout,
720
+ out_channels=mult * model_channels,
721
+ dims=dims,
722
+ use_checkpoint=use_checkpoint,
723
+ use_scale_shift_norm=use_scale_shift_norm,
724
+ )
725
+ )
726
+ ]
727
+ ch = mult * model_channels
728
+ if ds in attention_resolutions:
729
+ if num_head_channels == -1:
730
+ dim_head = ch // num_heads
731
+ else:
732
+ num_heads = ch // num_head_channels
733
+ dim_head = num_head_channels
734
+ if legacy:
735
+ # num_heads = 1
736
+ dim_head = (
737
+ ch // num_heads
738
+ if use_spatial_transformer
739
+ else num_head_channels
740
+ )
741
+ if exists(disable_self_attentions):
742
+ disabled_sa = disable_self_attentions[level]
743
+ else:
744
+ disabled_sa = False
745
+
746
+ if (
747
+ not exists(num_attention_blocks)
748
+ or nr < num_attention_blocks[level]
749
+ ):
750
+ layers.append(
751
+ checkpoint_wrapper_fn(
752
+ AttentionBlock(
753
+ ch,
754
+ use_checkpoint=use_checkpoint,
755
+ num_heads=num_heads,
756
+ num_head_channels=dim_head,
757
+ use_new_attention_order=use_new_attention_order,
758
+ )
759
+ )
760
+ if not use_spatial_transformer
761
+ else checkpoint_wrapper_fn(
762
+ SpatialTransformer(
763
+ ch,
764
+ num_heads,
765
+ dim_head,
766
+ depth=transformer_depth[level],
767
+ context_dim=context_dim,
768
+ disable_self_attn=disabled_sa,
769
+ use_linear=use_linear_in_transformer,
770
+ attn_type=spatial_transformer_attn_type,
771
+ use_checkpoint=use_checkpoint,
772
+ )
773
+ )
774
+ )
775
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
776
+ self._feature_size += ch
777
+ input_block_chans.append(ch)
778
+ if level != len(channel_mult) - 1:
779
+ out_ch = ch
780
+ self.input_blocks.append(
781
+ TimestepEmbedSequential(
782
+ checkpoint_wrapper_fn(
783
+ ResBlock(
784
+ ch,
785
+ time_embed_dim,
786
+ dropout,
787
+ out_channels=out_ch,
788
+ dims=dims,
789
+ use_checkpoint=use_checkpoint,
790
+ use_scale_shift_norm=use_scale_shift_norm,
791
+ down=True,
792
+ )
793
+ )
794
+ if resblock_updown
795
+ else Downsample(
796
+ ch, conv_resample, dims=dims, out_channels=out_ch
797
+ )
798
+ )
799
+ )
800
+ ch = out_ch
801
+ input_block_chans.append(ch)
802
+ ds *= 2
803
+ self._feature_size += ch
804
+
805
+ if num_head_channels == -1:
806
+ dim_head = ch // num_heads
807
+ else:
808
+ num_heads = ch // num_head_channels
809
+ dim_head = num_head_channels
810
+ if legacy:
811
+ # num_heads = 1
812
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
813
+ self.middle_block = TimestepEmbedSequential(
814
+ checkpoint_wrapper_fn(
815
+ ResBlock(
816
+ ch,
817
+ time_embed_dim,
818
+ dropout,
819
+ dims=dims,
820
+ use_checkpoint=use_checkpoint,
821
+ use_scale_shift_norm=use_scale_shift_norm,
822
+ )
823
+ ),
824
+ checkpoint_wrapper_fn(
825
+ AttentionBlock(
826
+ ch,
827
+ use_checkpoint=use_checkpoint,
828
+ num_heads=num_heads,
829
+ num_head_channels=dim_head,
830
+ use_new_attention_order=use_new_attention_order,
831
+ )
832
+ )
833
+ if not use_spatial_transformer
834
+ else checkpoint_wrapper_fn(
835
+ SpatialTransformer( # always uses a self-attn
836
+ ch,
837
+ num_heads,
838
+ dim_head,
839
+ depth=transformer_depth_middle,
840
+ context_dim=context_dim,
841
+ disable_self_attn=disable_middle_self_attn,
842
+ use_linear=use_linear_in_transformer,
843
+ attn_type=spatial_transformer_attn_type,
844
+ use_checkpoint=use_checkpoint,
845
+ )
846
+ ),
847
+ checkpoint_wrapper_fn(
848
+ ResBlock(
849
+ ch,
850
+ time_embed_dim,
851
+ dropout,
852
+ dims=dims,
853
+ use_checkpoint=use_checkpoint,
854
+ use_scale_shift_norm=use_scale_shift_norm,
855
+ )
856
+ ),
857
+ )
858
+ self._feature_size += ch
859
+
860
+ self.output_blocks = nn.ModuleList([])
861
+ for level, mult in list(enumerate(channel_mult))[::-1]:
862
+ for i in range(self.num_res_blocks[level] + 1):
863
+ ich = input_block_chans.pop()
864
+ layers = [
865
+ checkpoint_wrapper_fn(
866
+ ResBlock(
867
+ ch + ich,
868
+ time_embed_dim,
869
+ dropout,
870
+ out_channels=model_channels * mult,
871
+ dims=dims,
872
+ use_checkpoint=use_checkpoint,
873
+ use_scale_shift_norm=use_scale_shift_norm,
874
+ )
875
+ )
876
+ ]
877
+ ch = model_channels * mult
878
+ if ds in attention_resolutions:
879
+ if num_head_channels == -1:
880
+ dim_head = ch // num_heads
881
+ else:
882
+ num_heads = ch // num_head_channels
883
+ dim_head = num_head_channels
884
+ if legacy:
885
+ # num_heads = 1
886
+ dim_head = (
887
+ ch // num_heads
888
+ if use_spatial_transformer
889
+ else num_head_channels
890
+ )
891
+ if exists(disable_self_attentions):
892
+ disabled_sa = disable_self_attentions[level]
893
+ else:
894
+ disabled_sa = False
895
+
896
+ if (
897
+ not exists(num_attention_blocks)
898
+ or i < num_attention_blocks[level]
899
+ ):
900
+ layers.append(
901
+ checkpoint_wrapper_fn(
902
+ AttentionBlock(
903
+ ch,
904
+ use_checkpoint=use_checkpoint,
905
+ num_heads=num_heads_upsample,
906
+ num_head_channels=dim_head,
907
+ use_new_attention_order=use_new_attention_order,
908
+ )
909
+ )
910
+ if not use_spatial_transformer
911
+ else checkpoint_wrapper_fn(
912
+ SpatialTransformer(
913
+ ch,
914
+ num_heads,
915
+ dim_head,
916
+ depth=transformer_depth[level],
917
+ context_dim=context_dim,
918
+ disable_self_attn=disabled_sa,
919
+ use_linear=use_linear_in_transformer,
920
+ attn_type=spatial_transformer_attn_type,
921
+ use_checkpoint=use_checkpoint,
922
+ )
923
+ )
924
+ )
925
+ if level and i == self.num_res_blocks[level]:
926
+ out_ch = ch
927
+ layers.append(
928
+ checkpoint_wrapper_fn(
929
+ ResBlock(
930
+ ch,
931
+ time_embed_dim,
932
+ dropout,
933
+ out_channels=out_ch,
934
+ dims=dims,
935
+ use_checkpoint=use_checkpoint,
936
+ use_scale_shift_norm=use_scale_shift_norm,
937
+ up=True,
938
+ )
939
+ )
940
+ if resblock_updown
941
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
942
+ )
943
+ ds //= 2
944
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
945
+ self._feature_size += ch
946
+
947
+ self.out = checkpoint_wrapper_fn(
948
+ nn.Sequential(
949
+ normalization(ch),
950
+ nn.SiLU(),
951
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
952
+ )
953
+ )
954
+ if self.predict_codebook_ids:
955
+ self.id_predictor = checkpoint_wrapper_fn(
956
+ nn.Sequential(
957
+ normalization(ch),
958
+ conv_nd(dims, model_channels, n_embed, 1),
959
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
960
+ )
961
+ )
962
+
963
+ def convert_to_fp16(self):
964
+ """
965
+ Convert the torso of the model to float16.
966
+ """
967
+ self.input_blocks.apply(convert_module_to_f16)
968
+ self.middle_block.apply(convert_module_to_f16)
969
+ self.output_blocks.apply(convert_module_to_f16)
970
+
971
+ def convert_to_fp32(self):
972
+ """
973
+ Convert the torso of the model to float32.
974
+ """
975
+ self.input_blocks.apply(convert_module_to_f32)
976
+ self.middle_block.apply(convert_module_to_f32)
977
+ self.output_blocks.apply(convert_module_to_f32)
978
+
979
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
980
+ """
981
+ Apply the model to an input batch.
982
+ :param x: an [N x C x ...] Tensor of inputs.
983
+ :param timesteps: a 1-D batch of timesteps.
984
+ :param context: conditioning plugged in via crossattn
985
+ :param y: an [N] Tensor of labels, if class-conditional.
986
+ :return: an [N x C x ...] Tensor of outputs.
987
+ """
988
+ assert (y is not None) == (
989
+ self.num_classes is not None
990
+ ), "must specify y if and only if the model is class-conditional"
991
+ hs = []
992
+
993
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
994
+ emb = self.time_embed(t_emb)
995
+
996
+ if self.num_classes is not None:
997
+ assert y.shape[0] == x.shape[0]
998
+ emb = emb + self.label_emb(y)
999
+
1000
+ # h = x.type(self.dtype)
1001
+ h = x
1002
+ for module in self.input_blocks:
1003
+ h = module(h, emb, context)
1004
+ hs.append(h)
1005
+ h = self.middle_block(h, emb, context)
1006
+ for module in self.output_blocks:
1007
+ h = th.cat([h, hs.pop()], dim=1)
1008
+ h = module(h, emb, context)
1009
+ h = h.type(x.dtype)
1010
+ if self.predict_codebook_ids:
1011
+ assert False, "not supported anymore. what the f*** are you doing?"
1012
+ else:
1013
+ return self.out(h)
1014
+
1015
+
1016
+ class NoTimeUNetModel(UNetModel):
1017
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1018
+ timesteps = th.zeros_like(timesteps)
1019
+ return super().forward(x, timesteps, context, y, **kwargs)
1020
+
1021
+
1022
+ class EncoderUNetModel(nn.Module):
1023
+ """
1024
+ The half UNet model with attention and timestep embedding.
1025
+ For usage, see UNet.
1026
+ """
1027
+
1028
+ def __init__(
1029
+ self,
1030
+ image_size,
1031
+ in_channels,
1032
+ model_channels,
1033
+ out_channels,
1034
+ num_res_blocks,
1035
+ attention_resolutions,
1036
+ dropout=0,
1037
+ channel_mult=(1, 2, 4, 8),
1038
+ conv_resample=True,
1039
+ dims=2,
1040
+ use_checkpoint=False,
1041
+ use_fp16=False,
1042
+ num_heads=1,
1043
+ num_head_channels=-1,
1044
+ num_heads_upsample=-1,
1045
+ use_scale_shift_norm=False,
1046
+ resblock_updown=False,
1047
+ use_new_attention_order=False,
1048
+ pool="adaptive",
1049
+ *args,
1050
+ **kwargs,
1051
+ ):
1052
+ super().__init__()
1053
+
1054
+ if num_heads_upsample == -1:
1055
+ num_heads_upsample = num_heads
1056
+
1057
+ self.in_channels = in_channels
1058
+ self.model_channels = model_channels
1059
+ self.out_channels = out_channels
1060
+ self.num_res_blocks = num_res_blocks
1061
+ self.attention_resolutions = attention_resolutions
1062
+ self.dropout = dropout
1063
+ self.channel_mult = channel_mult
1064
+ self.conv_resample = conv_resample
1065
+ self.use_checkpoint = use_checkpoint
1066
+ self.dtype = th.float16 if use_fp16 else th.float32
1067
+ self.num_heads = num_heads
1068
+ self.num_head_channels = num_head_channels
1069
+ self.num_heads_upsample = num_heads_upsample
1070
+
1071
+ time_embed_dim = model_channels * 4
1072
+ self.time_embed = nn.Sequential(
1073
+ linear(model_channels, time_embed_dim),
1074
+ nn.SiLU(),
1075
+ linear(time_embed_dim, time_embed_dim),
1076
+ )
1077
+
1078
+ self.input_blocks = nn.ModuleList(
1079
+ [
1080
+ TimestepEmbedSequential(
1081
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1082
+ )
1083
+ ]
1084
+ )
1085
+ self._feature_size = model_channels
1086
+ input_block_chans = [model_channels]
1087
+ ch = model_channels
1088
+ ds = 1
1089
+ for level, mult in enumerate(channel_mult):
1090
+ for _ in range(num_res_blocks):
1091
+ layers = [
1092
+ ResBlock(
1093
+ ch,
1094
+ time_embed_dim,
1095
+ dropout,
1096
+ out_channels=mult * model_channels,
1097
+ dims=dims,
1098
+ use_checkpoint=use_checkpoint,
1099
+ use_scale_shift_norm=use_scale_shift_norm,
1100
+ )
1101
+ ]
1102
+ ch = mult * model_channels
1103
+ if ds in attention_resolutions:
1104
+ layers.append(
1105
+ AttentionBlock(
1106
+ ch,
1107
+ use_checkpoint=use_checkpoint,
1108
+ num_heads=num_heads,
1109
+ num_head_channels=num_head_channels,
1110
+ use_new_attention_order=use_new_attention_order,
1111
+ )
1112
+ )
1113
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1114
+ self._feature_size += ch
1115
+ input_block_chans.append(ch)
1116
+ if level != len(channel_mult) - 1:
1117
+ out_ch = ch
1118
+ self.input_blocks.append(
1119
+ TimestepEmbedSequential(
1120
+ ResBlock(
1121
+ ch,
1122
+ time_embed_dim,
1123
+ dropout,
1124
+ out_channels=out_ch,
1125
+ dims=dims,
1126
+ use_checkpoint=use_checkpoint,
1127
+ use_scale_shift_norm=use_scale_shift_norm,
1128
+ down=True,
1129
+ )
1130
+ if resblock_updown
1131
+ else Downsample(
1132
+ ch, conv_resample, dims=dims, out_channels=out_ch
1133
+ )
1134
+ )
1135
+ )
1136
+ ch = out_ch
1137
+ input_block_chans.append(ch)
1138
+ ds *= 2
1139
+ self._feature_size += ch
1140
+
1141
+ self.middle_block = TimestepEmbedSequential(
1142
+ ResBlock(
1143
+ ch,
1144
+ time_embed_dim,
1145
+ dropout,
1146
+ dims=dims,
1147
+ use_checkpoint=use_checkpoint,
1148
+ use_scale_shift_norm=use_scale_shift_norm,
1149
+ ),
1150
+ AttentionBlock(
1151
+ ch,
1152
+ use_checkpoint=use_checkpoint,
1153
+ num_heads=num_heads,
1154
+ num_head_channels=num_head_channels,
1155
+ use_new_attention_order=use_new_attention_order,
1156
+ ),
1157
+ ResBlock(
1158
+ ch,
1159
+ time_embed_dim,
1160
+ dropout,
1161
+ dims=dims,
1162
+ use_checkpoint=use_checkpoint,
1163
+ use_scale_shift_norm=use_scale_shift_norm,
1164
+ ),
1165
+ )
1166
+ self._feature_size += ch
1167
+ self.pool = pool
1168
+ if pool == "adaptive":
1169
+ self.out = nn.Sequential(
1170
+ normalization(ch),
1171
+ nn.SiLU(),
1172
+ nn.AdaptiveAvgPool2d((1, 1)),
1173
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1174
+ nn.Flatten(),
1175
+ )
1176
+ elif pool == "attention":
1177
+ assert num_head_channels != -1
1178
+ self.out = nn.Sequential(
1179
+ normalization(ch),
1180
+ nn.SiLU(),
1181
+ AttentionPool2d(
1182
+ (image_size // ds), ch, num_head_channels, out_channels
1183
+ ),
1184
+ )
1185
+ elif pool == "spatial":
1186
+ self.out = nn.Sequential(
1187
+ nn.Linear(self._feature_size, 2048),
1188
+ nn.ReLU(),
1189
+ nn.Linear(2048, self.out_channels),
1190
+ )
1191
+ elif pool == "spatial_v2":
1192
+ self.out = nn.Sequential(
1193
+ nn.Linear(self._feature_size, 2048),
1194
+ normalization(2048),
1195
+ nn.SiLU(),
1196
+ nn.Linear(2048, self.out_channels),
1197
+ )
1198
+ else:
1199
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1200
+
1201
+ def convert_to_fp16(self):
1202
+ """
1203
+ Convert the torso of the model to float16.
1204
+ """
1205
+ self.input_blocks.apply(convert_module_to_f16)
1206
+ self.middle_block.apply(convert_module_to_f16)
1207
+
1208
+ def convert_to_fp32(self):
1209
+ """
1210
+ Convert the torso of the model to float32.
1211
+ """
1212
+ self.input_blocks.apply(convert_module_to_f32)
1213
+ self.middle_block.apply(convert_module_to_f32)
1214
+
1215
+ def forward(self, x, timesteps):
1216
+ """
1217
+ Apply the model to an input batch.
1218
+ :param x: an [N x C x ...] Tensor of inputs.
1219
+ :param timesteps: a 1-D batch of timesteps.
1220
+ :return: an [N x K] Tensor of outputs.
1221
+ """
1222
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1223
+
1224
+ results = []
1225
+ # h = x.type(self.dtype)
1226
+ h = x
1227
+ for module in self.input_blocks:
1228
+ h = module(h, emb)
1229
+ if self.pool.startswith("spatial"):
1230
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1231
+ h = self.middle_block(h, emb)
1232
+ if self.pool.startswith("spatial"):
1233
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1234
+ h = th.cat(results, axis=-1)
1235
+ return self.out(h)
1236
+ else:
1237
+ h = h.type(x.dtype)
1238
+ return self.out(h)
1239
+
1240
+
1241
+ if __name__ == "__main__":
1242
+
1243
+ class Dummy(nn.Module):
1244
+ def __init__(self, in_channels=3, model_channels=64):
1245
+ super().__init__()
1246
+ self.input_blocks = nn.ModuleList(
1247
+ [
1248
+ TimestepEmbedSequential(
1249
+ conv_nd(2, in_channels, model_channels, 3, padding=1)
1250
+ )
1251
+ ]
1252
+ )
1253
+
1254
+ model = UNetModel(
1255
+ use_checkpoint=True,
1256
+ image_size=64,
1257
+ in_channels=4,
1258
+ out_channels=4,
1259
+ model_channels=128,
1260
+ attention_resolutions=[4, 2],
1261
+ num_res_blocks=2,
1262
+ channel_mult=[1, 2, 4],
1263
+ num_head_channels=64,
1264
+ use_spatial_transformer=False,
1265
+ use_linear_in_transformer=True,
1266
+ transformer_depth=1,
1267
+ legacy=False,
1268
+ ).cuda()
1269
+ x = th.randn(11, 4, 64, 64).cuda()
1270
+ t = th.randint(low=0, high=10, size=(11,), device="cuda")
1271
+ o = model(x, t)
1272
+ print("done.")
sgm/modules/diffusionmodules/sampling.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
3
+ """
4
+
5
+
6
+ from typing import Dict, Union
7
+
8
+ import torch
9
+ from omegaconf import ListConfig, OmegaConf
10
+ from tqdm import tqdm
11
+
12
+ from ...modules.diffusionmodules.sampling_utils import (
13
+ get_ancestral_step,
14
+ linear_multistep_coeff,
15
+ to_d,
16
+ to_neg_log_sigma,
17
+ to_sigma,
18
+ )
19
+ from ...util import append_dims, default, instantiate_from_config
20
+ from k_diffusion.sampling import get_sigmas_karras, BrownianTreeNoiseSampler
21
+
22
+ DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
23
+
24
+
25
+ class BaseDiffusionSampler:
26
+ def __init__(
27
+ self,
28
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
29
+ num_steps: Union[int, None] = None,
30
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
31
+ verbose: bool = False,
32
+ device: str = "cuda",
33
+ ):
34
+ self.num_steps = num_steps
35
+ self.discretization = instantiate_from_config(discretization_config)
36
+ self.guider = instantiate_from_config(
37
+ default(
38
+ guider_config,
39
+ DEFAULT_GUIDER,
40
+ )
41
+ )
42
+ self.verbose = verbose
43
+ self.device = device
44
+
45
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
46
+ sigmas = self.discretization(
47
+ self.num_steps if num_steps is None else num_steps, device=self.device
48
+ )
49
+ uc = default(uc, cond)
50
+
51
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
52
+ num_sigmas = len(sigmas)
53
+
54
+ s_in = x.new_ones([x.shape[0]])
55
+
56
+ return x, s_in, sigmas, num_sigmas, cond, uc
57
+
58
+ def denoise(self, x, denoiser, sigma, cond, uc):
59
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
60
+ denoised = self.guider(denoised, sigma)
61
+ return denoised
62
+
63
+ def get_sigma_gen(self, num_sigmas):
64
+ sigma_generator = range(num_sigmas - 1)
65
+ if self.verbose:
66
+ print("#" * 30, " Sampling setting ", "#" * 30)
67
+ print(f"Sampler: {self.__class__.__name__}")
68
+ print(f"Discretization: {self.discretization.__class__.__name__}")
69
+ print(f"Guider: {self.guider.__class__.__name__}")
70
+ sigma_generator = tqdm(
71
+ sigma_generator,
72
+ total=num_sigmas,
73
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
74
+ )
75
+ return sigma_generator
76
+
77
+
78
+ class SingleStepDiffusionSampler(BaseDiffusionSampler):
79
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
80
+ raise NotImplementedError
81
+
82
+ def euler_step(self, x, d, dt):
83
+ return x + dt * d
84
+
85
+
86
+ class EDMSampler(SingleStepDiffusionSampler):
87
+ def __init__(
88
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
89
+ ):
90
+ super().__init__(*args, **kwargs)
91
+
92
+ self.s_churn = s_churn
93
+ self.s_tmin = s_tmin
94
+ self.s_tmax = s_tmax
95
+ self.s_noise = s_noise
96
+
97
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
98
+ sigma_hat = sigma * (gamma + 1.0)
99
+ if gamma > 0:
100
+ eps = torch.randn_like(x) * self.s_noise
101
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
102
+
103
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
104
+ # print('denoised', denoised.mean(axis=[0, 2, 3]))
105
+ d = to_d(x, sigma_hat, denoised)
106
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
107
+
108
+ euler_step = self.euler_step(x, d, dt)
109
+ x = self.possible_correction_step(
110
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
111
+ )
112
+ return x
113
+
114
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
115
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
116
+ x, cond, uc, num_steps
117
+ )
118
+
119
+ for i in self.get_sigma_gen(num_sigmas):
120
+ gamma = (
121
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
122
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
123
+ else 0.0
124
+ )
125
+ x = self.sampler_step(
126
+ s_in * sigmas[i],
127
+ s_in * sigmas[i + 1],
128
+ denoiser,
129
+ x,
130
+ cond,
131
+ uc,
132
+ gamma,
133
+ )
134
+
135
+ return x
136
+
137
+
138
+ class AncestralSampler(SingleStepDiffusionSampler):
139
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
140
+ super().__init__(*args, **kwargs)
141
+
142
+ self.eta = eta
143
+ self.s_noise = s_noise
144
+ self.noise_sampler = lambda x: torch.randn_like(x)
145
+
146
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
147
+ d = to_d(x, sigma, denoised)
148
+ dt = append_dims(sigma_down - sigma, x.ndim)
149
+
150
+ return self.euler_step(x, d, dt)
151
+
152
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
153
+ x = torch.where(
154
+ append_dims(next_sigma, x.ndim) > 0.0,
155
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
156
+ x,
157
+ )
158
+ return x
159
+
160
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
161
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
162
+ x, cond, uc, num_steps
163
+ )
164
+
165
+ for i in self.get_sigma_gen(num_sigmas):
166
+ x = self.sampler_step(
167
+ s_in * sigmas[i],
168
+ s_in * sigmas[i + 1],
169
+ denoiser,
170
+ x,
171
+ cond,
172
+ uc,
173
+ )
174
+
175
+ return x
176
+
177
+
178
+ class LinearMultistepSampler(BaseDiffusionSampler):
179
+ def __init__(
180
+ self,
181
+ order=4,
182
+ *args,
183
+ **kwargs,
184
+ ):
185
+ super().__init__(*args, **kwargs)
186
+
187
+ self.order = order
188
+
189
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
190
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
191
+ x, cond, uc, num_steps
192
+ )
193
+
194
+ ds = []
195
+ sigmas_cpu = sigmas.detach().cpu().numpy()
196
+ for i in self.get_sigma_gen(num_sigmas):
197
+ sigma = s_in * sigmas[i]
198
+ denoised = denoiser(
199
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
200
+ )
201
+ denoised = self.guider(denoised, sigma)
202
+ d = to_d(x, sigma, denoised)
203
+ ds.append(d)
204
+ if len(ds) > self.order:
205
+ ds.pop(0)
206
+ cur_order = min(i + 1, self.order)
207
+ coeffs = [
208
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
209
+ for j in range(cur_order)
210
+ ]
211
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
212
+
213
+ return x
214
+
215
+
216
+ class EulerEDMSampler(EDMSampler):
217
+ def possible_correction_step(
218
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
219
+ ):
220
+ # print("euler_step: ", euler_step.mean(axis=[0, 2, 3]))
221
+ return euler_step
222
+
223
+
224
+ class HeunEDMSampler(EDMSampler):
225
+ def possible_correction_step(
226
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
227
+ ):
228
+ if torch.sum(next_sigma) < 1e-14:
229
+ # Save a network evaluation if all noise levels are 0
230
+ return euler_step
231
+ else:
232
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
233
+ d_new = to_d(euler_step, next_sigma, denoised)
234
+ d_prime = (d + d_new) / 2.0
235
+
236
+ # apply correction if noise level is not 0
237
+ x = torch.where(
238
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
239
+ )
240
+ return x
241
+
242
+
243
+ class EulerAncestralSampler(AncestralSampler):
244
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
245
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
246
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
247
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
248
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
249
+
250
+ return x
251
+
252
+
253
+ class DPMPP2SAncestralSampler(AncestralSampler):
254
+ def get_variables(self, sigma, sigma_down):
255
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
256
+ h = t_next - t
257
+ s = t + 0.5 * h
258
+ return h, s, t, t_next
259
+
260
+ def get_mult(self, h, s, t, t_next):
261
+ mult1 = to_sigma(s) / to_sigma(t)
262
+ mult2 = (-0.5 * h).expm1()
263
+ mult3 = to_sigma(t_next) / to_sigma(t)
264
+ mult4 = (-h).expm1()
265
+
266
+ return mult1, mult2, mult3, mult4
267
+
268
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
269
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
270
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
271
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
272
+
273
+ if torch.sum(sigma_down) < 1e-14:
274
+ # Save a network evaluation if all noise levels are 0
275
+ x = x_euler
276
+ else:
277
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
278
+ mult = [
279
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
280
+ ]
281
+
282
+ x2 = mult[0] * x - mult[1] * denoised
283
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
284
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
285
+
286
+ # apply correction if noise level is not 0
287
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
288
+
289
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
290
+ return x
291
+
292
+
293
+ class DPMPP2MSampler(BaseDiffusionSampler):
294
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
295
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
296
+ h = t_next - t
297
+
298
+ if previous_sigma is not None:
299
+ h_last = t - to_neg_log_sigma(previous_sigma)
300
+ r = h_last / h
301
+ return h, r, t, t_next
302
+ else:
303
+ return h, None, t, t_next
304
+
305
+ def get_mult(self, h, r, t, t_next, previous_sigma):
306
+ mult1 = to_sigma(t_next) / to_sigma(t)
307
+ mult2 = (-h).expm1()
308
+
309
+ if previous_sigma is not None:
310
+ mult3 = 1 + 1 / (2 * r)
311
+ mult4 = 1 / (2 * r)
312
+ return mult1, mult2, mult3, mult4
313
+ else:
314
+ return mult1, mult2
315
+
316
+ def sampler_step(
317
+ self,
318
+ old_denoised,
319
+ previous_sigma,
320
+ sigma,
321
+ next_sigma,
322
+ denoiser,
323
+ x,
324
+ cond,
325
+ uc=None,
326
+ ):
327
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
328
+
329
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
330
+ mult = [
331
+ append_dims(mult, x.ndim)
332
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
333
+ ]
334
+
335
+ x_standard = mult[0] * x - mult[1] * denoised
336
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
337
+ # Save a network evaluation if all noise levels are 0 or on the first step
338
+ return x_standard, denoised
339
+ else:
340
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
341
+ x_advanced = mult[0] * x - mult[1] * denoised_d
342
+
343
+ # apply correction if noise level is not 0 and not first step
344
+ x = torch.where(
345
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
346
+ )
347
+
348
+ return x, denoised
349
+
350
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
351
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
352
+ x, cond, uc, num_steps
353
+ )
354
+
355
+ old_denoised = None
356
+ for i in self.get_sigma_gen(num_sigmas):
357
+ x, old_denoised = self.sampler_step(
358
+ old_denoised,
359
+ None if i == 0 else s_in * sigmas[i - 1],
360
+ s_in * sigmas[i],
361
+ s_in * sigmas[i + 1],
362
+ denoiser,
363
+ x,
364
+ cond,
365
+ uc=uc,
366
+ )
367
+
368
+ return x
369
+
370
+
371
+ class SubstepSampler(EulerAncestralSampler):
372
+ def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, restore_cfg=4.0,
373
+ restore_cfg_s_tmin=0.05, eta=1., n_sample_steps=4, *args, **kwargs):
374
+ super().__init__(*args, **kwargs)
375
+ self.n_sample_steps = n_sample_steps
376
+ self.steps_subset = [0, 100, 200, 300, 1000]
377
+
378
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
379
+ sigmas = self.discretization(1000, device=self.device)
380
+ sigmas = sigmas[
381
+ self.steps_subset[: self.num_steps] + self.steps_subset[-1:]
382
+ ]
383
+ print(sigmas)
384
+ # uc = cond
385
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
386
+ num_sigmas = len(sigmas)
387
+ s_in = x.new_ones([x.shape[0]])
388
+ return x, s_in, sigmas, num_sigmas, cond, uc
389
+
390
+ def denoise(self, x, denoiser, sigma, cond, uc, control_scale=1.0):
391
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), control_scale)
392
+ denoised = self.guider(denoised, sigma)
393
+ return denoised
394
+
395
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, control_scale=1.0, *args, **kwargs):
396
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
397
+ x, cond, uc, num_steps
398
+ )
399
+
400
+ for i in self.get_sigma_gen(num_sigmas):
401
+ x = self.sampler_step(
402
+ s_in * sigmas[i],
403
+ s_in * sigmas[i + 1],
404
+ denoiser,
405
+ x,
406
+ cond,
407
+ uc,
408
+ control_scale=control_scale,
409
+ )
410
+
411
+ return x
412
+
413
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, control_scale=1.0):
414
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
415
+ denoised = self.denoise(x, denoiser, sigma, cond, uc, control_scale=control_scale)
416
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
417
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
418
+
419
+ return x
420
+
421
+
422
+ class RestoreDPMPP2MSampler(DPMPP2MSampler):
423
+ def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, restore_cfg=4.0,
424
+ restore_cfg_s_tmin=0.05, eta=1., *args, **kwargs):
425
+ self.s_noise = s_noise
426
+ self.eta = eta
427
+ super().__init__(*args, **kwargs)
428
+
429
+ def denoise(self, x, denoiser, sigma, cond, uc, control_scale=1.0):
430
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), control_scale)
431
+ denoised = self.guider(denoised, sigma)
432
+ return denoised
433
+
434
+ def get_mult(self, h, r, t, t_next, previous_sigma):
435
+ eta_h = self.eta * h
436
+ mult1 = to_sigma(t_next) / to_sigma(t) * (-eta_h).exp()
437
+ mult2 = (-h -eta_h).expm1()
438
+
439
+ if previous_sigma is not None:
440
+ mult3 = 1 + 1 / (2 * r)
441
+ mult4 = 1 / (2 * r)
442
+ return mult1, mult2, mult3, mult4
443
+ else:
444
+ return mult1, mult2
445
+
446
+
447
+ def sampler_step(
448
+ self,
449
+ old_denoised,
450
+ previous_sigma,
451
+ sigma,
452
+ next_sigma,
453
+ denoiser,
454
+ x,
455
+ cond,
456
+ uc=None,
457
+ eps_noise=None,
458
+ control_scale=1.0,
459
+ ):
460
+ denoised = self.denoise(x, denoiser, sigma, cond, uc, control_scale=control_scale)
461
+
462
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
463
+ eta_h = self.eta * h
464
+ mult = [
465
+ append_dims(mult, x.ndim)
466
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
467
+ ]
468
+
469
+ x_standard = mult[0] * x - mult[1] * denoised
470
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
471
+ # Save a network evaluation if all noise levels are 0 or on the first step
472
+ return x_standard, denoised
473
+ else:
474
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
475
+ x_advanced = mult[0] * x - mult[1] * denoised_d
476
+
477
+ # apply correction if noise level is not 0 and not first step
478
+ x = torch.where(
479
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
480
+ )
481
+ if self.eta:
482
+ x = x + eps_noise * next_sigma * (-2 * eta_h).expm1().neg().sqrt() * self.s_noise
483
+
484
+ return x, denoised
485
+
486
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, control_scale=1.0, **kwargs):
487
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
488
+ x, cond, uc, num_steps
489
+ )
490
+ sigmas_min, sigmas_max = sigmas[-2].cpu(), sigmas[0].cpu()
491
+ sigmas_new = get_sigmas_karras(self.num_steps, sigmas_min, sigmas_max, device=x.device)
492
+ sigmas = sigmas_new
493
+
494
+ noise_sampler = BrownianTreeNoiseSampler(x, sigmas_min, sigmas_max)
495
+
496
+ old_denoised = None
497
+ for i in self.get_sigma_gen(num_sigmas):
498
+ if i > 0 and torch.sum(s_in * sigmas[i + 1]) > 1e-14:
499
+ eps_noise = noise_sampler(s_in * sigmas[i], s_in * sigmas[i + 1])
500
+ else:
501
+ eps_noise = None
502
+ x, old_denoised = self.sampler_step(
503
+ old_denoised,
504
+ None if i == 0 else s_in * sigmas[i - 1],
505
+ s_in * sigmas[i],
506
+ s_in * sigmas[i + 1],
507
+ denoiser,
508
+ x,
509
+ cond,
510
+ uc=uc,
511
+ eps_noise=eps_noise,
512
+ control_scale=control_scale,
513
+ )
514
+
515
+ return x
516
+
517
+
518
+ def to_d_center(denoised, x_center, x):
519
+ b = denoised.shape[0]
520
+ v_center = (denoised - x_center).view(b, -1)
521
+ v_denoise = (x - denoised).view(b, -1)
522
+ d_center = v_center - v_denoise * (v_center * v_denoise).sum(dim=1).view(b, 1) / \
523
+ (v_denoise * v_denoise).sum(dim=1).view(b, 1)
524
+ d_center = d_center / d_center.view(x.shape[0], -1).norm(dim=1).view(-1, 1)
525
+ return d_center.view(denoised.shape)
526
+
527
+
528
+ class RestoreEDMSampler(SingleStepDiffusionSampler):
529
+ def __init__(
530
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, restore_cfg=4.0,
531
+ restore_cfg_s_tmin=0.05, *args, **kwargs
532
+ ):
533
+ super().__init__(*args, **kwargs)
534
+
535
+ self.s_churn = s_churn
536
+ self.s_tmin = s_tmin
537
+ self.s_tmax = s_tmax
538
+ self.s_noise = s_noise
539
+ self.restore_cfg = restore_cfg
540
+ self.restore_cfg_s_tmin = restore_cfg_s_tmin
541
+ self.sigma_max = 14.6146
542
+
543
+ def denoise(self, x, denoiser, sigma, cond, uc, control_scale=1.0):
544
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), control_scale)
545
+ denoised = self.guider(denoised, sigma)
546
+ return denoised
547
+
548
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, x_center=None, eps_noise=None,
549
+ control_scale=1.0, use_linear_control_scale=False, control_scale_start=0.0):
550
+ sigma_hat = sigma * (gamma + 1.0)
551
+ if gamma > 0:
552
+ if eps_noise is not None:
553
+ eps = eps_noise * self.s_noise
554
+ else:
555
+ eps = torch.randn_like(x) * self.s_noise
556
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
557
+
558
+ if use_linear_control_scale:
559
+ control_scale = (sigma[0].item() / self.sigma_max) * (control_scale_start - control_scale) + control_scale
560
+
561
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc, control_scale=control_scale)
562
+
563
+ if (next_sigma[0] > self.restore_cfg_s_tmin) and (self.restore_cfg > 0):
564
+ d_center = (denoised - x_center)
565
+ denoised = denoised - d_center * ((sigma.view(-1, 1, 1, 1) / self.sigma_max) ** self.restore_cfg)
566
+
567
+ d = to_d(x, sigma_hat, denoised)
568
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
569
+ x = self.euler_step(x, d, dt)
570
+ return x
571
+
572
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, x_center=None, control_scale=1.0,
573
+ use_linear_control_scale=False, control_scale_start=0.0):
574
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
575
+ x, cond, uc, num_steps
576
+ )
577
+
578
+ for _idx, i in enumerate(self.get_sigma_gen(num_sigmas)):
579
+ gamma = (
580
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
581
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
582
+ else 0.0
583
+ )
584
+ x = self.sampler_step(
585
+ s_in * sigmas[i],
586
+ s_in * sigmas[i + 1],
587
+ denoiser,
588
+ x,
589
+ cond,
590
+ uc,
591
+ gamma,
592
+ x_center,
593
+ control_scale=control_scale,
594
+ use_linear_control_scale=use_linear_control_scale,
595
+ control_scale_start=control_scale_start,
596
+ )
597
+ return x
598
+
599
+
600
+ class TiledRestoreEDMSampler(RestoreEDMSampler):
601
+ def __init__(self, tile_size=128, tile_stride=64, *args, **kwargs):
602
+ super().__init__(*args, **kwargs)
603
+ self.tile_size = tile_size
604
+ self.tile_stride = tile_stride
605
+ self.tile_weights = gaussian_weights(self.tile_size, self.tile_size, 1)
606
+
607
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, x_center=None, control_scale=1.0,
608
+ use_linear_control_scale=False, control_scale_start=0.0):
609
+ use_local_prompt = isinstance(cond, list)
610
+ b, _, h, w = x.shape
611
+ latent_tiles_iterator = _sliding_windows(h, w, self.tile_size, self.tile_stride)
612
+ tile_weights = self.tile_weights.repeat(b, 1, 1, 1)
613
+ if not use_local_prompt:
614
+ LQ_latent = cond['control']
615
+ else:
616
+ assert len(cond) == len(latent_tiles_iterator), "Number of local prompts should be equal to number of tiles"
617
+ LQ_latent = cond[0]['control']
618
+ clean_LQ_latent = x_center
619
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
620
+ x, cond, uc, num_steps
621
+ )
622
+
623
+ for _idx, i in enumerate(self.get_sigma_gen(num_sigmas)):
624
+ gamma = (
625
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
626
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
627
+ else 0.0
628
+ )
629
+ x_next = torch.zeros_like(x)
630
+ count = torch.zeros_like(x)
631
+ eps_noise = torch.randn_like(x)
632
+ for j, (hi, hi_end, wi, wi_end) in enumerate(latent_tiles_iterator):
633
+ x_tile = x[:, :, hi:hi_end, wi:wi_end]
634
+ _eps_noise = eps_noise[:, :, hi:hi_end, wi:wi_end]
635
+ x_center_tile = clean_LQ_latent[:, :, hi:hi_end, wi:wi_end]
636
+ if use_local_prompt:
637
+ _cond = cond[j]
638
+ else:
639
+ _cond = cond
640
+ _cond['control'] = LQ_latent[:, :, hi:hi_end, wi:wi_end]
641
+ uc['control'] = LQ_latent[:, :, hi:hi_end, wi:wi_end]
642
+ _x = self.sampler_step(
643
+ s_in * sigmas[i],
644
+ s_in * sigmas[i + 1],
645
+ denoiser,
646
+ x_tile,
647
+ _cond,
648
+ uc,
649
+ gamma,
650
+ x_center_tile,
651
+ eps_noise=_eps_noise,
652
+ control_scale=control_scale,
653
+ use_linear_control_scale=use_linear_control_scale,
654
+ control_scale_start=control_scale_start,
655
+ )
656
+ x_next[:, :, hi:hi_end, wi:wi_end] += _x * tile_weights
657
+ count[:, :, hi:hi_end, wi:wi_end] += tile_weights
658
+ x_next /= count
659
+ x = x_next
660
+ return x
661
+
662
+
663
+ class TiledRestoreDPMPP2MSampler(RestoreDPMPP2MSampler):
664
+ def __init__(self, tile_size=128, tile_stride=64, *args, **kwargs):
665
+ super().__init__(*args, **kwargs)
666
+ self.tile_size = tile_size
667
+ self.tile_stride = tile_stride
668
+ self.tile_weights = gaussian_weights(self.tile_size, self.tile_size, 1)
669
+
670
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, control_scale=1.0, **kwargs):
671
+ use_local_prompt = isinstance(cond, list)
672
+ b, _, h, w = x.shape
673
+ latent_tiles_iterator = _sliding_windows(h, w, self.tile_size, self.tile_stride)
674
+ tile_weights = self.tile_weights.repeat(b, 1, 1, 1)
675
+ if not use_local_prompt:
676
+ LQ_latent = cond['control']
677
+ else:
678
+ assert len(cond) == len(latent_tiles_iterator), "Number of local prompts should be equal to number of tiles"
679
+ LQ_latent = cond[0]['control']
680
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
681
+ x, cond, uc, num_steps
682
+ )
683
+ sigmas_min, sigmas_max = sigmas[-2].cpu(), sigmas[0].cpu()
684
+ sigmas_new = get_sigmas_karras(self.num_steps, sigmas_min, sigmas_max, device=x.device)
685
+ sigmas = sigmas_new
686
+
687
+ noise_sampler = BrownianTreeNoiseSampler(x, sigmas_min, sigmas_max)
688
+
689
+ old_denoised = None
690
+ for _idx, i in enumerate(self.get_sigma_gen(num_sigmas)):
691
+ if i > 0 and torch.sum(s_in * sigmas[i + 1]) > 1e-14:
692
+ eps_noise = noise_sampler(s_in * sigmas[i], s_in * sigmas[i + 1])
693
+ else:
694
+ eps_noise = torch.zeros_like(x)
695
+ x_next = torch.zeros_like(x)
696
+ old_denoised_next = torch.zeros_like(x)
697
+ count = torch.zeros_like(x)
698
+ for j, (hi, hi_end, wi, wi_end) in enumerate(latent_tiles_iterator):
699
+ x_tile = x[:, :, hi:hi_end, wi:wi_end]
700
+ _eps_noise = eps_noise[:, :, hi:hi_end, wi:wi_end]
701
+ if old_denoised is not None:
702
+ old_denoised_tile = old_denoised[:, :, hi:hi_end, wi:wi_end]
703
+ else:
704
+ old_denoised_tile = None
705
+ if use_local_prompt:
706
+ _cond = cond[j]
707
+ else:
708
+ _cond = cond
709
+ _cond['control'] = LQ_latent[:, :, hi:hi_end, wi:wi_end]
710
+ uc['control'] = LQ_latent[:, :, hi:hi_end, wi:wi_end]
711
+ _x, _old_denoised = self.sampler_step(
712
+ old_denoised_tile,
713
+ None if i == 0 else s_in * sigmas[i - 1],
714
+ s_in * sigmas[i],
715
+ s_in * sigmas[i + 1],
716
+ denoiser,
717
+ x_tile,
718
+ _cond,
719
+ uc=uc,
720
+ eps_noise=_eps_noise,
721
+ control_scale=control_scale,
722
+ )
723
+ x_next[:, :, hi:hi_end, wi:wi_end] += _x * tile_weights
724
+ old_denoised_next[:, :, hi:hi_end, wi:wi_end] += _old_denoised * tile_weights
725
+ count[:, :, hi:hi_end, wi:wi_end] += tile_weights
726
+ old_denoised_next /= count
727
+ x_next /= count
728
+ x = x_next
729
+ old_denoised = old_denoised_next
730
+ return x
731
+
732
+
733
+ def gaussian_weights(tile_width, tile_height, nbatches):
734
+ """Generates a gaussian mask of weights for tile contributions"""
735
+ from numpy import pi, exp, sqrt
736
+ import numpy as np
737
+
738
+ latent_width = tile_width
739
+ latent_height = tile_height
740
+
741
+ var = 0.01
742
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
743
+ x_probs = [exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var)
744
+ for x in range(latent_width)]
745
+ midpoint = latent_height / 2
746
+ y_probs = [exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var)
747
+ for y in range(latent_height)]
748
+
749
+ weights = np.outer(y_probs, x_probs)
750
+ return torch.tile(torch.tensor(weights, device='cuda'), (nbatches, 4, 1, 1))
751
+
752
+
753
+ def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int):
754
+ hi_list = list(range(0, h - tile_size + 1, tile_stride))
755
+ if (h - tile_size) % tile_stride != 0:
756
+ hi_list.append(h - tile_size)
757
+
758
+ wi_list = list(range(0, w - tile_size + 1, tile_stride))
759
+ if (w - tile_size) % tile_stride != 0:
760
+ wi_list.append(w - tile_size)
761
+
762
+ coords = []
763
+ for hi in hi_list:
764
+ for wi in wi_list:
765
+ coords.append((hi, hi + tile_size, wi, wi + tile_size))
766
+ return coords
sgm/modules/diffusionmodules/sampling_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy import integrate
3
+
4
+ from ...util import append_dims
5
+
6
+
7
+ class NoDynamicThresholding:
8
+ def __call__(self, uncond, cond, scale):
9
+ return uncond + scale.view(-1, 1, 1, 1) * (cond - uncond)
10
+
11
+
12
+ def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13
+ if order - 1 > i:
14
+ raise ValueError(f"Order {order} too high for step {i}")
15
+
16
+ def fn(tau):
17
+ prod = 1.0
18
+ for k in range(order):
19
+ if j == k:
20
+ continue
21
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
22
+ return prod
23
+
24
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
25
+
26
+
27
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
28
+ if not eta:
29
+ return sigma_to, 0.0
30
+ sigma_up = torch.minimum(
31
+ sigma_to,
32
+ eta
33
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
34
+ )
35
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
36
+ return sigma_down, sigma_up
37
+
38
+
39
+ def to_d(x, sigma, denoised):
40
+ return (x - denoised) / append_dims(sigma, x.ndim)
41
+
42
+
43
+ def to_neg_log_sigma(sigma):
44
+ return sigma.log().neg()
45
+
46
+
47
+ def to_sigma(neg_log_sigma):
48
+ return neg_log_sigma.neg().exp()
sgm/modules/diffusionmodules/sigma_sampling.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ...util import default, instantiate_from_config
4
+
5
+
6
+ class EDMSampling:
7
+ def __init__(self, p_mean=-1.2, p_std=1.2):
8
+ self.p_mean = p_mean
9
+ self.p_std = p_std
10
+
11
+ def __call__(self, n_samples, rand=None):
12
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
13
+ return log_sigma.exp()
14
+
15
+
16
+ class DiscreteSampling:
17
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True, idx_range=None):
18
+ self.num_idx = num_idx
19
+ self.sigmas = instantiate_from_config(discretization_config)(
20
+ num_idx, do_append_zero=do_append_zero, flip=flip
21
+ )
22
+ self.idx_range = idx_range
23
+
24
+ def idx_to_sigma(self, idx):
25
+ # print(self.sigmas[idx])
26
+ return self.sigmas[idx]
27
+
28
+ def __call__(self, n_samples, rand=None):
29
+ if self.idx_range is None:
30
+ idx = default(
31
+ rand,
32
+ torch.randint(0, self.num_idx, (n_samples,)),
33
+ )
34
+ else:
35
+ idx = default(
36
+ rand,
37
+ torch.randint(self.idx_range[0], self.idx_range[1], (n_samples,)),
38
+ )
39
+ return self.idx_to_sigma(idx)
40
+
sgm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ adopted from
3
+ https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
4
+ and
5
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
6
+ and
7
+ https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
8
+
9
+ thanks!
10
+ """
11
+
12
+ import math
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import repeat
17
+
18
+
19
+ def make_beta_schedule(
20
+ schedule,
21
+ n_timestep,
22
+ linear_start=1e-4,
23
+ linear_end=2e-2,
24
+ ):
25
+ if schedule == "linear":
26
+ betas = (
27
+ torch.linspace(
28
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
29
+ )
30
+ ** 2
31
+ )
32
+ return betas.numpy()
33
+
34
+
35
+ def extract_into_tensor(a, t, x_shape):
36
+ b, *_ = t.shape
37
+ out = a.gather(-1, t)
38
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
39
+
40
+
41
+ def mixed_checkpoint(func, inputs: dict, params, flag):
42
+ """
43
+ Evaluate a function without caching intermediate activations, allowing for
44
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
45
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
46
+ it also works with non-tensor inputs
47
+ :param func: the function to evaluate.
48
+ :param inputs: the argument dictionary to pass to `func`.
49
+ :param params: a sequence of parameters `func` depends on but does not
50
+ explicitly take as arguments.
51
+ :param flag: if False, disable gradient checkpointing.
52
+ """
53
+ if flag:
54
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
55
+ tensor_inputs = [
56
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
57
+ ]
58
+ non_tensor_keys = [
59
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
60
+ ]
61
+ non_tensor_inputs = [
62
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
63
+ ]
64
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
65
+ return MixedCheckpointFunction.apply(
66
+ func,
67
+ len(tensor_inputs),
68
+ len(non_tensor_inputs),
69
+ tensor_keys,
70
+ non_tensor_keys,
71
+ *args,
72
+ )
73
+ else:
74
+ return func(**inputs)
75
+
76
+
77
+ class MixedCheckpointFunction(torch.autograd.Function):
78
+ @staticmethod
79
+ def forward(
80
+ ctx,
81
+ run_function,
82
+ length_tensors,
83
+ length_non_tensors,
84
+ tensor_keys,
85
+ non_tensor_keys,
86
+ *args,
87
+ ):
88
+ ctx.end_tensors = length_tensors
89
+ ctx.end_non_tensors = length_tensors + length_non_tensors
90
+ ctx.gpu_autocast_kwargs = {
91
+ "enabled": torch.is_autocast_enabled(),
92
+ "dtype": torch.get_autocast_gpu_dtype(),
93
+ "cache_enabled": torch.is_autocast_cache_enabled(),
94
+ }
95
+ assert (
96
+ len(tensor_keys) == length_tensors
97
+ and len(non_tensor_keys) == length_non_tensors
98
+ )
99
+
100
+ ctx.input_tensors = {
101
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
102
+ }
103
+ ctx.input_non_tensors = {
104
+ key: val
105
+ for (key, val) in zip(
106
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
107
+ )
108
+ }
109
+ ctx.run_function = run_function
110
+ ctx.input_params = list(args[ctx.end_non_tensors :])
111
+
112
+ with torch.no_grad():
113
+ output_tensors = ctx.run_function(
114
+ **ctx.input_tensors, **ctx.input_non_tensors
115
+ )
116
+ return output_tensors
117
+
118
+ @staticmethod
119
+ def backward(ctx, *output_grads):
120
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
121
+ ctx.input_tensors = {
122
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
123
+ for key in ctx.input_tensors
124
+ }
125
+
126
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
127
+ # Fixes a bug where the first op in run_function modifies the
128
+ # Tensor storage in place, which is not allowed for detach()'d
129
+ # Tensors.
130
+ shallow_copies = {
131
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
132
+ for key in ctx.input_tensors
133
+ }
134
+ # shallow_copies.update(additional_args)
135
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
136
+ input_grads = torch.autograd.grad(
137
+ output_tensors,
138
+ list(ctx.input_tensors.values()) + ctx.input_params,
139
+ output_grads,
140
+ allow_unused=True,
141
+ )
142
+ del ctx.input_tensors
143
+ del ctx.input_params
144
+ del output_tensors
145
+ return (
146
+ (None, None, None, None, None)
147
+ + input_grads[: ctx.end_tensors]
148
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
149
+ + input_grads[ctx.end_tensors :]
150
+ )
151
+
152
+
153
+ def checkpoint(func, inputs, params, flag):
154
+ """
155
+ Evaluate a function without caching intermediate activations, allowing for
156
+ reduced memory at the expense of extra compute in the backward pass.
157
+ :param func: the function to evaluate.
158
+ :param inputs: the argument sequence to pass to `func`.
159
+ :param params: a sequence of parameters `func` depends on but does not
160
+ explicitly take as arguments.
161
+ :param flag: if False, disable gradient checkpointing.
162
+ """
163
+ if flag:
164
+ args = tuple(inputs) + tuple(params)
165
+ return CheckpointFunction.apply(func, len(inputs), *args)
166
+ else:
167
+ return func(*inputs)
168
+
169
+
170
+ class CheckpointFunction(torch.autograd.Function):
171
+ @staticmethod
172
+ def forward(ctx, run_function, length, *args):
173
+ ctx.run_function = run_function
174
+ ctx.input_tensors = list(args[:length])
175
+ ctx.input_params = list(args[length:])
176
+ ctx.gpu_autocast_kwargs = {
177
+ "enabled": torch.is_autocast_enabled(),
178
+ "dtype": torch.get_autocast_gpu_dtype(),
179
+ "cache_enabled": torch.is_autocast_cache_enabled(),
180
+ }
181
+ with torch.no_grad():
182
+ output_tensors = ctx.run_function(*ctx.input_tensors)
183
+ return output_tensors
184
+
185
+ @staticmethod
186
+ def backward(ctx, *output_grads):
187
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
188
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
189
+ # Fixes a bug where the first op in run_function modifies the
190
+ # Tensor storage in place, which is not allowed for detach()'d
191
+ # Tensors.
192
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
193
+ output_tensors = ctx.run_function(*shallow_copies)
194
+ input_grads = torch.autograd.grad(
195
+ output_tensors,
196
+ ctx.input_tensors + ctx.input_params,
197
+ output_grads,
198
+ allow_unused=True,
199
+ )
200
+ del ctx.input_tensors
201
+ del ctx.input_params
202
+ del output_tensors
203
+ return (None, None) + input_grads
204
+
205
+
206
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
207
+ """
208
+ Create sinusoidal timestep embeddings.
209
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
210
+ These may be fractional.
211
+ :param dim: the dimension of the output.
212
+ :param max_period: controls the minimum frequency of the embeddings.
213
+ :return: an [N x dim] Tensor of positional embeddings.
214
+ """
215
+ if not repeat_only:
216
+ half = dim // 2
217
+ freqs = torch.exp(
218
+ -math.log(max_period)
219
+ * torch.arange(start=0, end=half, dtype=torch.float32)
220
+ / half
221
+ ).to(device=timesteps.device)
222
+ args = timesteps[:, None].float() * freqs[None]
223
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
224
+ if dim % 2:
225
+ embedding = torch.cat(
226
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
227
+ )
228
+ else:
229
+ embedding = repeat(timesteps, "b -> b d", d=dim)
230
+ return embedding
231
+
232
+
233
+ def zero_module(module):
234
+ """
235
+ Zero out the parameters of a module and return it.
236
+ """
237
+ for p in module.parameters():
238
+ p.detach().zero_()
239
+ return module
240
+
241
+
242
+ def scale_module(module, scale):
243
+ """
244
+ Scale the parameters of a module and return it.
245
+ """
246
+ for p in module.parameters():
247
+ p.detach().mul_(scale)
248
+ return module
249
+
250
+
251
+ def mean_flat(tensor):
252
+ """
253
+ Take the mean over all non-batch dimensions.
254
+ """
255
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
256
+
257
+
258
+ def normalization(channels):
259
+ """
260
+ Make a standard normalization layer.
261
+ :param channels: number of input channels.
262
+ :return: an nn.Module for normalization.
263
+ """
264
+ return GroupNorm32(32, channels)
265
+
266
+
267
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
268
+ class SiLU(nn.Module):
269
+ def forward(self, x):
270
+ return x * torch.sigmoid(x)
271
+
272
+
273
+ class GroupNorm32(nn.GroupNorm):
274
+ def forward(self, x):
275
+ # return super().forward(x.float()).type(x.dtype)
276
+ return super().forward(x)
277
+
278
+
279
+ def conv_nd(dims, *args, **kwargs):
280
+ """
281
+ Create a 1D, 2D, or 3D convolution module.
282
+ """
283
+ if dims == 1:
284
+ return nn.Conv1d(*args, **kwargs)
285
+ elif dims == 2:
286
+ return nn.Conv2d(*args, **kwargs)
287
+ elif dims == 3:
288
+ return nn.Conv3d(*args, **kwargs)
289
+ raise ValueError(f"unsupported dimensions: {dims}")
290
+
291
+
292
+ def linear(*args, **kwargs):
293
+ """
294
+ Create a linear module.
295
+ """
296
+ return nn.Linear(*args, **kwargs)
297
+
298
+
299
+ def avg_pool_nd(dims, *args, **kwargs):
300
+ """
301
+ Create a 1D, 2D, or 3D average pooling module.
302
+ """
303
+ if dims == 1:
304
+ return nn.AvgPool1d(*args, **kwargs)
305
+ elif dims == 2:
306
+ return nn.AvgPool2d(*args, **kwargs)
307
+ elif dims == 3:
308
+ return nn.AvgPool3d(*args, **kwargs)
309
+ raise ValueError(f"unsupported dimensions: {dims}")
sgm/modules/diffusionmodules/wrappers.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from packaging import version
4
+ # import torch._dynamo
5
+ # torch._dynamo.config.suppress_errors = True
6
+ # torch._dynamo.config.cache_size_limit = 512
7
+
8
+ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
9
+
10
+
11
+ class IdentityWrapper(nn.Module):
12
+ def __init__(self, diffusion_model, compile_model: bool = False):
13
+ super().__init__()
14
+ compile = (
15
+ torch.compile
16
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
17
+ and compile_model
18
+ else lambda x: x
19
+ )
20
+ self.diffusion_model = compile(diffusion_model)
21
+
22
+ def forward(self, *args, **kwargs):
23
+ return self.diffusion_model(*args, **kwargs)
24
+
25
+
26
+ class OpenAIWrapper(IdentityWrapper):
27
+ def forward(
28
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
29
+ ) -> torch.Tensor:
30
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
31
+ return self.diffusion_model(
32
+ x,
33
+ timesteps=t,
34
+ context=c.get("crossattn", None),
35
+ y=c.get("vector", None),
36
+ **kwargs,
37
+ )
38
+
39
+
40
+ class OpenAIHalfWrapper(IdentityWrapper):
41
+ def __init__(self, *args, **kwargs):
42
+ super().__init__(*args, **kwargs)
43
+ self.diffusion_model = self.diffusion_model.half()
44
+
45
+ def forward(
46
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
47
+ ) -> torch.Tensor:
48
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
49
+ _context = c.get("crossattn", None)
50
+ _y = c.get("vector", None)
51
+ if _context is not None:
52
+ _context = _context.half()
53
+ if _y is not None:
54
+ _y = _y.half()
55
+ x = x.half()
56
+ t = t.half()
57
+
58
+ out = self.diffusion_model(
59
+ x,
60
+ timesteps=t,
61
+ context=_context,
62
+ y=_y,
63
+ **kwargs,
64
+ )
65
+ return out.float()
66
+
67
+
68
+ class ControlWrapper(nn.Module):
69
+ def __init__(self, diffusion_model, compile_model: bool = False, dtype=torch.float32):
70
+ super().__init__()
71
+ self.compile = (
72
+ torch.compile
73
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
74
+ and compile_model
75
+ else lambda x: x
76
+ )
77
+ self.diffusion_model = self.compile(diffusion_model)
78
+ self.control_model = None
79
+ self.dtype = dtype
80
+
81
+ def load_control_model(self, control_model):
82
+ self.control_model = self.compile(control_model)
83
+
84
+ def forward(
85
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, control_scale=1, **kwargs
86
+ ) -> torch.Tensor:
87
+ with torch.autocast("cuda", dtype=self.dtype):
88
+ control = self.control_model(x=c.get("control", None), timesteps=t, xt=x,
89
+ control_vector=c.get("control_vector", None),
90
+ mask_x=c.get("mask_x", None),
91
+ context=c.get("crossattn", None),
92
+ y=c.get("vector", None))
93
+ out = self.diffusion_model(
94
+ x,
95
+ timesteps=t,
96
+ context=c.get("crossattn", None),
97
+ y=c.get("vector", None),
98
+ control=control,
99
+ control_scale=control_scale,
100
+ **kwargs,
101
+ )
102
+ return out.float()
103
+