Spaces:
Running
on
L40S
Running
on
L40S
import math | |
import torch | |
import torch.nn as nn | |
from ..attention import ImgToTriplaneTransformer | |
import math | |
from einops import rearrange | |
class ImgToTriplaneModel(nn.Module): | |
""" | |
The full UNet model with attention and timestep embedding. | |
:param in_channels: channels in the input Tensor. | |
:param model_channels: base channel count for the model. | |
:param out_channels: channels in the output Tensor. | |
:param num_res_blocks: number of residual blocks per downsample. | |
:param attention_resolutions: a collection of downsample rates at which | |
attention will take place. May be a set, list, or tuple. | |
For example, if this contains 4, then at 4x downsampling, attention | |
will be used. | |
:param dropout: the dropout probability. | |
:param channel_mult: channel multiplier for each level of the UNet. | |
:param conv_resample: if True, use learned convolutions for upsampling and | |
downsampling. | |
:param dims: determines if the signal is 1D, 2D, or 3D. | |
:param num_classes: if specified (as an int), then this model will be | |
class-conditional with `num_classes` classes. | |
:param use_checkpoint: use gradient checkpointing to reduce memory usage. | |
:param num_heads: the number of attention heads in each attention layer. | |
:param num_heads_channels: if specified, ignore num_heads and instead use | |
a fixed channel width per attention head. | |
:param num_heads_upsample: works with num_heads to set a different number | |
of heads for upsampling. Deprecated. | |
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. | |
:param resblock_updown: use residual blocks for up/downsampling. | |
:param use_new_attention_order: use a different attention pattern for potentially | |
increased efficiency. | |
""" | |
def __init__( | |
self, | |
pos_emb_size=32, | |
pos_emb_dim=1024, | |
cam_cond_dim=20, | |
n_heads=16, | |
d_head=64, | |
depth=16, | |
context_dim=768, | |
triplane_dim=80, | |
upsample_time=1, | |
use_fp16=False, | |
use_bf16=True, | |
): | |
super().__init__() | |
self.pos_emb_size = pos_emb_size | |
self.pos_emb_dim = pos_emb_dim | |
# init embedding | |
self.pos_emb = nn.Parameter(torch.zeros(1, 3 * pos_emb_size * pos_emb_size, pos_emb_dim)) | |
# TODO initialize pos_emb with a Gaussian random of zero-mean and std of 1/sqrt(1024). | |
# build image to triplane decoder | |
self.img_to_triplane_decoder = ImgToTriplaneTransformer( | |
query_dim=pos_emb_dim, n_heads=n_heads, | |
d_head=d_head, depth=depth, context_dim=context_dim, | |
triplane_size=pos_emb_size, | |
) | |
self.is_conv_upsampler = False | |
# build upsampler | |
self.triplane_dim = triplane_dim | |
if self.is_conv_upsampler: | |
upsamplers = [] | |
for i in range(upsample_time): | |
if i == 0: | |
upsampler = nn.ConvTranspose2d(in_channels=pos_emb_dim, out_channels=triplane_dim, | |
kernel_size=2, stride=2, | |
padding=0, output_padding=0) | |
upsamplers.append(upsampler) | |
else: | |
upsampler = nn.ConvTranspose2d(in_channels=triplane_dim, out_channels=triplane_dim, | |
kernel_size=2, stride=2, | |
padding=0, output_padding=0) | |
upsamplers.append(upsampler) | |
if upsamplers: | |
self.upsampler = nn.Sequential(*upsamplers) | |
else: | |
self.upsampler = nn.Conv2d(in_channels=pos_emb_dim, out_channels=triplane_dim, | |
kernel_size=3, stride=1, padding=1) | |
else: | |
self.upsample_ratio = 4 | |
self.upsampler = nn.Linear(in_features=pos_emb_dim, out_features=triplane_dim*(self.upsample_ratio**2)) | |
def forward(self, x, cam_cond=None, **kwargs): | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
:param y: an [N] Tensor of labels, if class-conditional. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
B = x.shape[0] | |
h = self.pos_emb.expand(B, -1, -1) | |
context = x | |
h = self.img_to_triplane_decoder(h, context=context) | |
h = h.view(B * 3, self.pos_emb_size, self.pos_emb_size, self.pos_emb_dim) | |
if self.is_conv_upsampler: | |
h = rearrange(h, 'b h w c -> b c h w') | |
h = self.upsampler(h) | |
h = rearrange(h, '(b d) c h w-> b d c h w', d=3) | |
h = h.type(x.dtype) | |
return h | |
else: | |
h = self.upsampler(h) #[b, h, w, triplane_dim*4] | |
b, height, width, _ = h.shape | |
h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) #[b, h, w, triplane_dim, 2, 2] | |
h = h.permute(0,3,1,4,2,5).contiguous() #[b, triplane_dim, h, 2, w, 2] | |
h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio) | |
h = rearrange(h, '(b d) c h w-> b d c h w', d=3) | |
h = h.type(x.dtype) | |
return h | |