|
from typing import Tuple |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class VQGANConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
ch: int = 128, |
|
out_ch: int = 3, |
|
in_channels: int = 3, |
|
num_res_blocks: int = 2, |
|
resolution: int = 256, |
|
z_channels: int = 256, |
|
ch_mult: Tuple = (1, 1, 2, 2, 4), |
|
attn_resolutions: int = (16,), |
|
n_embed: int = 1024, |
|
embed_dim: int = 256, |
|
dropout: float = 0.0, |
|
double_z: bool = False, |
|
resamp_with_conv: bool = True, |
|
give_pre_end: bool = False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.ch = ch |
|
self.out_ch = out_ch |
|
self.in_channels = in_channels |
|
self.num_res_blocks = num_res_blocks |
|
self.resolution = resolution |
|
self.z_channels = z_channels |
|
self.ch_mult = list(ch_mult) |
|
self.attn_resolutions = list(attn_resolutions) |
|
self.n_embed = n_embed |
|
self.embed_dim = embed_dim |
|
self.dropout = dropout |
|
self.double_z = double_z |
|
self.resamp_with_conv = resamp_with_conv |
|
self.give_pre_end = give_pre_end |
|
self.num_resolutions = len(ch_mult) |