EditAnything / utils /stable_diffusion_reference.py
Shanghua Gao
udpate
78acc58
# Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py
# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236
import torch.fft as fft
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import numpy as np
import PIL.Image
import torch
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UpBlock2D,
)
from diffusers.utils import PIL_INTERPOLATION, logging
import torch.nn.functional as F
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import UniPCMultistepScheduler
>>> from diffusers.utils import load_image
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
>>> pipe = StableDiffusionReferencePipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
safety_checker=None,
torch_dtype=torch.float16
).to('cuda:0')
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
>>> result_img = pipe(ref_image=input_image,
prompt="1girl",
num_inference_steps=20,
reference_attn=True,
reference_adain=True).images[0]
>>> result_img.show()
```
"""
def torch_dfs(model: torch.nn.Module):
result = [model]
for child in model.children():
result += torch_dfs(child)
return result
@torch.no_grad()
def add_freq_feature(feature1, feature2, ref_ratio):
"""
feature1: reference feature
feature2: target feature
ref_ratio: larger ratio means larger reference frequency
"""
# Convert features to float32 (if not already) for compatibility with fft operations
data_type = feature2.dtype
feature1 = feature1.to(torch.float32)
feature2 = feature2.to(torch.float32)
# Compute the Fourier transforms of both features
spectrum1 = fft.fftn(feature1, dim=(-2, -1))
spectrum2 = fft.fftn(feature2, dim=(-2, -1))
# Extract high-frequency magnitude and phase from feature1
magnitude1 = torch.abs(spectrum1)
# phase1 = torch.angle(spectrum1)
# Extract magnitude and phase from feature2
magnitude2 = torch.abs(spectrum2)
phase2 = torch.angle(spectrum2)
magnitude2.mul_((1-ref_ratio)).add_(magnitude1 * ref_ratio)
# phase2.mul_(1.0).add_(phase1 * 0.0)
# Combine magnitude and phase information
mixed_spectrum = torch.polar(magnitude2, phase2)
# Compute the inverse Fourier transform to get the mixed feature
mixed_feature = fft.ifftn(mixed_spectrum, dim=(-2, -1))
del feature1, feature2, spectrum1, spectrum2, magnitude1, magnitude2, phase2, mixed_spectrum
# Convert back to the original data type and return the result
return mixed_feature.to(data_type)
@torch.no_grad()
def save_ref_feature(feature, mask):
"""
feature: n,c,h,w
mask: n,1,h,w
return n,c,h,w
"""
return feature * mask
@torch.no_grad()
def mix_ref_feature(feature, ref_fea_bank, cfg=True, ref_scale=0.0, dim3=False):
"""
feature: n,l,c or n,c,h,w
ref_fea_bank: [(n,c,h,w)]
cfg: True/False
return n,l,c or n,c,h,w
"""
if cfg:
ref_fea = torch.cat(
(ref_fea_bank+ref_fea_bank), dim=0)
else:
ref_fea = ref_fea_bank
if dim3:
feature = feature.permute(0, 2, 1).view(ref_fea.shape)
mixed_feature = add_freq_feature(ref_fea, feature, ref_scale)
if dim3:
mixed_feature = mixed_feature.view(
ref_fea.shape[0], ref_fea.shape[1], -1).permute(0, 2, 1)
del ref_fea
del feature
return mixed_feature
def mix_norm_feature(x, inpaint_mask, mean_bank, var_bank, do_classifier_free_guidance, style_fidelity, uc_mask, eps=1e-6):
"""
x: input feature n,c,h,w
inpaint_mask: mask region to inpain
"""
# get the inpainting region and only mix this region.
scale_ratio = inpaint_mask.shape[2] / x.shape[2]
this_inpaint_mask = F.interpolate(
inpaint_mask.to(x.device), scale_factor=1 / scale_ratio
)
this_inpaint_mask = this_inpaint_mask.repeat(
x.shape[0], x.shape[1], 1, 1
).bool()
masked_x = (
x[this_inpaint_mask]
.detach()
.clone()
.view(x.shape[0], x.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_x, dim=(2, 3), keepdim=True, correction=0
)
std = torch.maximum(
var, torch.zeros_like(var) + eps) ** 0.5
mean_acc = sum(mean_bank) / float(len(mean_bank))
var_acc = sum(var_bank) / float(len(var_bank))
std_acc = (
torch.maximum(var_acc, torch.zeros_like(
var_acc) + eps) ** 0.5
)
x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc
x_c = x_uc.clone()
if do_classifier_free_guidance and style_fidelity > 0:
x_c[uc_mask] = masked_x[uc_mask]
masked_x = style_fidelity * x_c + \
(1.0 - style_fidelity) * x_uc
x[this_inpaint_mask] = masked_x.view(-1)
return x
class StableDiffusionReferencePipeline:
def prepare_ref_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = (image - 0.5) / 0.5
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
def prepare_ref_latents(
self,
refimage,
batch_size,
dtype,
device,
generator,
do_classifier_free_guidance,
):
refimage = refimage.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
ref_image_latents = [
self.vae.encode(refimage[i: i + 1]).latent_dist.sample(
generator=generator[i]
)
for i in range(batch_size)
]
ref_image_latents = torch.cat(ref_image_latents, dim=0)
else:
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(
generator=generator
)
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
if ref_image_latents.shape[0] < batch_size:
if not batch_size % ref_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
ref_image_latents = ref_image_latents.repeat(
batch_size // ref_image_latents.shape[0], 1, 1, 1
)
ref_image_latents = (
torch.cat([ref_image_latents] * 2)
if do_classifier_free_guidance
else ref_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
return ref_image_latents
def check_ref_input(self, reference_attn, reference_adain):
assert (
reference_attn or reference_adain
), "`reference_attn` or `reference_adain` must be True."
def redefine_ref_model(
self, model, reference_attn, reference_adain, model_type="unet"
):
def hacked_basic_transformer_inner_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
):
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.norm1(
hidden_states,
timestep,
class_labels,
hidden_dtype=hidden_states.dtype,
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
cross_attention_kwargs = (
cross_attention_kwargs if cross_attention_kwargs is not None else {}
)
if self.only_cross_attention:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
else:
if self.MODE == "write":
if self.attention_auto_machine_weight > self.attn_weight:
# print("hacked_basic_transformer_inner_forward")
scale_ratio = (
(self.ref_mask.shape[2] * self.ref_mask.shape[3])
/ norm_hidden_states.shape[1]
) ** 0.5
this_ref_mask = F.interpolate(
self.ref_mask.to(norm_hidden_states.device),
scale_factor=1 / scale_ratio,
)
resize_norm_hidden_states = norm_hidden_states.view(
norm_hidden_states.shape[0],
this_ref_mask.shape[2],
this_ref_mask.shape[3],
-1,
).permute(0, 3, 1, 2)
ref_scale = 1.0
resize_norm_hidden_states = F.interpolate(
resize_norm_hidden_states,
scale_factor=ref_scale,
mode="bilinear",
)
this_ref_mask = F.interpolate(
this_ref_mask, scale_factor=ref_scale
)
self.fea_bank.append(save_ref_feature(
resize_norm_hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
resize_norm_hidden_states.shape[0],
resize_norm_hidden_states.shape[1],
1,
1,
).bool()
masked_norm_hidden_states = (
resize_norm_hidden_states[this_ref_mask]
.detach()
.clone()
.view(
resize_norm_hidden_states.shape[0],
resize_norm_hidden_states.shape[1],
-1,
)
)
masked_norm_hidden_states = masked_norm_hidden_states.permute(
0, 2, 1
)
self.bank.append(masked_norm_hidden_states)
del masked_norm_hidden_states
del this_ref_mask
del resize_norm_hidden_states
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.MODE == "read":
if self.attention_auto_machine_weight > self.attn_weight:
freq_norm_hidden_states = mix_ref_feature(
norm_hidden_states,
self.fea_bank,
cfg=self.do_classifier_free_guidance,
ref_scale=self.ref_scale,
dim3=True)
self.fea_bank.clear()
this_bank = torch.cat(self.bank+self.bank, dim=0)
ref_hidden_states = torch.cat(
(freq_norm_hidden_states, this_bank), dim=1
)
del this_bank
self.bank.clear()
attn_output_uc = self.attn1(
freq_norm_hidden_states,
encoder_hidden_states=ref_hidden_states,
**cross_attention_kwargs,
)
del ref_hidden_states
attn_output_c = attn_output_uc.clone()
if self.do_classifier_free_guidance and self.style_fidelity > 0:
attn_output_c[self.uc_mask] = self.attn1(
norm_hidden_states[self.uc_mask],
encoder_hidden_states=norm_hidden_states[self.uc_mask],
**cross_attention_kwargs,
)
attn_output = (
self.style_fidelity * attn_output_c
+ (1.0 - self.style_fidelity) * attn_output_uc
)
self.bank.clear()
self.fea_bank.clear()
del attn_output_c
del attn_output_uc
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states
if self.only_cross_attention
else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
self.bank.clear()
self.fea_bank.clear()
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep)
if self.use_ada_layer_norm
else self.norm2(hidden_states)
)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = (
norm_hidden_states *
(1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
def hacked_mid_forward(self, *args, **kwargs):
eps = 1e-6
x = self.original_forward(*args, **kwargs)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# mask var mean
scale_ratio = self.ref_mask.shape[2] / x.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(x.device), scale_factor=1 / scale_ratio
)
self.fea_bank.append(save_ref_feature(
x, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
x.shape[0], x.shape[1], 1, 1
).bool()
masked_x = (
x[this_ref_mask]
.detach()
.clone()
.view(x.shape[0], x.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_x, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(torch.cat([mean]*2, dim=0))
self.var_bank.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_mid_forward")
x = mix_ref_feature(
x, self.fea_bank, cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
self.fea_bank = []
x = mix_norm_feature(x, self.inpaint_mask, self.mean_bank, self.var_bank,
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
self.mean_bank = []
self.var_bank = []
return x
def hack_CrossAttnDownBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
output_states = ()
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
self.fea_bank0.append(save_ref_feature(
hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank0.append(torch.cat([mean]*2, dim=0))
self.var_bank0.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank0) > 0
and len(self.var_bank0) > 0
):
# print("hacked_CrossAttnDownBlock2D_forward0")
hidden_states = mix_ref_feature(
hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i],
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
self.fea_bank.append(save_ref_feature(
hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(torch.cat([mean]*2, dim=0))
self.var_bank.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hack_CrossAttnDownBlock2D_forward")
hidden_states = mix_ref_feature(
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
output_states = output_states + (hidden_states,)
if self.MODE == "read":
self.mean_bank0 = []
self.var_bank0 = []
self.mean_bank = []
self.var_bank = []
self.fea_bank0 = []
self.fea_bank = []
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
eps = 1e-6
output_states = ()
for i, resnet in enumerate(self.resnets):
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
self.fea_bank.append(save_ref_feature(
hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(torch.cat([mean]*2, dim=0))
self.var_bank.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_DownBlock2D_forward")
hidden_states = mix_ref_feature(
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
output_states = output_states + (hidden_states,)
if self.MODE == "read":
self.mean_bank = []
self.var_bank = []
self.fea_bank = []
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,)
return hidden_states, output_states
def hacked_CrossAttnUpBlock2D_forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
eps = 1e-6
# TODO(Patrick, William) - attention mask is not used
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat(
[hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
self.fea_bank0.append(save_ref_feature(
hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank0.append(torch.cat([mean]*2, dim=0))
self.var_bank0.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank0) > 0
and len(self.var_bank0) > 0
):
# print("hacked_CrossAttnUpBlock2D_forward1")
hidden_states = mix_ref_feature(
hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i],
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
# attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
self.fea_bank.append(save_ref_feature(
hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(torch.cat([mean]*2, dim=0))
self.var_bank.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_CrossAttnUpBlock2D_forward")
hidden_states = mix_ref_feature(
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
if self.MODE == "read":
self.mean_bank0 = []
self.var_bank0 = []
self.mean_bank = []
self.var_bank = []
self.fea_bank = []
self.fea_bank0 = []
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def hacked_UpBlock2D_forward(
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None
):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat(
[hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.MODE == "write":
if self.gn_auto_machine_weight >= self.gn_weight:
# var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0)
# mask var mean
scale_ratio = self.ref_mask.shape[2] / \
hidden_states.shape[2]
this_ref_mask = F.interpolate(
self.ref_mask.to(hidden_states.device),
scale_factor=1 / scale_ratio,
)
self.fea_bank.append(save_ref_feature(
hidden_states, this_ref_mask))
this_ref_mask = this_ref_mask.repeat(
hidden_states.shape[0], hidden_states.shape[1], 1, 1
).bool()
masked_hidden_states = (
hidden_states[this_ref_mask]
.detach()
.clone()
.view(hidden_states.shape[0], hidden_states.shape[1], -1, 1)
)
var, mean = torch.var_mean(
masked_hidden_states, dim=(2, 3), keepdim=True, correction=0
)
self.mean_bank.append(torch.cat([mean]*2, dim=0))
self.var_bank.append(torch.cat([var]*2, dim=0))
if self.MODE == "read":
if (
self.gn_auto_machine_weight >= self.gn_weight
and len(self.mean_bank) > 0
and len(self.var_bank) > 0
):
# print("hacked_UpBlock2D_forward")
hidden_states = mix_ref_feature(
hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale)
hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i],
self.do_classifier_free_guidance,
self.style_fidelity, self.uc_mask)
if self.MODE == "read":
self.mean_bank = []
self.var_bank = []
self.fea_bank = []
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
if model_type == "unet":
if reference_attn:
attn_modules = [
module
for module in torch_dfs(model)
if isinstance(module, BasicTransformerBlock)
]
attn_modules = sorted(
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
module.forward = hacked_basic_transformer_inner_forward.__get__(
module, BasicTransformerBlock
)
module.bank = []
module.fea_bank = []
module.attn_weight = float(i) / float(len(attn_modules))
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
module.ref_scale = self.ref_scale
else:
attn_modules = None
if reference_adain:
gn_modules = [model.mid_block]
model.mid_block.gn_weight = 0
down_blocks = model.down_blocks
for w, module in enumerate(down_blocks):
module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
gn_modules.append(module)
# print(module.__class__.__name__,module.gn_weight)
up_blocks = model.up_blocks
for w, module in enumerate(up_blocks):
module.gn_weight = float(w) / float(len(up_blocks))
gn_modules.append(module)
# print(module.__class__.__name__,module.gn_weight)
for i, module in enumerate(gn_modules):
if getattr(module, "original_forward", None) is None:
module.original_forward = module.forward
if i == 0:
# mid_block
module.forward = hacked_mid_forward.__get__(
module, torch.nn.Module
)
# elif isinstance(module, CrossAttnDownBlock2D):
# module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
# module, CrossAttnDownBlock2D
# )
# module.mean_bank0 = []
# module.var_bank0 = []
# module.fea_bank0 = []
elif isinstance(module, DownBlock2D):
module.forward = hacked_DownBlock2D_forward.__get__(
module, DownBlock2D
)
# elif isinstance(module, CrossAttnUpBlock2D):
# module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D)
# module.mean_bank0 = []
# module.var_bank0 = []
# module.fea_bank0 = []
elif isinstance(module, UpBlock2D):
module.forward = hacked_UpBlock2D_forward.__get__(
module, UpBlock2D
)
module.mean_bank0 = []
module.var_bank0 = []
module.fea_bank0 = []
module.mean_bank = []
module.var_bank = []
module.fea_bank = []
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
module.inpaint_mask = self.inpaint_mask
module.ref_scale = self.ref_scale
else:
gn_modules = None
elif model_type == "controlnet":
model = model.nets[-1] # only hack the inpainting controlnet
if reference_attn:
attn_modules = [
module
for module in torch_dfs(model)
if isinstance(module, BasicTransformerBlock)
]
attn_modules = sorted(
attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
)
for i, module in enumerate(attn_modules):
module._original_inner_forward = module.forward
module.forward = hacked_basic_transformer_inner_forward.__get__(
module, BasicTransformerBlock
)
module.bank = []
module.fea_bank = []
# float(i) / float(len(attn_modules))
module.attn_weight = 0.0
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
module.ref_scale = self.ref_scale
else:
attn_modules = None
# gn_modules = None
if reference_adain:
gn_modules = [model.mid_block]
model.mid_block.gn_weight = 0
down_blocks = model.down_blocks
for w, module in enumerate(down_blocks):
module.gn_weight = 1.0 - float(w) / float(len(down_blocks))
gn_modules.append(module)
# print(module.__class__.__name__,module.gn_weight)
for i, module in enumerate(gn_modules):
if getattr(module, "original_forward", None) is None:
module.original_forward = module.forward
if i == 0:
# mid_block
module.forward = hacked_mid_forward.__get__(
module, torch.nn.Module
)
# elif isinstance(module, CrossAttnDownBlock2D):
# module.forward = hack_CrossAttnDownBlock2D_forward.__get__(
# module, CrossAttnDownBlock2D
# )
# module.mean_bank0 = []
# module.var_bank0 = []
# module.fea_bank0 = []
elif isinstance(module, DownBlock2D):
module.forward = hacked_DownBlock2D_forward.__get__(
module, DownBlock2D
)
module.mean_bank = []
module.var_bank = []
module.fea_bank = []
module.attention_auto_machine_weight = (
self.attention_auto_machine_weight
)
module.gn_auto_machine_weight = self.gn_auto_machine_weight
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.do_classifier_free_guidance = (
self.do_classifier_free_guidance
)
module.uc_mask = self.uc_mask
module.style_fidelity = self.style_fidelity
module.ref_mask = self.ref_mask
module.inpaint_mask = self.inpaint_mask
module.ref_scale = self.ref_scale
else:
gn_modules = None
return attn_modules, gn_modules
def change_module_mode(self, mode, attn_modules, gn_modules):
if attn_modules is not None:
for i, module in enumerate(attn_modules):
module.MODE = mode
if gn_modules is not None:
for i, module in enumerate(gn_modules):
module.MODE = mode