File size: 2,780 Bytes
02c5426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn

from models import register


@register('mlp_pw')
class MLP(nn.Module):

    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        self.relu_0 = nn.ReLU(inplace=True)
        self.relu_1 = nn.ReLU(inplace=True)
        self.relu_2 = nn.ReLU(inplace=True)
        self.relu_3 = nn.ReLU(inplace=True)
        self.hidden=hidden_list[0]

    def forward(self, x,Coeff,basis,bias):
        # x(b*h*w,580)
        # Coeff(b*h*w,10)
        # basis[0](10,16*580)
        # basis[1](10,16*16)
        # basis[2](10,16*16)
        # basis[3](10,16*16)
        # basis[4](10,3*16)
        # bias[0](10,16)
        # bias[1](10,16)
        # bias[2](10,16)
        # bias[3](10,16)
        # bias[4](10,3)
        device=x.device
        # Applies a linear transformation to the incoming data: :math:`y = xA^T + b
        # Layer0
        x = x.unsqueeze(1)
        #  sum(  (b*h*w,1,580)*(b*h*w,16,580)  ,  dim=2  )  ->  (b*h*w,16)
        x = torch.sum(x*torch.matmul(Coeff.to(device),basis[0].to(device)).view(-1,self.hidden,580),dim=2)
        #  (b*h*w,16) + (b*h*w,16)  ->  (b*h*w,16)
        x = x + torch.matmul(Coeff.to(device),bias[0].to(device))
        x = self.relu_0(x)
        # Layer1
        x = x.unsqueeze(1)
        x = torch.sum(x*torch.matmul(Coeff.to(device),basis[1].to(device)).view(-1,self.hidden,self.hidden),dim=2)
        x = x + torch.matmul(Coeff.to(device),bias[1].to(device))
        x = self.relu_1(x)
        # Layer2
        x = x.unsqueeze(1)
        x = torch.sum(x*torch.matmul(Coeff.to(device),basis[2].to(device)).view(-1,self.hidden,self.hidden),dim=2)
        x = x + torch.matmul(Coeff.to(device),bias[2].to(device))
        x = self.relu_2(x)
        # Layer3
        x = x.unsqueeze(1)
        x = torch.sum(x*torch.matmul(Coeff.to(device),basis[3].to(device)).view(-1,self.hidden,self.hidden),dim=2)
        x = x + torch.matmul(Coeff.to(device),bias[3].to(device))
        x = self.relu_3(x)
        # Layer4
        x = x.unsqueeze(1)
        x = torch.sum(x*torch.matmul(Coeff.to(device),basis[4].to(device)).view(-1,3,self.hidden),dim=2)
        x = x + torch.matmul(Coeff.to(device),bias[4].to(device))

        return x


@register('mlp')
class MLP(nn.Module):

    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        layers = []
        lastv = in_dim
        for hidden in hidden_list:
            layers.append(nn.Linear(lastv, hidden))
            layers.append(nn.ReLU())
            lastv = hidden
        layers.append(nn.Linear(lastv, out_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        shape = x.shape[:-1]
        x = self.layers(x.view(-1, x.shape[-1]))
        return x.view(*shape, -1)