ctrl-x / ctrl_x /utils /sdxl.py
multimodalart's picture
Upload folder using huggingface_hub
3aff77a verified
raw
history blame
12.2 kB
from types import MethodType
from typing import Optional
from diffusers.models.attention_processor import Attention
import torch
import torch.nn.functional as F
from .feature import *
from .utils import *
def convolution_forward( # From <class 'diffusers.models.resnet.ResnetBlock2D'>, forward (diffusers==0.28.0)
self,
input_tensor: torch.Tensor,
temb: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
do_structure_control = self.do_control and self.t in self.structure_schedule
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if self.time_embedding_norm == "default":
if temb is not None:
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
elif self.time_embedding_norm == "scale_shift":
if temb is None:
raise ValueError(
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
)
time_scale, time_shift = torch.chunk(temb, 2, dim=1)
hidden_states = self.norm2(hidden_states)
hidden_states = hidden_states * (1 + time_scale) + time_shift
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
# Feature injection and AdaIN (hidden_states)
if do_structure_control and "hidden_states" in self.structure_target:
hidden_states = feature_injection(hidden_states, batch_order=self.batch_order)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
# Feature injection and AdaIN (output_tensor)
if do_structure_control and "output_tensor" in self.structure_target:
output_tensor = feature_injection(output_tensor, batch_order=self.batch_order)
return output_tensor
class AttnProcessor2_0: # From <class 'diffusers.models.attention_processor.AttnProcessor2_0'> (diffusers==0.28.0)
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
do_structure_control = attn.do_control and attn.t in attn.structure_schedule
do_appearance_control = attn.do_control and attn.t in attn.appearance_schedule
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
no_encoder_hidden_states = encoder_hidden_states is None
if no_encoder_hidden_states:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if do_appearance_control: # Assume we only have this for self attention
hidden_states_normed = normalize(hidden_states, dim=-2) # B H D C
encoder_hidden_states_normed = normalize(encoder_hidden_states, dim=-2)
query_normed = attn.to_q(hidden_states_normed)
key_normed = attn.to_k(encoder_hidden_states_normed)
inner_dim = key_normed.shape[-1]
head_dim = inner_dim // attn.heads
query_normed = query_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_normed = key_normed.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Match query and key injection with structure injection (if injection is happening this layer)
if do_structure_control:
if "query" in attn.structure_target:
query_normed = feature_injection(query_normed, batch_order=attn.batch_order)
if "key" in attn.structure_target:
key_normed = feature_injection(key_normed, batch_order=attn.batch_order)
# Appearance transfer (before)
if do_appearance_control and "before" in attn.appearance_target:
hidden_states = hidden_states.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
if no_encoder_hidden_states:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Feature injection (query, key, and/or value)
if do_structure_control:
if "query" in attn.structure_target:
query = feature_injection(query, batch_order=attn.batch_order)
if "key" in attn.structure_target:
key = feature_injection(key, batch_order=attn.batch_order)
if "value" in attn.structure_target:
value = feature_injection(value, batch_order=attn.batch_order)
# Appearance transfer (value)
if do_appearance_control and "value" in attn.appearance_target:
value = appearance_transfer(value, query_normed, key_normed, batch_order=attn.batch_order)
# The output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Appearance transfer (after)
if do_appearance_control and "after" in attn.appearance_target:
hidden_states = appearance_transfer(hidden_states, query_normed, key_normed, batch_order=attn.batch_order)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Linear projection
hidden_states = attn.to_out[0](hidden_states, *args)
# Dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def register_control(
model,
timesteps,
control_schedule, # structure_conv, structure_attn, appearance_attn
control_target = [["output_tensor"], ["query", "key"], ["before"]],
):
# Assume timesteps in reverse order (T -> 0)
for block_type in ["encoder", "decoder", "middle"]:
blocks = {
"encoder": model.unet.down_blocks,
"decoder": model.unet.up_blocks,
"middle": [model.unet.mid_block],
}[block_type]
control_schedule_block = control_schedule[block_type]
if block_type == "middle":
control_schedule_block = [control_schedule_block]
for layer in range(len(control_schedule_block)):
# Convolution
num_blocks = len(blocks[layer].resnets) if hasattr(blocks[layer], "resnets") else 0
for block in range(num_blocks):
convolution = blocks[layer].resnets[block]
convolution.structure_target = control_target[0]
convolution.structure_schedule = get_schedule(
timesteps, get_elem(control_schedule_block[layer][0], block)
)
convolution.forward = MethodType(convolution_forward, convolution)
# Self-attention
num_blocks = len(blocks[layer].attentions) if hasattr(blocks[layer], "attentions") else 0
for block in range(num_blocks):
for transformer_block in blocks[layer].attentions[block].transformer_blocks:
attention = transformer_block.attn1
attention.structure_target = control_target[1]
attention.structure_schedule = get_schedule(
timesteps, get_elem(control_schedule_block[layer][1], block)
)
attention.appearance_target = control_target[2]
attention.appearance_schedule = get_schedule(
timesteps, get_elem(control_schedule_block[layer][2], block)
)
attention.processor = AttnProcessor2_0()
def register_attr(model, t, do_control, batch_order):
for layer_type in ["encoder", "decoder", "middle"]:
blocks = {"encoder": model.unet.down_blocks, "decoder": model.unet.up_blocks,
"middle": [model.unet.mid_block]}[layer_type]
for layer in blocks:
# Convolution
for module in layer.resnets:
module.t = t
module.do_control = do_control
module.batch_order = batch_order
# Self-attention
if hasattr(layer, "attentions"):
for block in layer.attentions:
for module in block.transformer_blocks:
module.attn1.t = t
module.attn1.do_control = do_control
module.attn1.batch_order = batch_order