kohya_ss / library /hypernetwork.py
Ateras's picture
Upload folder using huggingface_hub
fe6327d
raw
history blame
No virus
6.74 kB
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import (
Attention,
AttnProcessor2_0,
SlicedAttnProcessor,
XFormersAttnProcessor
)
try:
import xformers.ops
except:
xformers = None
loaded_networks = []
def apply_single_hypernetwork(
hypernetwork, hidden_states, encoder_hidden_states
):
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
return context_k, context_v
def apply_hypernetworks(context_k, context_v, layer=None):
if len(loaded_networks) == 0:
return context_v, context_v
for hypernetwork in loaded_networks:
context_k, context_v = hypernetwork.forward(context_k, context_v)
context_k = context_k.to(dtype=context_k.dtype)
context_v = context_v.to(dtype=context_k.dtype)
return context_k, context_v
def xformers_forward(
self: XFormersAttnProcessor,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def sliced_attn_forward(
self: SlicedAttnProcessor,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads),
device=query.device,
dtype=query.dtype,
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = (
attention_mask[start_idx:end_idx] if attention_mask is not None else None
)
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def v2_0_forward(
self: AttnProcessor2_0,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def replace_attentions_for_hypernetwork():
import diffusers.models.attention_processor
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
xformers_forward
)
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
sliced_attn_forward
)
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward