""" CIFAR 10 INPUT - [3, 32, 32] """ import torch.nn as nn class BasicBlock(nn.Module): def __init__(self, in_channel, out_channel, dropout): super(BasicBlock, self).__init__() self.cblock = nn.Sequential( *[ self._get_base_layer( in_channel if i == 0 else out_channel, out_channel, dropout ) for i in range(2) ] ) def _get_base_layer(self, in_channel, out_channel, dropout): return nn.Sequential( nn.Conv2d( in_channel, out_channel, kernel_size=3, padding=1, padding_mode="replicate", bias=False, ), nn.BatchNorm2d(out_channel), nn.ReLU(), nn.Dropout(dropout), ) def forward(self, x): return self.cblock(x) + x class DavidPageNet(nn.Module): def __init__(self, channels=[64, 128, 256, 512], dropout=0.01): super(DavidPageNet, self).__init__() self.block0 = self._get_base_layer(3, channels[0], pool=False) self.block1 = nn.Sequential( *[ self._get_base_layer(channels[0], channels[1]), BasicBlock(channels[1], channels[1], dropout), ] ) self.block2 = self._get_base_layer(channels[1], channels[2]) self.block3 = nn.Sequential( *[ self._get_base_layer(channels[2], channels[3]), BasicBlock(channels[3], channels[3], dropout), ] ) self.logit = nn.Sequential( nn.MaxPool2d(4), nn.Flatten(), nn.Linear(512, 10), ) def _get_base_layer(self, in_channel, out_channel, pool=True): return nn.Sequential( nn.Conv2d( in_channel, out_channel, stride=1, padding=1, kernel_size=3, bias=False, padding_mode="replicate", ), nn.MaxPool2d(2) if pool else nn.Identity(), nn.BatchNorm2d(out_channel), nn.ReLU(), ) def forward(self, x): x = self.block0(x) x = self.block1(x) x = self.block2(x) x = self.block3(x) return self.logit(x)