import torch import torch.nn as nn biases = False class Pool2BN(nn.Module): def __init__(self, num_channels): super().__init__() self.bn = torch.nn.BatchNorm1d(num_channels * 2) def forward(self, x): avgp = torch.nn.functional.adaptive_avg_pool1d(x, 1)[:, :, 0] maxp = torch.nn.functional.adaptive_max_pool1d(x, 1)[:, :, 0] x = torch.cat((avgp, maxp), axis=1) x = self.bn(x) return x class MLP(torch.nn.Module): def __init__(self, layer_sizes, biases=False, sigmoid=False, dropout=None): super().__init__() layers = [] prev_size = layer_sizes[0] for i, s in enumerate(layer_sizes[1:]): if i != 0 and dropout is not None: layers.append(torch.nn.Dropout(dropout)) layers.append(torch.nn.Linear(in_features=prev_size, out_features=s, bias=biases)) if i != len(layer_sizes) - 2: if sigmoid: # layers.append(torch.nn.Sigmoid()) layers.append(torch.nn.Tanh()) else: layers.append(torch.nn.ReLU()) layers.append(torch.nn.BatchNorm1d(s)) prev_size = s self.mlp = torch.nn.Sequential(*layers) def forward(self, x): return self.mlp(x) class SimpleCNN(torch.nn.Module): def __init__(self, k, num_filters, sigmoid=False, additional_layer=False): super(SimpleCNN, self).__init__() self.sigmoid = sigmoid self.cnn = torch.nn.Conv1d(in_channels=4, out_channels=num_filters, kernel_size=k, bias=biases) self.additional_layer = additional_layer if additional_layer: self.bn = nn.BatchNorm1d(num_filters) # self.do = nn.Dropout(0.5) self.cnn2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=1, bias=biases) self.post = Pool2BN(num_filters) def forward(self, x): x = self.cnn(x) x = (torch.tanh if self.sigmoid else torch.relu)(x) if self.additional_layer: x = self.bn(x) # x = self.do(x) x = self.cnn2(x) x = (torch.tanh if self.sigmoid else torch.relu)(x) x = self.post(x) #print(f'x shape at CNN output: {x.shape}') return x class ResNet1dBlock(torch.nn.Module): def __init__(self, num_filters, k1, internal_filters, k2, dropout=None, dilation=None): super().__init__() self.init_do = torch.nn.Dropout(dropout) if dropout is not None else None self.bn1 = torch.nn.BatchNorm1d(num_filters) if dilation is None: dilation = 1 self.cnn1 = torch.nn.Conv1d(in_channels=num_filters, out_channels=internal_filters, kernel_size=k1, bias=biases, dilation=dilation, padding=(k1 // 2) * dilation) self.bn2 = torch.nn.BatchNorm1d(internal_filters) self.cnn2 = torch.nn.Conv1d(in_channels=internal_filters, out_channels=num_filters, kernel_size=k2, bias=biases, padding=k2 // 2) def forward(self, x): x_orig = x x = self.bn1(x) x = torch.relu(x) if self.init_do is not None: x = self.init_do(x) x = self.cnn1(x) x = self.bn2(x) x = torch.relu(x) x = self.cnn2(x) return x + x_orig class ResNet1d(torch.nn.Module): def __init__(self, num_filters, block_spec, dropout=None, dilation=None): super().__init__() blocks = [ResNet1dBlock(num_filters, *spec, dropout=dropout, dilation=dilation) for spec in block_spec] self.blocks = torch.nn.Sequential(*blocks) def forward(self, x): return self.blocks(x)