Spaces:
Running
on
L40S
Running
on
L40S
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
from torch.autograd import grad | |
# from fightingcv_attention.attention.SelfAttention import ScaledDotProductAttention | |
import numpy as np | |
class SDF2Density(pl.LightningModule): | |
def __init__(self): | |
super(SDF2Density, self).__init__() | |
# learnable parameters beta, with initial value 0.1 | |
self.beta = nn.Parameter(torch.tensor(0.1)) | |
def forward(self, sdf): | |
# use Laplace CDF to compute the probability | |
# temporally use sigmoid to represent laplace CDF | |
return 1.0/(self.beta+1e-6)*F.sigmoid(-sdf/(self.beta+1e-6)) | |
class SDF2Occ(pl.LightningModule): | |
def __init__(self): | |
super(SDF2Occ, self).__init__() | |
# learnable parameters beta, with initial value 0.1 | |
self.beta = nn.Parameter(torch.tensor(0.1)) | |
def forward(self, sdf): | |
# use Laplace CDF to compute the probability | |
# temporally use sigmoid to represent laplace CDF | |
return F.sigmoid(-sdf/(self.beta+1e-6)) | |
class DeformationMLP(pl.LightningModule): | |
def __init__(self,input_dim=64,output_dim=3,activation='LeakyReLU',name=None,opt=None): | |
super(DeformationMLP, self).__init__() | |
self.name = name | |
self.activation = activation | |
self.activate = nn.LeakyReLU(inplace=True) | |
# self.mlp = nn.Sequential( | |
# nn.Conv1d(input_dim+8+1+3, 64, 1), | |
# nn.LeakyReLU(inplace=True), | |
# nn.Conv1d(64, output_dim, 1), | |
# ) | |
channels=[input_dim+8+1+3,128, 64, output_dim] | |
self.deform_mlp=MLP(filter_channels=channels, | |
name="if", | |
res_layers=opt.res_layers, | |
norm=opt.norm_mlp, | |
last_op=None) # occupancy | |
smplx_dim = 10475 | |
k=8 | |
self.per_pt_code = nn.Embedding(smplx_dim,k) | |
def forward(self, feature,smpl_vis,pts_id, xyz): | |
''' | |
feature may include multiple view inputs | |
args: | |
feature: [B, C_in, N] | |
return: | |
[B, C_out, N] prediction | |
''' | |
y = feature | |
e_code=self.per_pt_code(pts_id).permute(0,2,1) # a code that distinguishes each point on different parts of the body | |
y=torch.cat([y,xyz,smpl_vis,e_code],1) | |
y = self.deform_mlp(y) | |
return y | |
class MLP(pl.LightningModule): | |
def __init__(self, | |
filter_channels, | |
name=None, | |
res_layers=[], | |
norm='group', | |
last_op=None): | |
super(MLP, self).__init__() | |
self.filters = nn.ModuleList() | |
self.norms = nn.ModuleList() | |
self.res_layers = res_layers | |
self.norm = norm | |
self.last_op = last_op | |
self.name = name | |
self.activate = nn.LeakyReLU(inplace=True) | |
for l in range(0, len(filter_channels) - 1): | |
if l in self.res_layers: | |
self.filters.append( | |
nn.Conv1d(filter_channels[l] + filter_channels[0], | |
filter_channels[l + 1], 1)) | |
else: | |
self.filters.append( | |
nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) | |
if l != len(filter_channels) - 2: | |
if norm == 'group': | |
self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) | |
elif norm == 'batch': | |
self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) | |
elif norm == 'instance': | |
self.norms.append(nn.InstanceNorm1d(filter_channels[l + | |
1])) | |
elif norm == 'weight': | |
self.filters[l] = nn.utils.weight_norm(self.filters[l], | |
name='weight') | |
# print(self.filters[l].weight_g.size(), | |
# self.filters[l].weight_v.size()) | |
def forward(self, feature): | |
''' | |
feature may include multiple view inputs | |
args: | |
feature: [B, C_in, N] | |
return: | |
[B, C_out, N] prediction | |
''' | |
y = feature | |
tmpy = feature | |
for i, f in enumerate(self.filters): | |
y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) | |
if i != len(self.filters) - 1: | |
if self.norm not in ['batch', 'group', 'instance']: | |
y = self.activate(y) | |
else: | |
y = self.activate(self.norms[i](y)) | |
if self.last_op is not None: | |
y = self.last_op(y) | |
return y | |
# Positional encoding (section 5.1) | |
class Embedder(pl.LightningModule): | |
def __init__(self, **kwargs): | |
self.kwargs = kwargs | |
self.create_embedding_fn() | |
def create_embedding_fn(self): | |
embed_fns = [] | |
d = self.kwargs['input_dims'] | |
out_dim = 0 | |
if self.kwargs['include_input']: | |
embed_fns.append(lambda x : x) | |
out_dim += d | |
max_freq = self.kwargs['max_freq_log2'] | |
N_freqs = self.kwargs['num_freqs'] | |
if self.kwargs['log_sampling']: | |
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) | |
else: | |
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) | |
for freq in freq_bands: | |
for p_fn in self.kwargs['periodic_fns']: | |
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) | |
out_dim += d | |
self.embed_fns = embed_fns | |
self.out_dim = out_dim | |
def embed(self, inputs): | |
return torch.cat([fn(inputs) for fn in self.embed_fns], -1) | |
def get_embedder(multires=6, i=0): | |
if i == -1: | |
return nn.Identity(), 3 | |
embed_kwargs = { | |
'include_input' : True, | |
'input_dims' : 3, | |
'max_freq_log2' : multires-1, | |
'num_freqs' : multires, | |
'log_sampling' : True, | |
'periodic_fns' : [torch.sin, torch.cos], | |
} | |
embedder_obj = Embedder(**embed_kwargs) | |
embed = lambda x, eo=embedder_obj : eo.embed(x) | |
return embed, embedder_obj.out_dim | |
# Transformer encoder layer | |
# uses Embedder to add positional encoding to input points | |
# uses query points as query, deformed points as key, point features as value for attention | |
class TransformerEncoderLayer(pl.LightningModule): | |
def __init__(self, d_model=256, skips=4, multires=6, num_mlp_layers=8, dropout=0.1, opt=None): | |
super(TransformerEncoderLayer, self).__init__() | |
embed_fn, input_ch = get_embedder(multires=multires) | |
self.skips=skips | |
self.dropout = dropout | |
D=num_mlp_layers | |
self.positional_encoding = embed_fn | |
self.d_model = d_model | |
triplane_dim=64 | |
opt.mlp_dim[0]=triplane_dim+6+8 | |
opt.mlp_dim_color[0]=triplane_dim+6+8 | |
self.geo_mlp=MLP(filter_channels=opt.mlp_dim, | |
name="if", | |
res_layers=opt.res_layers, | |
norm=opt.norm_mlp, | |
last_op=nn.Sigmoid()) # occupancy | |
self.color_mlp=MLP(filter_channels=opt.mlp_dim_color, | |
name="color_if", | |
res_layers=opt.res_layers, | |
norm=opt.norm_mlp, | |
last_op=nn.Tanh()) # color | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self,query_points,key_points,point_features,smpl_feat,training=True,type='shape'): | |
# Q=self.positional_encoding(query_points) #[B,N,39] | |
# K=self.positional_encoding(key_points) #[B,N',39] | |
# V=point_features.permute(0,2,1) #[B,N',192] | |
# t=0.1 | |
# #attn_output, attn_output_weights = self.attention(Q.permute(1,0,2), K.permute(1,0,2), V.permute(1,0,2)) #[B,N,192] | |
# attn_output_weights = torch.bmm(Q, K.transpose(1, 2)) #[B,N,N'] | |
# attn_output_weights = self.softmax(attn_output_weights/t) #[B,N,N'] | |
# # drop out | |
# attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True) | |
# # master feature | |
# attn_output = torch.bmm(attn_output_weights, V) #[B,N,192] | |
attn_output=point_features # [B,N,192] bary centric interpolation | |
feature=torch.cat([attn_output,smpl_feat],dim=1) | |
if type=='shape': | |
h=feature | |
h=self.geo_mlp(h) # [B,1,N] | |
return h | |
elif type=='color': | |
#f=self.head(feature) #[B,N,512] | |
h=feature | |
h=self.color_mlp(h) # [B,3,N] | |
return h | |
elif type=='shape_color': | |
h_s=feature | |
h_c=feature | |
h_s=self.geo_mlp(h_s) # [B,1,N] | |
h_c=self.color_mlp(h_c) # [B,3,N] | |
return h_s,h_c | |
class Swish(pl.LightningModule): | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, x): | |
x = x * F.sigmoid(x) | |
return x | |
# # Import pytorch modules | |
# import torch | |
# import torch.nn as nn | |
# import torch.nn.functional as F | |
# Define positional encoding class | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=1000): | |
super(PositionalEncoding, self).__init__() | |
# Compute the positional encodings once in log space. | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2) * | |
-(math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + self.pe[:, :x.size(1)] | |
return x | |
# # Define model parameters | |
# d_model = 256 # output size of MLP | |
# nhead = 8 # number of attention heads | |
# dim_feedforward = 512 # hidden size of MLP | |
# num_layers = 2 # number of MLP layers | |
# num_frequencies = 6 # number of frequencies for positional encoding | |
# dropout = 0.1 # dropout rate | |
# # Define model components | |
# pos_encoder = PositionalEncoding(d_model, num_frequencies) # positional encoding layer | |
# encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) # transformer encoder layer | |
# encoder = nn.TransformerEncoder(encoder_layer, num_layers) # transformer encoder | |
# mlp_geo = nn.Sequential(nn.Linear(3, d_model), nn.ReLU(), nn.Linear(d_model, d_model)) # MLP for geometry | |
# mlp_alb = nn.Sequential(nn.Linear(3, d_model), nn.ReLU(), nn.Linear(d_model, d_model)) # MLP for albedo | |
# head_geo = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 3)) # geometry head | |
# head_alb = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 3), nn.Sigmoid()) # albedo head | |
# # Define input tensors | |
# # deformed body points: (batch_size, num_points, 3) | |
# x = torch.randn(batch_size, num_points, 3) | |
# # query point positions: (batch_size, num_queries, 3) | |
# y = torch.randn(batch_size, num_queries, 3) | |
# # Map both d | |