cifar10 / model.py
dawn17's picture
Upload model.py
fe8ce56
raw
history blame contribute delete
No virus
2.39 kB
"""
CIFAR 10
INPUT - [3, 32, 32]
"""
import torch.nn as nn
class BasicBlock(nn.Module):
def __init__(self, in_channel, out_channel, dropout):
super(BasicBlock, self).__init__()
self.cblock = nn.Sequential(
*[
self._get_base_layer(
in_channel if i == 0 else out_channel, out_channel, dropout
)
for i in range(2)
]
)
def _get_base_layer(self, in_channel, out_channel, dropout):
return nn.Sequential(
nn.Conv2d(
in_channel,
out_channel,
kernel_size=3,
padding=1,
padding_mode="replicate",
bias=False,
),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.Dropout(dropout),
)
def forward(self, x):
return self.cblock(x) + x
class DavidPageNet(nn.Module):
def __init__(self, channels=[64, 128, 256, 512], dropout=0.01):
super(DavidPageNet, self).__init__()
self.block0 = self._get_base_layer(3, channels[0], pool=False)
self.block1 = nn.Sequential(
*[
self._get_base_layer(channels[0], channels[1]),
BasicBlock(channels[1], channels[1], dropout),
]
)
self.block2 = self._get_base_layer(channels[1], channels[2])
self.block3 = nn.Sequential(
*[
self._get_base_layer(channels[2], channels[3]),
BasicBlock(channels[3], channels[3], dropout),
]
)
self.logit = nn.Sequential(
nn.MaxPool2d(4),
nn.Flatten(),
nn.Linear(512, 10),
)
def _get_base_layer(self, in_channel, out_channel, pool=True):
return nn.Sequential(
nn.Conv2d(
in_channel,
out_channel,
stride=1,
padding=1,
kernel_size=3,
bias=False,
padding_mode="replicate",
),
nn.MaxPool2d(2) if pool else nn.Identity(),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
)
def forward(self, x):
x = self.block0(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
return self.logit(x)