LASA / models /modules /resunet.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
18.7 kB
import torch
import torch.nn as nn
from .unet import RollOut_Conv
from .Positional_Embedding import PositionalEmbedding
import torch.nn.functional as F
from .utils import zero_module
from .image_sampler import MultiImage_Fuse_Sampler, MultiImage_Global_Sampler,MultiImage_TriFuse_Sampler
class ResidualConv_MultiImgAtten(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding, reso=64,
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
norm="batch"):
super(ResidualConv_MultiImgAtten, self).__init__()
self.use_attn=use_attn
if norm=="batch":
norm_layer=nn.BatchNorm2d
elif norm==None:
norm_layer=nn.Identity
self.conv_block = nn.Sequential(
norm_layer(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, padding=padding
)
)
self.out_layer=nn.Sequential(
norm_layer(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
norm_layer(output_dim),
)
self.roll_out_conv=nn.Sequential(
norm_layer(output_dim),
nn.ReLU(),
RollOut_Conv(output_dim, output_dim),
)
if self.use_attn:
self.img_sampler = MultiImage_Fuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
out_channels=output_dim,padding=triplane_padding)
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
t_emb = t_emb[:,:,None,None]
out=self.conv_block(x)+t_emb
out=self.out_layer(out)
feature=out+self.conv_skip(x)
feature = self.roll_out_conv(feature)
if self.use_attn:
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
feature=self.down_conv(feature)
return feature
class ResidualConv_TriMultiImgAtten(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding, reso=64,
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
norm="batch"):
super(ResidualConv_TriMultiImgAtten, self).__init__()
self.use_attn=use_attn
if norm=="batch":
norm_layer=nn.BatchNorm2d
elif norm==None:
norm_layer=nn.Identity
self.conv_block = nn.Sequential(
norm_layer(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, padding=padding
)
)
self.out_layer=nn.Sequential(
norm_layer(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
norm_layer(output_dim),
)
self.roll_out_conv=nn.Sequential(
norm_layer(output_dim),
nn.ReLU(),
RollOut_Conv(output_dim, output_dim),
)
if self.use_attn:
self.img_sampler = MultiImage_TriFuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
out_channels=output_dim,max_nimg=5,padding=triplane_padding)
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
t_emb = t_emb[:,:,None,None]
out=self.conv_block(x)+t_emb
out=self.out_layer(out)
feature=out+self.conv_skip(x)
feature = self.roll_out_conv(feature)
if self.use_attn:
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
feature=self.down_conv(feature)
return feature
class ResidualConv_GlobalAtten(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding, reso=64,
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
norm="batch"):
super(ResidualConv_GlobalAtten, self).__init__()
self.use_attn=use_attn
if norm=="batch":
norm_layer=nn.BatchNorm2d
elif norm==None:
norm_layer=nn.Identity
self.conv_block = nn.Sequential(
norm_layer(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, padding=padding
)
)
self.out_layer=nn.Sequential(
norm_layer(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
norm_layer(output_dim),
)
self.roll_out_conv=nn.Sequential(
norm_layer(output_dim),
nn.ReLU(),
RollOut_Conv(output_dim, output_dim),
)
if self.use_attn:
self.img_sampler = MultiImage_Global_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
out_channels=output_dim,max_nimg=5,padding=triplane_padding)
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
t_emb = t_emb[:,:,None,None]
out=self.conv_block(x)+t_emb
out=self.out_layer(out)
feature=out+self.conv_skip(x)
feature = self.roll_out_conv(feature)
if self.use_attn:
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
feature=self.down_conv(feature)
return feature
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding, t_input_dim=256):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
RollOut_Conv(output_dim,output_dim),
)
self.out_layer=nn.Sequential(
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(output_dim),
)
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
def forward(self, x,t_emb):
t_emb = F.silu(self.map_layer0(t_emb))
t_emb = F.silu(self.map_layer1(t_emb))
t_emb = t_emb[:,:,None,None]
out=self.conv_block(x)+t_emb
out=self.out_layer(out)
return out + self.conv_skip(x)
class Upsample(nn.Module):
def __init__(self, input_dim, output_dim, kernel, stride):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
input_dim, output_dim, kernel_size=kernel, stride=stride
)
def forward(self, x):
return self.upsample(x)
class ResUnet_Par_cond(nn.Module):
def __init__(self, channel, filters=[64, 128, 256, 512, 1024],output_channel=32,par_channel=32):
super(ResUnet_Par_cond, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.residual_conv_1 = ResidualConv(filters[0]+par_channel, filters[1], 2, 1)
self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
self.residual_conv_3 = ResidualConv(filters[2], filters[3], 2, 1)
self.bridge = ResidualConv(filters[3],filters[4],2,1)
self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
self.up_residual_conv1 = ResidualConv(filters[4] + filters[3], filters[3], 1, 1)
self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
self.up_residual_conv2 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
self.up_residual_conv3 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
self.up_residual_conv4 = ResidualConv(filters[1] + filters[0]+par_channel, filters[0], 1, 1)
self.output_layer = nn.Sequential(
#nn.LayerNorm(filters[0]),
nn.LayerNorm(64),#normalize along width dimension, usually it should normalize along channel dimension,
# I don't know why, but the finetuning performance increase significantly
zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
)
self.par_channel=par_channel
self.par_conv=nn.Sequential(
nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
)
self.t_emb_layer=PositionalEmbedding(256)
self.cat_emb=nn.Linear(
in_features=6,
out_features=256,
)
def forward(self, x,t,category_code,par_point_feat):
# Encode
t_emb=self.t_emb_layer(t)
cat_emb=self.cat_emb(category_code)
t_emb=t_emb+cat_emb
#print(t_emb.shape)
x1 = self.input_layer(x) + self.input_skip(x)
if par_point_feat is not None:
par_point_feat=self.par_conv(par_point_feat)
else:
bs,_,H,W=x1.shape
#print(x1.shape)
par_point_feat=torch.zeros((bs,self.par_channel,H,W)).float().to(x1.device)
x1 = torch.cat([x1, par_point_feat], dim=1)
x2 = self.residual_conv_1(x1,t_emb)
x3 = self.residual_conv_2(x2,t_emb)
# Bridge
x4 = self.residual_conv_3(x3,t_emb)
x5 = self.bridge(x4,t_emb)
x6=self.upsample_1(x5)
x6=torch.cat([x6,x4],dim=1)
x7=self.up_residual_conv1(x6,t_emb)
x7=self.upsample_2(x7)
x7=torch.cat([x7,x3],dim=1)
x8=self.up_residual_conv2(x7,t_emb)
x8 = self.upsample_3(x8)
x8 = torch.cat([x8, x2], dim=1)
#print(x8.shape)
x9 = self.up_residual_conv3(x8,t_emb)
x9 = self.upsample_4(x9)
x9 = torch.cat([x9, x1], dim=1)
x10 = self.up_residual_conv4(x9,t_emb)
output=self.output_layer(x10)
return output
class ResUnet_DirectAttenMultiImg_Cond(nn.Module):
def __init__(self, channel, filters=[64, 128, 256, 512, 1024],
img_in_channels=1024,vit_reso=16,output_channel=32,
use_par=False,par_channel=32,triplane_padding=0.1,norm='batch',
use_cat_embedding=False,
block_type="multiview_local"):
super(ResUnet_DirectAttenMultiImg_Cond, self).__init__()
if block_type == "multiview_local":
block=ResidualConv_MultiImgAtten
elif block_type =="multiview_global":
block=ResidualConv_GlobalAtten
elif block_type =="multiview_tri":
block=ResidualConv_TriMultiImgAtten
else:
raise NotImplementedError
if norm=="batch":
norm_layer=nn.BatchNorm2d
elif norm==None:
norm_layer=nn.Identity
self.use_cat_embedding=use_cat_embedding
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
norm_layer(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.use_par=use_par
input_1_channels=filters[0]
if self.use_par:
self.par_conv = nn.Sequential(
nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
)
input_1_channels=filters[0]+par_channel
self.residual_conv_1 = block(input_1_channels, filters[1], 2, 1,reso=64
,use_attn=False,triplane_padding=triplane_padding,norm=norm)
self.residual_conv_2 = block(filters[1], filters[2], 2, 1, reso=32,
use_attn=False,triplane_padding=triplane_padding,norm=norm)
self.residual_conv_3 = block(filters[2], filters[3], 2, 1,reso=16,
use_attn=False,triplane_padding=triplane_padding,norm=norm)
self.bridge = block(filters[3] , filters[4], 2, 1, reso=8
,use_attn=False,triplane_padding=triplane_padding,norm=norm) #input reso is 8, output reso is 4
self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
self.up_residual_conv1 = block(filters[4] + filters[3], filters[3], 1, 1,reso=8,img_in_channels=img_in_channels,vit_reso=vit_reso,
use_attn=True,triplane_padding=triplane_padding,norm=norm)
self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
self.up_residual_conv2 = block(filters[3] + filters[2], filters[2], 1, 1,reso=16,img_in_channels=img_in_channels,vit_reso=vit_reso,
use_attn=True,triplane_padding=triplane_padding,norm=norm)
self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
self.up_residual_conv3 = block(filters[2] + filters[1], filters[1], 1, 1,reso=32,img_in_channels=img_in_channels,vit_reso=vit_reso,
use_attn=True,triplane_padding=triplane_padding,norm=norm)
self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
self.up_residual_conv4 = block(filters[1] + input_1_channels, filters[0], 1, 1, reso=64,
use_attn=False,triplane_padding=triplane_padding,norm=norm)
self.output_layer = nn.Sequential(
nn.LayerNorm(64), #normalize along width dimension, usually it should normalize along channel dimension,
# I don't know why, but the finetuning performance increase significantly
#nn.LayerNorm([filters[0], 192, 64]),
zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
)
self.t_emb_layer=PositionalEmbedding(256)
if use_cat_embedding:
self.cat_emb = nn.Linear(
in_features=6,
out_features=256,
)
def forward(self, x,t,image_emb,proj_mat,valid_frames,category_code,par_point_feat=None):
# Encode
t_emb=self.t_emb_layer(t)
if self.use_cat_embedding:
cat_emb=self.cat_emb(category_code)
t_emb=t_emb+cat_emb
x1 = self.input_layer(x) + self.input_skip(x)
if self.use_par:
par_point_feat=self.par_conv(par_point_feat)
x1 = torch.cat([x1, par_point_feat], dim=1)
x2 = self.residual_conv_1(x1,t_emb,image_emb,proj_mat,valid_frames)
x3 = self.residual_conv_2(x2,t_emb,image_emb,proj_mat,valid_frames)
x4 = self.residual_conv_3(x3,t_emb,image_emb,proj_mat,valid_frames)
x5 = self.bridge(x4,t_emb,image_emb,proj_mat,valid_frames)
x6=self.upsample_1(x5)
x6=torch.cat([x6,x4],dim=1)
x7=self.up_residual_conv1(x6,t_emb,image_emb,proj_mat,valid_frames)
x7=self.upsample_2(x7)
x7=torch.cat([x7,x3],dim=1)
x8=self.up_residual_conv2(x7,t_emb,image_emb,proj_mat,valid_frames)
x8 = self.upsample_3(x8)
x8 = torch.cat([x8, x2], dim=1)
#print(x8.shape)
x9 = self.up_residual_conv3(x8,t_emb,image_emb,proj_mat,valid_frames)
x9 = self.upsample_4(x9)
x9 = torch.cat([x9, x1], dim=1)
x10 = self.up_residual_conv4(x9,t_emb,image_emb,proj_mat,valid_frames)
output=self.output_layer(x10)
return output
if __name__=="__main__":
net=ResUnet(32,output_channel=32).float().cuda()
n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("Model = %s" % str(net))
print('number of params (M): %.2f' % (n_parameters / 1.e6))
par_point_feat=torch.randn((10,32,64*3,64)).float().cuda()
input=torch.randn((10,32,64*3,64)).float().cuda()
t=torch.randn((10,1,1,1)).float().cuda()
output=net(input,t.flatten(),par_point_feat)
#print(output.shape)