Spaces:
Runtime error
Runtime error
File size: 8,337 Bytes
149cc2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
from typing import Optional
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
#from flax_memory_efficient_attention import jax_memory_efficient_attention
#from flax_attention import FlaxAttention
from diffusers.models.attention_flax import FlaxAttention
class TransformerPseudo3DModel(nn.Module):
in_channels: int
num_attention_heads: int
attention_head_dim: int
num_layers: int = 1
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
inner_dim = self.num_attention_heads * self.attention_head_dim
self.norm = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.proj_in = nn.Conv(
inner_dim,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
transformer_blocks = []
#CheckpointTransformerBlock = nn.checkpoint(
# BasicTransformerBlockPseudo3D,
# static_argnums = (2,3,4)
# #prevent_cse = False
#)
CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
for _ in range(self.num_layers):
transformer_blocks.append(CheckpointTransformerBlock(
dim = inner_dim,
num_attention_heads = self.num_attention_heads,
attention_head_dim = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
))
self.transformer_blocks = transformer_blocks
self.proj_out = nn.Conv(
inner_dim,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
def __call__(self,
hidden_states: jax.Array,
encoder_hidden_states: Optional[jax.Array] = None
) -> jax.Array:
is_video = hidden_states.ndim == 5
f: Optional[int] = None
if is_video:
# jax is channels last
# b,c,f,h,w WRONG
# b,f,h,w,c CORRECT
# b, c, f, h, w = hidden_states.shape
#hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
b, f, h, w, c = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
f,
height,
width
)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
if is_video:
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
return hidden_states
class BasicTransformerBlockPseudo3D(nn.Module):
dim: int
num_attention_heads: int
attention_head_dim: int
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.attn1 = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
self.attn2 = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.attn_temporal = FlaxAttention(
query_dim = self.dim,
heads = self.num_attention_heads,
dim_head = self.attention_head_dim,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
def __call__(self,
hidden_states: jax.Array,
context: Optional[jax.Array] = None,
frames_length: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None
) -> jax.Array:
if context is not None and frames_length is not None:
context = context.repeat(frames_length, axis = 0)
# self attention
norm_hidden_states = self.norm1(hidden_states)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# cross attention
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn2(
norm_hidden_states,
context = context
) + hidden_states
# temporal attention
if frames_length is not None:
#bf, hw, c = hidden_states.shape
# (b f) (h w) c -> b f (h w) c
#hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
#b, f, hw, c = hidden_states.shape
# b f (h w) c -> b (h w) f c
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
# b (h w) f c -> (b h w) f c
#hidden_states = hidden_states.reshape(b * hw, frames_length, c)
hidden_states = einops.rearrange(
hidden_states,
'(b f) (h w) c -> (b h w) f c',
f = frames_length,
h = height,
w = width
)
norm_hidden_states = self.norm_temporal(hidden_states)
hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
# (b h w) f c -> b (h w) f c
#hidden_states = hidden_states.reshape(b, hw, f, c)
# b (h w) f c -> b f (h w) c
#hidden_states = hidden_states.transpose(0, 2, 1, 3)
# b f h w c -> (b f) (h w) c
#hidden_states = hidden_states.reshape(bf, hw, c)
hidden_states = einops.rearrange(
hidden_states,
'(b h w) f c -> (b f) (h w) c',
f = frames_length,
h = height,
w = width
)
norm_hidden_states = self.norm3(hidden_states)
hidden_states = self.ff(norm_hidden_states) + hidden_states
return hidden_states
class FeedForward(nn.Module):
dim: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.net_0 = GEGLU(self.dim, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states
class GEGLU(nn.Module):
dim: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
return hidden_linear * nn.gelu(hidden_gelu)
|