import torch import torchvision from torch import nn class TinyCNN(nn.Module): def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None: super().__init__() self.conv_block_1 = nn.Sequential( nn.Conv2d(in_channels=input_shape, out_channels=hidden_units, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Conv2d(in_channels=hidden_units, out_channels=128, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(p=0.25) ) self.conv_block_2 = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=0), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, padding=0), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(p=0.25) ) self.conv_block_3 = nn.Sequential( nn.Conv2d(128, 128, kernel_size=3, padding=0), nn.ReLU(), nn.Conv2d(128, 512, kernel_size=3, padding=0), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(p=0.25) ) self.conv_block_4 = nn.Sequential( nn.Conv2d(512, 512, kernel_size=3, padding=0), nn.ReLU(), nn.Conv2d(512, 512, kernel_size=3, padding=0), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout(p=0.25) ) self.fc_1 = nn.Sequential( nn.Flatten(), nn.Linear(in_features=512*16, out_features = 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(p=0.25) ) self.fc_2 = nn.Sequential( nn.Linear(in_features=512, out_features=256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p=0.25) ) self.classifier = nn.Sequential( nn.Linear(in_features=256, out_features=output_shape) ) def forward(self, x): x = self.conv_block_1(x) x = self.conv_block_2(x) x = self.conv_block_3(x) x = self.conv_block_4(x) x = self.fc_1(x) x = self.fc_2(x) x = self.classifier(x) return x