Edit model card

model class definition:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 5, 1, 2)
        self.conv2 = nn.Conv2d(16, 32, 5, 1, 2)
        self.bn1 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(32, 32, 5, 1, 2)
        self.conv4 = nn.Conv2d(32, 64, 5, 1, 2)
        self.bn2 = nn.BatchNorm2d(64)
        self.maxpool2 = nn.MaxPool2d(2)
        
        self.conv5 = nn.Conv2d(64, 64, 5, 1, 2)
        self.conv6 = nn.Conv2d(64, 128, 5, 1, 2)
        self.bn3 = nn.BatchNorm2d(128)
        self.maxpool3 = nn.MaxPool2d(2)
        
        self.dense1 = nn.Linear(4*4*128, 256)
        self.bn5 = nn.BatchNorm1d(256)
        self.dense2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.einsum("ijkl->iljk", [x])
        x = F.relu(self.conv1(x))
        x = F.dropout(self.maxpool1(F.relu(self.bn1(self.conv2(x)))), 0.25)
        x = F.relu(self.conv3(x))
        x = F.dropout(self.maxpool2(F.relu(self.bn2(self.conv4(x)))), 0.25)
        x = F.relu(self.conv5(x))
        x = F.dropout(self.maxpool3(F.relu(self.bn3(self.conv6(x)))), 0.25)
        x = x.reshape(x.shape[0], x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.bn5(self.dense1(x)))
        x = F.softmax(self.dense2(x), dim = 1)
        return x
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train Sparklizm/CIFAR10_0.8048