|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
from diffusers.models.attention import Attention |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from einops import rearrange, repeat |
|
import math |
|
|
|
import torch.nn.functional as F |
|
if is_xformers_available(): |
|
import xformers |
|
import xformers.ops |
|
else: |
|
xformers = None |
|
|
|
class RowwiseMVAttention(Attention): |
|
def set_use_memory_efficient_attention_xformers( |
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs |
|
): |
|
processor = XFormersMVAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
|
|
class IPCDAttention(Attention): |
|
def set_use_memory_efficient_attention_xformers( |
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs |
|
): |
|
processor = XFormersIPCDAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
|
|
|
|
|
|
class XFormersMVAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_views=1, |
|
multiview_attention=True, |
|
cd_attention_mid=False |
|
): |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
height = int(math.sqrt(sequence_length)) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
if attn.group_norm is not None: |
|
print('Warning: using group norm, pay attention to use it in row-wise attention') |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key_raw = attn.to_k(encoder_hidden_states) |
|
value_raw = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
|
|
def transpose(tensor): |
|
tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) |
|
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) |
|
tensor = torch.cat([tensor_0, tensor_1], dim=3) |
|
tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) |
|
return tensor |
|
|
|
|
|
if cd_attention_mid: |
|
key = transpose(key_raw) |
|
value = transpose(value_raw) |
|
query = transpose(query) |
|
else: |
|
key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) |
|
value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) |
|
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) |
|
|
|
|
|
query = attn.head_to_batch_dim(query) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if cd_attention_mid: |
|
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) |
|
hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) |
|
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) |
|
hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) |
|
else: |
|
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) |
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class XFormersIPCDAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def process(self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_tasks=2, |
|
num_views=6): |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
height = int(math.sqrt(sequence_length)) |
|
height_st = height // 3 |
|
height_end = height - height_st |
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
assert num_tasks == 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transpose(tensor): |
|
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) |
|
tensor = torch.cat([tensor_0, tensor_1], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
return tensor |
|
key = transpose(key) |
|
value = transpose(value) |
|
query = transpose(query) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) |
|
|
|
hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) |
|
face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach() |
|
face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') |
|
hidden_states_normal = hidden_states_normal.clone() |
|
hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal |
|
|
|
hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c") |
|
|
|
|
|
hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) |
|
face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach() |
|
face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') |
|
hidden_states_color = hidden_states_color.clone() |
|
hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color |
|
|
|
hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c") |
|
|
|
hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) |
|
|
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
return hidden_states |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_tasks=2, |
|
): |
|
hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
class IPCrossAttn(Attention): |
|
r""" |
|
Attention processor for IP-Adapater. |
|
Args: |
|
hidden_size (`int`): |
|
The hidden size of the attention layer. |
|
cross_attention_dim (`int`): |
|
The number of channels in the `encoder_hidden_states`. |
|
scale (`float`, defaults to 1.0): |
|
the weight scale of image prompt. |
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): |
|
The context length of the image features. |
|
""" |
|
|
|
def __init__(self, |
|
query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0): |
|
super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention) |
|
|
|
self.ip_scale = ip_scale |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_use_memory_efficient_attention_xformers( |
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs |
|
): |
|
processor = XFormersIPCrossAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
class XFormersIPCrossAttnProcessor: |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_views=1 |
|
): |
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class RowwiseMVProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_views=1, |
|
cd_attention_mid=False |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
height = int(math.sqrt(sequence_length)) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
def transpose(tensor): |
|
tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) |
|
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) |
|
tensor = torch.cat([tensor_0, tensor_1], dim=3) |
|
tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) |
|
return tensor |
|
|
|
if cd_attention_mid: |
|
key = transpose(key) |
|
value = transpose(value) |
|
query = transpose(query) |
|
else: |
|
key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) |
|
value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) |
|
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask) |
|
hidden_states = torch.bmm(attention_probs, value) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
if cd_attention_mid: |
|
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) |
|
hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) |
|
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) |
|
hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) |
|
else: |
|
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) |
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class CDAttention(Attention): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_use_memory_efficient_attention_xformers( |
|
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs |
|
): |
|
processor = XFormersCDAttnProcessor() |
|
self.set_processor(processor) |
|
|
|
|
|
class XFormersCDAttnProcessor: |
|
r""" |
|
Default processor for performing attention-related computations. |
|
""" |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
temb=None, |
|
num_tasks=2 |
|
): |
|
|
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
assert num_tasks == 2 |
|
|
|
def transpose(tensor): |
|
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) |
|
tensor = torch.cat([tensor_0, tensor_1], dim=1) |
|
return tensor |
|
key = transpose(key) |
|
value = transpose(value) |
|
query = transpose(query) |
|
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) |
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|