Spaces:
Runtime error
Runtime error
File size: 10,821 Bytes
149cc2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
from typing import Tuple, Union
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
from diffusers.models.modeling_flax_utils import FlaxModelMixin
from diffusers.utils import BaseOutput
from .flax_unet_pseudo3d_blocks import (
CrossAttnDownBlockPseudo3D,
CrossAttnUpBlockPseudo3D,
DownBlockPseudo3D,
UpBlockPseudo3D,
UNetMidBlockPseudo3DCrossAttn
)
#from flax_embeddings import (
# TimestepEmbedding,
# Timesteps
#)
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .flax_resnet_pseudo3d import ConvPseudo3D
class UNetPseudo3DConditionOutput(BaseOutput):
sample: jax.Array
@flax_register_to_config
class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample_size: Union[int, Tuple[int, int]] = (64, 64)
in_channels: int = 4
out_channels: int = 4
down_block_types: Tuple[str] = (
"CrossAttnDownBlockPseudo3D",
"CrossAttnDownBlockPseudo3D",
"CrossAttnDownBlockPseudo3D",
"DownBlockPseudo3D"
)
up_block_types: Tuple[str] = (
"UpBlockPseudo3D",
"CrossAttnUpBlockPseudo3D",
"CrossAttnUpBlockPseudo3D",
"CrossAttnUpBlockPseudo3D"
)
block_out_channels: Tuple[int] = (
320,
640,
1280,
1280
)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
cross_attention_dim: int = 768
flip_sin_to_cos: bool = True
freq_shift: int = 0
use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32
param_dtype: str = 'float32'
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
if self.param_dtype == 'bfloat16':
param_dtype = jnp.bfloat16
elif self.param_dtype == 'float16':
param_dtype = jnp.float16
elif self.param_dtype == 'float32':
param_dtype = jnp.float32
else:
raise ValueError(f'unknown parameter type: {self.param_dtype}')
sample_size = self.sample_size
if isinstance(sample_size, int):
sample_size = (sample_size, sample_size)
sample_shape = (1, self.in_channels, 1, *sample_size)
sample = jnp.zeros(sample_shape, dtype = param_dtype)
timesteps = jnp.ones((1, ), dtype = jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype)
params_rng, dropout_rng = jax.random.split(rng)
rngs = { "params": params_rng, "dropout": dropout_rng }
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
def setup(self) -> None:
if isinstance(self.attention_head_dim, int):
attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types)
else:
attention_head_dim = self.attention_head_dim
time_embed_dim = self.block_out_channels[0] * 4
self.conv_in = ConvPseudo3D(
features = self.block_out_channels[0],
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
self.time_proj = FlaxTimesteps(
dim = self.block_out_channels[0],
flip_sin_to_cos = self.flip_sin_to_cos,
freq_shift = self.freq_shift
)
self.time_embedding = FlaxTimestepEmbedding(
time_embed_dim = time_embed_dim,
dtype = self.dtype
)
down_blocks = []
output_channels = self.block_out_channels[0]
for i, down_block_type in enumerate(self.down_block_types):
input_channels = output_channels
output_channels = self.block_out_channels[i]
is_final_block = i == len(self.block_out_channels) - 1
# allows loading 3d models with old layer type names in their configs
# eg. 2D instead of Pseudo3D, like lxj's timelapse model
if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']:
down_block = CrossAttnDownBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
num_layers = self.layers_per_block,
attn_num_head_channels = attention_head_dim[i],
add_downsample = not is_final_block,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']:
down_block = DownBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
num_layers = self.layers_per_block,
add_downsample = not is_final_block,
dtype = self.dtype
)
else:
raise NotImplementedError(f'Unimplemented down block type: {down_block_type}')
down_blocks.append(down_block)
self.down_blocks = down_blocks
self.mid_block = UNetMidBlockPseudo3DCrossAttn(
in_channels = self.block_out_channels[-1],
attn_num_head_channels = attention_head_dim[-1],
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
up_blocks = []
reversed_block_out_channels = list(reversed(self.block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
output_channels = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channels = output_channels
output_channels = reversed_block_out_channels[i]
input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)]
is_final_block = i == len(self.block_out_channels) - 1
if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']:
up_block = CrossAttnUpBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
prev_output_channels = prev_output_channels,
num_layers = self.layers_per_block + 1,
attn_num_head_channels = reversed_attention_head_dim[i],
add_upsample = not is_final_block,
use_memory_efficient_attention = self.use_memory_efficient_attention,
dtype = self.dtype
)
elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']:
up_block = UpBlockPseudo3D(
in_channels = input_channels,
out_channels = output_channels,
prev_output_channels = prev_output_channels,
num_layers = self.layers_per_block + 1,
add_upsample = not is_final_block,
dtype = self.dtype
)
else:
raise NotImplementedError(f'Unimplemented up block type: {up_block_type}')
up_blocks.append(up_block)
self.up_blocks = up_blocks
self.conv_norm_out = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.conv_out = ConvPseudo3D(
features = self.out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
def __call__(self,
sample: jax.Array,
timesteps: jax.Array,
encoder_hidden_states: jax.Array,
return_dict: bool = True
) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]:
if timesteps.dtype != jnp.float32:
timesteps = timesteps.astype(dtype = jnp.float32)
if len(timesteps.shape) == 0:
timesteps = jnp.expand_dims(timesteps, 0)
# b,c,f,h,w -> b,f,h,w,c
sample = sample.transpose((0, 2, 3, 4, 1))
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
sample = self.conv_in(sample)
down_block_res_samples = (sample, )
for down_block in self.down_blocks:
if isinstance(down_block, CrossAttnDownBlockPseudo3D):
sample, res_samples = down_block(
hidden_states = sample,
temb = t_emb,
encoder_hidden_states = encoder_hidden_states
)
elif isinstance(down_block, DownBlockPseudo3D):
sample, res_samples = down_block(
hidden_states = sample,
temb = t_emb
)
else:
raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}')
down_block_res_samples += res_samples
sample = self.mid_block(
hidden_states = sample,
temb = t_emb,
encoder_hidden_states = encoder_hidden_states
)
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-(self.layers_per_block + 1):]
down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)]
if isinstance(up_block, CrossAttnUpBlockPseudo3D):
sample = up_block(
hidden_states = sample,
temb = t_emb,
encoder_hidden_states = encoder_hidden_states,
res_hidden_states_tuple = res_samples
)
elif isinstance(up_block, UpBlockPseudo3D):
sample = up_block(
hidden_states = sample,
temb = t_emb,
res_hidden_states_tuple = res_samples
)
else:
raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}')
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
# b,f,h,w,c -> b,c,f,h,w
sample = sample.transpose((0, 4, 1, 2, 3))
if not return_dict:
return (sample, )
return UNetPseudo3DConditionOutput(sample = sample)
|