|
import torch |
|
import torch.nn as nn |
|
|
|
from models import register |
|
|
|
|
|
@register('mlp_pw') |
|
class MLP(nn.Module): |
|
|
|
def __init__(self, in_dim, out_dim, hidden_list): |
|
super().__init__() |
|
self.relu_0 = nn.ReLU(inplace=True) |
|
self.relu_1 = nn.ReLU(inplace=True) |
|
self.relu_2 = nn.ReLU(inplace=True) |
|
self.relu_3 = nn.ReLU(inplace=True) |
|
self.hidden=hidden_list[0] |
|
|
|
def forward(self, x,Coeff,basis,bias): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device=x.device |
|
|
|
|
|
x = x.unsqueeze(1) |
|
|
|
x = torch.sum(x*torch.matmul(Coeff.to(device),basis[0].to(device)).view(-1,self.hidden,580),dim=2) |
|
|
|
x = x + torch.matmul(Coeff.to(device),bias[0].to(device)) |
|
x = self.relu_0(x) |
|
|
|
x = x.unsqueeze(1) |
|
x = torch.sum(x*torch.matmul(Coeff.to(device),basis[1].to(device)).view(-1,self.hidden,self.hidden),dim=2) |
|
x = x + torch.matmul(Coeff.to(device),bias[1].to(device)) |
|
x = self.relu_1(x) |
|
|
|
x = x.unsqueeze(1) |
|
x = torch.sum(x*torch.matmul(Coeff.to(device),basis[2].to(device)).view(-1,self.hidden,self.hidden),dim=2) |
|
x = x + torch.matmul(Coeff.to(device),bias[2].to(device)) |
|
x = self.relu_2(x) |
|
|
|
x = x.unsqueeze(1) |
|
x = torch.sum(x*torch.matmul(Coeff.to(device),basis[3].to(device)).view(-1,self.hidden,self.hidden),dim=2) |
|
x = x + torch.matmul(Coeff.to(device),bias[3].to(device)) |
|
x = self.relu_3(x) |
|
|
|
x = x.unsqueeze(1) |
|
x = torch.sum(x*torch.matmul(Coeff.to(device),basis[4].to(device)).view(-1,3,self.hidden),dim=2) |
|
x = x + torch.matmul(Coeff.to(device),bias[4].to(device)) |
|
|
|
return x |
|
|
|
|
|
@register('mlp') |
|
class MLP(nn.Module): |
|
|
|
def __init__(self, in_dim, out_dim, hidden_list): |
|
super().__init__() |
|
layers = [] |
|
lastv = in_dim |
|
for hidden in hidden_list: |
|
layers.append(nn.Linear(lastv, hidden)) |
|
layers.append(nn.ReLU()) |
|
lastv = hidden |
|
layers.append(nn.Linear(lastv, out_dim)) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
shape = x.shape[:-1] |
|
x = self.layers(x.view(-1, x.shape[-1])) |
|
return x.view(*shape, -1) |
|
|