NeuralBody / lib /networks /latent_xyzc.py
pengsida
initial commit
1ba539f
import torch.nn as nn
import spconv
import torch.nn.functional as F
import torch
from lib.config import cfg
from . import embedder
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
self.c = nn.Embedding(6890, 16)
self.xyzc_net = SparseConvNet()
self.latent = nn.Embedding(cfg.num_train_frame, 128)
self.actvn = nn.ReLU()
self.fc_0 = nn.Conv1d(352, 256, 1)
self.fc_1 = nn.Conv1d(256, 256, 1)
self.fc_2 = nn.Conv1d(256, 256, 1)
self.alpha_fc = nn.Conv1d(256, 1, 1)
self.feature_fc = nn.Conv1d(256, 256, 1)
self.latent_fc = nn.Conv1d(384, 256, 1)
self.view_fc = nn.Conv1d(346, 128, 1)
self.rgb_fc = nn.Conv1d(128, 3, 1)
def encode_sparse_voxels(self, sp_input):
coord = sp_input['coord']
out_sh = sp_input['out_sh']
batch_size = sp_input['batch_size']
code = self.c(torch.arange(0, 6890).to(coord.device))
xyzc = spconv.SparseConvTensor(code, coord, out_sh, batch_size)
feature_volume = self.xyzc_net(xyzc)
return feature_volume
def pts_to_can_pts(self, pts, sp_input):
"""transform pts from the world coordinate to the smpl coordinate"""
Th = sp_input['Th']
pts = pts - Th
R = sp_input['R']
pts = torch.matmul(pts, R)
return pts
def get_grid_coords(self, pts, sp_input):
# convert xyz to the voxel coordinate dhw
dhw = pts[..., [2, 1, 0]]
min_dhw = sp_input['bounds'][:, 0, [2, 1, 0]]
dhw = dhw - min_dhw[:, None]
dhw = dhw / torch.tensor(cfg.voxel_size).to(dhw)
# convert the voxel coordinate to [-1, 1]
out_sh = torch.tensor(sp_input['out_sh']).to(dhw)
dhw = dhw / out_sh * 2 - 1
# convert dhw to whd, since the occupancy is indexed by dhw
grid_coords = dhw[..., [2, 1, 0]]
return grid_coords
def interpolate_features(self, grid_coords, feature_volume):
features = []
for volume in feature_volume:
feature = F.grid_sample(volume,
grid_coords,
padding_mode='zeros',
align_corners=True)
features.append(feature)
features = torch.cat(features, dim=1)
features = features.view(features.size(0), -1, features.size(4))
return features
def calculate_density(self, wpts, feature_volume, sp_input):
# interpolate features
ppts = self.pts_to_can_pts(wpts, sp_input)
grid_coords = self.get_grid_coords(ppts, sp_input)
grid_coords = grid_coords[:, None, None]
xyzc_features = self.interpolate_features(grid_coords, feature_volume)
# calculate density
net = self.actvn(self.fc_0(xyzc_features))
net = self.actvn(self.fc_1(net))
net = self.actvn(self.fc_2(net))
alpha = self.alpha_fc(net)
alpha = alpha.transpose(1, 2)
return alpha
def calculate_density_color(self, wpts, viewdir, feature_volume, sp_input):
# interpolate features
ppts = self.pts_to_can_pts(wpts, sp_input)
grid_coords = self.get_grid_coords(ppts, sp_input)
grid_coords = grid_coords[:, None, None]
xyzc_features = self.interpolate_features(grid_coords, feature_volume)
# calculate density
net = self.actvn(self.fc_0(xyzc_features))
net = self.actvn(self.fc_1(net))
net = self.actvn(self.fc_2(net))
alpha = self.alpha_fc(net)
# calculate color
features = self.feature_fc(net)
latent = self.latent(sp_input['latent_index'])
latent = latent[..., None].expand(*latent.shape, net.size(2))
features = torch.cat((features, latent), dim=1)
features = self.latent_fc(features)
viewdir = embedder.view_embedder(viewdir)
viewdir = viewdir.transpose(1, 2)
light_pts = embedder.xyz_embedder(wpts)
light_pts = light_pts.transpose(1, 2)
features = torch.cat((features, viewdir, light_pts), dim=1)
net = self.actvn(self.view_fc(features))
rgb = self.rgb_fc(net)
raw = torch.cat((rgb, alpha), dim=1)
raw = raw.transpose(1, 2)
return raw
def forward(self, sp_input, grid_coords, viewdir, light_pts):
coord = sp_input['coord']
out_sh = sp_input['out_sh']
batch_size = sp_input['batch_size']
p_features = grid_coords.transpose(1, 2)
grid_coords = grid_coords[:, None, None]
code = self.c(torch.arange(0, 6890).to(p_features.device))
xyzc = spconv.SparseConvTensor(code, coord, out_sh, batch_size)
xyzc_features = self.xyzc_net(xyzc, grid_coords)
net = self.actvn(self.fc_0(xyzc_features))
net = self.actvn(self.fc_1(net))
net = self.actvn(self.fc_2(net))
alpha = self.alpha_fc(net)
features = self.feature_fc(net)
latent = self.latent(sp_input['latent_index'])
latent = latent[..., None].expand(*latent.shape, net.size(2))
features = torch.cat((features, latent), dim=1)
features = self.latent_fc(features)
viewdir = viewdir.transpose(1, 2)
light_pts = light_pts.transpose(1, 2)
features = torch.cat((features, viewdir, light_pts), dim=1)
net = self.actvn(self.view_fc(features))
rgb = self.rgb_fc(net)
raw = torch.cat((rgb, alpha), dim=1)
raw = raw.transpose(1, 2)
return raw
class SparseConvNet(nn.Module):
def __init__(self):
super(SparseConvNet, self).__init__()
self.conv0 = double_conv(16, 16, 'subm0')
self.down0 = stride_conv(16, 32, 'down0')
self.conv1 = double_conv(32, 32, 'subm1')
self.down1 = stride_conv(32, 64, 'down1')
self.conv2 = triple_conv(64, 64, 'subm2')
self.down2 = stride_conv(64, 128, 'down2')
self.conv3 = triple_conv(128, 128, 'subm3')
self.down3 = stride_conv(128, 128, 'down3')
self.conv4 = triple_conv(128, 128, 'subm4')
def forward(self, x):
net = self.conv0(x)
net = self.down0(net)
net = self.conv1(net)
net1 = net.dense()
net = self.down1(net)
net = self.conv2(net)
net2 = net.dense()
net = self.down2(net)
net = self.conv3(net)
net3 = net.dense()
net = self.down3(net)
net = self.conv4(net)
net4 = net.dense()
volumes = [net1, net2, net3, net4]
return volumes
def single_conv(in_channels, out_channels, indice_key=None):
return spconv.SparseSequential(
spconv.SubMConv3d(in_channels,
out_channels,
1,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
def double_conv(in_channels, out_channels, indice_key=None):
return spconv.SparseSequential(
spconv.SubMConv3d(in_channels,
out_channels,
3,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
spconv.SubMConv3d(out_channels,
out_channels,
3,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
def triple_conv(in_channels, out_channels, indice_key=None):
return spconv.SparseSequential(
spconv.SubMConv3d(in_channels,
out_channels,
3,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
spconv.SubMConv3d(out_channels,
out_channels,
3,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
spconv.SubMConv3d(out_channels,
out_channels,
3,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01),
nn.ReLU(),
)
def stride_conv(in_channels, out_channels, indice_key=None):
return spconv.SparseSequential(
spconv.SparseConv3d(in_channels,
out_channels,
3,
2,
padding=1,
bias=False,
indice_key=indice_key),
nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), nn.ReLU())