model class definition:
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 5, 1, 2)
self.conv2 = nn.Conv2d(16, 32, 5, 1, 2)
self.bn1 = nn.BatchNorm2d(32)
self.maxpool1 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(32, 32, 5, 1, 2)
self.conv4 = nn.Conv2d(32, 64, 5, 1, 2)
self.bn2 = nn.BatchNorm2d(64)
self.maxpool2 = nn.MaxPool2d(2)
self.conv5 = nn.Conv2d(64, 64, 5, 1, 2)
self.conv6 = nn.Conv2d(64, 128, 5, 1, 2)
self.bn3 = nn.BatchNorm2d(128)
self.maxpool3 = nn.MaxPool2d(2)
self.dense1 = nn.Linear(4*4*128, 256)
self.bn5 = nn.BatchNorm1d(256)
self.dense2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.einsum("ijkl->iljk", [x])
x = F.relu(self.conv1(x))
x = F.dropout(self.maxpool1(F.relu(self.bn1(self.conv2(x)))), 0.25)
x = F.relu(self.conv3(x))
x = F.dropout(self.maxpool2(F.relu(self.bn2(self.conv4(x)))), 0.25)
x = F.relu(self.conv5(x))
x = F.dropout(self.maxpool3(F.relu(self.bn3(self.conv6(x)))), 0.25)
x = x.reshape(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3])
x = F.relu(self.bn5(self.dense1(x)))
x = F.softmax(self.dense2(x), dim = 1)
return x