mawairon commited on
Commit
778a9fa
1 Parent(s): ad5e55a

Create model_archs

Browse files
Files changed (1) hide show
  1. model_archs +119 -0
model_archs ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ biases = False
5
+
6
+
7
+ class Pool2BN(nn.Module):
8
+ def __init__(self, num_channels):
9
+ super().__init__()
10
+ self.bn = torch.nn.BatchNorm1d(num_channels * 2)
11
+
12
+ def forward(self, x):
13
+ avgp = torch.nn.functional.adaptive_avg_pool1d(x, 1)[:, :, 0]
14
+ maxp = torch.nn.functional.adaptive_max_pool1d(x, 1)[:, :, 0]
15
+ x = torch.cat((avgp, maxp), axis=1)
16
+ x = self.bn(x)
17
+ return x
18
+
19
+ class MLP(torch.nn.Module):
20
+ def __init__(self, layer_sizes, biases=False, sigmoid=False, dropout=None):
21
+ super().__init__()
22
+ layers = []
23
+ prev_size = layer_sizes[0]
24
+ for i, s in enumerate(layer_sizes[1:]):
25
+ if i != 0 and dropout is not None:
26
+ layers.append(torch.nn.Dropout(dropout))
27
+
28
+ layers.append(torch.nn.Linear(in_features=prev_size, out_features=s, bias=biases))
29
+ if i != len(layer_sizes) - 2:
30
+ if sigmoid:
31
+ # layers.append(torch.nn.Sigmoid())
32
+ layers.append(torch.nn.Tanh())
33
+ else:
34
+ layers.append(torch.nn.ReLU())
35
+
36
+ layers.append(torch.nn.BatchNorm1d(s))
37
+
38
+ prev_size = s
39
+
40
+ self.mlp = torch.nn.Sequential(*layers)
41
+
42
+ def forward(self, x):
43
+
44
+ return self.mlp(x)
45
+
46
+
47
+ class SimpleCNN(torch.nn.Module):
48
+ def __init__(self, k, num_filters, sigmoid=False, additional_layer=False):
49
+ super(SimpleCNN, self).__init__()
50
+ self.sigmoid = sigmoid
51
+ self.cnn = torch.nn.Conv1d(in_channels=4, out_channels=num_filters, kernel_size=k, bias=biases)
52
+
53
+ self.additional_layer = additional_layer
54
+ if additional_layer:
55
+ self.bn = nn.BatchNorm1d(num_filters)
56
+ # self.do = nn.Dropout(0.5)
57
+ self.cnn2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=1, bias=biases)
58
+
59
+ self.post = Pool2BN(num_filters)
60
+
61
+ def forward(self, x):
62
+ x = self.cnn(x)
63
+ x = (torch.tanh if self.sigmoid else torch.relu)(x)
64
+
65
+ if self.additional_layer:
66
+ x = self.bn(x)
67
+ # x = self.do(x)
68
+ x = self.cnn2(x)
69
+ x = (torch.tanh if self.sigmoid else torch.relu)(x)
70
+
71
+ x = self.post(x)
72
+ #print(f'x shape at CNN output: {x.shape}')
73
+ return x
74
+
75
+
76
+ class ResNet1dBlock(torch.nn.Module):
77
+ def __init__(self, num_filters, k1, internal_filters, k2, dropout=None, dilation=None):
78
+ super().__init__()
79
+
80
+ self.init_do = torch.nn.Dropout(dropout) if dropout is not None else None
81
+ self.bn1 = torch.nn.BatchNorm1d(num_filters)
82
+ if dilation is None:
83
+ dilation = 1
84
+
85
+ self.cnn1 = torch.nn.Conv1d(in_channels=num_filters, out_channels=internal_filters, kernel_size=k1, bias=biases,
86
+ dilation=dilation,
87
+ padding=(k1 // 2) * dilation)
88
+
89
+ self.bn2 = torch.nn.BatchNorm1d(internal_filters)
90
+ self.cnn2 = torch.nn.Conv1d(in_channels=internal_filters, out_channels=num_filters, kernel_size=k2, bias=biases,
91
+ padding=k2 // 2)
92
+
93
+ def forward(self, x):
94
+ x_orig = x
95
+
96
+ x = self.bn1(x)
97
+ x = torch.relu(x)
98
+ if self.init_do is not None:
99
+ x = self.init_do(x)
100
+
101
+ x = self.cnn1(x)
102
+
103
+ x = self.bn2(x)
104
+ x = torch.relu(x)
105
+ x = self.cnn2(x)
106
+
107
+ return x + x_orig
108
+
109
+
110
+ class ResNet1d(torch.nn.Module):
111
+ def __init__(self, num_filters, block_spec, dropout=None, dilation=None):
112
+ super().__init__()
113
+
114
+ blocks = [ResNet1dBlock(num_filters, *spec, dropout=dropout, dilation=dilation) for spec in block_spec]
115
+ self.blocks = torch.nn.Sequential(*blocks)
116
+
117
+ def forward(self, x):
118
+ return self.blocks(x)
119
+