# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import torch import torch.nn as nn import pytorch_lightning as pl import torch.nn.functional as F from torch.autograd import grad # from fightingcv_attention.attention.SelfAttention import ScaledDotProductAttention import numpy as np class SDF2Density(pl.LightningModule): def __init__(self): super(SDF2Density, self).__init__() # learnable parameters beta, with initial value 0.1 self.beta = nn.Parameter(torch.tensor(0.1)) def forward(self, sdf): # use Laplace CDF to compute the probability # temporally use sigmoid to represent laplace CDF return 1.0/(self.beta+1e-6)*F.sigmoid(-sdf/(self.beta+1e-6)) class SDF2Occ(pl.LightningModule): def __init__(self): super(SDF2Occ, self).__init__() # learnable parameters beta, with initial value 0.1 self.beta = nn.Parameter(torch.tensor(0.1)) def forward(self, sdf): # use Laplace CDF to compute the probability # temporally use sigmoid to represent laplace CDF return F.sigmoid(-sdf/(self.beta+1e-6)) class DeformationMLP(pl.LightningModule): def __init__(self,input_dim=64,output_dim=3,activation='LeakyReLU',name=None,opt=None): super(DeformationMLP, self).__init__() self.name = name self.activation = activation self.activate = nn.LeakyReLU(inplace=True) # self.mlp = nn.Sequential( # nn.Conv1d(input_dim+8+1+3, 64, 1), # nn.LeakyReLU(inplace=True), # nn.Conv1d(64, output_dim, 1), # ) channels=[input_dim+8+1+3,128, 64, output_dim] self.deform_mlp=MLP(filter_channels=channels, name="if", res_layers=opt.res_layers, norm=opt.norm_mlp, last_op=None) # occupancy smplx_dim = 10475 k=8 self.per_pt_code = nn.Embedding(smplx_dim,k) def forward(self, feature,smpl_vis,pts_id, xyz): ''' feature may include multiple view inputs args: feature: [B, C_in, N] return: [B, C_out, N] prediction ''' y = feature e_code=self.per_pt_code(pts_id).permute(0,2,1) # a code that distinguishes each point on different parts of the body y=torch.cat([y,xyz,smpl_vis,e_code],1) y = self.deform_mlp(y) return y class MLP(pl.LightningModule): def __init__(self, filter_channels, name=None, res_layers=[], norm='group', last_op=None): super(MLP, self).__init__() self.filters = nn.ModuleList() self.norms = nn.ModuleList() self.res_layers = res_layers self.norm = norm self.last_op = last_op self.name = name self.activate = nn.LeakyReLU(inplace=True) for l in range(0, len(filter_channels) - 1): if l in self.res_layers: self.filters.append( nn.Conv1d(filter_channels[l] + filter_channels[0], filter_channels[l + 1], 1)) else: self.filters.append( nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) if l != len(filter_channels) - 2: if norm == 'group': self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) elif norm == 'batch': self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) elif norm == 'instance': self.norms.append(nn.InstanceNorm1d(filter_channels[l + 1])) elif norm == 'weight': self.filters[l] = nn.utils.weight_norm(self.filters[l], name='weight') # print(self.filters[l].weight_g.size(), # self.filters[l].weight_v.size()) def forward(self, feature): ''' feature may include multiple view inputs args: feature: [B, C_in, N] return: [B, C_out, N] prediction ''' y = feature tmpy = feature for i, f in enumerate(self.filters): y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) if i != len(self.filters) - 1: if self.norm not in ['batch', 'group', 'instance']: y = self.activate(y) else: y = self.activate(self.norms[i](y)) if self.last_op is not None: y = self.last_op(y) return y # Positional encoding (section 5.1) class Embedder(pl.LightningModule): def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x : x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) else: freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires=6, i=0): if i == -1: return nn.Identity(), 3 embed_kwargs = { 'include_input' : True, 'input_dims' : 3, 'max_freq_log2' : multires-1, 'num_freqs' : multires, 'log_sampling' : True, 'periodic_fns' : [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = lambda x, eo=embedder_obj : eo.embed(x) return embed, embedder_obj.out_dim # Transformer encoder layer # uses Embedder to add positional encoding to input points # uses query points as query, deformed points as key, point features as value for attention class TransformerEncoderLayer(pl.LightningModule): def __init__(self, d_model=256, skips=4, multires=6, num_mlp_layers=8, dropout=0.1, opt=None): super(TransformerEncoderLayer, self).__init__() embed_fn, input_ch = get_embedder(multires=multires) self.skips=skips self.dropout = dropout D=num_mlp_layers self.positional_encoding = embed_fn self.d_model = d_model triplane_dim=64 opt.mlp_dim[0]=triplane_dim+6+8 opt.mlp_dim_color[0]=triplane_dim+6+8 self.geo_mlp=MLP(filter_channels=opt.mlp_dim, name="if", res_layers=opt.res_layers, norm=opt.norm_mlp, last_op=nn.Sigmoid()) # occupancy self.color_mlp=MLP(filter_channels=opt.mlp_dim_color, name="color_if", res_layers=opt.res_layers, norm=opt.norm_mlp, last_op=nn.Tanh()) # color self.softmax = nn.Softmax(dim=-1) def forward(self,query_points,key_points,point_features,smpl_feat,training=True,type='shape'): # Q=self.positional_encoding(query_points) #[B,N,39] # K=self.positional_encoding(key_points) #[B,N',39] # V=point_features.permute(0,2,1) #[B,N',192] # t=0.1 # #attn_output, attn_output_weights = self.attention(Q.permute(1,0,2), K.permute(1,0,2), V.permute(1,0,2)) #[B,N,192] # attn_output_weights = torch.bmm(Q, K.transpose(1, 2)) #[B,N,N'] # attn_output_weights = self.softmax(attn_output_weights/t) #[B,N,N'] # # drop out # attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=True) # # master feature # attn_output = torch.bmm(attn_output_weights, V) #[B,N,192] attn_output=point_features # [B,N,192] bary centric interpolation feature=torch.cat([attn_output,smpl_feat],dim=1) if type=='shape': h=feature h=self.geo_mlp(h) # [B,1,N] return h elif type=='color': #f=self.head(feature) #[B,N,512] h=feature h=self.color_mlp(h) # [B,3,N] return h elif type=='shape_color': h_s=feature h_c=feature h_s=self.geo_mlp(h_s) # [B,1,N] h_c=self.color_mlp(h_c) # [B,3,N] return h_s,h_c class Swish(pl.LightningModule): def __init__(self): super(Swish, self).__init__() def forward(self, x): x = x * F.sigmoid(x) return x # # Import pytorch modules # import torch # import torch.nn as nn # import torch.nn.functional as F # Define positional encoding class class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=1000): super(PositionalEncoding, self).__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:, :x.size(1)] return x # # Define model parameters # d_model = 256 # output size of MLP # nhead = 8 # number of attention heads # dim_feedforward = 512 # hidden size of MLP # num_layers = 2 # number of MLP layers # num_frequencies = 6 # number of frequencies for positional encoding # dropout = 0.1 # dropout rate # # Define model components # pos_encoder = PositionalEncoding(d_model, num_frequencies) # positional encoding layer # encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) # transformer encoder layer # encoder = nn.TransformerEncoder(encoder_layer, num_layers) # transformer encoder # mlp_geo = nn.Sequential(nn.Linear(3, d_model), nn.ReLU(), nn.Linear(d_model, d_model)) # MLP for geometry # mlp_alb = nn.Sequential(nn.Linear(3, d_model), nn.ReLU(), nn.Linear(d_model, d_model)) # MLP for albedo # head_geo = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 3)) # geometry head # head_alb = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, 3), nn.Sigmoid()) # albedo head # # Define input tensors # # deformed body points: (batch_size, num_points, 3) # x = torch.randn(batch_size, num_points, 3) # # query point positions: (batch_size, num_queries, 3) # y = torch.randn(batch_size, num_queries, 3) # # Map both d