Spaces:
Running
Running
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers | |
from functools import partial | |
from typing import Tuple | |
import math | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import flax.linen as nn | |
from flax.core.frozen_dict import FrozenDict | |
from transformers.modeling_flax_utils import FlaxPreTrainedModel | |
from .configuration_vqgan import VQGANConfig | |
class Upsample(nn.Module): | |
in_channels: int | |
with_conv: bool | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
if self.with_conv: | |
self.conv = nn.Conv( | |
self.in_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states): | |
batch, height, width, channels = hidden_states.shape | |
hidden_states = jax.image.resize( | |
hidden_states, | |
shape=(batch, height * 2, width * 2, channels), | |
method="nearest", | |
) | |
if self.with_conv: | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class Downsample(nn.Module): | |
in_channels: int | |
with_conv: bool | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
if self.with_conv: | |
self.conv = nn.Conv( | |
self.in_channels, | |
kernel_size=(3, 3), | |
strides=(2, 2), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states): | |
if self.with_conv: | |
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim | |
hidden_states = jnp.pad(hidden_states, pad_width=pad) | |
hidden_states = self.conv(hidden_states) | |
else: | |
hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID") | |
return hidden_states | |
class ResnetBlock(nn.Module): | |
in_channels: int | |
out_channels: int = None | |
use_conv_shortcut: bool = False | |
temb_channels: int = 512 | |
dropout_prob: float = 0.0 | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels | |
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) | |
self.conv1 = nn.Conv( | |
self.out_channels_, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
if self.temb_channels: | |
self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype) | |
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) | |
self.dropout = nn.Dropout(self.dropout_prob) | |
self.conv2 = nn.Conv( | |
self.out_channels_, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
if self.in_channels != self.out_channels_: | |
if self.use_conv_shortcut: | |
self.conv_shortcut = nn.Conv( | |
self.out_channels_, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
else: | |
self.nin_shortcut = nn.Conv( | |
self.out_channels_, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states, temb=None, deterministic: bool = True): | |
residual = hidden_states | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.conv1(hidden_states) | |
if temb is not None: | |
hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.dropout(hidden_states, deterministic) | |
hidden_states = self.conv2(hidden_states) | |
if self.in_channels != self.out_channels_: | |
if self.use_conv_shortcut: | |
residual = self.conv_shortcut(residual) | |
else: | |
residual = self.nin_shortcut(residual) | |
return hidden_states + residual | |
class AttnBlock(nn.Module): | |
in_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
conv = partial( | |
nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype | |
) | |
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) | |
self.q, self.k, self.v = conv(), conv(), conv() | |
self.proj_out = conv() | |
def __call__(self, hidden_states): | |
residual = hidden_states | |
hidden_states = self.norm(hidden_states) | |
query = self.q(hidden_states) | |
key = self.k(hidden_states) | |
value = self.v(hidden_states) | |
# compute attentions | |
batch, height, width, channels = query.shape | |
query = query.reshape((batch, height * width, channels)) | |
key = key.reshape((batch, height * width, channels)) | |
attn_weights = jnp.einsum("...qc,...kc->...qk", query, key) | |
attn_weights = attn_weights * (int(channels) ** -0.5) | |
attn_weights = nn.softmax(attn_weights, axis=2) | |
## attend to values | |
value = value.reshape((batch, height * width, channels)) | |
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) | |
hidden_states = hidden_states.reshape((batch, height, width, channels)) | |
hidden_states = self.proj_out(hidden_states) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class UpsamplingBlock(nn.Module): | |
config: VQGANConfig | |
curr_res: int | |
block_idx: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
if self.block_idx == self.config.num_resolutions - 1: | |
block_in = self.config.ch * self.config.ch_mult[-1] | |
else: | |
block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1] | |
block_out = self.config.ch * self.config.ch_mult[self.block_idx] | |
self.temb_ch = 0 | |
res_blocks = [] | |
attn_blocks = [] | |
for _ in range(self.config.num_res_blocks + 1): | |
res_blocks.append( | |
ResnetBlock( | |
block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype | |
) | |
) | |
block_in = block_out | |
if self.curr_res in self.config.attn_resolutions: | |
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) | |
self.block = res_blocks | |
self.attn = attn_blocks | |
self.upsample = None | |
if self.block_idx != 0: | |
self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype) | |
def __call__(self, hidden_states, temb=None, deterministic: bool = True): | |
for res_block in self.block: | |
hidden_states = res_block(hidden_states, temb, deterministic=deterministic) | |
for attn_block in self.attn: | |
hidden_states = attn_block(hidden_states) | |
if self.upsample is not None: | |
hidden_states = self.upsample(hidden_states) | |
return hidden_states | |
class DownsamplingBlock(nn.Module): | |
config: VQGANConfig | |
curr_res: int | |
block_idx: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
in_ch_mult = (1,) + tuple(self.config.ch_mult) | |
block_in = self.config.ch * in_ch_mult[self.block_idx] | |
block_out = self.config.ch * self.config.ch_mult[self.block_idx] | |
self.temb_ch = 0 | |
res_blocks = [] | |
attn_blocks = [] | |
for _ in range(self.config.num_res_blocks): | |
res_blocks.append( | |
ResnetBlock( | |
block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype | |
) | |
) | |
block_in = block_out | |
if self.curr_res in self.config.attn_resolutions: | |
attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) | |
self.block = res_blocks | |
self.attn = attn_blocks | |
self.downsample = None | |
if self.block_idx != self.config.num_resolutions - 1: | |
self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype) | |
def __call__(self, hidden_states, temb=None, deterministic: bool = True): | |
for res_block in self.block: | |
hidden_states = res_block(hidden_states, temb, deterministic=deterministic) | |
for attn_block in self.attn: | |
hidden_states = attn_block(hidden_states) | |
if self.downsample is not None: | |
hidden_states = self.downsample(hidden_states) | |
return hidden_states | |
class MidBlock(nn.Module): | |
in_channels: int | |
temb_channels: int | |
dropout: float | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.block_1 = ResnetBlock( | |
self.in_channels, | |
self.in_channels, | |
temb_channels=self.temb_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype) | |
self.block_2 = ResnetBlock( | |
self.in_channels, | |
self.in_channels, | |
temb_channels=self.temb_channels, | |
dropout_prob=self.dropout, | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states, temb=None, deterministic: bool = True): | |
hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic) | |
hidden_states = self.attn_1(hidden_states) | |
hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic) | |
return hidden_states | |
class Encoder(nn.Module): | |
config: VQGANConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.temb_ch = 0 | |
# downsampling | |
self.conv_in = nn.Conv( | |
self.config.ch, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
curr_res = self.config.resolution | |
downsample_blocks = [] | |
for i_level in range(self.config.num_resolutions): | |
downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype)) | |
if i_level != self.config.num_resolutions - 1: | |
curr_res = curr_res // 2 | |
self.down = downsample_blocks | |
# middle | |
mid_channels = self.config.ch * self.config.ch_mult[-1] | |
self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype) | |
# end | |
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) | |
self.conv_out = nn.Conv( | |
2 * self.config.z_channels if self.config.double_z else self.config.z_channels, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, pixel_values, deterministic: bool = True): | |
# timestep embedding | |
temb = None | |
# downsampling | |
hidden_states = self.conv_in(pixel_values) | |
for block in self.down: | |
hidden_states = block(hidden_states, temb, deterministic=deterministic) | |
# middle | |
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) | |
# end | |
hidden_states = self.norm_out(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.conv_out(hidden_states) | |
return hidden_states | |
class Decoder(nn.Module): | |
config: VQGANConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.temb_ch = 0 | |
# compute in_ch_mult, block_in and curr_res at lowest res | |
block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1] | |
curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1) | |
self.z_shape = (1, self.config.z_channels, curr_res, curr_res) | |
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) | |
# z to block_in | |
self.conv_in = nn.Conv( | |
block_in, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
# middle | |
self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype) | |
# upsampling | |
upsample_blocks = [] | |
for i_level in reversed(range(self.config.num_resolutions)): | |
upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype)) | |
if i_level != 0: | |
curr_res = curr_res * 2 | |
self.up = list(reversed(upsample_blocks)) # reverse to get consistent order | |
# end | |
self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) | |
self.conv_out = nn.Conv( | |
self.config.out_ch, | |
kernel_size=(3, 3), | |
strides=(1, 1), | |
padding=((1, 1), (1, 1)), | |
dtype=self.dtype, | |
) | |
def __call__(self, hidden_states, deterministic: bool = True): | |
# timestep embedding | |
temb = None | |
# z to block_in | |
hidden_states = self.conv_in(hidden_states) | |
# middle | |
hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) | |
# upsampling | |
for block in reversed(self.up): | |
hidden_states = block(hidden_states, temb, deterministic=deterministic) | |
# end | |
if self.config.give_pre_end: | |
return hidden_states | |
hidden_states = self.norm_out(hidden_states) | |
hidden_states = nn.swish(hidden_states) | |
hidden_states = self.conv_out(hidden_states) | |
return hidden_states | |
class VectorQuantizer(nn.Module): | |
""" | |
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py | |
____________________________________________ | |
Discretization bottleneck part of the VQ-VAE. | |
Inputs: | |
- n_e : number of embeddings | |
- e_dim : dimension of embedding | |
- beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 | |
_____________________________________________ | |
""" | |
config: VQGANConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init | |
def __call__(self, hidden_states): | |
""" | |
Inputs the output of the encoder network z and maps it to a discrete | |
one-hot vector that is the index of the closest embedding vector e_j | |
z (continuous) -> z_q (discrete) | |
z.shape = (batch, channel, height, width) | |
quantization pipeline: | |
1. get encoder input (B,C,H,W) | |
2. flatten input to (B*H*W,C) | |
""" | |
# flatten | |
hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim)) | |
# dummy op to init the weights, so we can access them below | |
self.embedding(jnp.ones((1, 1), dtype="i4")) | |
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z | |
emb_weights = self.variables["params"]["embedding"]["embedding"] | |
distance = ( | |
jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True) | |
+ jnp.sum(emb_weights ** 2, axis=1) | |
- 2 * jnp.dot(hidden_states_flattended, emb_weights.T) | |
) | |
# get quantized latent vectors | |
min_encoding_indices = jnp.argmin(distance, axis=1) | |
z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape) | |
# reshape to (batch, num_tokens) | |
min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1) | |
# compute the codebook_loss (q_loss) outside the model | |
# here we return the embeddings and indices | |
return z_q, min_encoding_indices | |
def get_codebook_entry(self, indices, shape=None): | |
# indices are expected to be of shape (batch, num_tokens) | |
# get quantized latent vectors | |
batch, num_tokens = indices.shape | |
z_q = self.embedding(indices) | |
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1) | |
return z_q | |
class VQModule(nn.Module): | |
config: VQGANConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.encoder = Encoder(self.config, dtype=self.dtype) | |
self.decoder = Decoder(self.config, dtype=self.dtype) | |
self.quantize = VectorQuantizer(self.config, dtype=self.dtype) | |
self.quant_conv = nn.Conv( | |
self.config.embed_dim, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
self.post_quant_conv = nn.Conv( | |
self.config.z_channels, | |
kernel_size=(1, 1), | |
strides=(1, 1), | |
padding="VALID", | |
dtype=self.dtype, | |
) | |
def encode(self, pixel_values, deterministic: bool = True): | |
hidden_states = self.encoder(pixel_values, deterministic=deterministic) | |
hidden_states = self.quant_conv(hidden_states) | |
quant_states, indices = self.quantize(hidden_states) | |
return quant_states, indices | |
def decode(self, hidden_states, deterministic: bool = True): | |
hidden_states = self.post_quant_conv(hidden_states) | |
hidden_states = self.decoder(hidden_states, deterministic=deterministic) | |
return hidden_states | |
def decode_code(self, code_b): | |
hidden_states = self.quantize.get_codebook_entry(code_b) | |
hidden_states = self.decode(hidden_states) | |
return hidden_states | |
def __call__(self, pixel_values, deterministic: bool = True): | |
quant_states, indices = self.encode(pixel_values, deterministic) | |
hidden_states = self.decode(quant_states, deterministic) | |
return hidden_states, indices | |
class VQGANPreTrainedModel(FlaxPreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface | |
for downloading and loading pretrained models. | |
""" | |
config_class = VQGANConfig | |
base_model_prefix = "model" | |
module_class: nn.Module = None | |
def __init__( | |
self, | |
config: VQGANConfig, | |
input_shape: Tuple = (1, 256, 256, 3), | |
seed: int = 0, | |
dtype: jnp.dtype = jnp.float32, | |
**kwargs, | |
): | |
module = self.module_class(config=config, dtype=dtype, **kwargs) | |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) | |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: | |
# init input tensors | |
pixel_values = jnp.zeros(input_shape, dtype=jnp.float32) | |
params_rng, dropout_rng = jax.random.split(rng) | |
rngs = {"params": params_rng, "dropout": dropout_rng} | |
return self.module.init(rngs, pixel_values)["params"] | |
def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False): | |
# Handle any PRNG if needed | |
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} | |
return self.module.apply( | |
{"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode | |
) | |
def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False): | |
# Handle any PRNG if needed | |
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} | |
return self.module.apply( | |
{"params": params or self.params}, | |
jnp.array(hidden_states), | |
not train, | |
rngs=rngs, | |
method=self.module.decode, | |
) | |
def decode_code(self, indices, params: dict = None): | |
return self.module.apply( | |
{"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code | |
) | |
def __call__( | |
self, | |
pixel_values, | |
params: dict = None, | |
dropout_rng: jax.random.PRNGKey = None, | |
train: bool = False, | |
): | |
# Handle any PRNG if needed | |
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} | |
return self.module.apply( | |
{"params": params or self.params}, | |
jnp.array(pixel_values), | |
not train, | |
rngs=rngs, | |
) | |
class VQModel(VQGANPreTrainedModel): | |
module_class = VQModule | |