Spaces:
Build error
Build error
""" | |
AUTOENCODER WITH ARCHTECTURE FROM VERSION 2 | |
""" | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def swish(x): | |
return x * torch.sigmoid(x) | |
def Normalize(in_channels): | |
return nn.GroupNorm( | |
num_groups=32, | |
num_channels=in_channels, | |
eps=1e-6, | |
affine=True | |
) | |
class Upsample(nn.Module): | |
def __init__(self, in_channels): | |
super().__init__() | |
self.conv = nn.Conv3d( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
def forward(self, x): | |
x = F.interpolate(x, scale_factor=2.0, mode="nearest") | |
x = self.conv(x) | |
return x | |
class Downsample(nn.Module): | |
def __init__(self, in_channels): | |
super().__init__() | |
self.conv = nn.Conv3d( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
stride=2, | |
padding=0 | |
) | |
def forward(self, x): | |
pad = (0, 1, 0, 1, 0, 1) | |
x = nn.functional.pad(x, pad, mode="constant", value=0) | |
x = self.conv(x) | |
return x | |
class ResBlock(nn.Module): | |
def __init__(self, in_channels, out_channels=None): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = in_channels if out_channels is None else out_channels | |
self.norm1 = Normalize(in_channels) | |
self.conv1 = nn.Conv3d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
self.norm2 = Normalize(out_channels) | |
self.conv2 = nn.Conv3d( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
if self.in_channels != self.out_channels: | |
self.nin_shortcut = nn.Conv3d( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0 | |
) | |
def forward(self, x): | |
h = x | |
h = self.norm1(h) | |
h = F.silu(h) | |
h = self.conv1(h) | |
h = self.norm2(h) | |
h = F.silu(h) | |
h = self.conv2(h) | |
if self.in_channels != self.out_channels: | |
x = self.nin_shortcut(x) | |
return x + h | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
n_channels: int, | |
z_channels: int, | |
ch_mult: Tuple[int], | |
num_res_blocks: int, | |
resolution: Tuple[int], | |
attn_resolutions: Tuple[int], | |
**ignorekwargs, | |
) -> None: | |
super().__init__() | |
self.in_channels = in_channels | |
self.n_channels = n_channels | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.attn_resolutions = attn_resolutions | |
curr_res = resolution | |
in_ch_mult = (1,) + tuple(ch_mult) | |
blocks = [] | |
# initial convolution | |
blocks.append( | |
nn.Conv3d( | |
in_channels, | |
n_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
) | |
# residual and downsampling blocks, with attention on smaller res (16x16) | |
for i in range(self.num_resolutions): | |
block_in_ch = n_channels * in_ch_mult[i] | |
block_out_ch = n_channels * ch_mult[i] | |
for _ in range(self.num_res_blocks): | |
blocks.append(ResBlock(block_in_ch, block_out_ch)) | |
block_in_ch = block_out_ch | |
if i != self.num_resolutions - 1: | |
blocks.append(Downsample(block_in_ch)) | |
curr_res = tuple(ti // 2 for ti in curr_res) | |
# normalise and convert to latent size | |
blocks.append(Normalize(block_in_ch)) | |
blocks.append( | |
nn.Conv3d( | |
block_in_ch, | |
z_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
) | |
self.blocks = nn.ModuleList(blocks) | |
def forward(self, x): | |
for block in self.blocks: | |
x = block(x) | |
return x | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
n_channels: int, | |
z_channels: int, | |
out_channels: int, | |
ch_mult: Tuple[int], | |
num_res_blocks: int, | |
resolution: Tuple[int], | |
attn_resolutions: Tuple[int], | |
**ignorekwargs, | |
) -> None: | |
super().__init__() | |
self.n_channels = n_channels | |
self.z_channels = z_channels | |
self.out_channels = out_channels | |
self.ch_mult = ch_mult | |
self.num_resolutions = len(ch_mult) | |
self.num_res_blocks = num_res_blocks | |
self.resolution = resolution | |
self.attn_resolutions = attn_resolutions | |
block_in_ch = n_channels * self.ch_mult[-1] | |
curr_res = tuple(ti // 2 ** (self.num_resolutions - 1) for ti in resolution) | |
blocks = [] | |
# initial conv | |
blocks.append( | |
nn.Conv3d( | |
z_channels, | |
block_in_ch, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
) | |
for i in reversed(range(self.num_resolutions)): | |
block_out_ch = n_channels * self.ch_mult[i] | |
for _ in range(self.num_res_blocks): | |
blocks.append(ResBlock(block_in_ch, block_out_ch)) | |
block_in_ch = block_out_ch | |
if i != 0: | |
blocks.append(Upsample(block_in_ch)) | |
curr_res = tuple(ti * 2 for ti in curr_res) | |
blocks.append(Normalize(block_in_ch)) | |
blocks.append( | |
nn.Conv3d( | |
block_in_ch, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
) | |
) | |
self.blocks = nn.ModuleList(blocks) | |
def forward(self, x): | |
for block in self.blocks: | |
x = block(x) | |
return x | |
class AutoencoderKL(nn.Module): | |
def __init__(self, embed_dim: int, hparams) -> None: | |
super().__init__() | |
self.encoder = Encoder(**hparams) | |
self.decoder = Decoder(**hparams) | |
self.quant_conv_mu = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1) | |
self.quant_conv_log_sigma = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1) | |
self.post_quant_conv = torch.nn.Conv3d(embed_dim, hparams["z_channels"], 1) | |
self.embed_dim = embed_dim | |
def decode(self, z): | |
z = self.post_quant_conv(z) | |
dec = self.decoder(z) | |
return dec | |
def reconstruct_ldm_outputs(self, z): | |
x_hat = self.decode(z) | |
return x_hat | |