3DTopia-XL / models /vae3d_dib.py
FrozenBurning
single view to 3D init release
81ecb2b
import numpy as np
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from utils.typing import *
from .attention import MemEffAttention
class VolumeAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
groups: int = 32,
eps: float = 1e-5,
residual: bool = True,
skip_scale: float = 1,
):
super().__init__()
self.residual = residual
self.skip_scale = skip_scale
self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
def forward(self, x):
# x: [B, C, H, W, D]
B, C, H, W, D = x.shape
res = x
x = self.norm(x)
x = x.permute(0, 2, 3, 4, 1).reshape(B, -1, C)
x = self.attn(x)
x = x.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3).reshape(B, C, H, W, D)
if self.residual:
x = (x + res) * self.skip_scale
return x
class DiagonalGaussianDistribution:
def __init__(self, parameters, deterministic=False):
# parameters: [B, 2C, ...]
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
def sample(self):
sample = torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
return x
def kl(self, other=None, dims=[1, 2, 3, 4]):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims)
else:
return 0.5 * torch.mean(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=dims,
)
def nll(self, sample, dims=[1, 2, 3, 4]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
def mode(self):
return self.mean
class ResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resample: Literal['default', 'up', 'down'] = 'default',
groups: int = 32,
eps: float = 1e-5,
skip_scale: float = 1, # multiplied to output
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.skip_scale = skip_scale
self.norm1 = nn.GroupNorm(num_groups=min(groups, in_channels), num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=min(groups, out_channels), num_channels=out_channels, eps=eps, affine=True)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.act = F.silu
self.resample = None
if resample == 'up':
self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
elif resample == 'down':
self.resample = nn.AvgPool3d(kernel_size=2, stride=2)
self.shortcut = nn.Identity()
if self.in_channels != self.out_channels:
self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True)
def forward(self, x):
res = x
x = self.norm1(x)
x = self.act(x)
if self.resample:
res = self.resample(res)
x = self.resample(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x)
x = (x + self.shortcut(res)) * self.skip_scale
return x
class DownBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample: bool = True,
skip_scale: float = 1,
gradient_checkpointing: bool = False,
):
super().__init__()
self.gradient_checkpointing = gradient_checkpointing
nets = []
for i in range(num_layers):
cin = in_channels if i == 0 else out_channels
nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale))
self.nets = nn.ModuleList(nets)
self.downsample = None
if downsample:
self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
def forward(self, x):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
def _forward(self, x):
for net in self.nets:
x = net(x)
if self.downsample:
x = self.downsample(x)
return x
class MidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
attention: bool = True,
attention_heads: int = 8,
skip_scale: float = 1,
gradient_checkpointing: bool = False,
):
super().__init__()
self.gradient_checkpointing = gradient_checkpointing
nets = []
attns = []
# first layer
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
# more layers
for i in range(num_layers):
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
if attention:
attns.append(VolumeAttention(in_channels, attention_heads, skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)
def forward(self, x):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
def _forward(self, x):
x = self.nets[0](x)
for attn, net in zip(self.attns, self.nets[1:]):
if attn:
x = attn(x)
x = net(x)
return x
class UpBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
upsample: bool = True,
skip_scale: float = 1,
gradient_checkpointing: bool = False,
):
super().__init__()
self.gradient_checkpointing = gradient_checkpointing
nets = []
for i in range(num_layers):
cin = in_channels if i == 0 else out_channels
nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale))
self.nets = nn.ModuleList(nets)
self.upsample = None
if upsample:
self.upsample = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, x, use_reentrant=False)
else:
return self._forward(x)
def _forward(self, x):
for net in self.nets:
x = net(x)
if self.upsample:
x = self.upsample(x)
return x
class Encoder(nn.Module):
def __init__(
self,
in_channels: int = 1,
out_channels: int = 2 * 16, # double_z
down_channels: Tuple[int, ...] = (8, 16, 32, 64),
mid_attention: bool = True,
layers_per_block: int = 2,
skip_scale: float = np.sqrt(0.5),
gradient_checkpointing: bool = False,
):
super().__init__()
# input (first downsample)
self.conv_in = nn.Conv3d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
# down
down_blocks = []
cout = down_channels[0]
for i in range(len(down_channels)):
cin = cout
cout = down_channels[i]
down_blocks.append(DownBlock(
cin, cout,
num_layers=layers_per_block,
downsample=(i != len(down_channels) - 1), # not final layer
skip_scale=skip_scale,
gradient_checkpointing=gradient_checkpointing,
))
self.down_blocks = nn.ModuleList(down_blocks)
# mid
self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale)
# last
self.norm_out = nn.GroupNorm(num_channels=down_channels[-1], num_groups=32, eps=1e-5)
self.conv_out = nn.Conv3d(down_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# x: [B, Cin, H, W, D]
# first
x = self.conv_in(x)
# down
for block in self.down_blocks:
x = block(x)
# mid
x = self.mid_block(x)
# last
x = self.norm_out(x)
x = F.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
def __init__(
self,
in_channels: int = 16,
out_channels: int = 1,
up_channels: Tuple[int, ...] = (64, 32, 16, 8),
mid_attention: bool = True,
layers_per_block: int = 2,
skip_scale: float = np.sqrt(0.5),
gradient_checkpointing: bool = False,
):
super().__init__()
# first
self.conv_in = nn.Conv3d(in_channels, up_channels[0], kernel_size=3, stride=1, padding=1)
# mid
self.mid_block = MidBlock(up_channels[0], attention=mid_attention, skip_scale=skip_scale)
# up
up_blocks = []
cout = up_channels[0]
for i in range(len(up_channels)):
cin = cout
cout = up_channels[i]
up_blocks.append(UpBlock(
cin, cout,
num_layers=layers_per_block,
upsample=(i != len(up_channels) - 1), # not final layer
skip_scale=skip_scale,
gradient_checkpointing=gradient_checkpointing,
))
self.up_blocks = nn.ModuleList(up_blocks)
# last (upsample)
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=min(32, up_channels[-1]), eps=1e-5)
self.conv_out = nn.ConvTranspose3d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# x: [B, Cin, H, W, D]
# first
x = self.conv_in(x)
# mid
x = self.mid_block(x)
# up
for block in self.up_blocks:
x = block(x)
# last
x = self.norm_out(x)
x = F.silu(x)
x = self.conv_out(x)
return x
class VAE(nn.Module):
def __init__(
self,
in_channels: int = 1,
latent_channels: int = 16,
out_channels: int = 1,
down_channels: Tuple[int, ...] = (16, 32, 64, 128, 256),
mid_attention: bool = True,
up_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
layers_per_block: int = 2,
skip_scale: float = np.sqrt(0.5),
gradient_checkpointing: bool = False,
):
super().__init__()
# encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=2 * latent_channels, # double_z
down_channels=down_channels,
mid_attention=mid_attention,
layers_per_block=layers_per_block,
skip_scale=skip_scale,
gradient_checkpointing=gradient_checkpointing,
)
# decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_channels=up_channels,
mid_attention=mid_attention,
layers_per_block=layers_per_block,
skip_scale=skip_scale,
gradient_checkpointing=gradient_checkpointing,
)
# quant
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1)
def encode(self, x):
x = self.encoder(x)
x = self.quant_conv(x)
posterior = DiagonalGaussianDistribution(x)
return posterior
def decode(self, x):
x = self.post_quant_conv(x)
x = self.decoder(x)
return x
def forward(self, x, sample=True):
# x: [B, Cin, H, W, D]
p = self.encode(x)
if sample:
z = p.sample()
else:
z = p.mode()
x = self.decode(z)
return x, p