Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import models | |
from models.utils import get_activation | |
from models.network_utils import get_encoding, get_mlp | |
from systems.utils import update_module_step | |
class VolumeRadiance(nn.Module): | |
def __init__(self, config): | |
super(VolumeRadiance, self).__init__() | |
self.config = config | |
self.with_viewdir = False #self.config.get('wo_viewdir', False) | |
self.n_dir_dims = self.config.get('n_dir_dims', 3) | |
self.n_output_dims = 3 | |
if self.with_viewdir: | |
encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config) | |
self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims | |
# self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config) | |
else: | |
encoding = None | |
self.n_input_dims = self.config.input_feature_dim | |
network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) | |
self.encoding = encoding | |
self.network = network | |
def forward(self, features, dirs, *args): | |
# features = features.detach() | |
if self.with_viewdir: | |
dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1) | |
dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims)) | |
network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) | |
# network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) | |
color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() | |
# color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float() | |
# color = color + color_base | |
else: | |
network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) | |
color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() | |
if 'color_activation' in self.config: | |
color = get_activation(self.config.color_activation)(color) | |
return color | |
def update_step(self, epoch, global_step): | |
update_module_step(self.encoding, epoch, global_step) | |
def regularizations(self, out): | |
return {} | |
class VolumeColor(nn.Module): | |
def __init__(self, config): | |
super(VolumeColor, self).__init__() | |
self.config = config | |
self.n_output_dims = 3 | |
self.n_input_dims = self.config.input_feature_dim | |
network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) | |
self.network = network | |
def forward(self, features, *args): | |
network_inp = features.view(-1, features.shape[-1]) | |
color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() | |
if 'color_activation' in self.config: | |
color = get_activation(self.config.color_activation)(color) | |
return color | |
def regularizations(self, out): | |
return {} | |