FunSR / models /siren_modulation.py
KyanChen's picture
add
02c5426
raw
history blame
2.19 kB
from collections import OrderedDict
import torch
import torch.nn as nn
from models import register
@register('sirens')
class Sirens(nn.Module):
def __init__(self,
num_inner_layers,
in_dim,
modulation_dim,
out_dim=3,
base_channels=256,
is_residual=False,
):
super(Sirens, self).__init__()
self.in_dim = in_dim
self.num_inner_layers = num_inner_layers
self.is_residual = is_residual
self.first_mod = nn.Sequential(
nn.Conv2d(modulation_dim, base_channels, 1),
nn.ReLU()
)
self.first_coord = nn.Conv2d(in_dim, base_channels, 1)
self.inner_mods = nn.ModuleList()
self.inner_coords = nn.ModuleList()
for _ in range(self.num_inner_layers):
self.inner_mods.append(
nn.Sequential(
nn.Conv2d(modulation_dim+base_channels+base_channels, base_channels, 1),
nn.ReLU()
)
)
self.inner_coords.append(
nn.Conv2d(base_channels, base_channels, 1)
)
self.last_coord = nn.Sequential(
# nn.Conv2d(base_channels, base_channels//2, 1),
# nn.ReLU(),
nn.Conv2d(base_channels, out_dim, 1),
)
def forward(self, x, ori_modulations=None):
modulations = self.first_mod(ori_modulations)
x = self.first_coord(x) # B 2 H W -> B C H W
x = x + modulations
x = torch.sin(x)
for i_layer in range(self.num_inner_layers):
modulations = self.inner_mods[i_layer](
torch.cat((ori_modulations, modulations, x), dim=1))
# modulations = self.inner_mods[i_layer](
# torch.cat((ori_modulations, x), dim=1))
residual = self.inner_coords[i_layer](x)
residual = residual + modulations
residual = torch.sin(residual)
if self.is_residual:
x = x + residual
else:
x = residual
x = self.last_coord(x)
return x