ai_dreamer / vqgan_jax /configuration_vqgan.py
gsurma's picture
initial commit
cfa800c
from typing import Tuple
from transformers import PretrainedConfig
class VQGANConfig(PretrainedConfig):
def __init__(
self,
ch: int = 128,
out_ch: int = 3,
in_channels: int = 3,
num_res_blocks: int = 2,
resolution: int = 256,
z_channels: int = 256,
ch_mult: Tuple = (1, 1, 2, 2, 4),
attn_resolutions: int = (16,),
n_embed: int = 1024,
embed_dim: int = 256,
dropout: float = 0.0,
double_z: bool = False,
resamp_with_conv: bool = True,
give_pre_end: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.ch = ch
self.out_ch = out_ch
self.in_channels = in_channels
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.z_channels = z_channels
self.ch_mult = list(ch_mult)
self.attn_resolutions = list(attn_resolutions)
self.n_embed = n_embed
self.embed_dim = embed_dim
self.dropout = dropout
self.double_z = double_z
self.resamp_with_conv = resamp_with_conv
self.give_pre_end = give_pre_end
self.num_resolutions = len(ch_mult)