|
import math |
|
import os |
|
from typing import List, Literal, Optional |
|
|
|
import safetensors.torch |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
import xformers.ops |
|
from PIL import Image |
|
from torch import nn |
|
|
|
|
|
class ModelUtils: |
|
@property |
|
def dtype(self): |
|
return next(self.parameters()).dtype |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@classmethod |
|
def load(cls, load_from: str, device, overrides: Optional[List[str]] = None): |
|
import load_state_dict_patch |
|
|
|
load_from = [load_from] |
|
|
|
if overrides is not None: |
|
load_from += overrides |
|
|
|
state_dict = {} |
|
|
|
for load_from_ in load_from: |
|
if os.path.isdir(load_from_): |
|
load_from_ = os.path.join(load_from_, "diffusion_pytorch_model.safetensors") |
|
|
|
state_dict.update(safetensors.torch.load_file(load_from_, device=device)) |
|
|
|
with torch.device("meta"): |
|
model = cls() |
|
|
|
model.load_state_dict(state_dict, assign=True) |
|
|
|
return model |
|
|
|
|
|
vae_scaling_factor = 0.13025 |
|
|
|
|
|
class SDXLVae(nn.Module, ModelUtils): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
self.encoder = nn.ModuleDict(dict( |
|
|
|
conv_in=nn.Conv2d(3, 128, kernel_size=3, padding=1), |
|
|
|
down_blocks=nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(128, 128, kernel_size=3, stride=2)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(128, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=2)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(256, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=2)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict(resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]))), |
|
]), |
|
|
|
|
|
mid_block=nn.ModuleDict(dict( |
|
attentions=nn.ModuleList([VaeMidBlockAttention(512)]), |
|
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]), |
|
)), |
|
|
|
|
|
conv_norm_out=nn.GroupNorm(32, 512, eps=1e-06), |
|
conv_act=nn.SiLU(), |
|
conv_out=nn.Conv2d(512, 8, kernel_size=3, padding=1) |
|
)) |
|
|
|
|
|
self.quant_conv = nn.Conv2d(8, 8, kernel_size=1) |
|
|
|
|
|
|
|
|
|
self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1) |
|
|
|
self.decoder = nn.ModuleDict(dict( |
|
|
|
conv_in=nn.Conv2d(4, 512, kernel_size=3, padding=1), |
|
|
|
|
|
mid_block=nn.ModuleDict(dict( |
|
attentions=nn.ModuleList([VaeMidBlockAttention(512)]), |
|
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]), |
|
)), |
|
|
|
up_blocks=nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)))]), |
|
)), |
|
|
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6), ResnetBlock2D(512, 512, eps=1e-6)]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)))]), |
|
)), |
|
|
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(512, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6), ResnetBlock2D(256, 256, eps=1e-6)]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)))]), |
|
)), |
|
|
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ResnetBlock2D(256, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6), ResnetBlock2D(128, 128, eps=1e-6)]), |
|
)), |
|
]), |
|
|
|
|
|
conv_norm_out=nn.GroupNorm(32, 128, eps=1e-06), |
|
conv_act=nn.SiLU(), |
|
conv_out=nn.Conv2d(128, 3, kernel_size=3, padding=1) |
|
)) |
|
|
|
|
|
|
|
def encode(self, x, generator=None): |
|
h = x |
|
|
|
h = self.encoder["conv_in"](h) |
|
|
|
for down_block in self.encoder["down_blocks"]: |
|
for resnet in down_block["resnets"]: |
|
h = resnet(h) |
|
|
|
if "downsamplers" in down_block: |
|
h = F.pad(h, pad=(0, 1, 0, 1), mode="constant", value=0) |
|
h = down_block["downsamplers"][0]["conv"](h) |
|
|
|
h = self.encoder["mid_block"]["resnets"][0](h) |
|
h = self.encoder["mid_block"]["attentions"][0](h) |
|
h = self.encoder["mid_block"]["resnets"][1](h) |
|
|
|
h = self.encoder["conv_norm_out"](h) |
|
h = self.encoder["conv_act"](h) |
|
h = self.encoder["conv_out"](h) |
|
|
|
mean, logvar = self.quant_conv(h).chunk(2, dim=1) |
|
|
|
logvar = torch.clamp(logvar, -30.0, 20.0) |
|
|
|
std = torch.exp(0.5 * logvar) |
|
|
|
z = mean + torch.randn(mean.shape, device=mean.device, dtype=mean.dtype, generator=generator) * std |
|
|
|
z = z * vae_scaling_factor |
|
|
|
return z |
|
|
|
def decode(self, z): |
|
z = z / vae_scaling_factor |
|
|
|
h = z |
|
|
|
h = self.post_quant_conv(h) |
|
|
|
h = self.decoder["conv_in"](h) |
|
|
|
h = self.decoder["mid_block"]["resnets"][0](h) |
|
h = self.decoder["mid_block"]["attentions"][0](h) |
|
h = self.decoder["mid_block"]["resnets"][1](h) |
|
|
|
for up_block in self.decoder["up_blocks"]: |
|
for resnet in up_block["resnets"]: |
|
h = resnet(h) |
|
|
|
if "upsamplers" in up_block: |
|
h = F.interpolate(h, scale_factor=2.0, mode="nearest") |
|
h = up_block["upsamplers"][0]["conv"](h) |
|
|
|
h = self.decoder["conv_norm_out"](h) |
|
h = self.decoder["conv_act"](h) |
|
h = self.decoder["conv_out"](h) |
|
|
|
x_pred = h |
|
|
|
return x_pred |
|
|
|
@classmethod |
|
def input_pil_to_tensor(self, x): |
|
x = TF.to_tensor(x) |
|
x = TF.normalize(x, [0.5], [0.5]) |
|
if x.ndim == 3: |
|
x = x[None, :, :, :] |
|
return x |
|
|
|
@classmethod |
|
def output_tensor_to_pil(self, x_pred): |
|
x_pred = ((x_pred * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() |
|
|
|
x_pred = [Image.fromarray(x) for x in x_pred] |
|
|
|
return x_pred |
|
|
|
@classmethod |
|
def load_fp32(cls, device=None, overrides=None): |
|
return cls.load("./weights/sdxl_vae.safetensors", device=device, overrides=overrides) |
|
|
|
@classmethod |
|
def load_fp16(cls, device=None, overrides=None): |
|
return cls.load("./weights/sdxl_vae.fp16.safetensors", device=device, overrides=overrides) |
|
|
|
@classmethod |
|
def load_fp16_fix(cls, device=None, overrides=None): |
|
return cls.load("./weights/sdxl_vae_fp16_fix.safetensors", device=device, overrides=overrides) |
|
|
|
|
|
class SDXLUNet(nn.Module, ModelUtils): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
encoder_hidden_states_dim = 2048 |
|
|
|
|
|
|
|
time_sinusoidal_embedding_dim = 320 |
|
time_embedding_dim = 1280 |
|
|
|
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim) |
|
|
|
self.time_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
|
|
num_micro_conditioning_values = 6 |
|
micro_conditioning_embedding_dim = 256 |
|
additional_embedding_encoder_dim = 1280 |
|
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim) |
|
|
|
self.add_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1) |
|
|
|
self.down_blocks = nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 640, time_embedding_dim), |
|
ResnetBlock2D(640, 640, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(640, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
]), |
|
)), |
|
]) |
|
|
|
self.mid_block = nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]), |
|
)) |
|
|
|
self.up_blocks = nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280 + 640, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(1280, 1280, kernel_size=3, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280 + 640, 640, time_embedding_dim), |
|
ResnetBlock2D(640 + 640, 640, time_embedding_dim), |
|
ResnetBlock2D(640 + 320, 640, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(640 + 320, 320, time_embedding_dim), |
|
ResnetBlock2D(320 + 320, 320, time_embedding_dim), |
|
ResnetBlock2D(320 + 320, 320, time_embedding_dim), |
|
]), |
|
)) |
|
]) |
|
|
|
self.conv_norm_out = nn.GroupNorm(32, 320) |
|
self.conv_act = nn.SiLU() |
|
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x_t, |
|
t, |
|
encoder_hidden_states, |
|
micro_conditioning, |
|
pooled_encoder_hidden_states, |
|
down_block_additional_residuals: Optional[List[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
add_to_down_block_inputs: Optional[List[torch.Tensor]] = None, |
|
add_to_output: Optional[torch.Tensor] = None, |
|
): |
|
hidden_state = x_t |
|
|
|
t = self.get_sinusoidal_timestep_embedding(t) |
|
t = t.to(dtype=hidden_state.dtype) |
|
t = self.time_embedding["linear_1"](t) |
|
t = self.time_embedding["act"](t) |
|
t = self.time_embedding["linear_2"](t) |
|
|
|
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning) |
|
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype) |
|
additional_conditioning = additional_conditioning.flatten(1) |
|
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1) |
|
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["act"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning) |
|
|
|
t = t + additional_conditioning |
|
|
|
hidden_state = self.conv_in(hidden_state) |
|
|
|
residuals = [hidden_state] |
|
|
|
for down_block in self.down_blocks: |
|
for i, resnet in enumerate(down_block["resnets"]): |
|
if add_to_down_block_inputs is not None: |
|
hidden_state = hidden_state + add_to_down_block_inputs.pop(0) |
|
|
|
hidden_state = resnet(hidden_state, t) |
|
|
|
if "attentions" in down_block: |
|
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states) |
|
|
|
residuals.append(hidden_state) |
|
|
|
if "downsamplers" in down_block: |
|
if add_to_down_block_inputs is not None: |
|
hidden_state = hidden_state + add_to_down_block_inputs.pop(0) |
|
|
|
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state) |
|
|
|
residuals.append(hidden_state) |
|
|
|
hidden_state = self.mid_block["resnets"][0](hidden_state, t) |
|
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states) |
|
hidden_state = self.mid_block["resnets"][1](hidden_state, t) |
|
|
|
if mid_block_additional_residual is not None: |
|
hidden_state = hidden_state + mid_block_additional_residual |
|
|
|
for up_block in self.up_blocks: |
|
for i, resnet in enumerate(up_block["resnets"]): |
|
residual = residuals.pop() |
|
|
|
if down_block_additional_residuals is not None: |
|
residual = residual + down_block_additional_residuals.pop() |
|
|
|
hidden_state = torch.concat([hidden_state, residual], dim=1) |
|
|
|
hidden_state = resnet(hidden_state, t) |
|
|
|
if "attentions" in up_block: |
|
hidden_state = up_block["attentions"][i](hidden_state, encoder_hidden_states) |
|
|
|
if "upsamplers" in up_block: |
|
hidden_state = F.interpolate(hidden_state, scale_factor=2.0, mode="nearest") |
|
hidden_state = up_block["upsamplers"][0]["conv"](hidden_state) |
|
|
|
hidden_state = self.conv_norm_out(hidden_state) |
|
hidden_state = self.conv_act(hidden_state) |
|
hidden_state = self.conv_out(hidden_state) |
|
|
|
if add_to_output is not None: |
|
hidden_state = hidden_state + add_to_output |
|
|
|
eps_hat = hidden_state |
|
|
|
return eps_hat |
|
|
|
@classmethod |
|
def load_fp32(cls, device=None, overrides=None): |
|
return cls.load("./weights/sdxl_unet.safetensors", device=device, overrides=overrides) |
|
|
|
@classmethod |
|
def load_fp16(cls, device=None, overrides=None): |
|
return cls.load("./weights/sdxl_unet.fp16.safetensors", device=device, overrides=overrides) |
|
|
|
|
|
class SDXLControlNet(nn.Module, ModelUtils): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
encoder_hidden_states_dim = 2048 |
|
|
|
|
|
|
|
time_sinusoidal_embedding_dim = 320 |
|
time_embedding_dim = 1280 |
|
|
|
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim) |
|
|
|
self.time_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
|
|
num_micro_conditioning_values = 6 |
|
micro_conditioning_embedding_dim = 256 |
|
additional_embedding_encoder_dim = 1280 |
|
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim) |
|
|
|
self.add_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
self.controlnet_cond_embedding = nn.ModuleDict(dict( |
|
conv_in=nn.Conv2d(3, 16, kernel_size=3, padding=1), |
|
blocks=nn.ModuleList([ |
|
|
|
nn.Conv2d(16, 16, kernel_size=3, padding=1), |
|
nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2), |
|
|
|
nn.Conv2d(32, 32, kernel_size=3, padding=1), |
|
nn.Conv2d(32, 96, kernel_size=3, padding=1, stride=2), |
|
|
|
nn.Conv2d(96, 96, kernel_size=3, padding=1), |
|
nn.Conv2d(96, 256, kernel_size=3, padding=1, stride=2), |
|
]), |
|
conv_out=zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1)), |
|
)) |
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1) |
|
|
|
self.down_blocks = nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 640, time_embedding_dim), |
|
ResnetBlock2D(640, 640, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(640, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
]), |
|
)), |
|
]) |
|
|
|
self.controlnet_down_blocks = nn.ModuleList([ |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
]) |
|
|
|
self.mid_block = nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]), |
|
)) |
|
|
|
self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1)) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x_t, |
|
t, |
|
encoder_hidden_states, |
|
micro_conditioning, |
|
pooled_encoder_hidden_states, |
|
controlnet_cond, |
|
): |
|
hidden_state = x_t |
|
|
|
t = self.get_sinusoidal_timestep_embedding(t) |
|
t = t.to(dtype=hidden_state.dtype) |
|
t = self.time_embedding["linear_1"](t) |
|
t = self.time_embedding["act"](t) |
|
t = self.time_embedding["linear_2"](t) |
|
|
|
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning) |
|
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype) |
|
additional_conditioning = additional_conditioning.flatten(1) |
|
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1) |
|
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["act"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning) |
|
|
|
t = t + additional_conditioning |
|
|
|
controlnet_cond = self.controlnet_cond_embedding["conv_in"](controlnet_cond) |
|
controlnet_cond = F.silu(controlnet_cond) |
|
|
|
for block in self.controlnet_cond_embedding["blocks"]: |
|
controlnet_cond = F.silu(block(controlnet_cond)) |
|
|
|
controlnet_cond = self.controlnet_cond_embedding["conv_out"](controlnet_cond) |
|
|
|
hidden_state = self.conv_in(hidden_state) |
|
|
|
hidden_state = hidden_state + controlnet_cond |
|
|
|
down_block_res_sample = self.controlnet_down_blocks[0](hidden_state) |
|
down_block_res_samples = [down_block_res_sample] |
|
|
|
for down_block in self.down_blocks: |
|
for i, resnet in enumerate(down_block["resnets"]): |
|
hidden_state = resnet(hidden_state, t) |
|
|
|
if "attentions" in down_block: |
|
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states) |
|
|
|
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state) |
|
down_block_res_samples.append(down_block_res_sample) |
|
|
|
if "downsamplers" in down_block: |
|
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state) |
|
|
|
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state) |
|
down_block_res_samples.append(down_block_res_sample) |
|
|
|
hidden_state = self.mid_block["resnets"][0](hidden_state, t) |
|
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states) |
|
hidden_state = self.mid_block["resnets"][1](hidden_state, t) |
|
|
|
mid_block_res_sample = self.controlnet_mid_block(hidden_state) |
|
|
|
return dict( |
|
down_block_res_samples=down_block_res_samples, |
|
mid_block_res_sample=mid_block_res_sample, |
|
) |
|
|
|
@classmethod |
|
def from_unet(cls, unet): |
|
controlnet = cls() |
|
|
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) |
|
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) |
|
|
|
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) |
|
|
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) |
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) |
|
|
|
return controlnet |
|
|
|
|
|
class SDXLControlNetPreEncodedControlnetCond(nn.Module, ModelUtils): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
encoder_hidden_states_dim = 2048 |
|
|
|
|
|
|
|
time_sinusoidal_embedding_dim = 320 |
|
time_embedding_dim = 1280 |
|
|
|
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim) |
|
|
|
self.time_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
|
|
num_micro_conditioning_values = 6 |
|
micro_conditioning_embedding_dim = 256 |
|
additional_embedding_encoder_dim = 1280 |
|
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim) |
|
|
|
self.add_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d(9, 320, kernel_size=3, padding=1) |
|
|
|
self.down_blocks = nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 640, time_embedding_dim), |
|
ResnetBlock2D(640, 640, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(640, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
]), |
|
)), |
|
]) |
|
|
|
self.controlnet_down_blocks = nn.ModuleList([ |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
]) |
|
|
|
self.mid_block = nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]), |
|
)) |
|
|
|
self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1)) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x_t, |
|
t, |
|
encoder_hidden_states, |
|
micro_conditioning, |
|
pooled_encoder_hidden_states, |
|
controlnet_cond, |
|
): |
|
hidden_state = x_t |
|
|
|
t = self.get_sinusoidal_timestep_embedding(t) |
|
t = t.to(dtype=hidden_state.dtype) |
|
t = self.time_embedding["linear_1"](t) |
|
t = self.time_embedding["act"](t) |
|
t = self.time_embedding["linear_2"](t) |
|
|
|
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning) |
|
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype) |
|
additional_conditioning = additional_conditioning.flatten(1) |
|
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1) |
|
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["act"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning) |
|
|
|
t = t + additional_conditioning |
|
|
|
hidden_state = torch.concat((hidden_state, controlnet_cond), dim=1) |
|
|
|
hidden_state = self.conv_in(hidden_state) |
|
|
|
down_block_res_sample = self.controlnet_down_blocks[0](hidden_state) |
|
down_block_res_samples = [down_block_res_sample] |
|
|
|
for down_block in self.down_blocks: |
|
for i, resnet in enumerate(down_block["resnets"]): |
|
hidden_state = resnet(hidden_state, t) |
|
|
|
if "attentions" in down_block: |
|
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states) |
|
|
|
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state) |
|
down_block_res_samples.append(down_block_res_sample) |
|
|
|
if "downsamplers" in down_block: |
|
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state) |
|
|
|
down_block_res_sample = self.controlnet_down_blocks[len(down_block_res_samples)](hidden_state) |
|
down_block_res_samples.append(down_block_res_sample) |
|
|
|
hidden_state = self.mid_block["resnets"][0](hidden_state, t) |
|
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states) |
|
hidden_state = self.mid_block["resnets"][1](hidden_state, t) |
|
|
|
mid_block_res_sample = self.controlnet_mid_block(hidden_state) |
|
|
|
return dict( |
|
down_block_res_samples=down_block_res_samples, |
|
mid_block_res_sample=mid_block_res_sample, |
|
) |
|
|
|
@classmethod |
|
def from_unet(cls, unet): |
|
controlnet = cls() |
|
|
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) |
|
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) |
|
|
|
conv_in_weight = unet.conv_in.state_dict()["weight"] |
|
padding = torch.zeros((320, 5, 3, 3), device=conv_in_weight.device, dtype=conv_in_weight.dtype) |
|
conv_in_weight = torch.concat((conv_in_weight, padding), dim=1) |
|
|
|
conv_in_bias = unet.conv_in.state_dict()["bias"] |
|
|
|
controlnet.conv_in.load_state_dict({"weight": conv_in_weight, "bias": conv_in_bias}) |
|
|
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) |
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) |
|
|
|
return controlnet |
|
|
|
|
|
class SDXLControlNetFull(nn.Module, ModelUtils): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
encoder_hidden_states_dim = 2048 |
|
|
|
|
|
|
|
time_sinusoidal_embedding_dim = 320 |
|
time_embedding_dim = 1280 |
|
|
|
self.get_sinusoidal_timestep_embedding = lambda timesteps: get_sinusoidal_embedding(timesteps, time_sinusoidal_embedding_dim) |
|
|
|
self.time_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(time_sinusoidal_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
|
|
num_micro_conditioning_values = 6 |
|
micro_conditioning_embedding_dim = 256 |
|
additional_embedding_encoder_dim = 1280 |
|
self.get_sinusoidal_micro_conditioning_embedding = lambda micro_conditioning: get_sinusoidal_embedding(micro_conditioning, micro_conditioning_embedding_dim) |
|
|
|
self.add_embedding = nn.ModuleDict(dict( |
|
linear_1=nn.Linear(additional_embedding_encoder_dim + num_micro_conditioning_values * micro_conditioning_embedding_dim, time_embedding_dim), |
|
act=nn.SiLU(), |
|
linear_2=nn.Linear(time_embedding_dim, time_embedding_dim), |
|
)) |
|
|
|
|
|
self.controlnet_cond_embedding = nn.ModuleDict(dict( |
|
conv_in=nn.Conv2d(3, 16, kernel_size=3, padding=1), |
|
blocks=nn.ModuleList([ |
|
|
|
nn.Conv2d(16, 16, kernel_size=3, padding=1), |
|
nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2), |
|
|
|
nn.Conv2d(32, 32, kernel_size=3, padding=1), |
|
nn.Conv2d(32, 96, kernel_size=3, padding=1, stride=2), |
|
|
|
nn.Conv2d(96, 96, kernel_size=3, padding=1), |
|
nn.Conv2d(96, 256, kernel_size=3, padding=1, stride=2), |
|
]), |
|
conv_out=zero_module(nn.Conv2d(256, 320, kernel_size=3, padding=1)), |
|
)) |
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d(4, 320, kernel_size=3, padding=1) |
|
|
|
self.down_blocks = nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
ResnetBlock2D(320, 320, time_embedding_dim), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(320, 640, time_embedding_dim), |
|
ResnetBlock2D(640, 640, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
]), |
|
downsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(640, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
]), |
|
)), |
|
]) |
|
|
|
self.controlnet_down_blocks = nn.ModuleList([ |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
]) |
|
|
|
self.mid_block = nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10)]), |
|
)) |
|
|
|
self.controlnet_mid_block = zero_module(nn.Conv2d(1280, 1280, kernel_size=1)) |
|
|
|
self.up_blocks = nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280 + 1280, 1280, time_embedding_dim), |
|
ResnetBlock2D(1280 + 640, 1280, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
TransformerDecoder2D(1280, encoder_hidden_states_dim, num_transformer_blocks=10), |
|
]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(1280, 1280, kernel_size=3, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(1280 + 640, 640, time_embedding_dim), |
|
ResnetBlock2D(640 + 640, 640, time_embedding_dim), |
|
ResnetBlock2D(640 + 320, 640, time_embedding_dim), |
|
]), |
|
attentions=nn.ModuleList([ |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
TransformerDecoder2D(640, encoder_hidden_states_dim, num_transformer_blocks=2), |
|
]), |
|
upsamplers=nn.ModuleList([nn.ModuleDict(dict(conv=nn.Conv2d(640, 640, kernel_size=3, padding=1)))]), |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList([ |
|
ResnetBlock2D(640 + 320, 320, time_embedding_dim), |
|
ResnetBlock2D(320 + 320, 320, time_embedding_dim), |
|
ResnetBlock2D(320 + 320, 320, time_embedding_dim), |
|
]), |
|
)) |
|
]) |
|
|
|
|
|
|
|
self.controlnet_up_blocks = nn.ModuleList([ |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 1280, kernel_size=1)), |
|
zero_module(nn.Conv2d(1280, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 640, kernel_size=1)), |
|
zero_module(nn.Conv2d(640, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
zero_module(nn.Conv2d(320, 320, kernel_size=1)), |
|
]) |
|
|
|
self.conv_norm_out = nn.GroupNorm(32, 320) |
|
self.conv_act = nn.SiLU() |
|
self.conv_out = nn.Conv2d(320, 4, kernel_size=3, padding=1) |
|
|
|
self.controlnet_conv_out = zero_module(nn.Conv2d(4, 4, kernel_size=1)) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
x_t, |
|
t, |
|
encoder_hidden_states, |
|
micro_conditioning, |
|
pooled_encoder_hidden_states, |
|
controlnet_cond, |
|
): |
|
hidden_state = x_t |
|
|
|
t = self.get_sinusoidal_timestep_embedding(t) |
|
t = t.to(dtype=hidden_state.dtype) |
|
t = self.time_embedding["linear_1"](t) |
|
t = self.time_embedding["act"](t) |
|
t = self.time_embedding["linear_2"](t) |
|
|
|
additional_conditioning = self.get_sinusoidal_micro_conditioning_embedding(micro_conditioning) |
|
additional_conditioning = additional_conditioning.to(dtype=hidden_state.dtype) |
|
additional_conditioning = additional_conditioning.flatten(1) |
|
additional_conditioning = torch.concat([pooled_encoder_hidden_states, additional_conditioning], dim=-1) |
|
additional_conditioning = self.add_embedding["linear_1"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["act"](additional_conditioning) |
|
additional_conditioning = self.add_embedding["linear_2"](additional_conditioning) |
|
|
|
t = t + additional_conditioning |
|
|
|
controlnet_cond = self.controlnet_cond_embedding["conv_in"](controlnet_cond) |
|
controlnet_cond = F.silu(controlnet_cond) |
|
|
|
for block in self.controlnet_cond_embedding["blocks"]: |
|
controlnet_cond = F.silu(block(controlnet_cond)) |
|
|
|
controlnet_cond = self.controlnet_cond_embedding["conv_out"](controlnet_cond) |
|
|
|
hidden_state = self.conv_in(hidden_state) |
|
|
|
hidden_state = hidden_state + controlnet_cond |
|
|
|
residuals = [hidden_state] |
|
|
|
add_to_down_block_input = self.controlnet_down_blocks[0](hidden_state) |
|
add_to_down_block_inputs = [add_to_down_block_input] |
|
|
|
for down_block in self.down_blocks: |
|
for i, resnet in enumerate(down_block["resnets"]): |
|
hidden_state = resnet(hidden_state, t) |
|
|
|
if "attentions" in down_block: |
|
hidden_state = down_block["attentions"][i](hidden_state, encoder_hidden_states) |
|
|
|
if len(add_to_down_block_inputs) < len(self.controlnet_down_blocks): |
|
add_to_down_block_input = self.controlnet_down_blocks[len(add_to_down_block_inputs)](hidden_state) |
|
add_to_down_block_inputs.append(add_to_down_block_input) |
|
|
|
residuals.append(hidden_state) |
|
|
|
if "downsamplers" in down_block: |
|
hidden_state = down_block["downsamplers"][0]["conv"](hidden_state) |
|
|
|
if len(add_to_down_block_inputs) < len(self.controlnet_down_blocks): |
|
add_to_down_block_input = self.controlnet_down_blocks[len(add_to_down_block_inputs)](hidden_state) |
|
add_to_down_block_inputs.append(add_to_down_block_input) |
|
|
|
residuals.append(hidden_state) |
|
|
|
hidden_state = self.mid_block["resnets"][0](hidden_state, t) |
|
hidden_state = self.mid_block["attentions"][0](hidden_state, encoder_hidden_states) |
|
hidden_state = self.mid_block["resnets"][1](hidden_state, t) |
|
|
|
mid_block_res_sample = self.controlnet_mid_block(hidden_state) |
|
|
|
down_block_res_samples = [] |
|
|
|
for up_block in self.up_blocks: |
|
for i, resnet in enumerate(up_block["resnets"]): |
|
residual = residuals.pop() |
|
|
|
hidden_state = torch.concat([hidden_state, residual], dim=1) |
|
|
|
hidden_state = resnet(hidden_state, t) |
|
|
|
if "attentions" in up_block: |
|
hidden_state = up_block["attentions"][i](hidden_state, encoder_hidden_states) |
|
|
|
down_block_res_sample = self.controlnet_up_blocks[len(down_block_res_samples)](hidden_state) |
|
down_block_res_samples.insert(0, down_block_res_sample) |
|
|
|
if "upsamplers" in up_block: |
|
hidden_state = F.interpolate(hidden_state, scale_factor=2.0, mode="nearest") |
|
hidden_state = up_block["upsamplers"][0]["conv"](hidden_state) |
|
|
|
hidden_state = self.conv_norm_out(hidden_state) |
|
hidden_state = self.conv_act(hidden_state) |
|
hidden_state = self.conv_out(hidden_state) |
|
|
|
add_to_output = self.controlnet_conv_out(hidden_state) |
|
|
|
return dict( |
|
down_block_res_samples=down_block_res_samples, |
|
mid_block_res_sample=mid_block_res_sample, |
|
add_to_down_block_inputs=add_to_down_block_inputs, |
|
add_to_output=add_to_output, |
|
) |
|
|
|
@classmethod |
|
def from_unet(cls, unet): |
|
controlnet = cls() |
|
|
|
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) |
|
controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) |
|
|
|
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) |
|
|
|
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) |
|
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) |
|
controlnet.up_blocks.load_state_dict(unet.up_blocks.state_dict()) |
|
|
|
controlnet.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) |
|
controlnet.conv_out.load_state_dict(unet.conv_out.state_dict()) |
|
|
|
return controlnet |
|
|
|
|
|
class SDXLAdapter(nn.Module, ModelUtils): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
|
|
|
|
self.adapter = nn.ModuleDict(dict( |
|
|
|
unshuffle=nn.PixelUnshuffle(16), |
|
|
|
|
|
conv_in=nn.Conv2d(768, 320, kernel_size=3, padding=1), |
|
|
|
body=nn.ModuleList([ |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList( |
|
nn.ModuleDict(dict(block1=nn.Conv2d(320, 320, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(320, 320, kernel_size=1))), |
|
nn.ModuleDict(dict(block1=nn.Conv2d(320, 320, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(320, 320, kernel_size=1))), |
|
) |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
in_conv=nn.Conv2d(320, 640, kernel_size=1), |
|
resnets=nn.ModuleList( |
|
nn.ModuleDict(dict(block1=nn.Conv2d(640, 640, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(640, 640, kernel_size=1))), |
|
nn.ModuleDict(dict(block1=nn.Conv2d(640, 640, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(640, 640, kernel_size=1))), |
|
) |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
downsample=nn.AvgPool2d(kernel_size=2, stride=2, padding=0), |
|
in_conv=nn.Conv2d(640, 1280, kernel_size=1), |
|
resnets=nn.ModuleList( |
|
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))), |
|
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))), |
|
) |
|
)), |
|
|
|
nn.ModuleDict(dict( |
|
resnets=nn.ModuleList( |
|
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))), |
|
nn.ModuleDict(dict(block1=nn.Conv2d(1280, 1280, kernel_size=3, padding=1), act=nn.ReLU(), block2=nn.Conv2d(1280, 1280, kernel_size=1))), |
|
) |
|
)), |
|
]) |
|
)) |
|
|
|
|
|
|
|
def forward(self, x): |
|
x = self.unshuffle(x) |
|
x = self.conv_in(x) |
|
|
|
features = [] |
|
|
|
for block in self.body: |
|
if "downsample" in block: |
|
x = block["downsample"](x) |
|
|
|
if "in_conv" in block: |
|
x = block["in_conv"](x) |
|
|
|
for resnet in block["resnets"]: |
|
residual = x |
|
x = resnet["block1"](x) |
|
x = resnet["act"](x) |
|
x = resnet["block2"](x) |
|
x = residual + x |
|
|
|
features.append(x) |
|
|
|
return features |
|
|
|
|
|
def get_sinusoidal_embedding( |
|
indices: torch.Tensor, |
|
embedding_dim: int, |
|
): |
|
half_dim = embedding_dim // 2 |
|
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=indices.device) |
|
exponent = exponent / half_dim |
|
|
|
emb = torch.exp(exponent) |
|
emb = indices.unsqueeze(-1).float() * emb |
|
emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) |
|
|
|
return emb |
|
|
|
|
|
class ResnetBlock2D(nn.Module): |
|
def __init__(self, in_channels, out_channels, time_embedding_dim=None, eps=1e-5): |
|
super().__init__() |
|
|
|
if time_embedding_dim is not None: |
|
self.time_emb_proj = nn.Linear(time_embedding_dim, out_channels) |
|
else: |
|
self.time_emb_proj = None |
|
|
|
self.norm1 = torch.nn.GroupNorm(32, in_channels, eps=eps) |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
|
self.norm2 = nn.GroupNorm(32, out_channels, eps=eps) |
|
self.dropout = nn.Dropout(0.0) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) |
|
|
|
self.nonlinearity = nn.SiLU() |
|
|
|
if in_channels != out_channels: |
|
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) |
|
else: |
|
self.conv_shortcut = None |
|
|
|
def forward(self, hidden_states, temb=None): |
|
residual = hidden_states |
|
|
|
hidden_states = self.norm1(hidden_states) |
|
hidden_states = self.nonlinearity(hidden_states) |
|
hidden_states = self.conv1(hidden_states) |
|
|
|
if self.time_emb_proj is not None: |
|
assert temb is not None |
|
temb = self.nonlinearity(temb) |
|
temb = self.time_emb_proj(temb)[:, :, None, None] |
|
hidden_states = hidden_states + temb |
|
|
|
hidden_states = self.norm2(hidden_states) |
|
hidden_states = self.nonlinearity(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
hidden_states = self.conv2(hidden_states) |
|
|
|
if self.conv_shortcut is not None: |
|
residual = self.conv_shortcut(residual) |
|
|
|
hidden_states = hidden_states + residual |
|
|
|
return hidden_states |
|
|
|
|
|
class TransformerDecoder2D(nn.Module): |
|
def __init__(self, channels, encoder_hidden_states_dim, num_transformer_blocks): |
|
super().__init__() |
|
|
|
self.norm = nn.GroupNorm(32, channels, eps=1e-06) |
|
self.proj_in = nn.Linear(channels, channels) |
|
|
|
self.transformer_blocks = nn.ModuleList([TransformerDecoderBlock(channels, encoder_hidden_states_dim) for _ in range(num_transformer_blocks)]) |
|
|
|
self.proj_out = nn.Linear(channels, channels) |
|
|
|
def forward(self, hidden_states, encoder_hidden_states): |
|
batch_size, channels, height, width = hidden_states.shape |
|
|
|
residual = hidden_states |
|
|
|
hidden_states = self.norm(hidden_states) |
|
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) |
|
hidden_states = self.proj_in(hidden_states) |
|
|
|
for block in self.transformer_blocks: |
|
hidden_states = block(hidden_states, encoder_hidden_states) |
|
|
|
hidden_states = self.proj_out(hidden_states) |
|
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2).contiguous() |
|
|
|
hidden_states = hidden_states + residual |
|
|
|
return hidden_states |
|
|
|
|
|
class TransformerDecoderBlock(nn.Module): |
|
def __init__(self, channels, encoder_hidden_states_dim): |
|
super().__init__() |
|
|
|
self.norm1 = nn.LayerNorm(channels) |
|
self.attn1 = Attention(channels, channels) |
|
|
|
self.norm2 = nn.LayerNorm(channels) |
|
self.attn2 = Attention(channels, encoder_hidden_states_dim) |
|
|
|
self.norm3 = nn.LayerNorm(channels) |
|
self.ff = nn.ModuleDict(dict(net=nn.Sequential(GEGLU(channels, 4 * channels), nn.Dropout(0.0), nn.Linear(4 * channels, channels)))) |
|
|
|
def forward(self, hidden_states, encoder_hidden_states): |
|
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states |
|
|
|
hidden_states = self.attn2(self.norm2(hidden_states), encoder_hidden_states) + hidden_states |
|
|
|
hidden_states = self.ff["net"](self.norm3(hidden_states)) + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class AttentionMixin: |
|
attention_implementation: Literal["xformers", "torch_2.0_scaled_dot_product"] = "xformers" |
|
|
|
@classmethod |
|
def attention(cls, to_q, to_k, to_v, to_out, head_dim, hidden_states, encoder_hidden_states=None): |
|
batch_size, q_seq_len, channels = hidden_states.shape |
|
|
|
if encoder_hidden_states is not None: |
|
kv = encoder_hidden_states |
|
else: |
|
kv = hidden_states |
|
|
|
kv_seq_len = kv.shape[1] |
|
|
|
query = to_q(hidden_states) |
|
key = to_k(kv) |
|
value = to_v(kv) |
|
|
|
if AttentionMixin.attention_implementation == "xformers": |
|
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).contiguous() |
|
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous() |
|
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value) |
|
|
|
hidden_states = hidden_states.to(query.dtype) |
|
hidden_states = hidden_states.reshape(batch_size, q_seq_len, channels).contiguous() |
|
elif AttentionMixin.attention_implementation == "torch_2.0_scaled_dot_product": |
|
query = query.reshape(batch_size, q_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous() |
|
key = key.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous() |
|
value = value.reshape(batch_size, kv_seq_len, channels // head_dim, head_dim).transpose(1, 2).contiguous() |
|
|
|
hidden_states = F.scaled_dot_product_attention(query, key, value) |
|
|
|
hidden_states = hidden_states.to(query.dtype) |
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, q_seq_len, channels).contiguous() |
|
else: |
|
assert False |
|
|
|
hidden_states = to_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class Attention(nn.Module, AttentionMixin): |
|
def __init__(self, channels, encoder_hidden_states_dim): |
|
super().__init__() |
|
self.to_q = nn.Linear(channels, channels, bias=False) |
|
self.to_k = nn.Linear(encoder_hidden_states_dim, channels, bias=False) |
|
self.to_v = nn.Linear(encoder_hidden_states_dim, channels, bias=False) |
|
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0)) |
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None): |
|
return self.attention(self.to_q, self.to_k, self.to_v, self.to_out, 64, hidden_states, encoder_hidden_states) |
|
|
|
|
|
class VaeMidBlockAttention(nn.Module, AttentionMixin): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.group_norm = nn.GroupNorm(32, channels, eps=1e-06) |
|
self.to_q = nn.Linear(channels, channels) |
|
self.to_k = nn.Linear(channels, channels) |
|
self.to_v = nn.Linear(channels, channels) |
|
self.to_out = nn.Sequential(nn.Linear(channels, channels), nn.Dropout(0.0)) |
|
self.head_dim = channels |
|
|
|
def forward(self, hidden_states): |
|
residual = hidden_states |
|
|
|
batch_size, channels, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) |
|
|
|
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
hidden_states = self.attention(self.to_q, self.to_k, self.to_v, self.to_out, self.head_dim, hidden_states) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).view(batch_size, channels, height, width) |
|
|
|
hidden_states = hidden_states + residual |
|
|
|
return hidden_states |
|
|
|
|
|
class GEGLU(nn.Module): |
|
def __init__(self, dim_in: int, dim_out: int): |
|
super().__init__() |
|
self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) |
|
return hidden_states * F.gelu(gate) |
|
|
|
|
|
def zero_module(module): |
|
for p in module.parameters(): |
|
nn.init.zeros_(p) |
|
return module |
|
|