Spaces:
Runtime error
Runtime error
from typing import Tuple, Union | |
import jax | |
import jax.numpy as jnp | |
import flax.linen as nn | |
from flax.core.frozen_dict import FrozenDict | |
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config | |
from diffusers.models.modeling_flax_utils import FlaxModelMixin | |
from diffusers.utils import BaseOutput | |
from .flax_unet_pseudo3d_blocks import ( | |
CrossAttnDownBlockPseudo3D, | |
CrossAttnUpBlockPseudo3D, | |
DownBlockPseudo3D, | |
UpBlockPseudo3D, | |
UNetMidBlockPseudo3DCrossAttn | |
) | |
#from flax_embeddings import ( | |
# TimestepEmbedding, | |
# Timesteps | |
#) | |
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps | |
from .flax_resnet_pseudo3d import ConvPseudo3D | |
class UNetPseudo3DConditionOutput(BaseOutput): | |
sample: jax.Array | |
class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): | |
sample_size: Union[int, Tuple[int, int]] = (64, 64) | |
in_channels: int = 4 | |
out_channels: int = 4 | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlockPseudo3D", | |
"CrossAttnDownBlockPseudo3D", | |
"CrossAttnDownBlockPseudo3D", | |
"DownBlockPseudo3D" | |
) | |
up_block_types: Tuple[str] = ( | |
"UpBlockPseudo3D", | |
"CrossAttnUpBlockPseudo3D", | |
"CrossAttnUpBlockPseudo3D", | |
"CrossAttnUpBlockPseudo3D" | |
) | |
block_out_channels: Tuple[int] = ( | |
320, | |
640, | |
1280, | |
1280 | |
) | |
layers_per_block: int = 2 | |
attention_head_dim: Union[int, Tuple[int]] = 8 | |
cross_attention_dim: int = 768 | |
flip_sin_to_cos: bool = True | |
freq_shift: int = 0 | |
use_memory_efficient_attention: bool = False | |
dtype: jnp.dtype = jnp.float32 | |
param_dtype: str = 'float32' | |
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: | |
if self.param_dtype == 'bfloat16': | |
param_dtype = jnp.bfloat16 | |
elif self.param_dtype == 'float16': | |
param_dtype = jnp.float16 | |
elif self.param_dtype == 'float32': | |
param_dtype = jnp.float32 | |
else: | |
raise ValueError(f'unknown parameter type: {self.param_dtype}') | |
sample_size = self.sample_size | |
if isinstance(sample_size, int): | |
sample_size = (sample_size, sample_size) | |
sample_shape = (1, self.in_channels, 1, *sample_size) | |
sample = jnp.zeros(sample_shape, dtype = param_dtype) | |
timesteps = jnp.ones((1, ), dtype = jnp.int32) | |
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype) | |
params_rng, dropout_rng = jax.random.split(rng) | |
rngs = { "params": params_rng, "dropout": dropout_rng } | |
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] | |
def setup(self) -> None: | |
if isinstance(self.attention_head_dim, int): | |
attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types) | |
else: | |
attention_head_dim = self.attention_head_dim | |
time_embed_dim = self.block_out_channels[0] * 4 | |
self.conv_in = ConvPseudo3D( | |
features = self.block_out_channels[0], | |
kernel_size = (3, 3), | |
strides = (1, 1), | |
padding = ((1, 1), (1, 1)), | |
dtype = self.dtype | |
) | |
self.time_proj = FlaxTimesteps( | |
dim = self.block_out_channels[0], | |
flip_sin_to_cos = self.flip_sin_to_cos, | |
freq_shift = self.freq_shift | |
) | |
self.time_embedding = FlaxTimestepEmbedding( | |
time_embed_dim = time_embed_dim, | |
dtype = self.dtype | |
) | |
down_blocks = [] | |
output_channels = self.block_out_channels[0] | |
for i, down_block_type in enumerate(self.down_block_types): | |
input_channels = output_channels | |
output_channels = self.block_out_channels[i] | |
is_final_block = i == len(self.block_out_channels) - 1 | |
# allows loading 3d models with old layer type names in their configs | |
# eg. 2D instead of Pseudo3D, like lxj's timelapse model | |
if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']: | |
down_block = CrossAttnDownBlockPseudo3D( | |
in_channels = input_channels, | |
out_channels = output_channels, | |
num_layers = self.layers_per_block, | |
attn_num_head_channels = attention_head_dim[i], | |
add_downsample = not is_final_block, | |
use_memory_efficient_attention = self.use_memory_efficient_attention, | |
dtype = self.dtype | |
) | |
elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']: | |
down_block = DownBlockPseudo3D( | |
in_channels = input_channels, | |
out_channels = output_channels, | |
num_layers = self.layers_per_block, | |
add_downsample = not is_final_block, | |
dtype = self.dtype | |
) | |
else: | |
raise NotImplementedError(f'Unimplemented down block type: {down_block_type}') | |
down_blocks.append(down_block) | |
self.down_blocks = down_blocks | |
self.mid_block = UNetMidBlockPseudo3DCrossAttn( | |
in_channels = self.block_out_channels[-1], | |
attn_num_head_channels = attention_head_dim[-1], | |
use_memory_efficient_attention = self.use_memory_efficient_attention, | |
dtype = self.dtype | |
) | |
up_blocks = [] | |
reversed_block_out_channels = list(reversed(self.block_out_channels)) | |
reversed_attention_head_dim = list(reversed(attention_head_dim)) | |
output_channels = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(self.up_block_types): | |
prev_output_channels = output_channels | |
output_channels = reversed_block_out_channels[i] | |
input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)] | |
is_final_block = i == len(self.block_out_channels) - 1 | |
if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']: | |
up_block = CrossAttnUpBlockPseudo3D( | |
in_channels = input_channels, | |
out_channels = output_channels, | |
prev_output_channels = prev_output_channels, | |
num_layers = self.layers_per_block + 1, | |
attn_num_head_channels = reversed_attention_head_dim[i], | |
add_upsample = not is_final_block, | |
use_memory_efficient_attention = self.use_memory_efficient_attention, | |
dtype = self.dtype | |
) | |
elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']: | |
up_block = UpBlockPseudo3D( | |
in_channels = input_channels, | |
out_channels = output_channels, | |
prev_output_channels = prev_output_channels, | |
num_layers = self.layers_per_block + 1, | |
add_upsample = not is_final_block, | |
dtype = self.dtype | |
) | |
else: | |
raise NotImplementedError(f'Unimplemented up block type: {up_block_type}') | |
up_blocks.append(up_block) | |
self.up_blocks = up_blocks | |
self.conv_norm_out = nn.GroupNorm( | |
num_groups = 32, | |
epsilon = 1e-5 | |
) | |
self.conv_out = ConvPseudo3D( | |
features = self.out_channels, | |
kernel_size = (3, 3), | |
strides = (1, 1), | |
padding = ((1, 1), (1, 1)), | |
dtype = self.dtype | |
) | |
def __call__(self, | |
sample: jax.Array, | |
timesteps: jax.Array, | |
encoder_hidden_states: jax.Array, | |
return_dict: bool = True | |
) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]: | |
if timesteps.dtype != jnp.float32: | |
timesteps = timesteps.astype(dtype = jnp.float32) | |
if len(timesteps.shape) == 0: | |
timesteps = jnp.expand_dims(timesteps, 0) | |
# b,c,f,h,w -> b,f,h,w,c | |
sample = sample.transpose((0, 2, 3, 4, 1)) | |
t_emb = self.time_proj(timesteps) | |
t_emb = self.time_embedding(t_emb) | |
sample = self.conv_in(sample) | |
down_block_res_samples = (sample, ) | |
for down_block in self.down_blocks: | |
if isinstance(down_block, CrossAttnDownBlockPseudo3D): | |
sample, res_samples = down_block( | |
hidden_states = sample, | |
temb = t_emb, | |
encoder_hidden_states = encoder_hidden_states | |
) | |
elif isinstance(down_block, DownBlockPseudo3D): | |
sample, res_samples = down_block( | |
hidden_states = sample, | |
temb = t_emb | |
) | |
else: | |
raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}') | |
down_block_res_samples += res_samples | |
sample = self.mid_block( | |
hidden_states = sample, | |
temb = t_emb, | |
encoder_hidden_states = encoder_hidden_states | |
) | |
for up_block in self.up_blocks: | |
res_samples = down_block_res_samples[-(self.layers_per_block + 1):] | |
down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)] | |
if isinstance(up_block, CrossAttnUpBlockPseudo3D): | |
sample = up_block( | |
hidden_states = sample, | |
temb = t_emb, | |
encoder_hidden_states = encoder_hidden_states, | |
res_hidden_states_tuple = res_samples | |
) | |
elif isinstance(up_block, UpBlockPseudo3D): | |
sample = up_block( | |
hidden_states = sample, | |
temb = t_emb, | |
res_hidden_states_tuple = res_samples | |
) | |
else: | |
raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}') | |
sample = self.conv_norm_out(sample) | |
sample = nn.silu(sample) | |
sample = self.conv_out(sample) | |
# b,f,h,w,c -> b,c,f,h,w | |
sample = sample.transpose((0, 4, 1, 2, 3)) | |
if not return_dict: | |
return (sample, ) | |
return UNetPseudo3DConditionOutput(sample = sample) | |