import torch from torch import nn class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1), nn.Dropout2d(p=0.3), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1), nn.Dropout2d(p=0.3), nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1), nn.Dropout2d(p=0.4), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1), ) self.flatten = nn.Flatten() self.fc = nn.Sequential( nn.Linear(256*4*4, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(256, 151), ) def forward(self, x): out = self.conv(x) out = self.flatten(out) out = self.fc(out) return out