from typing import Any, Dict, Optional import torch from torch import nn from diffusers.models.attention import Attention from diffusers.utils.import_utils import is_xformers_available from einops import rearrange, repeat import math import torch.nn.functional as F if is_xformers_available(): import xformers import xformers.ops else: xformers = None class RowwiseMVAttention(Attention): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, *args, **kwargs ): processor = XFormersMVAttnProcessor() self.set_processor(processor) # print("using xformers attention processor") class IPCDAttention(Attention): def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, *args, **kwargs ): processor = XFormersIPCDAttnProcessor() self.set_processor(processor) # print("using xformers attention processor") class XFormersMVAttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_views=1, multiview_attention=True, cd_attention_mid=False ): # print(num_views) residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) height = int(math.sqrt(sequence_length)) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # from yuancheng; here attention_mask is None if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: print('Warning: using group norm, pay attention to use it in row-wise attention') hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key_raw = attn.to_k(encoder_hidden_states) value_raw = attn.to_v(encoder_hidden_states) # print('query', query.shape, 'key', key.shape, 'value', value.shape) # pdb.set_trace() def transpose(tensor): tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) return tensor # print(mvcd_attention) # import pdb;pdb.set_trace() if cd_attention_mid: key = transpose(key_raw) value = transpose(value_raw) query = transpose(query) else: key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if cd_attention_mid: hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) else: hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class XFormersIPCDAttnProcessor: r""" Default processor for performing attention-related computations. """ def process(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_tasks=2, num_views=6): ### TODO: num_views residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) height = int(math.sqrt(sequence_length)) height_st = height // 3 height_end = height - height_st attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # from yuancheng; here attention_mask is None if attention_mask is not None: # expand our mask's singleton query_tokens dimension: # [batch*heads, 1, key_tokens] -> # [batch*heads, query_tokens, key_tokens] # so that it can be added as a bias onto the attention scores that xformers computes: # [batch*heads, query_tokens, key_tokens] # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. _, query_tokens, _ = hidden_states.shape attention_mask = attention_mask.expand(-1, query_tokens, -1) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) assert num_tasks == 2 # only support two tasks now # ip attn # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] # print(body_hidden_states.shape, face_hidden_states.shape) # import pdb;pdb.set_trace() # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) # hidden_states = rearrange( # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), # 'b v l c -> (b v) l c') # face cross attention # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) # ip_key = attn.to_k_ip(ip_hidden_states) # ip_value = attn.to_v_ip(ip_hidden_states) # ip_key = attn.head_to_batch_dim(ip_key).contiguous() # ip_value = attn.head_to_batch_dim(ip_value).contiguous() # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) # import pdb;pdb.set_trace() def transpose(tensor): tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c # tensor = rearrange(tensor, "(b v) l c -> b v l c", v=num_views+1) # body, face = tensor[:, :-1, :], tensor[:, -1:, :] # b,v,l,c; b,1,l,c # face = face.repeat(1, num_views, 1, 1) # b,v,l,c # tensor = torch.cat([body, face], dim=2) # b, v, 4hw, c # tensor = rearrange(tensor, "b v l c -> (b v) l c") return tensor key = transpose(key) value = transpose(value) query = transpose(query) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) # bv, hw, c hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach() face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') hidden_states_normal = hidden_states_normal.clone() # Create a copy of hidden_states_normal hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal # hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_normal hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c") hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach() face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') hidden_states_color = hidden_states_color.clone() # Create a copy of hidden_states_color hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color # hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_color hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c") hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_tasks=2, ): hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks) # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c') # body_hidden_states, head_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1:, :, :] # import pdb;pdb.set_trace() # hidden_states = body_hidden_states + attn.ip_scale * head_hidden_states.detach().repeat(1, views, 1, 1) # hidden_states = rearrange( # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states], dim=1), # 'b v l c -> (b v) l c') return hidden_states class IPCrossAttn(Attention): r""" Attention processor for IP-Adapater. Args: hidden_size (`int`): The hidden size of the attention layer. cross_attention_dim (`int`): The number of channels in the `encoder_hidden_states`. scale (`float`, defaults to 1.0): the weight scale of image prompt. num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): The context length of the image features. """ def __init__(self, query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0): super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention) self.ip_scale = ip_scale # self.num_tokens = num_tokens # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) # self.to_out_ip = nn.ModuleList([]) # self.to_out_ip.append(nn.Linear(self.inner_dim, self.inner_dim, bias=bias)) # self.to_out_ip.append(nn.Dropout(dropout)) # nn.init.zeros_(self.to_k_ip.weight.data) # nn.init.zeros_(self.to_v_ip.weight.data) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, *args, **kwargs ): processor = XFormersIPCrossAttnProcessor() self.set_processor(processor) class XFormersIPCrossAttnProcessor: def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_views=1 ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = attn.batch_to_head_dim(hidden_states) # ip attn # hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) # body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] # print(body_hidden_states.shape, face_hidden_states.shape) # import pdb;pdb.set_trace() # hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) # hidden_states = rearrange( # torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), # 'b v l c -> (b v) l c') # face cross attention # ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) # ip_key = attn.to_k_ip(ip_hidden_states) # ip_value = attn.to_v_ip(ip_hidden_states) # ip_key = attn.head_to_batch_dim(ip_key).contiguous() # ip_value = attn.head_to_batch_dim(ip_value).contiguous() # ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() # ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) # ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) # ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) # ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) # import pdb;pdb.set_trace() # body_hidden_states = body_hidden_states + attn.ip_scale * ip_hidden_states # hidden_states = rearrange( # torch.cat([rearrange(body_hidden_states, '(b v) l c -> b v l c', v=num_views-1), face_hidden_states.unsqueeze(1)], dim=1), # 'b v l c -> (b v) l c') # import pdb;pdb.set_trace() # hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor # TODO: region control # region control # if len(region_control.prompt_image_conditioning) == 1: # region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) # if region_mask is not None: # h, w = region_mask.shape[:2] # ratio = (h * w / query.shape[1]) ** 0.5 # mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) # else: # mask = torch.ones_like(ip_hidden_states) # ip_hidden_states = ip_hidden_states * mask return hidden_states class RowwiseMVProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_views=1, cd_attention_mid=False ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) height = int(math.sqrt(sequence_length)) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) # print('query', query.shape, 'key', key.shape, 'value', value.shape) #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) # pdb.set_trace() # multi-view self-attention def transpose(tensor): tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) return tensor if cd_attention_mid: key = transpose(key) value = transpose(value) query = transpose(query) else: key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if cd_attention_mid: hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) else: hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class CDAttention(Attention): # def __init__(self, ip_scale, # query_dim, heads, dim_head, dropout, bias, cross_attention_dim, upcast_attention, processor): # super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, processor=processor) # self.ip_scale = ip_scale # self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) # self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) # nn.init.zeros_(self.to_k_ip.weight.data) # nn.init.zeros_(self.to_v_ip.weight.data) def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, *args, **kwargs ): processor = XFormersCDAttnProcessor() self.set_processor(processor) # print("using xformers attention processor") class XFormersCDAttnProcessor: r""" Default processor for performing attention-related computations. """ def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, num_tasks=2 ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) assert num_tasks == 2 # only support two tasks now def transpose(tensor): tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c return tensor key = transpose(key) value = transpose(value) query = transpose(query) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() value = attn.head_to_batch_dim(value).contiguous() hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states