d-edit / controller.py
afeng's picture
first
d807efd
raw
history blame
No virus
7.24 kB
import numpy as np
import torch
import math
import xformers
class DummyController:
def __call__(self, *args):
return args[0]
def __init__(self):
self.num_att_layers = 0
class GroupedCAController:
def __init__(self, mask_list = None):
self.mask_list = mask_list
if self.mask_list is None:
self.is_decom = False
else:
self.is_decom = True
def mask_img_to_mask_vec(self, mask, length):
mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze()
mask_vec = mask_vec.flatten()
return mask_vec
def ca_forward_decom(self, q, k_list, v_list, scale, place_in_unet):
# attn [Bh, N, d ]
# [8, 4096, 77]
# q [Bh, N, d] [8, 4096, 40] [8, 1024, 80] [8, 256,160] [8, 64, 160]
# k [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
# v [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
N = q.shape[1]
mask_vec_list = []
for mask in self.mask_list:
mask_vec = self.mask_img_to_mask_vec(mask, int(math.sqrt(N))) # [1,N,1]
mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1)
mask_vec_list.append(mask_vec)
out = 0
for mask_vec, k, v in zip(mask_vec_list, k_list, v_list):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale # [8, 4096, 20]
attn = sim.softmax(dim=-1) # [Bh,N,P] [8,4096,20]
attn = attn.masked_fill(mask_vec==0, 0)
masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
# mask_vec_inf = torch.where(mask_vec>0, 0, torch.finfo(k.dtype).min)
# masked_out1 = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask_vec_inf, op=None, scale=scale)
out += masked_out
return out
def reshape_heads_to_batch_dim(self):
def func(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return func
def reshape_batch_dim_to_heads(self):
def func(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return func
def register_attention_disentangled_control(unet, controller):
def ca_forward(self, place_in_unet):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states =None, attention_mask=None):
if isinstance(controller, DummyController): # SA CA full
q = self.to_q(x)
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.head_to_batch_dim(v)
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
# attn = sim.softmax(dim=-1)
# attn = controller(attn, is_cross, place_in_unet)
# out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=None, scale=self.scale
)
out = self.batch_to_head_dim(out)
else: # decom: CA+SA
is_cross = encoder_hidden_states is not None
assert is_cross is not None
encoder_hidden_states_list = encoder_hidden_states if is_cross else x
q = self.to_q(x)
q = self.head_to_batch_dim(q) # [Bh, 4096, 320/h ] h: 8
if is_cross: #CA
k_list = []
v_list = []
assert type(encoder_hidden_states_list) is list
for encoder_hidden_states in encoder_hidden_states_list:
k = self.to_k(encoder_hidden_states)
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
k_list.append(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
v_list.append(v)
out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet) # [Bh,N,d]
out = self.batch_to_head_dim(out)
else: # SA
exit("decomposing SA!")
k = self.to_k(x)
v = self.to_v(x)
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
import pdb; pdb.set_trace()
if k.shape[1] <= 1024 ** 2:
out = controller.sa_forward(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
else:
print("warining")
out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
# attn = sim.softmax(dim=-1) # [8,4096,4096] [Bh,N,N]
# out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
out = self.batch_to_head_dim(out) # [B, H, N, D]
return to_out(out)
return forward
if controller is None:
controller = DummyController()
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim:
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = unet.named_children()
for net in sub_nets:
if "down" in net[0]:
down_count = register_recr(net[1], 0, "down")#6
cross_att_count += down_count
elif "up" in net[0]:
up_count = register_recr(net[1], 0, "up") #9
cross_att_count += up_count
elif "mid" in net[0]:
mid_count = register_recr(net[1], 0, "mid") #1
cross_att_count += mid_count
controller.num_att_layers = cross_att_count