LASA / models /modules /image_sampler.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
55.7 kB
import sys
sys.path.append('../..')
import torch
import torch.nn as nn
import math
from models.modules.unet import RollOut_Conv
from einops import rearrange, reduce
MB =1024.0*1024.0
def mask_kernel(x, sigma=1):
return torch.abs(x) < sigma #if the distance is smaller than the kernel size, return True
def mask_kernel_close_false(x, sigma=1):
return torch.abs(x) > sigma #if the distance is smaller than the kernel size, return False
class Image_Local_Sampler(nn.Module):
def __init__(self,reso,padding=0.1,in_channels=1280,out_channels=512):
super().__init__()
self.triplane_reso=reso
self.padding=padding
self.get_triplane_coord()
self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1)
def get_triplane_coord(self):
'''xz plane firstly, z is at the '''
x=torch.arange(self.triplane_reso)
z=torch.arange(self.triplane_reso)
X,Z=torch.meshgrid(x,z,indexing='xy')
xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
'''xy plane'''
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
X, Y = torch.meshgrid(x, y, indexing='xy')
xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
'''yz plane'''
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
Y,Z = torch.meshgrid(y,z,indexing='xy')
yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
triplane_coords=triplane_coords/(self.triplane_reso-1)
triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
self.triplane_coords=triplane_coords.float().cuda()
def forward(self,image_feat,proj_mat):
image_feat=self.img_proj(image_feat)
batch_size=image_feat.shape[0]
triplane_coords=self.triplane_coords.unsqueeze(0).expand(batch_size,-1,-1,-1) #B,192,64,3
#print(torch.amin(triplane_coords),torch.amax(triplane_coords))
coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,triplane_coords.shape[1],triplane_coords.shape[2],1)).float().cuda()],dim=-1)
coord_inimg=torch.einsum('bhwc,bck->bhwk',coord_homo,proj_mat.transpose(1,2))
x=coord_inimg[:,:,:,0]/coord_inimg[:,:,:,2]
y=coord_inimg[:,:,:,1]/coord_inimg[:,:,:,2]
x=(x/(224.0-1.0)-0.5)*2 #-1~1
y=(y/(224.0-1.0)-0.5)*2 #-1~1
dist=coord_inimg[:,:,:,2]
xy=torch.cat([x[:,:,:,None],y[:,:,:,None]],dim=-1)
#print(image_feat.shape,xy.shape)
sample_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear')
return sample_feat
def position_encoding(d_model, length):
if d_model % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
pe = torch.zeros(length, d_model)
position = torch.arange(0, length).unsqueeze(1) #length,1
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
-(math.log(10000.0) / d_model))) #d_model//2, this is the frequency
pe[:, 0::2] = torch.sin(position.float() * div_term) #length*(d_model//2)
pe[:, 1::2] = torch.cos(position.float() * div_term)
return pe
class Image_Vox_Local_Sampler(nn.Module):
def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
super().__init__()
self.triplane_reso=reso
self.padding=padding
self.get_vox_coord()
self.out_channels=out_channels
self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
self.vox_process=nn.Sequential(
nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1,),
)
self.k=nn.Linear(in_features=inner_channel,out_features=inner_channel)
self.q=nn.Linear(in_features=inner_channel,out_features=inner_channel)
self.v=nn.Linear(in_features=inner_channel,out_features=inner_channel)
self.attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
self.condition_pe = position_encoding(inner_channel, self.triplane_reso).unsqueeze(0)
def get_vox_coord(self):
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
vox_coor=vox_coor/(self.triplane_reso-1)
vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
self.vox_coor=vox_coor.view(-1,3).float().cuda()
def forward(self,triplane_feat,image_feat,proj_mat):
xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
image_feat=self.img_proj(image_feat)
batch_size=image_feat.shape[0]
vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
x=(x/(224.0-1.0)-0.5)*2 #-1~1
y=(y/(224.0-1.0)-0.5)*2 #-1~1
xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
#print(image_feat.shape,xy.shape)
grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
grid_feat=self.vox_process(grid_feat)
xzy_grid=grid_feat.permute(0,4,2,3,1)
xz_as_query=xz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
xz_as_key=xzy_grid.reshape(batch_size*self.triplane_reso**2,self.triplane_reso,-1)
xyz_grid=grid_feat.permute(0,3,2,4,1)
xy_as_query=xy_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
xy_as_key = xyz_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
yzx_grid = grid_feat.permute(0, 4, 3, 2, 1)
yz_as_query = yz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
yz_as_key = yzx_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
query=self.q(torch.cat([xz_as_query,xy_as_query,yz_as_query],dim=0))
key=self.k(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
value=self.v(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
attn,_=self.attn(query,key,value)
xz_plane,xy_plane,yz_plane=torch.split(attn,dim=0,split_size_or_sections=batch_size*self.triplane_reso**2)
xz_plane=xz_plane.reshape(batch_size,self.triplane_reso,self.triplane_reso,-1).permute(0,3,1,2)
xy_plane = xy_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
yz_plane = yz_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
triplane_wImg=torch.cat([xz_plane,xy_plane,yz_plane],dim=2)
triplane_wImg=self.proj_out(triplane_wImg)
#print(triplane_wImg.shape)
return triplane_wImg
class Image_Direct_AttenwMask_Sampler(nn.Module):
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
super().__init__()
self.triplane_reso=reso
self.vit_reso=vit_reso
self.padding=padding
self.n_heads=n_heads
self.get_plane_expand_coord()
self.get_vit_coords()
self.out_channels=out_channels
self.kernel_func=mask_kernel
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
self.image_pe = position_encoding(inner_channel, self.vit_reso**2+1).unsqueeze(0).cuda().float() #1,n_img*reso*reso,channel
self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
def get_plane_expand_coord(self):
x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
first,second,third=torch.meshgrid(x,y,z,indexing='xy')
xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
# print(xyz_coords[0,0,0],xyz_coords[0,0,1],xyz_coords[1,0,0],xyz_coords[0,1,0])
# print(xzy_coords[0, 0, 0], xzy_coords[0, 0, 1], xzy_coords[1, 0, 0], xzy_coords[0, 1, 0])
# print(yzx_coords[0, 0, 0], yzx_coords[0, 0, 1], yzx_coords[1, 0, 0], yzx_coords[0, 1, 0])
xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
self.plane_coords=coords.cuda().float()
# self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
# self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
# self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
def get_vit_coords(self):
x=torch.arange(self.vit_reso)
y=torch.arange(self.vit_reso)
X,Y=torch.meshgrid(x,y,indexing='xy')
vit_coords=torch.stack([X,Y],dim=-1)
self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
def get_attn_mask(self,coords_proj,vit_coords,kernel_size=1.0):
'''
:param coords_proj: B,reso**3,2, in range of 0~1
:param vit_coords: B,vit_reso**2,2, in range of 0~vit_reso
:param kernel_size: 0.5, so that only one pixel will be select
:return:
'''
bs=coords_proj.shape[0]
coords_proj=coords_proj*(self.vit_reso-1)
#print(torch.amin(coords_proj[0,0:self.triplane_reso**3]),torch.amax(coords_proj[0,0:self.triplane_reso**3]))
dist=torch.cdist(coords_proj.float(),vit_coords.float())
mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B,3*reso**3,vit_reso**2
mask=mask.reshape(bs,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
mask=torch.sum(mask,dim=2)
attn_mask=(mask==0)
return attn_mask
def forward(self,triplane_feat,image_feat,proj_mat):
#xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
batch_size=image_feat.shape[0]
#print(self.plane_coords.shape)
coords=self.plane_coords.unsqueeze(0).expand(batch_size,-1,-1)
coords_homo=torch.cat([coords,torch.ones(batch_size,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
coords_inimg=torch.einsum('bhc,bck->bhk',coords_homo,proj_mat.transpose(1,2))
coords_x=coords_inimg[:,:,0]/coords_inimg[:,:,2]/(224.0-1) #0~1
coords_y=coords_inimg[:,:,1]/coords_inimg[:,:,2]/(224.0-1) #0~1
coords_x=torch.clamp(coords_x,min=0.0,max=1.0)
coords_y=torch.clamp(coords_y,min=0.0,max=1.0)
#print(torch.amin(coords_x),torch.amax(coords_x))
coords_proj=torch.stack([coords_x,coords_y],dim=-1)
vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1)
attn_mask=torch.repeat_interleave(
self.get_attn_mask(coords_proj,vit_coords,kernel_size=1.0),self.n_heads, 0
)
attn_mask = torch.cat([torch.zeros([attn_mask.shape[0], attn_mask.shape[1], 1]).cuda().bool(), attn_mask],
dim=-1) # add global token
#print(attn_mask.shape,torch.sum(attn_mask.float()))
triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
#print(triplane_feat.shape,self.triplane_pe.shape)
query=self.q(triplane_feat)+self.triplane_pe
key=self.k(image_feat)+self.image_pe
value=self.v(image_feat)+self.image_pe
#print(query.shape,key.shape,value.shape)
attn,_=self.attn(query,key,value,attn_mask=attn_mask)
#print(attn.shape)
output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
return output
class MultiImage_Direct_AttenwMask_Sampler(nn.Module):
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
super().__init__()
self.triplane_reso=reso
self.vit_reso=vit_reso
self.padding=padding
self.n_heads=n_heads
self.get_plane_expand_coord()
self.get_vit_coords()
self.out_channels=out_channels
self.kernel_func=mask_kernel
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
self.image_pe = position_encoding(inner_channel, max_nimg*(self.vit_reso**2+1)).unsqueeze(0).cuda().float()
self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
def get_plane_expand_coord(self):
x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
first,second,third=torch.meshgrid(x,y,z,indexing='xy')
xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
self.plane_coords=coords.cuda().float()
# self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
# self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
# self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
def get_vit_coords(self):
x=torch.arange(self.vit_reso)
y=torch.arange(self.vit_reso)
X,Y=torch.meshgrid(x,y,indexing='xy')
vit_coords=torch.stack([X,Y],dim=-1)
self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
def get_attn_mask(self,coords_proj,vit_coords,valid_frames,kernel_size=1.0):
'''
:param coords_proj: B,n_img,3*reso**3,2, in range of 0~vit_reso
:param vit_coords: B,n_img,vit_reso**2,2, in range of 0~vit_reso
:param kernel_size: 0.5, so that only one pixel will be select
:return:
'''
bs,n_img=coords_proj.shape[0],coords_proj.shape[1]
coords_proj_flat=coords_proj.reshape(bs*n_img,3*self.triplane_reso**3,2)
vit_coords_flat=vit_coords.reshape(bs*n_img,self.vit_reso**2,2)
dist=torch.cdist(coords_proj_flat.float(),vit_coords_flat.float())
mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B*n_img,3*reso**3,vit_reso**2
mask=mask.reshape(bs,n_img,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
mask=torch.sum(mask,dim=3) #B,n_img,3*reso**2,vit_reso**2
mask=torch.cat([torch.ones(size=mask.shape[0:3]).unsqueeze(3).float().cuda(),mask],dim=-1) #B,n_img,3*reso**2,vit_reso**2+1, add global mask
mask[valid_frames == 0, :, :] = False
mask=mask.permute(0,2,1,3).reshape(bs,3*self.triplane_reso**2,-1) #B,3*reso**2,n_img*(vit_resso**2+1)
attn_mask=(mask==0) #invert the mask, False indicates valid, True indicates invalid
return attn_mask
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
'''image feat is bs,n_img,length,channel'''
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
img_length=image_feat.shape[2]
image_feat_flat=image_feat.view(batch_size,n_img*img_length,-1)
coords=self.plane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
#print(coord_homo.shape,proj_mat.shape)
coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
x = x/(224.0-1)
y = y/(224.0-1)
coords_x=torch.clamp(x,min=0.0,max=1.0)*(self.vit_reso-1)
coords_y=torch.clamp(y,min=0.0,max=1.0)*(self.vit_reso-1)
coords_proj=torch.stack([coords_x,coords_y],dim=-1)
vit_coords=self.vit_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
attn_mask=torch.repeat_interleave(
self.get_attn_mask(coords_proj,vit_coords,valid_frames,kernel_size=1.0),self.n_heads, 0
)
triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
query=self.q(triplane_feat)+self.triplane_pe
key=self.k(image_feat_flat)+self.image_pe
value=self.v(image_feat_flat)+self.image_pe
attn,_=self.attn(query,key,value,attn_mask=attn_mask)
output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
return output
class MultiImage_Fuse_Sampler(nn.Module):
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
super().__init__()
self.triplane_reso=reso
self.vit_reso=vit_reso
self.inner_channel=inner_channel
self.padding=padding
self.n_heads=n_heads
self.get_vox_coord()
self.get_vit_coords()
self.out_channels=out_channels
self.kernel_func=mask_kernel
self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
#self.cross_attn=CrossAttention(query_dim=inner_channel,heads=8,dim_head=inner_channel//8)
self.cross_attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].cuda().float() #1,1,length,channel
#self.image_pe = self.image_pe.reshape(1,max_nimg,self.vit_reso,self.vit_reso,inner_channel)
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
def get_vit_coords(self):
x = torch.arange(self.vit_reso)
y = torch.arange(self.vit_reso)
X, Y = torch.meshgrid(x, y, indexing='xy')
vit_coords = torch.stack([X, Y], dim=-1)
self.vit_coords = vit_coords.cuda().float() #reso,reso,2
def get_vox_coord(self):
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
vox_coor = torch.cat([X[:, :, :, None], Y[:, :, :, None], Z[:, :, :, None]], dim=-1)
self.vox_index = vox_coor.view(-1, 3).long().cuda()
vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
self.vox_coor = vox_coor.view(-1, 3).float().cuda()
def get_attn_mask(self,valid_frames):
'''
:param valid_frames: of shape B,n_img
'''
#print(valid_frames)
#bs,n_img=valid_frames.shape[0:2]
attn_mask=(valid_frames.float()==0)
#attn_mask=attn_mask.unsqueeze(1).unsqueeze(2).expand(-1,self.triplane_reso**3,-1,-1) #B,1,n_img
#attn_mask=attn_mask.reshape(bs*self.triplane_reso**3,-1,n_img).bool()
attn_mask=torch.repeat_interleave(attn_mask.unsqueeze(1),self.triplane_reso**3,0)
# print(attn_mask[self.triplane_reso**3*1+10])
# print(attn_mask[self.triplane_reso ** 3 * 2+10])
# print(attn_mask[self.triplane_reso ** 3 * 3+10])
return attn_mask
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
'''image feat is bs,n_img,length,channel'''
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
image_feat=image_feat[:,:,1:,:] #discard global feature
#image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
#unflat_k_v=image_k_v.permute(0,4,1,2,3)
#vit_coords=self.vit_coords[None,None].expand(batch_size,n_img,-1,-1,-1) #Bs,n_img,reso,reso,2
coords=self.vox_coor.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3,1).float().cuda()],dim=-1)
coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
x = x/(224.0-1) #0~1
y = y/(224.0-1)
coords_proj=torch.stack([x,y],dim=-1)
coords_proj=(coords_proj-0.5)*2
img_index=((torch.arange(n_img)[None,:,None,None].expand(
batch_size,-1,self.triplane_reso**3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
# img_index_feat=torch.arange(n_img)[None,:,None,None,None].expand(
# batch_size,-1,self.vit_reso,self.vit_reso,-1).float().cuda() #Bs,n_img,reso,reso,1
#coords_feat=torch.cat([vit_coords,img_index_feat],dim=-1).permute(0,4,1,2,3)#Bs,n_img,reso,reso,3
grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
grid=torch.clamp(grid,min=-1.0,max=1.0)
sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3,
dim=2) # B,C,64,64
xz_vox_feat=xz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,4,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zxy
xz_vox_feat=rearrange(xz_vox_feat, 'b c z x y -> b (x y z) c')
xy_vox_feat=xy_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,2,4).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #yxz
xy_vox_feat=rearrange(xy_vox_feat, 'b c y x z -> b (x y z) c')
yz_vox_feat=yz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,4,3,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zyx
yz_vox_feat=rearrange(yz_vox_feat, 'b c z y x -> b (x y z) c')
#xz_vox_feat = xz_feat[:, :, vox_index[:, 2], vox_index[:, 0]].transpose(1, 2) # B,C,64*64*64
#xy_vox_feat = xy_feat[:, :, vox_index[:, 1], vox_index[:, 0]].transpose(1, 2)
#yz_vox_feat = yz_feat[:, :, vox_index[:, 2], vox_index[:, 1]].transpose(1, 2)
triplane_expand_feat = torch.cat([xz_vox_feat, xy_vox_feat, yz_vox_feat], dim=-1) # B,64*64*64,3*C
triplane_query = self.q(triplane_expand_feat) + self.triplane_pe
k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
#k_v=sample_k_v.permute(0,3,2,1).reshape(batch_size*self.triplane_reso**3,n_img,-1) #B*64**3,n_img,C
k=k_v[:,:,0:self.inner_channel]
v=k_v[:,:,self.inner_channel:]
q=rearrange(triplane_query,'b k c -> (b k) 1 c')
#q=triplane_query.view(batch_size*self.triplane_reso**3,1,-1)
#k,v is of shape, B*reso**3,k,channel, q is of shape B*reso**3,1,channel
#attn mask should be B*reso**3*n_heads,1,k
#attn_mask=torch.repeat_interleave(self.get_attn_mask(valid_frames),self.n_heads,0)
#print(q.shape,k.shape,v.shape)
attn_out,_=self.cross_attn(q,k,v)#attn_mask=attn_mask) #fuse multi-view feature
#volume=attn_out.view(batch_size,self.triplane_reso,self.triplane_reso,self.triplane_reso,-1) #B,reso,reso,reso,channel
#print(attn_out.shape)
volume=rearrange(attn_out,'(b x y z) 1 c -> b x y z c',x=self.triplane_reso,y=self.triplane_reso,z=self.triplane_reso)
#xz_feat = torch.mean(volume, dim=2).transpose(1,2) #B,reso,reso,C
xz_feat = reduce(volume, "b x y z c -> b z x c", 'mean')
#xy_feat = torch.mean(volume, dim=3).transpose(1,2) #B,reso,reso,C
xy_feat= reduce(volume, 'b x y z c -> b y x c', 'mean')
#yz_feat = torch.mean(volume, dim=1).transpose(1,2) #B,reso,reso,C
yz_feat=reduce(volume, 'b x y z c -> b z y c', 'mean')
triplane_out = torch.cat([xz_feat, xy_feat, yz_feat], dim=1) #B,reso*3,reso,C
#print(triplane_out.shape)
triplane_out = self.proj_out(triplane_out)
triplane_out = triplane_out.permute(0,3,1,2)
#print(triplane_out.shape)
return triplane_out
class MultiImage_TriFuse_Sampler(nn.Module):
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
super().__init__()
self.triplane_reso=reso
self.vit_reso=vit_reso
self.inner_channel=inner_channel
self.padding=padding
self.n_heads=n_heads
self.get_triplane_coord()
self.get_vit_coords()
self.out_channels=out_channels
self.kernel_func=mask_kernel
self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.cross_attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 2*3).unsqueeze(0).cuda().float()
def get_vit_coords(self):
x = torch.arange(self.vit_reso)
y = torch.arange(self.vit_reso)
X, Y = torch.meshgrid(x, y, indexing='xy')
vit_coords = torch.stack([X, Y], dim=-1)
self.vit_coords = vit_coords.cuda().float() #reso,reso,2
def get_triplane_coord(self):
'''xz plane firstly, z is at the '''
x = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
X, Z = torch.meshgrid(x, z, indexing='xy')
xz_coords = torch.cat(
[X[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2, Z[:, :, None]],
dim=-1) # in xyz order
'''xy plane'''
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
X, Y = torch.meshgrid(x, y, indexing='xy')
xy_coords = torch.cat(
[X[:, :, None], Y[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2],
dim=-1) # in xyz order
'''yz plane'''
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
Y, Z = torch.meshgrid(y, z, indexing='xy')
yz_coords = torch.cat(
[torch.ones_like(Y[:, :, None]) * (self.triplane_reso - 1) / 2, Y[:, :, None], Z[:, :, None]], dim=-1)
triplane_coords = torch.cat([xz_coords, xy_coords, yz_coords], dim=0)
triplane_coords = triplane_coords / (self.triplane_reso - 1)
triplane_coords = (triplane_coords - 0.5) * 2 * (1 + self.padding + 10e-6)
self.triplane_coords = triplane_coords.view(-1,3).float().cuda()
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
'''image feat is bs,n_img,length,channel'''
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
image_feat=image_feat[:,:,1:,:] #discard global feature
#print(image_feat.shape)
#image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**2*3,1).float().cuda()],dim=-1)
coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
x = x/(224.0-1) #0~1
y = y/(224.0-1)
coords_proj=torch.stack([x,y],dim=-1)
coords_proj=(coords_proj-0.5)*2
img_index=((torch.arange(n_img)[None,:,None,None].expand(
batch_size,-1,self.triplane_reso**2*3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
grid=torch.clamp(grid,min=-1.0,max=1.0)
sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
triplane_flat_feat=rearrange(triplane_feat,'b c h w -> b (h w) c')
triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
k=k_v[:,:,0:self.inner_channel]
v=k_v[:,:,self.inner_channel:]
q=rearrange(triplane_query,'b k c -> (b k) 1 c')
attn_out,_=self.cross_attn(q,k,v)
triplane_out=rearrange(attn_out,'(b h w) 1 c -> b c h w',b=batch_size,h=self.triplane_reso*3,w=self.triplane_reso)
triplane_out = self.proj_out(triplane_out)
return triplane_out
class MultiImage_Global_Sampler(nn.Module):
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
super().__init__()
self.triplane_reso=reso
self.vit_reso=vit_reso
self.inner_channel=inner_channel
self.padding=padding
self.n_heads=n_heads
self.out_channels=out_channels
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
self.cross_attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso**2*3).unsqueeze(0).cuda().float()
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
'''image feat is bs,n_img,length,channel
triplane feat is bs,C,H*3,W
'''
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
L=image_feat.shape[2]-1
image_feat=image_feat[:,:,1:,:] #discard global feature
image_k=self.k(image_feat)+self.image_pe #B,n_img,h*w,c
image_v=self.v(image_feat)+self.image_pe #B,n_img,h*w,c
image_k=image_k.view(batch_size,n_img*L,-1)
image_v=image_v.view(batch_size,n_img*L,-1)
triplane_flat_feat=rearrange(triplane_feat,"b c h w -> b (h w) c")
triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
#print(triplane_query.shape,image_k.shape,image_v.shape)
attn_out,_=self.cross_attn(triplane_query,image_k,image_v)
triplane_flat_out = self.proj_out(attn_out)
triplane_out=rearrange(triplane_flat_out,"b (h w) c -> b c h w",h=self.triplane_reso*3,w=self.triplane_reso)
return triplane_out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
if context_dim is None:
context_dim = query_dim
self.scale = dim_head ** -0.5
self.heads = heads
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, q,k,v):
h = self.heads
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class Image_Vox_Local_Sampler_Pooling(nn.Module):
def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,stride=4):
super().__init__()
self.triplane_reso=reso
self.padding=padding
self.get_vox_coord()
self.out_channels=out_channels
self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
self.vox_process=nn.Sequential(
nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1)
)
self.xz_conv=nn.Sequential(
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
nn.AvgPool3d((1,stride,1),stride=(1,stride,1)), #8
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
nn.AvgPool3d((1,stride,1), stride=(1,stride,1)), #2
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
)
self.xy_conv = nn.Sequential(
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 8
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 2
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
)
self.yz_conv = nn.Sequential(
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 8
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 2
nn.BatchNorm3d(inner_channel),
nn.ReLU(),
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
)
self.roll_out_conv=RollOut_Conv(in_channels=inner_channel,out_channels=out_channels)
#self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
def get_vox_coord(self):
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
vox_coor=vox_coor/(self.triplane_reso-1)
vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
self.vox_coor=vox_coor.view(-1,3).float().cuda()
def forward(self,image_feat,proj_mat):
image_feat=self.img_proj(image_feat)
batch_size=image_feat.shape[0]
vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
x=(x/(224.0-1.0)-0.5)*2 #-1~1
y=(y/(224.0-1.0)-0.5)*2 #-1~1
xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
#print(image_feat.shape,xy.shape)
grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
grid_feat=self.vox_process(grid_feat)
xz_feat=torch.mean(self.xz_conv(grid_feat),dim=3).transpose(2,3)
xy_feat=torch.mean(self.xy_conv(grid_feat),dim=4).transpose(2,3)
yz_feat=torch.mean(self.yz_conv(grid_feat),dim=2).transpose(2,3)
triplane_wImg=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
#print(triplane_wImg.shape)
return self.roll_out_conv(triplane_wImg)
class Image_ExpandVox_attn_Sampler(nn.Module):
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
super().__init__()
self.triplane_reso=reso
self.padding=padding
self.vit_reso=vit_reso
self.get_vox_coord()
self.get_vit_coords()
self.out_channels=out_channels
self.n_heads=n_heads
self.kernel_func = mask_kernel_close_false
self.k = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
# self.q_xz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
# self.q_xy = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
# self.q_yz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
self.v = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
self.attn = torch.nn.MultiheadAttention(
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
self.out_proj=nn.Linear(in_features=inner_channel,out_features=out_channels)
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
self.image_pe = position_encoding(inner_channel, self.vit_reso ** 2+1).unsqueeze(0).cuda().float()
def get_vox_coord(self):
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
self.vox_index=vox_coor.view(-1,3).long().cuda()
vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
self.vox_coor = vox_coor.view(-1, 3).float().cuda()
# print(self.vox_coor[0])
# print(self.vox_coor[self.triplane_reso**2])#x should increase
# print(self.vox_coor[self.triplane_reso]) #y should increase
# print(self.vox_coor[1])#z should increase
def get_vit_coords(self):
x=torch.arange(self.vit_reso)
y=torch.arange(self.vit_reso)
X,Y=torch.meshgrid(x,y,indexing='xy')
vit_coords=torch.stack([X,Y],dim=-1)
self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
def compute_attn_mask(self,proj_coords,vit_coords,kernel_size=1.0):
dist = torch.cdist(proj_coords.float(), vit_coords.float())
mask = self.kernel_func(dist, sigma=kernel_size) # True if valid, B,reso**3,vit_reso**2
return mask
def forward(self,triplane_feat,image_feat,proj_mat):
xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3, dim=2) # B,C,64,64
#xz_feat=self.q_xz(xz_feat)
#xy_feat=self.q_xy(xy_feat)
#yz_feat=self.q_yz(yz_feat)
batch_size=image_feat.shape[0]
vox_index=self.vox_index #64*64*64,3
xz_vox_feat=xz_feat[:,:,vox_index[:,2],vox_index[:,0]].transpose(1,2) #B,C,64*64*64
xy_vox_feat=xy_feat[:,:,vox_index[:,1],vox_index[:,0]].transpose(1,2)
yz_vox_feat=yz_feat[:,:,vox_index[:,2],vox_index[:,1]].transpose(1,2)
triplane_expand_feat=torch.cat([xz_vox_feat,xy_vox_feat,yz_vox_feat],dim=-1)#B,C,64*64*64,3
triplane_query=self.q(triplane_expand_feat)+self.triplane_pe
'''compute projection'''
vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #
vox_homo = torch.cat([vox_coords, torch.ones((batch_size, self.triplane_reso ** 3, 1)).float().cuda()], dim=-1)
coord_inimg = torch.einsum('bhc,bck->bhk', vox_homo, proj_mat.transpose(1, 2))
x = coord_inimg[:, :, 0] / coord_inimg[:, :, 2]
y = coord_inimg[:, :, 1] / coord_inimg[:, :, 2]
#
x = x / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1
y = y / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1 #B,N
xy=torch.stack([x,y],dim=-1) #B,64*64*64,2
xy=torch.clamp(xy,min=0,max=self.vit_reso-1)
vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1) #B, 16*16,2
attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,vit_coords,kernel_size=0.5),
self.n_heads,0) #B*n_heads, reso**3, vit_reso**2
k=self.k(image_feat)+self.image_pe
v=self.v(image_feat)+self.image_pe
attn_mask=torch.cat([torch.zeros([attn_mask.shape[0],attn_mask.shape[1],1]).cuda().bool(),attn_mask],dim=-1) #add empty token to each key and value
vox_feat,_=self.attn(triplane_query,k,v,attn_mask=attn_mask) #B,reso**3,C
feat_volume=self.out_proj(vox_feat).transpose(1,2).reshape(batch_size,-1,self.triplane_reso,
self.triplane_reso,self.triplane_reso)
xz_feat=torch.mean(feat_volume,dim=3).transpose(2,3)
xy_feat=torch.mean(feat_volume,dim=4).transpose(2,3)
yz_feat=torch.mean(feat_volume,dim=2).transpose(2,3)
triplane_out=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
return triplane_out
class Multi_Image_Fusion(nn.Module):
def __init__(self,reso,image_reso=16,padding=0.1,img_channels=1280,triplane_channel=64,inner_channels=128,output_channel=64,n_heads=8):
super().__init__()
self.triplane_reso=reso
self.image_reso=image_reso
self.padding=padding
self.get_triplane_coord()
self.get_vit_coords()
self.img_proj=nn.Conv3d(in_channels=img_channels,out_channels=512,kernel_size=1)
self.kernel_func=mask_kernel
self.q = nn.Linear(in_features=triplane_channel, out_features=inner_channels, bias=False)
self.k = nn.Linear(in_features=512, out_features=inner_channels)
self.v = nn.Linear(in_features=512, out_features=inner_channels)
self.attn = torch.nn.MultiheadAttention(
embed_dim=inner_channels, num_heads=n_heads, batch_first=True)
self.out_proj=nn.Linear(in_features=inner_channels,out_features=output_channel)
self.n_heads=n_heads
def get_triplane_coord(self):
'''xz plane firstly, z is at the '''
x=torch.arange(self.triplane_reso)
z=torch.arange(self.triplane_reso)
X,Z=torch.meshgrid(x,z,indexing='xy')
xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
'''xy plane'''
x = torch.arange(self.triplane_reso)
y = torch.arange(self.triplane_reso)
X, Y = torch.meshgrid(x, y, indexing='xy')
xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
'''yz plane'''
y = torch.arange(self.triplane_reso)
z = torch.arange(self.triplane_reso)
Y,Z = torch.meshgrid(y,z,indexing='xy')
yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
triplane_coords=triplane_coords/(self.triplane_reso-1)
triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
self.triplane_coords=triplane_coords.float().cuda()
def get_vit_coords(self):
x=torch.arange(self.image_reso)
y=torch.arange(self.image_reso)
X,Y=torch.meshgrid(x,y,indexing='xy')
vit_coords=torch.cat([X[:,:,None],Y[:,:,None]],dim=-1)
self.vit_coords=vit_coords.float().cuda() #in x,y order
def compute_attn_mask(self,proj_coord,vit_coords,valid_frames,kernel_size=2.0):
'''
:param proj_coord: B,K,H,W,2
:param vit_coords: H,W,2
:return:
'''
B,K=proj_coord.shape[0:2]
vit_coords_expand=vit_coords[None,None,:,:,:].expand(B,K,-1,-1,-1)
proj_coord=proj_coord.view(B*K,proj_coord.shape[2]*proj_coord.shape[3],proj_coord.shape[4])
vit_coords_expand=vit_coords_expand.view(B*K,self.image_reso*self.image_reso,2)
attn_mask=self.kernel_func(torch.cdist(proj_coord,vit_coords_expand),sigma=float(kernel_size))
attn_mask=attn_mask.reshape(B,K,proj_coord.shape[1],vit_coords_expand.shape[1])
valid_expand=valid_frames[:,:,None,None]
attn_mask[valid_frames>0,:,:]=True
attn_mask=attn_mask.permute(0,2,1,3)
attn_mask=attn_mask.reshape(B,proj_coord.shape[1],K*vit_coords_expand.shape[1])
atten_index=torch.where(attn_mask[0,0]==False)
return attn_mask
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
'''
:param image_feat: B,C,K,16,16
:param proj_mat: B,K,4,4
:param valid_frames: B,K, true if have image, used to compute attn_mask for transformer
:return:
'''
image_feat=self.img_proj(image_feat)
batch_size=image_feat.shape[0] #K is number of frames
K=image_feat.shape[2]
triplane_coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,K,-1,-1,-1) #B,K,192,64,3
#print(torch.amin(triplane_coords),torch.amax(triplane_coords))
coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,K,triplane_coords.shape[2],triplane_coords.shape[3],1)).float().cuda()],dim=-1)
#print(coord_homo.shape,proj_mat.shape)
coord_inimg=torch.einsum('bjhwc,bjck->bjhwk',coord_homo,proj_mat.transpose(2,3))
x=coord_inimg[:,:,:,:,0]/coord_inimg[:,:,:,:,2]
y=coord_inimg[:,:,:,:,1]/coord_inimg[:,:,:,:,2]
x=x/(224.0-1.0)*(self.image_reso-1)
y=y/(224.0-1.0)*(self.image_reso-1)
xy=torch.cat([x[...,None],y[...,None]],dim=-1) #B,K,H,W,2
image_value=image_feat.view(image_feat.shape[0],image_feat.shape[1],-1).transpose(1,2)
triplane_query=triplane_feat.view(triplane_feat.shape[0],triplane_feat.shape[1],-1).transpose(1,2)
valid_frames=1.0-valid_frames.float()
attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,self.vit_coords,valid_frames),
self.n_heads,dim=0)
q=self.q(triplane_query)
k=self.k(image_value)
v=self.v(image_value)
#print(q.shape,k.shape,v.shape)
attn,_=self.attn(q,k,v,attn_mask=attn_mask)
#print(attn.shape)
output=self.out_proj(attn).transpose(1,2).reshape(batch_size,-1,triplane_feat.shape[2],triplane_feat.shape[3])
#print(output.shape)
return output
if __name__=="__main__":
# import sys
# sys.path.append("../..")
# from datasets.SingleView_dataset import Object_PartialPoints_Img
# from datasets.transforms import Aug_with_Tran
# #sampler=#Image_Vox_Local_Sampler_Pooling(reso=64,padding=0.1,out_channels=64,stride=4).cuda().float()
# sampler=Image_ExpandVox_attn_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=64,inner_channel=64
# ,out_channels=64,n_heads=8).cuda().float()
# # sampler=Image_Direct_AttenwMask_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
# # ,out_channels=64,n_heads=8).cuda().float()
# dataset_config = {
# "data_path": "/data1/haolin/datasets",
# "surface_size": 20000,
# "par_pc_size": 4096,
# "load_proj_mat": True,
# }
# transform = Aug_with_Tran()
# datasets = Object_PartialPoints_Img(dataset_config['data_path'], split_filename="val_par_img.json", split='val',
# transform=transform, sampling=False,
# num_samples=1024, return_surface=True, ret_sample=True,
# surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
# surface_size=dataset_config['surface_size'],
# load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
# load_org_img=False, load_triplane=True, replica=1)
#
# dataloader = torch.utils.data.DataLoader(
# datasets=datasets,
# batch_size=64,
# shuffle=True
# )
# iterator = dataloader.__iter__()
# data_batch = iterator.next()
# unflatten = torch.nn.Unflatten(1, (16, 16))
# image = data_batch['image'][:,:,:].cuda().float()
# #image=unflatten(image).permute(0,3,1,2)
# proj_mat = data_batch['proj_mat'].cuda().float()
# triplane_feat=torch.randn((64,64,32*3,32)).cuda().float()
# sampler(triplane_feat,image,proj_mat)
# memory_usage=torch.cuda.max_memory_allocated() / MB
# print("memory usage %f mb"%(memory_usage))
import sys
sys.path.append("../..")
from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
from datasets.transforms import Aug_with_Tran
dataset_config = {
"data_path": "/data1/haolin/datasets",
"surface_size": 20000,
"par_pc_size": 4096,
"load_proj_mat": True,
}
transform = Aug_with_Tran()
dataset = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename="train_par_img.json", split='train',
transform=transform, sampling=False,
num_samples=1024, return_surface=True, ret_sample=True,
surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
surface_size=dataset_config['surface_size'],
load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
load_org_img=False, load_triplane=True, replica=1)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=10,
shuffle=False
)
iterator = dataloader.__iter__()
data_batch = iterator.next()
#unflatten = torch.nn.Unflatten(2, (16, 16))
image = data_batch['image'][:,:,:,:].cuda().float()
#image=unflatten(image).permute(0,4,1,2,3)
proj_mat = data_batch['proj_mat'].cuda().float()
valid_frames = data_batch['valid_frames'].cuda().float()
triplane_feat=torch.randn((10,128,32*3,32)).cuda().float()
# fusion_module=MultiImage_Fuse_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
# ,out_channels=64,n_heads=8).cuda().float()
fusion_module=MultiImage_Global_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
,out_channels=64,n_heads=8).cuda().float()
fusion_module(triplane_feat,image,proj_mat,valid_frames)
memory_usage=torch.cuda.max_memory_allocated() / MB
print("memory usage %f mb"%(memory_usage))