|
import torch |
|
import torch.nn as nn |
|
from .unet import RollOut_Conv |
|
from .Positional_Embedding import PositionalEmbedding |
|
import torch.nn.functional as F |
|
from .utils import zero_module |
|
from .image_sampler import MultiImage_Fuse_Sampler, MultiImage_Global_Sampler,MultiImage_TriFuse_Sampler |
|
|
|
class ResidualConv_MultiImgAtten(nn.Module): |
|
def __init__(self, input_dim, output_dim, stride, padding, reso=64, |
|
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1, |
|
norm="batch"): |
|
super(ResidualConv_MultiImgAtten, self).__init__() |
|
self.use_attn=use_attn |
|
|
|
if norm=="batch": |
|
norm_layer=nn.BatchNorm2d |
|
elif norm==None: |
|
norm_layer=nn.Identity |
|
|
|
self.conv_block = nn.Sequential( |
|
norm_layer(input_dim), |
|
nn.ReLU(), |
|
nn.Conv2d( |
|
input_dim, output_dim, kernel_size=3, padding=padding |
|
) |
|
) |
|
self.out_layer=nn.Sequential( |
|
norm_layer(output_dim), |
|
nn.ReLU(), |
|
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), |
|
) |
|
self.conv_skip = nn.Sequential( |
|
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1), |
|
norm_layer(output_dim), |
|
) |
|
self.roll_out_conv=nn.Sequential( |
|
norm_layer(output_dim), |
|
nn.ReLU(), |
|
RollOut_Conv(output_dim, output_dim), |
|
) |
|
if self.use_attn: |
|
self.img_sampler = MultiImage_Fuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim, |
|
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso, |
|
out_channels=output_dim,padding=triplane_padding) |
|
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding) |
|
|
|
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) |
|
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) |
|
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames): |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
t_emb = t_emb[:,:,None,None] |
|
|
|
out=self.conv_block(x)+t_emb |
|
out=self.out_layer(out) |
|
feature=out+self.conv_skip(x) |
|
feature = self.roll_out_conv(feature) |
|
if self.use_attn: |
|
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature |
|
feature=self.down_conv(feature) |
|
|
|
return feature |
|
|
|
class ResidualConv_TriMultiImgAtten(nn.Module): |
|
def __init__(self, input_dim, output_dim, stride, padding, reso=64, |
|
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1, |
|
norm="batch"): |
|
super(ResidualConv_TriMultiImgAtten, self).__init__() |
|
self.use_attn=use_attn |
|
|
|
if norm=="batch": |
|
norm_layer=nn.BatchNorm2d |
|
elif norm==None: |
|
norm_layer=nn.Identity |
|
|
|
self.conv_block = nn.Sequential( |
|
norm_layer(input_dim), |
|
nn.ReLU(), |
|
nn.Conv2d( |
|
input_dim, output_dim, kernel_size=3, padding=padding |
|
) |
|
) |
|
self.out_layer=nn.Sequential( |
|
norm_layer(output_dim), |
|
nn.ReLU(), |
|
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), |
|
) |
|
self.conv_skip = nn.Sequential( |
|
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1), |
|
norm_layer(output_dim), |
|
) |
|
self.roll_out_conv=nn.Sequential( |
|
norm_layer(output_dim), |
|
nn.ReLU(), |
|
RollOut_Conv(output_dim, output_dim), |
|
) |
|
if self.use_attn: |
|
self.img_sampler = MultiImage_TriFuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim, |
|
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso, |
|
out_channels=output_dim,max_nimg=5,padding=triplane_padding) |
|
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding) |
|
|
|
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) |
|
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) |
|
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames): |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
t_emb = t_emb[:,:,None,None] |
|
|
|
out=self.conv_block(x)+t_emb |
|
out=self.out_layer(out) |
|
feature=out+self.conv_skip(x) |
|
feature = self.roll_out_conv(feature) |
|
if self.use_attn: |
|
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature |
|
feature=self.down_conv(feature) |
|
|
|
return feature |
|
|
|
|
|
class ResidualConv_GlobalAtten(nn.Module): |
|
def __init__(self, input_dim, output_dim, stride, padding, reso=64, |
|
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1, |
|
norm="batch"): |
|
super(ResidualConv_GlobalAtten, self).__init__() |
|
self.use_attn=use_attn |
|
|
|
if norm=="batch": |
|
norm_layer=nn.BatchNorm2d |
|
elif norm==None: |
|
norm_layer=nn.Identity |
|
|
|
self.conv_block = nn.Sequential( |
|
norm_layer(input_dim), |
|
nn.ReLU(), |
|
nn.Conv2d( |
|
input_dim, output_dim, kernel_size=3, padding=padding |
|
) |
|
) |
|
self.out_layer=nn.Sequential( |
|
norm_layer(output_dim), |
|
nn.ReLU(), |
|
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), |
|
) |
|
self.conv_skip = nn.Sequential( |
|
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1), |
|
norm_layer(output_dim), |
|
) |
|
self.roll_out_conv=nn.Sequential( |
|
norm_layer(output_dim), |
|
nn.ReLU(), |
|
RollOut_Conv(output_dim, output_dim), |
|
) |
|
if self.use_attn: |
|
self.img_sampler = MultiImage_Global_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim, |
|
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso, |
|
out_channels=output_dim,max_nimg=5,padding=triplane_padding) |
|
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding) |
|
|
|
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) |
|
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) |
|
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames): |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
t_emb = t_emb[:,:,None,None] |
|
|
|
out=self.conv_block(x)+t_emb |
|
out=self.out_layer(out) |
|
feature=out+self.conv_skip(x) |
|
feature = self.roll_out_conv(feature) |
|
if self.use_attn: |
|
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature |
|
feature=self.down_conv(feature) |
|
|
|
return feature |
|
|
|
class ResidualConv(nn.Module): |
|
def __init__(self, input_dim, output_dim, stride, padding, t_input_dim=256): |
|
super(ResidualConv, self).__init__() |
|
|
|
self.conv_block = nn.Sequential( |
|
nn.BatchNorm2d(input_dim), |
|
nn.ReLU(), |
|
nn.Conv2d( |
|
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding |
|
), |
|
nn.BatchNorm2d(output_dim), |
|
nn.ReLU(), |
|
RollOut_Conv(output_dim,output_dim), |
|
) |
|
self.out_layer=nn.Sequential( |
|
nn.BatchNorm2d(output_dim), |
|
nn.ReLU(), |
|
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), |
|
) |
|
self.conv_skip = nn.Sequential( |
|
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), |
|
nn.BatchNorm2d(output_dim), |
|
) |
|
|
|
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) |
|
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) |
|
def forward(self, x,t_emb): |
|
t_emb = F.silu(self.map_layer0(t_emb)) |
|
t_emb = F.silu(self.map_layer1(t_emb)) |
|
t_emb = t_emb[:,:,None,None] |
|
|
|
out=self.conv_block(x)+t_emb |
|
out=self.out_layer(out) |
|
|
|
return out + self.conv_skip(x) |
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, input_dim, output_dim, kernel, stride): |
|
super(Upsample, self).__init__() |
|
|
|
self.upsample = nn.ConvTranspose2d( |
|
input_dim, output_dim, kernel_size=kernel, stride=stride |
|
) |
|
|
|
def forward(self, x): |
|
return self.upsample(x) |
|
|
|
|
|
|
|
class ResUnet_Par_cond(nn.Module): |
|
def __init__(self, channel, filters=[64, 128, 256, 512, 1024],output_channel=32,par_channel=32): |
|
super(ResUnet_Par_cond, self).__init__() |
|
|
|
self.input_layer = nn.Sequential( |
|
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), |
|
nn.BatchNorm2d(filters[0]), |
|
nn.ReLU(), |
|
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), |
|
) |
|
self.input_skip = nn.Sequential( |
|
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) |
|
) |
|
|
|
self.residual_conv_1 = ResidualConv(filters[0]+par_channel, filters[1], 2, 1) |
|
self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) |
|
self.residual_conv_3 = ResidualConv(filters[2], filters[3], 2, 1) |
|
self.bridge = ResidualConv(filters[3],filters[4],2,1) |
|
|
|
|
|
self.upsample_1 = Upsample(filters[4], filters[4], 2, 2) |
|
self.up_residual_conv1 = ResidualConv(filters[4] + filters[3], filters[3], 1, 1) |
|
|
|
self.upsample_2 = Upsample(filters[3], filters[3], 2, 2) |
|
self.up_residual_conv2 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) |
|
|
|
self.upsample_3 = Upsample(filters[2], filters[2], 2, 2) |
|
self.up_residual_conv3 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) |
|
|
|
self.upsample_4 = Upsample(filters[1], filters[1], 2, 2) |
|
self.up_residual_conv4 = ResidualConv(filters[1] + filters[0]+par_channel, filters[0], 1, 1) |
|
|
|
self.output_layer = nn.Sequential( |
|
|
|
nn.LayerNorm(64), |
|
|
|
zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)), |
|
) |
|
self.par_channel=par_channel |
|
self.par_conv=nn.Sequential( |
|
nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1), |
|
) |
|
self.t_emb_layer=PositionalEmbedding(256) |
|
self.cat_emb=nn.Linear( |
|
in_features=6, |
|
out_features=256, |
|
) |
|
|
|
def forward(self, x,t,category_code,par_point_feat): |
|
|
|
t_emb=self.t_emb_layer(t) |
|
cat_emb=self.cat_emb(category_code) |
|
t_emb=t_emb+cat_emb |
|
|
|
x1 = self.input_layer(x) + self.input_skip(x) |
|
if par_point_feat is not None: |
|
par_point_feat=self.par_conv(par_point_feat) |
|
else: |
|
bs,_,H,W=x1.shape |
|
|
|
par_point_feat=torch.zeros((bs,self.par_channel,H,W)).float().to(x1.device) |
|
x1 = torch.cat([x1, par_point_feat], dim=1) |
|
x2 = self.residual_conv_1(x1,t_emb) |
|
x3 = self.residual_conv_2(x2,t_emb) |
|
|
|
x4 = self.residual_conv_3(x3,t_emb) |
|
x5 = self.bridge(x4,t_emb) |
|
|
|
x6=self.upsample_1(x5) |
|
x6=torch.cat([x6,x4],dim=1) |
|
x7=self.up_residual_conv1(x6,t_emb) |
|
|
|
x7=self.upsample_2(x7) |
|
x7=torch.cat([x7,x3],dim=1) |
|
x8=self.up_residual_conv2(x7,t_emb) |
|
|
|
x8 = self.upsample_3(x8) |
|
x8 = torch.cat([x8, x2], dim=1) |
|
|
|
x9 = self.up_residual_conv3(x8,t_emb) |
|
|
|
x9 = self.upsample_4(x9) |
|
x9 = torch.cat([x9, x1], dim=1) |
|
x10 = self.up_residual_conv4(x9,t_emb) |
|
|
|
output=self.output_layer(x10) |
|
|
|
return output |
|
|
|
class ResUnet_DirectAttenMultiImg_Cond(nn.Module): |
|
def __init__(self, channel, filters=[64, 128, 256, 512, 1024], |
|
img_in_channels=1024,vit_reso=16,output_channel=32, |
|
use_par=False,par_channel=32,triplane_padding=0.1,norm='batch', |
|
use_cat_embedding=False, |
|
block_type="multiview_local"): |
|
super(ResUnet_DirectAttenMultiImg_Cond, self).__init__() |
|
|
|
if block_type == "multiview_local": |
|
block=ResidualConv_MultiImgAtten |
|
elif block_type =="multiview_global": |
|
block=ResidualConv_GlobalAtten |
|
elif block_type =="multiview_tri": |
|
block=ResidualConv_TriMultiImgAtten |
|
else: |
|
raise NotImplementedError |
|
|
|
if norm=="batch": |
|
norm_layer=nn.BatchNorm2d |
|
elif norm==None: |
|
norm_layer=nn.Identity |
|
|
|
self.use_cat_embedding=use_cat_embedding |
|
self.input_layer = nn.Sequential( |
|
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), |
|
norm_layer(filters[0]), |
|
nn.ReLU(), |
|
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), |
|
) |
|
self.input_skip = nn.Sequential( |
|
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) |
|
) |
|
self.use_par=use_par |
|
input_1_channels=filters[0] |
|
if self.use_par: |
|
self.par_conv = nn.Sequential( |
|
nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1), |
|
) |
|
input_1_channels=filters[0]+par_channel |
|
self.residual_conv_1 = block(input_1_channels, filters[1], 2, 1,reso=64 |
|
,use_attn=False,triplane_padding=triplane_padding,norm=norm) |
|
self.residual_conv_2 = block(filters[1], filters[2], 2, 1, reso=32, |
|
use_attn=False,triplane_padding=triplane_padding,norm=norm) |
|
self.residual_conv_3 = block(filters[2], filters[3], 2, 1,reso=16, |
|
use_attn=False,triplane_padding=triplane_padding,norm=norm) |
|
self.bridge = block(filters[3] , filters[4], 2, 1, reso=8 |
|
,use_attn=False,triplane_padding=triplane_padding,norm=norm) |
|
|
|
|
|
self.upsample_1 = Upsample(filters[4], filters[4], 2, 2) |
|
self.up_residual_conv1 = block(filters[4] + filters[3], filters[3], 1, 1,reso=8,img_in_channels=img_in_channels,vit_reso=vit_reso, |
|
use_attn=True,triplane_padding=triplane_padding,norm=norm) |
|
|
|
self.upsample_2 = Upsample(filters[3], filters[3], 2, 2) |
|
self.up_residual_conv2 = block(filters[3] + filters[2], filters[2], 1, 1,reso=16,img_in_channels=img_in_channels,vit_reso=vit_reso, |
|
use_attn=True,triplane_padding=triplane_padding,norm=norm) |
|
|
|
self.upsample_3 = Upsample(filters[2], filters[2], 2, 2) |
|
self.up_residual_conv3 = block(filters[2] + filters[1], filters[1], 1, 1,reso=32,img_in_channels=img_in_channels,vit_reso=vit_reso, |
|
use_attn=True,triplane_padding=triplane_padding,norm=norm) |
|
|
|
self.upsample_4 = Upsample(filters[1], filters[1], 2, 2) |
|
self.up_residual_conv4 = block(filters[1] + input_1_channels, filters[0], 1, 1, reso=64, |
|
use_attn=False,triplane_padding=triplane_padding,norm=norm) |
|
|
|
self.output_layer = nn.Sequential( |
|
nn.LayerNorm(64), |
|
|
|
|
|
zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)), |
|
) |
|
self.t_emb_layer=PositionalEmbedding(256) |
|
if use_cat_embedding: |
|
self.cat_emb = nn.Linear( |
|
in_features=6, |
|
out_features=256, |
|
) |
|
|
|
def forward(self, x,t,image_emb,proj_mat,valid_frames,category_code,par_point_feat=None): |
|
|
|
t_emb=self.t_emb_layer(t) |
|
if self.use_cat_embedding: |
|
cat_emb=self.cat_emb(category_code) |
|
t_emb=t_emb+cat_emb |
|
x1 = self.input_layer(x) + self.input_skip(x) |
|
if self.use_par: |
|
par_point_feat=self.par_conv(par_point_feat) |
|
x1 = torch.cat([x1, par_point_feat], dim=1) |
|
x2 = self.residual_conv_1(x1,t_emb,image_emb,proj_mat,valid_frames) |
|
x3 = self.residual_conv_2(x2,t_emb,image_emb,proj_mat,valid_frames) |
|
x4 = self.residual_conv_3(x3,t_emb,image_emb,proj_mat,valid_frames) |
|
x5 = self.bridge(x4,t_emb,image_emb,proj_mat,valid_frames) |
|
|
|
x6=self.upsample_1(x5) |
|
x6=torch.cat([x6,x4],dim=1) |
|
x7=self.up_residual_conv1(x6,t_emb,image_emb,proj_mat,valid_frames) |
|
|
|
x7=self.upsample_2(x7) |
|
x7=torch.cat([x7,x3],dim=1) |
|
x8=self.up_residual_conv2(x7,t_emb,image_emb,proj_mat,valid_frames) |
|
|
|
x8 = self.upsample_3(x8) |
|
x8 = torch.cat([x8, x2], dim=1) |
|
|
|
x9 = self.up_residual_conv3(x8,t_emb,image_emb,proj_mat,valid_frames) |
|
|
|
x9 = self.upsample_4(x9) |
|
x9 = torch.cat([x9, x1], dim=1) |
|
x10 = self.up_residual_conv4(x9,t_emb,image_emb,proj_mat,valid_frames) |
|
|
|
output=self.output_layer(x10) |
|
|
|
return output |
|
|
|
|
|
if __name__=="__main__": |
|
net=ResUnet(32,output_channel=32).float().cuda() |
|
n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad) |
|
print("Model = %s" % str(net)) |
|
print('number of params (M): %.2f' % (n_parameters / 1.e6)) |
|
par_point_feat=torch.randn((10,32,64*3,64)).float().cuda() |
|
input=torch.randn((10,32,64*3,64)).float().cuda() |
|
t=torch.randn((10,1,1,1)).float().cuda() |
|
output=net(input,t.flatten(),par_point_feat) |
|
|