# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers import math from functools import partial from typing import Tuple import flax import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..utils import BaseOutput from .modeling_flax_utils import FlaxModelMixin @flax.struct.dataclass class FlaxDecoderOutput(BaseOutput): """ Output of decoding method. Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): The decoded output sample from the last layer of the model. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ sample: jnp.ndarray @flax.struct.dataclass class FlaxAutoencoderKLOutput(BaseOutput): """ Output of AutoencoderKL encoding method. Args: latent_dist (`FlaxDiagonalGaussianDistribution`): Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. """ latent_dist: "FlaxDiagonalGaussianDistribution" class FlaxUpsample2D(nn.Module): """ Flax implementation of 2D Upsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, hidden_states): batch, height, width, channels = hidden_states.shape hidden_states = jax.image.resize( hidden_states, shape=(batch, height * 2, width * 2, channels), method="nearest", ) hidden_states = self.conv(hidden_states) return hidden_states class FlaxDownsample2D(nn.Module): """ Flax implementation of 2D Downsample layer Args: in_channels (`int`): Input channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dtype: jnp.dtype = jnp.float32 def setup(self): self.conv = nn.Conv( self.in_channels, kernel_size=(3, 3), strides=(2, 2), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states): pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim hidden_states = jnp.pad(hidden_states, pad_width=pad) hidden_states = self.conv(hidden_states) return hidden_states class FlaxResnetBlock2D(nn.Module): """ Flax implementation of 2D Resnet Block. Args: in_channels (`int`): Input channels out_channels (`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int = None dropout: float = 0.0 groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut self.conv_shortcut = None if use_nin_shortcut: self.conv_shortcut = nn.Conv( out_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def __call__(self, hidden_states, deterministic=True): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = nn.swish(hidden_states) hidden_states = self.dropout_layer(hidden_states, deterministic) hidden_states = self.conv2(hidden_states) if self.conv_shortcut is not None: residual = self.conv_shortcut(residual) return hidden_states + residual class FlaxAttentionBlock(nn.Module): r""" Flax Convolutional based multi-head attention block for diffusion-based VAE. Parameters: channels (:obj:`int`): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads num_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 dense = partial(nn.Dense, self.channels, dtype=self.dtype) self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() def transpose_for_scores(self, projection): new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) new_projection = projection.reshape(new_projection_shape) # (B, T, H, D) -> (B, H, T, D) new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) return new_projection def __call__(self, hidden_states): residual = hidden_states batch, height, width, channels = hidden_states.shape hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.reshape((batch, height * width, channels)) query = self.query(hidden_states) key = self.key(hidden_states) value = self.value(hidden_states) # transpose query = self.transpose_for_scores(query) key = self.transpose_for_scores(key) value = self.transpose_for_scores(value) # compute attentions scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) attn_weights = nn.softmax(attn_weights, axis=-1) # attend to values hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) hidden_states = hidden_states.reshape(new_hidden_states_shape) hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.reshape((batch, height, width, channels)) hidden_states = hidden_states + residual return hidden_states class FlaxDownEncoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Encoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_downsample: self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_downsample: hidden_states = self.downsamplers_0(hidden_states) return hidden_states class FlaxUpDecoderBlock2D(nn.Module): r""" Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): Input channels out_channels (:obj:`int`): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int out_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 def setup(self): resnets = [] for i in range(self.num_layers): in_channels = self.in_channels if i == 0 else self.out_channels res_block = FlaxResnetBlock2D( in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets if self.add_upsample: self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): for resnet in self.resnets: hidden_states = resnet(hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2D(nn.Module): r""" Flax Unet Mid-Block module. Parameters: in_channels (:obj:`int`): Input channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet and Attention block group norm num_attention_heads (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int dropout: float = 0.0 num_layers: int = 1 resnet_groups: int = 32 num_attention_heads: int = 1 dtype: jnp.dtype = jnp.float32 def setup(self): resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) ] attentions = [] for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.num_attention_heads, num_groups=resnet_groups, dtype=self.dtype, ) attentions.append(attn_block) res_block = FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) self.resnets = resnets self.attentions = attentions def __call__(self, hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn(hidden_states) hidden_states = resnet(hidden_states, deterministic=deterministic) return hidden_states class FlaxEncoder(nn.Module): r""" Flax Implementation of VAE Encoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): DownEncoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" double_z: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # in self.conv_in = nn.Conv( block_out_channels[0], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # downsampling down_blocks = [] output_channel = block_out_channels[0] for i, _ in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = FlaxDownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) down_blocks.append(down_block) self.down_blocks = down_blocks # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # in sample = self.conv_in(sample) # downsampling for block in self.down_blocks: sample = block(sample, deterministic=deterministic) # middle sample = self.mid_block(sample, deterministic=deterministic) # end sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDecoder(nn.Module): r""" Flax Implementation of VAE Decoder. This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (:obj:`int`, *optional*, defaults to 3): Input channels out_channels (:obj:`int`, *optional*, defaults to 3): Output channels up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): UpDecoder block type block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function double_z (:obj:`bool`, *optional*, defaults to `False`): Whether to double the last output channels dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ in_channels: int = 3 out_channels: int = 3 up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: int = (64,) layers_per_block: int = 2 norm_num_groups: int = 32 act_fn: str = "silu" dtype: jnp.dtype = jnp.float32 def setup(self): block_out_channels = self.block_out_channels # z to block_in self.conv_in = nn.Conv( block_out_channels[-1], kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) # middle self.mid_block = FlaxUNetMidBlock2D( in_channels=block_out_channels[-1], resnet_groups=self.norm_num_groups, num_attention_heads=None, dtype=self.dtype, ) # upsampling reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] up_blocks = [] for i, _ in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) up_blocks.append(up_block) prev_output_channel = output_channel self.up_blocks = up_blocks # end self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), dtype=self.dtype, ) def __call__(self, sample, deterministic: bool = True): # z to block_in sample = self.conv_in(sample) # middle sample = self.mid_block(sample, deterministic=deterministic) # upsampling for block in self.up_blocks: sample = block(sample, deterministic=deterministic) sample = self.conv_norm_out(sample) sample = nn.swish(sample) sample = self.conv_out(sample) return sample class FlaxDiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): # Last axis to account for channels-last self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) self.logvar = jnp.clip(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = jnp.exp(0.5 * self.logvar) self.var = jnp.exp(self.logvar) if self.deterministic: self.var = self.std = jnp.zeros_like(self.mean) def sample(self, key): return self.mean + self.std * jax.random.normal(key, self.mean.shape) def kl(self, other=None): if self.deterministic: return jnp.array([0.0]) if other is None: return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) return 0.5 * jnp.sum( jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, axis=[1, 2, 3], ) def nll(self, sample, axis=[1, 2, 3]): if self.deterministic: return jnp.array([0.0]) logtwopi = jnp.log(2.0 * jnp.pi) return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) def mode(self): return self.mean @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): r""" Flax implementation of a VAE model with KL loss for decoding latent representations. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its general usage and behavior. Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): Tuple of upsample block types. block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): Number of ResNet layer for each block. act_fn (`str`, *optional*, defaults to `silu`): The activation function to use. latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. sample_size (`int`, *optional*, defaults to 32): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): The `dtype` of the parameters. """ in_channels: int = 3 out_channels: int = 3 down_block_types: Tuple[str] = ("DownEncoderBlock2D",) up_block_types: Tuple[str] = ("UpDecoderBlock2D",) block_out_channels: Tuple[int] = (64,) layers_per_block: int = 1 act_fn: str = "silu" latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 scaling_factor: float = 0.18215 dtype: jnp.dtype = jnp.float32 def setup(self): self.encoder = FlaxEncoder( in_channels=self.config.in_channels, out_channels=self.config.latent_channels, down_block_types=self.config.down_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, act_fn=self.config.act_fn, norm_num_groups=self.config.norm_num_groups, double_z=True, dtype=self.dtype, ) self.decoder = FlaxDecoder( in_channels=self.config.latent_channels, out_channels=self.config.out_channels, up_block_types=self.config.up_block_types, block_out_channels=self.config.block_out_channels, layers_per_block=self.config.layers_per_block, norm_num_groups=self.config.norm_num_groups, act_fn=self.config.act_fn, dtype=self.dtype, ) self.quant_conv = nn.Conv( 2 * self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) self.post_quant_conv = nn.Conv( self.config.latent_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype, ) def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} return self.init(rngs, sample)["params"] def encode(self, sample, deterministic: bool = True, return_dict: bool = True): sample = jnp.transpose(sample, (0, 2, 3, 1)) hidden_states = self.encoder(sample, deterministic=deterministic) moments = self.quant_conv(hidden_states) posterior = FlaxDiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return FlaxAutoencoderKLOutput(latent_dist=posterior) def decode(self, latents, deterministic: bool = True, return_dict: bool = True): if latents.shape[-1] != self.config.latent_channels: latents = jnp.transpose(latents, (0, 2, 3, 1)) hidden_states = self.post_quant_conv(latents) hidden_states = self.decoder(hidden_states, deterministic=deterministic) hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) if not return_dict: return (hidden_states,) return FlaxDecoderOutput(sample=hidden_states) def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) if sample_posterior: rng = self.make_rng("gaussian") hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,) return FlaxDecoderOutput(sample=sample)