LASA / models /TriplaneVAE.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
4.17 kB
import torch.nn as nn
import sys,os
sys.path.append("..")
import torch
from datasets import build_dataset
from configs.config_utils import CONFIG
from torch.utils.data import DataLoader
from models.modules import PointEmbed
from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder
import numpy as np
class TriplaneVAE(nn.Module):
def __init__(self,opt):
super().__init__()
self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim'])
encoder_args=opt['encoder']
decoder_args=opt['decoder']
self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'],
plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding'])
self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'],
hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'],
plane_resolution=decoder_args['plane_reso'],padding=opt['padding'])
def forward(self,p,query):
'''
:param p: surface points cloud of shape B,N,3
:param query: sample points of shape B,N,3
:return:
'''
point_emb=self.point_embedder(p)
query_emb=self.point_embedder(query)
kl,plane_feat,means,logvars=self.encoder(p,point_emb)
if self.training:
if np.random.random()<0.5:
'''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness'''
plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear")
plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear")
# if self.training:
# if np.random.random()<0.5:
# means = torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
# vars=torch.exp(logvars)
# vars = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")
# new_logvars=torch.log(vars)
# posterior = DiagonalGaussianDistribution(means, new_logvars)
# plane_feat=posterior.sample()
# plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode='bilinear')
# mean_scale=torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
# vars = torch.exp(logvars)
# vars_scale = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")/4
# logvars_scale=torch.log(vars_scale)
# scale_noise=torch.randn(mean_scale.shape).to(mean_scale.device)
# plane_feat_scale2=mean_scale+torch.exp(0.5*logvars_scale)*scale_noise
# plane_feat=torch.nn.functional.interpolate(plane_feat_scale2,scale_factor=2,mode='bilinear')
o=self.decoder(plane_feat,query,query_emb)
return {'logits':o,'kl':kl}
def decode(self,plane_feature,query):
query_embedding=self.point_embedder(query)
o=self.decoder(plane_feature,query,query_embedding)
return o
def encode(self,p):
point_emb = self.point_embedder(p)
kl, plane_feat,mean,logvar = self.encoder(p, point_emb)
'''p is point cloud of B,N,3'''
return plane_feat,kl,mean,logvar
if __name__=="__main__":
configs=CONFIG("../configs/train_triplane_vae_64.yaml")
config=configs.config
dataset_config=config['datasets']
model_config=config["model"]
dataset=build_dataset("train",dataset_config)
dataset.__getitem__(0)
dataloader=DataLoader(
dataset=dataset,
batch_size=10,
shuffle=True,
num_workers=2,
)
net=TriplaneVAE(model_config).float().cuda()
for idx,data_batch in enumerate(dataloader):
if idx==1:
break
surface=data_batch['surface'].float().cuda()
query=data_batch['points'].float().cuda()
net(surface,query)