Disty0 commited on
Commit
558dd79
1 Parent(s): ba740ca

Upload stable_cascade.py

Browse files
Files changed (1) hide show
  1. dataset/stable_cascade.py +1789 -0
dataset/stable_cascade.py ADDED
@@ -0,0 +1,1789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # コードは Stable Cascade からコピーし、一部修正しています。元ライセンスは MIT です。
2
+ # The code is copied from Stable Cascade and modified. The original license is MIT.
3
+ # https://github.com/Stability-AI/StableCascade
4
+
5
+ import math
6
+ from types import SimpleNamespace
7
+ from typing import List, Optional
8
+ from einops import rearrange
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+ import torchvision
14
+
15
+ # Put this .py file into sd-scripts/library and run the training.
16
+ # It will run 1 step and FP16 fix the model after.
17
+
18
+ fp16_fix_save_path = "/mnt/DataSSD/AI/SoteDiffusion/Wuerstchen3"
19
+
20
+ MODEL_VERSION_STABLE_CASCADE = "stable_cascade"
21
+
22
+ EFFNET_PREPROCESS = torchvision.transforms.Compose(
23
+ [torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]
24
+ )
25
+
26
+ def check_scale(tensor):
27
+ return torch.mean(torch.abs(tensor))
28
+
29
+ def convert_state_dict_normal_attn_to_mha(state_dict):
30
+ # convert to_q/k/v and out_proj to nn.MultiheadAttention
31
+ for key in list(state_dict.keys()):
32
+ if "attention.attn." in key:
33
+ if "to_q.bias" in key:
34
+ q = state_dict.pop(key)
35
+ k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
36
+ v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
37
+ state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
38
+ elif "to_q.weight" in key:
39
+ q = state_dict.pop(key)
40
+ k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
41
+ v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
42
+ state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
43
+ elif "out_proj.bias" in key:
44
+ v = state_dict.pop(key)
45
+ state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
46
+ elif "out_proj.weight" in key:
47
+ v = state_dict.pop(key)
48
+ state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
49
+ return state_dict
50
+
51
+
52
+ # region VectorQuantize
53
+
54
+ # from torchtools https://github.com/pabloppp/pytorch-tools
55
+ # 依存ライブラリを増やしたくないのでここにコピペ
56
+
57
+
58
+ class vector_quantize(torch.autograd.Function):
59
+ @staticmethod
60
+ def forward(ctx, x, codebook):
61
+ with torch.no_grad():
62
+ codebook_sqr = torch.sum(codebook**2, dim=1)
63
+ x_sqr = torch.sum(x**2, dim=1, keepdim=True)
64
+
65
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
66
+ _, indices = dist.min(dim=1)
67
+
68
+ ctx.save_for_backward(indices, codebook)
69
+ ctx.mark_non_differentiable(indices)
70
+
71
+ nn = torch.index_select(codebook, 0, indices)
72
+ return nn, indices
73
+
74
+ @staticmethod
75
+ def backward(ctx, grad_output, grad_indices):
76
+ grad_inputs, grad_codebook = None, None
77
+
78
+ if ctx.needs_input_grad[0]:
79
+ grad_inputs = grad_output.clone()
80
+ if ctx.needs_input_grad[1]:
81
+ # Gradient wrt. the codebook
82
+ indices, codebook = ctx.saved_tensors
83
+
84
+ grad_codebook = torch.zeros_like(codebook)
85
+ grad_codebook.index_add_(0, indices, grad_output)
86
+
87
+ return (grad_inputs, grad_codebook)
88
+
89
+
90
+ class VectorQuantize(nn.Module):
91
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
92
+ """
93
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
94
+ Returns one tensor containing the nearest neighbour embeddings to each of the inputs,
95
+ with the same size as the input, vq and commitment components for the loss as a tuple
96
+ in the second output and the indices of the quantized vectors in the third:
97
+ quantized, (vq_loss, commit_loss), indices
98
+ """
99
+ super(VectorQuantize, self).__init__()
100
+
101
+ self.codebook = nn.Embedding(k, embedding_size)
102
+ self.codebook.weight.data.uniform_(-1.0 / k, 1.0 / k)
103
+ self.vq = vector_quantize.apply
104
+
105
+ self.ema_decay = ema_decay
106
+ self.ema_loss = ema_loss
107
+ if ema_loss:
108
+ self.register_buffer("ema_element_count", torch.ones(k))
109
+ self.register_buffer("ema_weight_sum", torch.zeros_like(self.codebook.weight))
110
+
111
+ def _laplace_smoothing(self, x, epsilon):
112
+ n = torch.sum(x)
113
+ return (x + epsilon) / (n + x.size(0) * epsilon) * n
114
+
115
+ def _updateEMA(self, z_e_x, indices):
116
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
117
+ elem_count = mask.sum(dim=0)
118
+ weight_sum = torch.mm(mask.t(), z_e_x)
119
+
120
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count)
121
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
122
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum)
123
+
124
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
125
+
126
+ def idx2vq(self, idx, dim=-1):
127
+ q_idx = self.codebook(idx)
128
+ if dim != -1:
129
+ q_idx = q_idx.movedim(-1, dim)
130
+ return q_idx
131
+
132
+ def forward(self, x, get_losses=True, dim=-1):
133
+ if dim != -1:
134
+ x = x.movedim(dim, -1)
135
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
136
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
137
+ vq_loss, commit_loss = None, None
138
+ if self.ema_loss and self.training:
139
+ self._updateEMA(z_e_x.detach(), indices.detach())
140
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
141
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
142
+ if get_losses:
143
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
144
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
145
+
146
+ z_q_x = z_q_x.view(x.shape)
147
+ if dim != -1:
148
+ z_q_x = z_q_x.movedim(-1, dim)
149
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
150
+
151
+
152
+ # endregion
153
+
154
+
155
+ class EfficientNetEncoder(nn.Module):
156
+ def __init__(self, c_latent=16):
157
+ super().__init__()
158
+ self.backbone = torchvision.models.efficientnet_v2_s(weights="DEFAULT").features.eval()
159
+ self.mapper = nn.Sequential(
160
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
161
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
162
+ )
163
+
164
+ def forward(self, x):
165
+ return self.mapper(self.backbone(x))
166
+
167
+ @property
168
+ def dtype(self) -> torch.dtype:
169
+ return next(self.parameters()).dtype
170
+
171
+ @property
172
+ def device(self) -> torch.device:
173
+ return next(self.parameters()).device
174
+
175
+ def encode(self, x):
176
+ """
177
+ VAE と同じように使えるようにするためのメソッド。正しくはちゃんと呼び出し側で分けるべきだが、暫定的な対応。
178
+ The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
179
+ """
180
+ # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
181
+
182
+ # x is -1 to 1, so we need to convert it to 0 to 1, and then preprocess it with EfficientNet's preprocessing.
183
+ x = (x + 1) / 2
184
+ x = EFFNET_PREPROCESS(x)
185
+
186
+ x = self(x)
187
+ return SimpleNamespace(latent_dist=SimpleNamespace(sample=lambda: x))
188
+
189
+
190
+ # なんかわりと乱暴な実装(;'∀')
191
+ # 一から学習することもないだろうから、無効化しておく
192
+
193
+ # class Linear(torch.nn.Linear):
194
+ # def reset_parameters(self):
195
+ # return None
196
+
197
+ # class Conv2d(torch.nn.Conv2d):
198
+ # def reset_parameters(self):
199
+ # return None
200
+
201
+ from torch.nn import Conv2d
202
+ from torch.nn import Linear
203
+
204
+
205
+ r"""
206
+ class Attention2D(nn.Module):
207
+ def __init__(self, c, nhead, dropout=0.0):
208
+ super().__init__()
209
+ self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
210
+
211
+ def forward(self, x, kv, self_attn=False):
212
+ orig_shape = x.shape
213
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
214
+ if self_attn:
215
+ kv = torch.cat([x, kv], dim=1)
216
+ x = self.attn(x, kv, kv, need_weights=False)[0]
217
+ x = x.permute(0, 2, 1).view(*orig_shape)
218
+ return x
219
+ """
220
+
221
+
222
+ class Attention(nn.Module):
223
+ def __init__(self, c, nhead, dropout=0.0):
224
+ # dropout is for attn_output_weights, so we may not need it. however, if we use sdpa, we enable it.
225
+ # xformers and normal attn are not affected by dropout
226
+ super().__init__()
227
+
228
+ self.to_q = Linear(c, c, bias=True)
229
+ self.to_k = Linear(c, c, bias=True)
230
+ self.to_v = Linear(c, c, bias=True)
231
+ self.out_proj = Linear(c, c, bias=True)
232
+ self.nhead = nhead
233
+ self.dropout = dropout
234
+ self.scale = (c // nhead) ** -0.5
235
+
236
+ # default is to use sdpa
237
+ self.use_memory_efficient_attention_xformers = False
238
+ self.use_sdpa = True
239
+
240
+ def set_use_xformers_or_sdpa(self, xformers, sdpa):
241
+ # print(f"Attention: set_use_xformers_or_sdpa: xformers={xformers}, sdpa={sdpa}")
242
+ self.use_memory_efficient_attention_xformers = xformers
243
+ self.use_sdpa = sdpa
244
+
245
+ def forward(self, q_in, k_in, v_in):
246
+ q_in = self.to_q(q_in)
247
+ k_in = self.to_k(k_in)
248
+ v_in = self.to_v(v_in)
249
+
250
+ if self.use_memory_efficient_attention_xformers:
251
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.nhead), (q_in, k_in, v_in))
252
+ del q_in, k_in, v_in
253
+ out = self.forward_memory_efficient_xformers(q, k, v)
254
+ del q, k, v
255
+ out = rearrange(out, "b n h d -> b n (h d)", h=self.nhead)
256
+ elif self.use_sdpa:
257
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.nhead), (q_in, k_in, v_in))
258
+ del q_in, k_in, v_in
259
+ out = self.forward_sdpa(q, k, v)
260
+ del q, k, v
261
+ out = rearrange(out, "b h n d -> b n (h d)", h=self.nhead)
262
+ else:
263
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.nhead), (q_in, k_in, v_in))
264
+ del q_in, k_in, v_in
265
+ out = self._attention(q, k, v)
266
+ del q, k, v
267
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=self.nhead)
268
+
269
+ return self.out_proj(out)
270
+
271
+ def _attention(self, query, key, value):
272
+ # if self.upcast_attention:
273
+ # query = query.float()
274
+ # key = key.float()
275
+
276
+ attention_scores = torch.baddbmm(
277
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
278
+ query,
279
+ key.transpose(-1, -2),
280
+ beta=0,
281
+ alpha=self.scale,
282
+ )
283
+ attention_probs = attention_scores.softmax(dim=-1)
284
+
285
+ # cast back to the original dtype
286
+ attention_probs = attention_probs.to(value.dtype)
287
+
288
+ # compute attention output
289
+ hidden_states = torch.bmm(attention_probs, value)
290
+
291
+ return hidden_states
292
+
293
+ def forward_memory_efficient_xformers(self, q, k, v):
294
+ import xformers.ops
295
+
296
+ q = q.contiguous()
297
+ k = k.contiguous()
298
+ v = v.contiguous()
299
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
300
+ del q, k, v
301
+
302
+ return out
303
+
304
+ def forward_sdpa(self, q, k, v):
305
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=False)
306
+ return out
307
+
308
+
309
+ class Attention2D(nn.Module):
310
+ r"""
311
+ to_q/k/v を個別に重みをもつように変更
312
+ modified to have separate weights for to_q/k/v
313
+ """
314
+
315
+ def __init__(self, c, nhead, dropout=0.0):
316
+ super().__init__()
317
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
318
+ self.attn = Attention(c, nhead, dropout=dropout) # , bias=True, batch_first=True)
319
+
320
+ def forward(self, x, kv, self_attn=False):
321
+ orig_shape = x.shape
322
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
323
+ if self_attn:
324
+ kv = torch.cat([x, kv], dim=1)
325
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
326
+ x = self.attn(x, kv, kv)
327
+ x = x.permute(0, 2, 1).view(*orig_shape)
328
+ return x
329
+
330
+ def set_use_xformers_or_sdpa(self, xformers, sdpa):
331
+ self.attn.set_use_xformers_or_sdpa(xformers, sdpa)
332
+
333
+
334
+ class LayerNorm2d(nn.LayerNorm):
335
+ def __init__(self, *args, **kwargs):
336
+ super().__init__(*args, **kwargs)
337
+
338
+ def forward(self, x):
339
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
340
+
341
+
342
+ class GlobalResponseNorm(nn.Module):
343
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
344
+
345
+ def __init__(self, dim):
346
+ super().__init__()
347
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
348
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
349
+
350
+ def forward(self, x):
351
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
352
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
353
+ return self.gamma * (x * Nx) + self.beta + x
354
+
355
+
356
+ class ResBlock(nn.Module):
357
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2):
358
+ super().__init__()
359
+ self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
360
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
361
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
362
+ self.channelwise = nn.Sequential(
363
+ Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
364
+ )
365
+
366
+ self.gradient_checkpointing = False
367
+ self.factor = 1
368
+
369
+ def set_factor(self, k):
370
+ if self.factor!=1:
371
+ return
372
+ self.factor = k
373
+ self.depthwise.bias.data /= k
374
+ self.channelwise[4].weight.data /= k
375
+ self.channelwise[4].bias.data /= k
376
+
377
+ def set_gradient_checkpointing(self, value):
378
+ self.gradient_checkpointing = value
379
+
380
+ def forward_body(self, x, x_skip=None):
381
+ x_res = x
382
+ x = x /self.factor
383
+ x = self.depthwise(x)
384
+ x = self.norm(x)
385
+ if torch.any(torch.isnan(x)):
386
+ print("nan in first norm")
387
+ if x_skip is not None:
388
+ x = torch.cat([x, x_skip], dim=1)
389
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * self.factor
390
+ if torch.any(torch.isnan(x)):
391
+ print("nan in second norm")
392
+ result = x + x_res
393
+ if check_scale(x) > 5:
394
+ self.scale = 0.1
395
+ return x+ x_res
396
+
397
+ def forward(self, x, x_skip=None):
398
+ if self.factor > 1:
399
+ print("ResBlock: factor > 1")
400
+ if self.training and self.gradient_checkpointing:
401
+ # logger.info("ResnetBlock2D: gradient_checkpointing")
402
+
403
+ def create_custom_forward(func):
404
+ def custom_forward(*inputs):
405
+ return func(*inputs)
406
+
407
+ return custom_forward
408
+
409
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, x_skip)
410
+ else:
411
+ x = self.forward_body(x, x_skip)
412
+
413
+ return x
414
+
415
+
416
+ class AttnBlock(nn.Module):
417
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
418
+ super().__init__()
419
+ self.self_attn = self_attn
420
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
421
+ self.attention = Attention2D(c, nhead, dropout)
422
+ self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c))
423
+
424
+ self.gradient_checkpointing = False
425
+ self.factor = 1
426
+
427
+ def set_factor(self, k):
428
+ if self.factor!=1:
429
+ return
430
+ self.factor = k
431
+ self.attention.attn.out_proj.weight.data /= k
432
+ if self.attention.attn.out_proj.bias is not None:
433
+ self.attention.attn.out_proj.bias.data /= k
434
+
435
+ def set_gradient_checkpointing(self, value):
436
+ self.gradient_checkpointing = value
437
+
438
+ def set_use_xformers_or_sdpa(self, xformers, sdpa):
439
+ self.attention.set_use_xformers_or_sdpa(xformers, sdpa)
440
+
441
+ def forward_body(self, x, kv):
442
+ kv = self.kv_mapper(kv)
443
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) * self.factor
444
+ return x
445
+
446
+ def forward(self, x, kv):
447
+ if self.factor > 1:
448
+ print("AttnBlock: factor > 1")
449
+ if self.training and self.gradient_checkpointing:
450
+ # logger.info("AttnBlock: gradient_checkpointing")
451
+
452
+ def create_custom_forward(func):
453
+ def custom_forward(*inputs):
454
+ return func(*inputs)
455
+
456
+ return custom_forward
457
+
458
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, kv)
459
+ else:
460
+ x = self.forward_body(x, kv)
461
+
462
+ return x
463
+
464
+
465
+ class FeedForwardBlock(nn.Module):
466
+ def __init__(self, c, dropout=0.0):
467
+ super().__init__()
468
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
469
+ self.channelwise = nn.Sequential(
470
+ Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
471
+ )
472
+
473
+ self.gradient_checkpointing = False
474
+
475
+ def set_gradient_checkpointing(self, value):
476
+ self.gradient_checkpointing = value
477
+
478
+ def forward_body(self, x):
479
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
480
+ return x
481
+
482
+ def forward(self, x):
483
+ if self.training and self.gradient_checkpointing:
484
+ # logger.info("FeedForwardBlock: gradient_checkpointing")
485
+
486
+ def create_custom_forward(func):
487
+ def custom_forward(*inputs):
488
+ return func(*inputs)
489
+
490
+ return custom_forward
491
+
492
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
493
+ else:
494
+ x = self.forward_body(x)
495
+
496
+ return x
497
+
498
+
499
+ class TimestepBlock(nn.Module):
500
+ def __init__(self, c, c_timestep, conds=["sca"]):
501
+ super().__init__()
502
+ self.mapper = Linear(c_timestep, c * 2)
503
+ self.conds = conds
504
+ for cname in conds:
505
+ setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
506
+ self.factor = 1
507
+
508
+ def set_factor(self, k, ext_k):
509
+ if self.factor!=1:
510
+ return
511
+ print(f"TimestepBlock: factor = {k}, ext_k = {ext_k}")
512
+ self.factor = k
513
+ k_factor = k/ext_k
514
+ a_weight_factor = 1/k_factor
515
+ b_weight_factor = 1/k
516
+ a_bias_offset = - ((k_factor - 1)/(k_factor))/(len(self.conds) + 1)
517
+
518
+ for module in [self.mapper, *(getattr(self, f"mapper_{cname}") for cname in self.conds)]:
519
+ a_bias, b_bias = module.bias.data.chunk(2, dim=0)
520
+ a_weight, b_weight = module.weight.data.chunk(2, dim=0)
521
+ module.weight.data.copy_(
522
+ torch.concat([
523
+ a_weight * a_weight_factor,
524
+ b_weight * b_weight_factor
525
+ ])
526
+ )
527
+ module.bias.data.copy_(
528
+ torch.concat([
529
+ a_bias * a_weight_factor + a_bias_offset,
530
+ b_bias * b_weight_factor
531
+ ])
532
+ )
533
+
534
+ def forward(self, x, t):
535
+ if self.factor > 1:
536
+ print("TimestepBlock: factor > 1")
537
+ t = t.chunk(len(self.conds) + 1, dim=1)
538
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
539
+ for i, c in enumerate(self.conds):
540
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
541
+ a, b = a + ac, b + bc
542
+ return (x * (1 + a) + b) * self.factor
543
+
544
+
545
+ class UpDownBlock2d(nn.Module):
546
+ def __init__(self, c_in, c_out, mode, enabled=True):
547
+ super().__init__()
548
+ assert mode in ["up", "down"]
549
+ interpolation = (
550
+ nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity()
551
+ )
552
+ mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
553
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
554
+
555
+ self.mode = mode
556
+
557
+ self.gradient_checkpointing = False
558
+
559
+ def set_gradient_checkpointing(self, value):
560
+ self.gradient_checkpointing = value
561
+
562
+ def forward_body(self, x):
563
+ org_dtype = x.dtype
564
+ for i, block in enumerate(self.blocks):
565
+ # 公式の実装では、常に float で計算しているが、すこしでもメモリを節約するために bfloat16 + Upsample のみ float に変換する
566
+ # In the official implementation, it always calculates in float, but for the sake of saving memory, it converts to float only for bfloat16 + Upsample
567
+ if x.dtype == torch.bfloat16 and (self.mode == "up" and i == 0 or self.mode != "up" and i == 1):
568
+ x = x.float()
569
+ x = block(x)
570
+ x = x.to(org_dtype)
571
+ return x
572
+
573
+ def forward(self, x):
574
+ if self.training and self.gradient_checkpointing:
575
+ # logger.info("UpDownBlock2d: gradient_checkpointing")
576
+
577
+ def create_custom_forward(func):
578
+ def custom_forward(*inputs):
579
+ return func(*inputs)
580
+
581
+ return custom_forward
582
+
583
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
584
+ else:
585
+ x = self.forward_body(x)
586
+
587
+ return x
588
+
589
+
590
+ class StageAResBlock(nn.Module):
591
+ def __init__(self, c, c_hidden):
592
+ super().__init__()
593
+ # depthwise/attention
594
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
595
+ self.depthwise = nn.Sequential(nn.ReplicationPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c))
596
+
597
+ # channelwise
598
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
599
+ self.channelwise = nn.Sequential(
600
+ nn.Linear(c, c_hidden),
601
+ nn.GELU(),
602
+ nn.Linear(c_hidden, c),
603
+ )
604
+
605
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
606
+
607
+ # Init weights
608
+ def _basic_init(module):
609
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
610
+ torch.nn.init.xavier_uniform_(module.weight)
611
+ if module.bias is not None:
612
+ nn.init.constant_(module.bias, 0)
613
+
614
+ self.apply(_basic_init)
615
+
616
+ def _norm(self, x, norm):
617
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
618
+
619
+ def forward(self, x):
620
+ mods = self.gammas
621
+
622
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
623
+ x = x + self.depthwise(x_temp) * mods[2]
624
+
625
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
626
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
627
+
628
+ return x
629
+
630
+
631
+ class StageA(nn.Module):
632
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.43): # 0.3764
633
+ super().__init__()
634
+ self.c_latent = c_latent
635
+ self.scale_factor = scale_factor
636
+ c_levels = [c_hidden // (2**i) for i in reversed(range(levels))]
637
+
638
+ # Encoder blocks
639
+ self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1))
640
+ down_blocks = []
641
+ for i in range(levels):
642
+ if i > 0:
643
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
644
+ block = StageAResBlock(c_levels[i], c_levels[i] * 4)
645
+ down_blocks.append(block)
646
+ down_blocks.append(
647
+ nn.Sequential(
648
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
649
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
650
+ )
651
+ )
652
+ self.down_blocks = nn.Sequential(*down_blocks)
653
+ self.down_blocks[0]
654
+
655
+ self.codebook_size = codebook_size
656
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
657
+
658
+ # Decoder blocks
659
+ up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))]
660
+ for i in range(levels):
661
+ for j in range(bottleneck_blocks if i == 0 else 1):
662
+ block = StageAResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
663
+ up_blocks.append(block)
664
+ if i < levels - 1:
665
+ up_blocks.append(
666
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)
667
+ )
668
+ self.up_blocks = nn.Sequential(*up_blocks)
669
+ self.out_block = nn.Sequential(
670
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
671
+ nn.PixelShuffle(2),
672
+ )
673
+
674
+ def encode(self, x, quantize=False):
675
+ x = self.in_block(x)
676
+ x = self.down_blocks(x)
677
+ if quantize:
678
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
679
+ return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
680
+ else:
681
+ return x / self.scale_factor, None, None, None
682
+
683
+ def decode(self, x):
684
+ x = x * self.scale_factor
685
+ x = self.up_blocks(x)
686
+ x = self.out_block(x)
687
+ return x
688
+
689
+ def forward(self, x, quantize=False):
690
+ qe, x, _, vq_loss = self.encode(x, quantize)
691
+ x = self.decode(qe)
692
+ return x, vq_loss
693
+
694
+
695
+ r"""
696
+
697
+ https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_b_3b.yaml
698
+
699
+ # GLOBAL STUFF
700
+ model_version: 3B
701
+ dtype: bfloat16
702
+
703
+ # For demonstration purposes in reconstruct_images.ipynb
704
+ webdataset_path: file:inference/imagenet_1024.tar
705
+ batch_size: 4
706
+ image_size: 1024
707
+ grad_accum_steps: 1
708
+
709
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
710
+ stage_a_checkpoint_path: models/stage_a.safetensors
711
+ generator_checkpoint_path: models/stage_b_bf16.safetensors
712
+ """
713
+
714
+
715
+ class StageB(nn.Module):
716
+ def __init__(
717
+ self,
718
+ c_in=4,
719
+ c_out=4,
720
+ c_r=64,
721
+ patch_size=2,
722
+ c_cond=1280,
723
+ c_hidden=[320, 640, 1280, 1280],
724
+ nhead=[-1, -1, 20, 20],
725
+ blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
726
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]],
727
+ level_config=["CT", "CT", "CTA", "CTA"],
728
+ c_clip=1280,
729
+ c_clip_seq=4,
730
+ c_effnet=16,
731
+ c_pixels=3,
732
+ kernel_size=3,
733
+ dropout=[0, 0, 0.1, 0.1],
734
+ self_attn=True,
735
+ t_conds=["sca"],
736
+ ):
737
+ super().__init__()
738
+ self.c_r = c_r
739
+ self.t_conds = t_conds
740
+ self.c_clip_seq = c_clip_seq
741
+ if not isinstance(dropout, list):
742
+ dropout = [dropout] * len(c_hidden)
743
+ if not isinstance(self_attn, list):
744
+ self_attn = [self_attn] * len(c_hidden)
745
+
746
+ # CONDITIONING
747
+ self.effnet_mapper = nn.Sequential(
748
+ nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1),
749
+ nn.GELU(),
750
+ nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
751
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
752
+ )
753
+ self.pixels_mapper = nn.Sequential(
754
+ nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1),
755
+ nn.GELU(),
756
+ nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
757
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
758
+ )
759
+ self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq)
760
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
761
+
762
+ self.embedding = nn.Sequential(
763
+ nn.PixelUnshuffle(patch_size),
764
+ nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
765
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
766
+ )
767
+
768
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
769
+ if block_type == "C":
770
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
771
+ elif block_type == "A":
772
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
773
+ elif block_type == "F":
774
+ return FeedForwardBlock(c_hidden, dropout=dropout)
775
+ elif block_type == "T":
776
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
777
+ else:
778
+ raise Exception(f"Block type {block_type} not supported")
779
+
780
+ # BLOCKS
781
+ # -- down blocks
782
+ self.down_blocks = nn.ModuleList()
783
+ self.down_downscalers = nn.ModuleList()
784
+ self.down_repeat_mappers = nn.ModuleList()
785
+ for i in range(len(c_hidden)):
786
+ if i > 0:
787
+ self.down_downscalers.append(
788
+ nn.Sequential(
789
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
790
+ nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
791
+ )
792
+ )
793
+ else:
794
+ self.down_downscalers.append(nn.Identity())
795
+ down_block = nn.ModuleList()
796
+ for _ in range(blocks[0][i]):
797
+ for block_type in level_config[i]:
798
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
799
+ down_block.append(block)
800
+ self.down_blocks.append(down_block)
801
+ if block_repeat is not None:
802
+ block_repeat_mappers = nn.ModuleList()
803
+ for _ in range(block_repeat[0][i] - 1):
804
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
805
+ self.down_repeat_mappers.append(block_repeat_mappers)
806
+
807
+ # -- up blocks
808
+ self.up_blocks = nn.ModuleList()
809
+ self.up_upscalers = nn.ModuleList()
810
+ self.up_repeat_mappers = nn.ModuleList()
811
+ for i in reversed(range(len(c_hidden))):
812
+ if i > 0:
813
+ self.up_upscalers.append(
814
+ nn.Sequential(
815
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
816
+ nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
817
+ )
818
+ )
819
+ else:
820
+ self.up_upscalers.append(nn.Identity())
821
+ up_block = nn.ModuleList()
822
+ for j in range(blocks[1][::-1][i]):
823
+ for k, block_type in enumerate(level_config[i]):
824
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
825
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
826
+ up_block.append(block)
827
+ self.up_blocks.append(up_block)
828
+ if block_repeat is not None:
829
+ block_repeat_mappers = nn.ModuleList()
830
+ for _ in range(block_repeat[1][::-1][i] - 1):
831
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
832
+ self.up_repeat_mappers.append(block_repeat_mappers)
833
+
834
+ # OUTPUT
835
+ self.clf = nn.Sequential(
836
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
837
+ nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
838
+ nn.PixelShuffle(patch_size),
839
+ )
840
+
841
+ # --- WEIGHT INIT ---
842
+ self.apply(self._init_weights) # General init
843
+ nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
844
+ nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
845
+ nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
846
+ nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
847
+ nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
848
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
849
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
850
+
851
+ # blocks
852
+ for level_block in self.down_blocks + self.up_blocks:
853
+ for block in level_block:
854
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
855
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
856
+ elif isinstance(block, TimestepBlock):
857
+ for layer in block.modules():
858
+ if isinstance(layer, nn.Linear):
859
+ nn.init.constant_(layer.weight, 0)
860
+
861
+ def _init_weights(self, m):
862
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
863
+ torch.nn.init.xavier_uniform_(m.weight)
864
+ if m.bias is not None:
865
+ nn.init.constant_(m.bias, 0)
866
+
867
+ def set_use_xformers_or_sdpa(self, xformers, sdpa):
868
+ for block in self.down_blocks + self.up_blocks:
869
+ for layer in block:
870
+ if hasattr(layer, "set_use_xformers_or_sdpa"):
871
+ layer.set_use_xformers_or_sdpa(xformers, sdpa)
872
+
873
+ def gen_r_embedding(self, r, max_positions=10000):
874
+ r = r * max_positions
875
+ half_dim = self.c_r // 2
876
+ emb = math.log(max_positions) / (half_dim - 1)
877
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
878
+ emb = r[:, None] * emb[None, :]
879
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
880
+ if self.c_r % 2 == 1: # zero pad
881
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
882
+ return emb
883
+
884
+ def gen_c_embeddings(self, clip):
885
+ if len(clip.shape) == 2:
886
+ clip = clip.unsqueeze(1)
887
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
888
+ clip = self.clip_norm(clip)
889
+ return clip
890
+
891
+ def _down_encode(self, x, r_embed, clip):
892
+ level_outputs = []
893
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
894
+ for down_block, downscaler, repmap in block_group:
895
+ x = downscaler(x)
896
+ for i in range(len(repmap) + 1):
897
+ for block in down_block:
898
+ if isinstance(block, ResBlock) or (
899
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
900
+ ):
901
+ x = block(x)
902
+ elif isinstance(block, AttnBlock) or (
903
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
904
+ ):
905
+ x = block(x, clip)
906
+ elif isinstance(block, TimestepBlock) or (
907
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
908
+ ):
909
+ x = block(x, r_embed)
910
+ else:
911
+ x = block(x)
912
+ if i < len(repmap):
913
+ x = repmap[i](x)
914
+ level_outputs.insert(0, x)
915
+ return level_outputs
916
+
917
+ def _up_decode(self, level_outputs, r_embed, clip):
918
+ x = level_outputs[0]
919
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
920
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
921
+ for j in range(len(repmap) + 1):
922
+ for k, block in enumerate(up_block):
923
+ if isinstance(block, ResBlock) or (
924
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
925
+ ):
926
+ skip = level_outputs[i] if k == 0 and i > 0 else None
927
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
928
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
929
+ x = block(x, skip)
930
+ elif isinstance(block, AttnBlock) or (
931
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
932
+ ):
933
+ x = block(x, clip)
934
+ elif isinstance(block, TimestepBlock) or (
935
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
936
+ ):
937
+ x = block(x, r_embed)
938
+ else:
939
+ x = block(x)
940
+ if j < len(repmap):
941
+ x = repmap[j](x)
942
+ x = upscaler(x)
943
+ return x
944
+
945
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
946
+ if pixels is None:
947
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
948
+
949
+ # Process the conditioning embeddings
950
+ r_embed = self.gen_r_embedding(r)
951
+ for c in self.t_conds:
952
+ t_cond = kwargs.get(c, torch.zeros_like(r))
953
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
954
+ clip = self.gen_c_embeddings(clip)
955
+
956
+ # Model Blocks
957
+ x = self.embedding(x)
958
+ x = x + self.effnet_mapper(
959
+ nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True)
960
+ )
961
+ x = x + nn.functional.interpolate(
962
+ self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True
963
+ )
964
+ level_outputs = self._down_encode(x, r_embed, clip)
965
+ x = self._up_decode(level_outputs, r_embed, clip)
966
+ return self.clf(x)
967
+
968
+ def update_weights_ema(self, src_model, beta=0.999):
969
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
970
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
971
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
972
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
973
+
974
+
975
+ r"""
976
+
977
+ https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_c_3b.yaml
978
+
979
+ # GLOBAL STUFF
980
+ model_version: 3.6B
981
+ dtype: bfloat16
982
+
983
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
984
+ previewer_checkpoint_path: models/previewer.safetensors
985
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
986
+ """
987
+
988
+
989
+ class StageC(nn.Module):
990
+ def __init__(
991
+ self,
992
+ c_in=16,
993
+ c_out=16,
994
+ c_r=64,
995
+ patch_size=1,
996
+ c_cond=2048,
997
+ c_hidden=[2048, 2048],
998
+ nhead=[32, 32],
999
+ blocks=[[8, 24], [24, 8]],
1000
+ block_repeat=[[1, 1], [1, 1]],
1001
+ level_config=["CTA", "CTA"],
1002
+ c_clip_text=1280,
1003
+ c_clip_text_pooled=1280,
1004
+ c_clip_img=768,
1005
+ c_clip_seq=4,
1006
+ kernel_size=3,
1007
+ dropout=[0.1, 0.1],
1008
+ self_attn=True,
1009
+ t_conds=["sca", "crp"],
1010
+ switch_level=[False],
1011
+ ):
1012
+ super().__init__()
1013
+ self.c_r = c_r
1014
+ self.t_conds = t_conds
1015
+ self.c_clip_seq = c_clip_seq
1016
+ if not isinstance(dropout, list):
1017
+ dropout = [dropout] * len(c_hidden)
1018
+ if not isinstance(self_attn, list):
1019
+ self_attn = [self_attn] * len(c_hidden)
1020
+
1021
+ # CONDITIONING
1022
+ self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
1023
+ self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
1024
+ self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
1025
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
1026
+
1027
+ self.embedding = nn.Sequential(
1028
+ nn.PixelUnshuffle(patch_size),
1029
+ nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
1030
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
1031
+ )
1032
+
1033
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
1034
+ if block_type == "C":
1035
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
1036
+ elif block_type == "A":
1037
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
1038
+ elif block_type == "F":
1039
+ return FeedForwardBlock(c_hidden, dropout=dropout)
1040
+ elif block_type == "T":
1041
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
1042
+ else:
1043
+ raise Exception(f"Block type {block_type} not supported")
1044
+
1045
+ # BLOCKS
1046
+ # -- down blocks
1047
+ self.down_blocks = nn.ModuleList()
1048
+ self.down_downscalers = nn.ModuleList()
1049
+ self.down_repeat_mappers = nn.ModuleList()
1050
+ for i in range(len(c_hidden)):
1051
+ if i > 0:
1052
+ self.down_downscalers.append(
1053
+ nn.Sequential(
1054
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
1055
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode="down", enabled=switch_level[i - 1]),
1056
+ )
1057
+ )
1058
+ else:
1059
+ self.down_downscalers.append(nn.Identity())
1060
+ down_block = nn.ModuleList()
1061
+ for _ in range(blocks[0][i]):
1062
+ for block_type in level_config[i]:
1063
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
1064
+ down_block.append(block)
1065
+ self.down_blocks.append(down_block)
1066
+ if block_repeat is not None:
1067
+ block_repeat_mappers = nn.ModuleList()
1068
+ for _ in range(block_repeat[0][i] - 1):
1069
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
1070
+ self.down_repeat_mappers.append(block_repeat_mappers)
1071
+
1072
+ # -- up blocks
1073
+ self.up_blocks = nn.ModuleList()
1074
+ self.up_upscalers = nn.ModuleList()
1075
+ self.up_repeat_mappers = nn.ModuleList()
1076
+ for i in reversed(range(len(c_hidden))):
1077
+ if i > 0:
1078
+ self.up_upscalers.append(
1079
+ nn.Sequential(
1080
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
1081
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode="up", enabled=switch_level[i - 1]),
1082
+ )
1083
+ )
1084
+ else:
1085
+ self.up_upscalers.append(nn.Identity())
1086
+ up_block = nn.ModuleList()
1087
+ for j in range(blocks[1][::-1][i]):
1088
+ for k, block_type in enumerate(level_config[i]):
1089
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
1090
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
1091
+ up_block.append(block)
1092
+ self.up_blocks.append(up_block)
1093
+ if block_repeat is not None:
1094
+ block_repeat_mappers = nn.ModuleList()
1095
+ for _ in range(block_repeat[1][::-1][i] - 1):
1096
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
1097
+ self.up_repeat_mappers.append(block_repeat_mappers)
1098
+
1099
+ # OUTPUT
1100
+ self.clf = nn.Sequential(
1101
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
1102
+ nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
1103
+ nn.PixelShuffle(patch_size),
1104
+ )
1105
+
1106
+ # --- WEIGHT INIT ---
1107
+ self.apply(self._init_weights) # General init
1108
+ nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
1109
+ nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
1110
+ nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
1111
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
1112
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
1113
+
1114
+ # blocks
1115
+ for level_block in self.down_blocks + self.up_blocks:
1116
+ for block in level_block:
1117
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
1118
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
1119
+ elif isinstance(block, TimestepBlock):
1120
+ for layer in block.modules():
1121
+ if isinstance(layer, nn.Linear):
1122
+ nn.init.constant_(layer.weight, 0)
1123
+
1124
+ def _init_weights(self, m):
1125
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
1126
+ torch.nn.init.xavier_uniform_(m.weight)
1127
+ if m.bias is not None:
1128
+ nn.init.constant_(m.bias, 0)
1129
+
1130
+ def set_gradient_checkpointing(self, value):
1131
+ for block in self.down_blocks + self.up_blocks:
1132
+ for layer in block:
1133
+ if hasattr(layer, "set_gradient_checkpointing"):
1134
+ layer.set_gradient_checkpointing(value)
1135
+
1136
+ def set_use_xformers_or_sdpa(self, xformers, sdpa):
1137
+ for block in self.down_blocks + self.up_blocks:
1138
+ for layer in block:
1139
+ if hasattr(layer, "set_use_xformers_or_sdpa"):
1140
+ layer.set_use_xformers_or_sdpa(xformers, sdpa)
1141
+
1142
+ def gen_r_embedding(self, r, max_positions=10000):
1143
+ r = r * max_positions
1144
+ half_dim = self.c_r // 2
1145
+ emb = math.log(max_positions) / (half_dim - 1)
1146
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
1147
+ emb = r[:, None] * emb[None, :]
1148
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
1149
+ if self.c_r % 2 == 1: # zero pad
1150
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
1151
+ return emb
1152
+
1153
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
1154
+ clip_txt = self.clip_txt_mapper(clip_txt)
1155
+ if len(clip_txt_pooled.shape) == 2:
1156
+ clip_txt_pool = clip_txt_pooled.unsqueeze(1)
1157
+ if len(clip_img.shape) == 2:
1158
+ clip_img = clip_img.unsqueeze(1)
1159
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
1160
+ clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1
1161
+ )
1162
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
1163
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
1164
+ clip = self.clip_norm(clip)
1165
+ return clip
1166
+
1167
+ def _down_encode(self, x, r_embed, clip, cnet=None):
1168
+ level_outputs = []
1169
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
1170
+ for down_block, downscaler, repmap in block_group:
1171
+ x = downscaler(x)
1172
+ for i in range(len(repmap) + 1):
1173
+ for block in down_block:
1174
+ if isinstance(block, ResBlock) or (
1175
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
1176
+ ):
1177
+ if cnet is not None:
1178
+ next_cnet = cnet()
1179
+ if next_cnet is not None:
1180
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
1181
+ x = block(x)
1182
+ elif isinstance(block, AttnBlock) or (
1183
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
1184
+ ):
1185
+ x = block(x, clip)
1186
+ elif isinstance(block, TimestepBlock) or (
1187
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
1188
+ ):
1189
+ x = block(x, r_embed)
1190
+ else:
1191
+ x = block(x)
1192
+ if i < len(repmap):
1193
+ x = repmap[i](x)
1194
+ level_outputs.insert(0, x)
1195
+ return level_outputs
1196
+
1197
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
1198
+ x = level_outputs[0]
1199
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
1200
+ now_factor = 1
1201
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
1202
+ for j in range(len(repmap) + 1):
1203
+ for k, block in enumerate(up_block):
1204
+ if getattr(block, "factor", 1) > 1:
1205
+ now_factor = -getattr(block, "factor", 1)
1206
+ scale = check_scale(x)
1207
+ if scale > 5 or (now_factor < 0 and scale > (5/-now_factor)):
1208
+ print('='*55)
1209
+ print(f"in: {i} {j} {k}")
1210
+ print("up", scale)
1211
+ if isinstance(block, ResBlock) or (
1212
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
1213
+ ):
1214
+ skip = level_outputs[i] if k == 0 and i > 0 else None
1215
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
1216
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
1217
+ if cnet is not None:
1218
+ next_cnet = cnet()
1219
+ if next_cnet is not None:
1220
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
1221
+ x = block(x, skip)
1222
+ if now_factor > 1 and block.factor == 1:
1223
+ block.set_factor(now_factor)
1224
+ elif isinstance(block, AttnBlock) or (
1225
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
1226
+ ):
1227
+ x = block(x, clip)
1228
+ if now_factor > 1 and block.factor == 1:
1229
+ block.set_factor(now_factor)
1230
+ elif isinstance(block, TimestepBlock) or (
1231
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
1232
+ ):
1233
+ x = block(x, r_embed)
1234
+ scale = check_scale(x)
1235
+ if now_factor > 1 and block.factor == 1:
1236
+ block.set_factor(now_factor, now_factor)
1237
+ pass
1238
+ elif i==1:
1239
+ now_factor = 5
1240
+ block.set_factor(now_factor, 1)
1241
+ else:
1242
+ x = block(x)
1243
+ scale = check_scale(x)
1244
+ if scale > 5 or (now_factor < 0 and scale > (5/-now_factor)):
1245
+ print(f"out: {i} {j} {k}", '='*50)
1246
+ print("up", scale)
1247
+ print(block.__class__.__name__, torch.sum(torch.isnan(x)))
1248
+ if j < len(repmap):
1249
+ x = repmap[j](x)
1250
+ print('-- pre upscaler ---')
1251
+ print(check_scale(x))
1252
+ x = upscaler(x)
1253
+ print('-- post upscaler ---')
1254
+ print(check_scale(x))
1255
+ if now_factor > 1:
1256
+ if isinstance(upscaler, UpDownBlock2d):
1257
+ upscaler.blocks[1].weight.data /= now_factor
1258
+ upscaler.blocks[1].bias.data /= now_factor
1259
+ scale = check_scale(x)
1260
+ if scale > 5:
1261
+ print('='*50)
1262
+ print("upscaler", check_scale(x))
1263
+ return x
1264
+
1265
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
1266
+ # Process the conditioning embeddings
1267
+ r_embed = self.gen_r_embedding(r)
1268
+ for c in self.t_conds:
1269
+ t_cond = kwargs.get(c, torch.zeros_like(r))
1270
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
1271
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
1272
+
1273
+ # Model Blocks
1274
+ x = self.embedding(x)
1275
+ print(check_scale(x))
1276
+ # ControlNet is not supported yet
1277
+ # if cnet is not None:
1278
+ # cnet = ControlNetDeliverer(cnet)
1279
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
1280
+ x1 = self._up_decode(level_outputs, r_embed, clip, cnet)
1281
+ result1 = self.clf(x1)
1282
+ #return result1
1283
+ self.half()
1284
+ sd = convert_state_dict_normal_attn_to_mha(self.state_dict())
1285
+ x2 = self._up_decode(level_outputs, r_embed, clip, cnet)
1286
+ result2 = self.clf(x2)
1287
+ print(torch.nn.functional.mse_loss(result1, result2))
1288
+ from safetensors.torch import save_file
1289
+ save_file(sd, f'{fp16_fix_save_path}/factor5_pass4.safetensors')
1290
+ raise Exception("Early Stop")
1291
+
1292
+ def update_weights_ema(self, src_model, beta=0.999):
1293
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
1294
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
1295
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
1296
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
1297
+
1298
+ @property
1299
+ def device(self):
1300
+ return next(self.parameters()).device
1301
+
1302
+ @property
1303
+ def dtype(self):
1304
+ return next(self.parameters()).dtype
1305
+
1306
+
1307
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
1308
+ class Previewer(nn.Module):
1309
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
1310
+ super().__init__()
1311
+ self.blocks = nn.Sequential(
1312
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
1313
+ nn.GELU(),
1314
+ nn.BatchNorm2d(c_hidden),
1315
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
1316
+ nn.GELU(),
1317
+ nn.BatchNorm2d(c_hidden),
1318
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
1319
+ nn.GELU(),
1320
+ nn.BatchNorm2d(c_hidden // 2),
1321
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
1322
+ nn.GELU(),
1323
+ nn.BatchNorm2d(c_hidden // 2),
1324
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
1325
+ nn.GELU(),
1326
+ nn.BatchNorm2d(c_hidden // 4),
1327
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
1328
+ nn.GELU(),
1329
+ nn.BatchNorm2d(c_hidden // 4),
1330
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
1331
+ nn.GELU(),
1332
+ nn.BatchNorm2d(c_hidden // 4),
1333
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
1334
+ nn.GELU(),
1335
+ nn.BatchNorm2d(c_hidden // 4),
1336
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
1337
+ )
1338
+
1339
+ def forward(self, x):
1340
+ return self.blocks(x)
1341
+
1342
+ @property
1343
+ def device(self):
1344
+ return next(self.parameters()).device
1345
+
1346
+ @property
1347
+ def dtype(self):
1348
+ return next(self.parameters()).dtype
1349
+
1350
+
1351
+ def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
1352
+ # deprecated
1353
+
1354
+ # self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
1355
+ # is_eval の処理をここでやるのは微妙なので別のところでやる
1356
+ # is_unconditional もここでやるのは微妙なので別のところでやる
1357
+ # clip_image はとりあえずサポートしない
1358
+ if captions is not None:
1359
+ clip_tokens_unpooled = tokenizer(
1360
+ captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
1361
+ ).to(text_model.device)
1362
+ text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
1363
+ else:
1364
+ text_encoder_output = text_model(input_ids, output_hidden_states=True)
1365
+
1366
+ text_embeddings = text_encoder_output.hidden_states[-1]
1367
+ text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
1368
+
1369
+ return text_embeddings, text_pooled_embeddings
1370
+ # return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings}
1371
+
1372
+
1373
+ # region gdf
1374
+
1375
+
1376
+ class SimpleSampler:
1377
+ def __init__(self, gdf):
1378
+ self.gdf = gdf
1379
+ self.current_step = -1
1380
+
1381
+ def __call__(self, *args, **kwargs):
1382
+ self.current_step += 1
1383
+ return self.step(*args, **kwargs)
1384
+
1385
+ def init_x(self, shape):
1386
+ return torch.randn(*shape)
1387
+
1388
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev):
1389
+ raise NotImplementedError("You should override the 'apply' function.")
1390
+
1391
+
1392
+ class DDIMSampler(SimpleSampler):
1393
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0):
1394
+ a, b = self.gdf.input_scaler(logSNR)
1395
+ if len(a.shape) == 1:
1396
+ a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1))
1397
+
1398
+ a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
1399
+ if len(a_prev.shape) == 1:
1400
+ a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))
1401
+
1402
+ sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0
1403
+ # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
1404
+ x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
1405
+ return x
1406
+
1407
+
1408
+ class DDPMSampler(DDIMSampler):
1409
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1):
1410
+ return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta)
1411
+
1412
+
1413
+ class LCMSampler(SimpleSampler):
1414
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev):
1415
+ a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
1416
+ if len(a_prev.shape) == 1:
1417
+ a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))
1418
+ return x0 * a_prev + torch.randn_like(epsilon) * b_prev
1419
+
1420
+
1421
+ class GDF:
1422
+ def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
1423
+ self.schedule = schedule
1424
+ self.input_scaler = input_scaler
1425
+ self.target = target
1426
+ self.noise_cond = noise_cond
1427
+ self.loss_weight = loss_weight
1428
+ self.offset_noise = offset_noise
1429
+
1430
+ def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
1431
+ stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
1432
+ return stretched_limits
1433
+
1434
+ def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
1435
+ if epsilon is None:
1436
+ epsilon = torch.randn_like(x0)
1437
+ if self.offset_noise > 0:
1438
+ if offset is None:
1439
+ offset = torch.randn([x0.size(0), x0.size(1)] + [1] * (len(x0.shape) - 2)).to(x0.device)
1440
+ epsilon = epsilon + offset * self.offset_noise
1441
+ logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
1442
+ a, b = self.input_scaler(logSNR) # B
1443
+ if len(a.shape) == 1:
1444
+ a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) # BxCxHxW
1445
+ target = self.target(x0, epsilon, logSNR, a, b)
1446
+
1447
+ # noised, noise, logSNR, t_cond
1448
+ return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
1449
+
1450
+ def undiffuse(self, x, logSNR, pred):
1451
+ a, b = self.input_scaler(logSNR)
1452
+ if len(a.shape) == 1:
1453
+ a, b = a.view(-1, *[1] * (len(x.shape) - 1)), b.view(-1, *[1] * (len(x.shape) - 1))
1454
+ return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
1455
+
1456
+ def sample(
1457
+ self,
1458
+ model,
1459
+ model_inputs,
1460
+ shape,
1461
+ unconditional_inputs=None,
1462
+ sampler=None,
1463
+ schedule=None,
1464
+ t_start=1.0,
1465
+ t_end=0.0,
1466
+ timesteps=20,
1467
+ x_init=None,
1468
+ cfg=3.0,
1469
+ cfg_t_stop=None,
1470
+ cfg_t_start=None,
1471
+ cfg_rho=0.7,
1472
+ sampler_params=None,
1473
+ shift=1,
1474
+ device="cpu",
1475
+ ):
1476
+ sampler_params = {} if sampler_params is None else sampler_params
1477
+ if sampler is None:
1478
+ sampler = DDPMSampler(self)
1479
+ r_range = torch.linspace(t_start, t_end, timesteps + 1)
1480
+ schedule = self.schedule if schedule is None else schedule
1481
+ logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(device)
1482
+
1483
+ x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
1484
+ if cfg is not None:
1485
+ if unconditional_inputs is None:
1486
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
1487
+ model_inputs = {
1488
+ k: (
1489
+ torch.cat([v, v_u], dim=0)
1490
+ if isinstance(v, torch.Tensor)
1491
+ else (
1492
+ [
1493
+ (
1494
+ torch.cat([vi, vi_u], dim=0)
1495
+ if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor)
1496
+ else None
1497
+ )
1498
+ for vi, vi_u in zip(v, v_u)
1499
+ ]
1500
+ if isinstance(v, list)
1501
+ else (
1502
+ {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v}
1503
+ if isinstance(v, dict)
1504
+ else None
1505
+ )
1506
+ )
1507
+ )
1508
+ for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
1509
+ }
1510
+ for i in range(0, timesteps):
1511
+ noise_cond = self.noise_cond(logSNR_range[i])
1512
+ if (
1513
+ cfg is not None
1514
+ and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop)
1515
+ and (cfg_t_start is None or r_range[i].item() <= cfg_t_start)
1516
+ ):
1517
+ cfg_val = cfg
1518
+ if isinstance(cfg_val, (list, tuple)):
1519
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
1520
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1 - r_range[i].item())
1521
+ pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
1522
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
1523
+ if cfg_rho > 0:
1524
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
1525
+ pred = cfg_rho * (pred_cfg * std_pos / (std_cfg + 1e-9)) + pred_cfg * (1 - cfg_rho)
1526
+ else:
1527
+ pred = pred_cfg
1528
+ else:
1529
+ pred = model(x, noise_cond, **model_inputs)
1530
+ x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
1531
+ x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i + 1], **sampler_params)
1532
+ altered_vars = yield (x0, x, pred)
1533
+
1534
+ # Update some running variables if the user wants
1535
+ if altered_vars is not None:
1536
+ cfg = altered_vars.get("cfg", cfg)
1537
+ cfg_rho = altered_vars.get("cfg_rho", cfg_rho)
1538
+ sampler = altered_vars.get("sampler", sampler)
1539
+ model_inputs = altered_vars.get("model_inputs", model_inputs)
1540
+ x = altered_vars.get("x", x)
1541
+ x_init = altered_vars.get("x_init", x_init)
1542
+
1543
+
1544
+ class BaseSchedule:
1545
+ def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs):
1546
+ self.setup(*args, **kwargs)
1547
+ self.limits = None
1548
+ self.discrete_steps = discrete_steps
1549
+ self.shift = shift
1550
+ if force_limits:
1551
+ self.reset_limits()
1552
+
1553
+ def reset_limits(self, shift=1, disable=False):
1554
+ try:
1555
+ self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max
1556
+ return self.limits
1557
+ except Exception:
1558
+ print("WARNING: this schedule doesn't support t and will be unbounded")
1559
+ return None
1560
+
1561
+ def setup(self, *args, **kwargs):
1562
+ raise NotImplementedError("this method needs to be overridden")
1563
+
1564
+ def schedule(self, *args, **kwargs):
1565
+ raise NotImplementedError("this method needs to be overridden")
1566
+
1567
+ def __call__(self, t, *args, shift=1, **kwargs):
1568
+ if isinstance(t, torch.Tensor):
1569
+ batch_size = None
1570
+ if self.discrete_steps is not None:
1571
+ if t.dtype != torch.long:
1572
+ t = (t * (self.discrete_steps - 1)).round().long()
1573
+ t = t / (self.discrete_steps - 1)
1574
+ t = t.clamp(0, 1)
1575
+ else:
1576
+ batch_size = t
1577
+ t = None
1578
+ logSNR = self.schedule(t, batch_size, *args, **kwargs)
1579
+ if shift * self.shift != 1:
1580
+ logSNR += 2 * np.log(1 / (shift * self.shift))
1581
+ if self.limits is not None:
1582
+ logSNR = logSNR.clamp(*self.limits)
1583
+ return logSNR
1584
+
1585
+
1586
+ class CosineSchedule(BaseSchedule):
1587
+ def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False):
1588
+ self.s = torch.tensor([s])
1589
+ self.clamp_range = clamp_range
1590
+ self.norm_instead = norm_instead
1591
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
1592
+
1593
+ def schedule(self, t, batch_size):
1594
+ if t is None:
1595
+ t = (1 - torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0)
1596
+ s, min_var = self.s.to(t.device), self.min_var.to(t.device)
1597
+ var = torch.cos((s + t) / (1 + s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var
1598
+ if self.norm_instead:
1599
+ var = var * (self.clamp_range[1] - self.clamp_range[0]) + self.clamp_range[0]
1600
+ else:
1601
+ var = var.clamp(*self.clamp_range)
1602
+ logSNR = (var / (1 - var)).log()
1603
+ return logSNR
1604
+
1605
+
1606
+ class BaseScaler:
1607
+ def __init__(self):
1608
+ self.stretched_limits = None
1609
+
1610
+ def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1):
1611
+ min_logSNR = schedule(torch.ones(1), shift=shift)
1612
+ max_logSNR = schedule(torch.zeros(1), shift=shift)
1613
+
1614
+ min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1]
1615
+ max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0]
1616
+ self.stretched_limits = [min_a, max_a, min_b, max_b]
1617
+ return self.stretched_limits
1618
+
1619
+ def stretch_limits(self, a, b):
1620
+ min_a, max_a, min_b, max_b = self.stretched_limits
1621
+ return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b)
1622
+
1623
+ def scalers(self, logSNR):
1624
+ raise NotImplementedError("this method needs to be overridden")
1625
+
1626
+ def __call__(self, logSNR):
1627
+ a, b = self.scalers(logSNR)
1628
+ if self.stretched_limits is not None:
1629
+ a, b = self.stretch_limits(a, b)
1630
+ return a, b
1631
+
1632
+
1633
+ class VPScaler(BaseScaler):
1634
+ def scalers(self, logSNR):
1635
+ a_squared = logSNR.sigmoid()
1636
+ a = a_squared.sqrt()
1637
+ b = (1 - a_squared).sqrt()
1638
+ return a, b
1639
+
1640
+
1641
+ class EpsilonTarget:
1642
+ def __call__(self, x0, epsilon, logSNR, a, b):
1643
+ return epsilon
1644
+
1645
+ def x0(self, noised, pred, logSNR, a, b):
1646
+ return (noised - pred * b) / a
1647
+
1648
+ def epsilon(self, noised, pred, logSNR, a, b):
1649
+ return pred
1650
+
1651
+
1652
+ class BaseNoiseCond:
1653
+ def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
1654
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
1655
+ self.shift = shift
1656
+ self.clamp_range = clamp_range
1657
+ self.setup(*args, **kwargs)
1658
+
1659
+ def setup(self, *args, **kwargs):
1660
+ pass # this method is optional, override it if required
1661
+
1662
+ def cond(self, logSNR):
1663
+ raise NotImplementedError("this method needs to be overridden")
1664
+
1665
+ def __call__(self, logSNR):
1666
+ if self.shift != 1:
1667
+ logSNR = logSNR.clone() + 2 * np.log(self.shift)
1668
+ return self.cond(logSNR).clamp(*self.clamp_range)
1669
+
1670
+
1671
+ class CosineTNoiseCond(BaseNoiseCond):
1672
+ def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
1673
+ self.s = torch.tensor([s])
1674
+ self.clamp_range = clamp_range
1675
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
1676
+
1677
+ def cond(self, logSNR):
1678
+ var = logSNR.sigmoid()
1679
+ var = var.clamp(*self.clamp_range)
1680
+ s, min_var = self.s.to(var.device), self.min_var.to(var.device)
1681
+ t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
1682
+ return t
1683
+
1684
+
1685
+ # --- Loss Weighting
1686
+ class BaseLossWeight:
1687
+ def weight(self, logSNR):
1688
+ raise NotImplementedError("this method needs to be overridden")
1689
+
1690
+ def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
1691
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
1692
+ if shift != 1:
1693
+ logSNR = logSNR.clone() + 2 * np.log(shift)
1694
+ return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
1695
+
1696
+
1697
+ # class ComposedLossWeight(BaseLossWeight):
1698
+ # def __init__(self, div, mul):
1699
+ # self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
1700
+ # self.div = [div] if isinstance(div, BaseLossWeight) else div
1701
+
1702
+ # def weight(self, logSNR):
1703
+ # prod, div = 1, 1
1704
+ # for m in self.mul:
1705
+ # prod *= m.weight(logSNR)
1706
+ # for d in self.div:
1707
+ # div *= d.weight(logSNR)
1708
+ # return prod/div
1709
+
1710
+ # class ConstantLossWeight(BaseLossWeight):
1711
+ # def __init__(self, v=1):
1712
+ # self.v = v
1713
+
1714
+ # def weight(self, logSNR):
1715
+ # return torch.ones_like(logSNR) * self.v
1716
+
1717
+ # class SNRLossWeight(BaseLossWeight):
1718
+ # def weight(self, logSNR):
1719
+ # return logSNR.exp()
1720
+
1721
+
1722
+ class P2LossWeight(BaseLossWeight):
1723
+ def __init__(self, k=1.0, gamma=1.0, s=1.0):
1724
+ self.k, self.gamma, self.s = k, gamma, s
1725
+
1726
+ def weight(self, logSNR):
1727
+ return (self.k + (logSNR * self.s).exp()) ** -self.gamma
1728
+
1729
+
1730
+ # class SNRPlusOneLossWeight(BaseLossWeight):
1731
+ # def weight(self, logSNR):
1732
+ # return logSNR.exp() + 1
1733
+
1734
+ # class MinSNRLossWeight(BaseLossWeight):
1735
+ # def __init__(self, max_snr=5):
1736
+ # self.max_snr = max_snr
1737
+
1738
+ # def weight(self, logSNR):
1739
+ # return logSNR.exp().clamp(max=self.max_snr)
1740
+
1741
+ # class MinSNRPlusOneLossWeight(BaseLossWeight):
1742
+ # def __init__(self, max_snr=5):
1743
+ # self.max_snr = max_snr
1744
+
1745
+ # def weight(self, logSNR):
1746
+ # return (logSNR.exp() + 1).clamp(max=self.max_snr)
1747
+
1748
+ # class TruncatedSNRLossWeight(BaseLossWeight):
1749
+ # def __init__(self, min_snr=1):
1750
+ # self.min_snr = min_snr
1751
+
1752
+ # def weight(self, logSNR):
1753
+ # return logSNR.exp().clamp(min=self.min_snr)
1754
+
1755
+ # class SechLossWeight(BaseLossWeight):
1756
+ # def __init__(self, div=2):
1757
+ # self.div = div
1758
+
1759
+ # def weight(self, logSNR):
1760
+ # return 1/(logSNR/self.div).cosh()
1761
+
1762
+ # class DebiasedLossWeight(BaseLossWeight):
1763
+ # def weight(self, logSNR):
1764
+ # return 1/logSNR.exp().sqrt()
1765
+
1766
+ # class SigmoidLossWeight(BaseLossWeight):
1767
+ # def __init__(self, s=1):
1768
+ # self.s = s
1769
+
1770
+ # def weight(self, logSNR):
1771
+ # return (logSNR * self.s).sigmoid()
1772
+
1773
+
1774
+ class AdaptiveLossWeight(BaseLossWeight):
1775
+ def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
1776
+ self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1)
1777
+ self.bucket_losses = torch.ones(buckets)
1778
+ self.weight_range = weight_range
1779
+
1780
+ def weight(self, logSNR):
1781
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
1782
+ return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
1783
+
1784
+ def update_buckets(self, logSNR, loss, beta=0.99):
1785
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
1786
+ self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta)
1787
+
1788
+
1789
+ # endregion gdf