williamberman
commited on
Commit
·
aa67e5e
1
Parent(s):
e4ea387
vae working
Browse files- sdxl.py +1 -1
- sdxl_models.py +18 -28
sdxl.py
CHANGED
@@ -9,7 +9,6 @@ import torch
|
|
9 |
import torch.nn.functional as F
|
10 |
import torchvision.transforms
|
11 |
import torchvision.transforms.functional as TF
|
12 |
-
import wandb
|
13 |
import webdataset as wds
|
14 |
from PIL import Image
|
15 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
@@ -17,6 +16,7 @@ from torch.utils.data import default_collate
|
|
17 |
from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
|
18 |
CLIPTokenizerFast)
|
19 |
|
|
|
20 |
from diffusion import (default_num_train_timesteps,
|
21 |
euler_ode_solver_diffusion_loop, make_sigmas)
|
22 |
from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
|
|
|
9 |
import torch.nn.functional as F
|
10 |
import torchvision.transforms
|
11 |
import torchvision.transforms.functional as TF
|
|
|
12 |
import webdataset as wds
|
13 |
from PIL import Image
|
14 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
16 |
from transformers import (CLIPTextModel, CLIPTextModelWithProjection,
|
17 |
CLIPTokenizerFast)
|
18 |
|
19 |
+
import wandb
|
20 |
from diffusion import (default_num_train_timesteps,
|
21 |
euler_ode_solver_diffusion_loop, make_sigmas)
|
22 |
from sdxl_models import (SDXLAdapter, SDXLControlNet, SDXLControlNetFull,
|
sdxl_models.py
CHANGED
@@ -6,7 +6,7 @@ import safetensors.torch
|
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
import torchvision.transforms.functional as TF
|
9 |
-
import xformers
|
10 |
from PIL import Image
|
11 |
from torch import nn
|
12 |
|
@@ -62,17 +62,17 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
62 |
# 128 -> 128
|
63 |
nn.ModuleDict(dict(
|
64 |
resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
|
65 |
-
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2
|
66 |
)),
|
67 |
# 128 -> 256
|
68 |
nn.ModuleDict(dict(
|
69 |
resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
|
70 |
-
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2
|
71 |
)),
|
72 |
# 256 -> 512
|
73 |
nn.ModuleDict(dict(
|
74 |
resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
75 |
-
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2
|
76 |
)),
|
77 |
# 512 -> 512
|
78 |
nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))),
|
@@ -151,6 +151,7 @@ class SDXLVae(nn.Module, ModelUtils):
|
|
151 |
h = resnet(h)
|
152 |
|
153 |
if "downsamplers" in down_block:
|
|
|
154 |
h = down_block["downsamplers"][0]["conv"](h)
|
155 |
|
156 |
h = self.encoder["mid_block"]["resnets"][0](h)
|
@@ -1333,49 +1334,38 @@ class Attention(nn.Module):
|
|
1333 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1334 |
|
1335 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1336 |
-
|
1337 |
-
|
1338 |
-
if input_ndim == 4:
|
1339 |
-
batch_size, channels, height, width = hidden_states.shape
|
1340 |
-
hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
|
1341 |
-
|
1342 |
-
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states, encoder_hidden_states)
|
1343 |
-
|
1344 |
-
if input_ndim == 4:
|
1345 |
-
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1346 |
-
|
1347 |
-
return hidden_states
|
1348 |
|
1349 |
|
1350 |
class VaeMidBlockAttention(nn.Module):
|
1351 |
def __init__(self, channels):
|
1352 |
super().__init__()
|
1353 |
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
1354 |
-
self.to_q = nn.Linear(channels, channels
|
1355 |
-
self.to_k = nn.Linear(channels, channels
|
1356 |
-
self.to_v = nn.Linear(channels, channels
|
1357 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
|
|
1358 |
|
1359 |
def forward(self, hidden_states):
|
1360 |
-
|
1361 |
|
1362 |
-
|
1363 |
-
|
1364 |
-
hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
|
1365 |
|
1366 |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1367 |
|
1368 |
-
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, hidden_states)
|
1369 |
|
1370 |
-
|
1371 |
-
|
|
|
1372 |
|
1373 |
return hidden_states
|
1374 |
|
1375 |
|
1376 |
-
def attention(to_q, to_k, to_v, to_out, hidden_states, encoder_hidden_states=None):
|
1377 |
batch_size, q_seq_len, channels = hidden_states.shape
|
1378 |
-
head_dim = 64
|
1379 |
|
1380 |
if encoder_hidden_states is not None:
|
1381 |
kv = encoder_hidden_states
|
|
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
import torchvision.transforms.functional as TF
|
9 |
+
import xformers.ops
|
10 |
from PIL import Image
|
11 |
from torch import nn
|
12 |
|
|
|
62 |
# 128 -> 128
|
63 |
nn.ModuleDict(dict(
|
64 |
resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]),
|
65 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2)))]),
|
66 |
)),
|
67 |
# 128 -> 256
|
68 |
nn.ModuleDict(dict(
|
69 |
resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]),
|
70 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2)))]),
|
71 |
)),
|
72 |
# 256 -> 512
|
73 |
nn.ModuleDict(dict(
|
74 |
resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]),
|
75 |
+
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2)))]),
|
76 |
)),
|
77 |
# 512 -> 512
|
78 |
nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))),
|
|
|
151 |
h = resnet(h)
|
152 |
|
153 |
if "downsamplers" in down_block:
|
154 |
+
h = F.pad(h, pad=(0, 1, 0, 1), mode="constant", value=0)
|
155 |
h = down_block["downsamplers"][0]["conv"](h)
|
156 |
|
157 |
h = self.encoder["mid_block"]["resnets"][0](h)
|
|
|
1334 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1335 |
|
1336 |
def forward(self, hidden_states, encoder_hidden_states=None):
|
1337 |
+
return attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1338 |
|
1339 |
|
1340 |
class VaeMidBlockAttention(nn.Module):
|
1341 |
def __init__(self, channels):
|
1342 |
super().__init__()
|
1343 |
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06)
|
1344 |
+
self.to_q = nn.Linear(channels, channels)
|
1345 |
+
self.to_k = nn.Linear(channels, channels)
|
1346 |
+
self.to_v = nn.Linear(channels, channels)
|
1347 |
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0))
|
1348 |
+
self.head_dim = channels
|
1349 |
|
1350 |
def forward(self, hidden_states):
|
1351 |
+
residual = hidden_states
|
1352 |
|
1353 |
+
batch_size, channels, height, width = hidden_states.shape
|
1354 |
+
hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
|
|
|
1355 |
|
1356 |
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1357 |
|
1358 |
+
hidden_states = attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states)
|
1359 |
|
1360 |
+
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width)
|
1361 |
+
|
1362 |
+
hidden_states = hidden_states + residual
|
1363 |
|
1364 |
return hidden_states
|
1365 |
|
1366 |
|
1367 |
+
def attention(to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None):
|
1368 |
batch_size, q_seq_len, channels = hidden_states.shape
|
|
|
1369 |
|
1370 |
if encoder_hidden_states is not None:
|
1371 |
kv = encoder_hidden_states
|