import pandas as pd import torch.nn as nn class CNN(nn.Module): def __init__(self, K): super(CNN, self).__init__() self.conv_layers = nn.Sequential( # conv1 nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(32), nn.MaxPool2d(2), # conv2 nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(64), nn.MaxPool2d(2), # conv3 nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(128), nn.MaxPool2d(2), # conv4 nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), nn.ReLU(), nn.BatchNorm2d(256), nn.MaxPool2d(2), ) self.dense_layers = nn.Sequential( nn.Dropout(0.4), nn.Linear(50176, 1024), nn.ReLU(), nn.Dropout(0.4), nn.Linear(1024, K), ) def forward(self, X): out = self.conv_layers(X) # Flatten out = out.view(-1, 50176) # Fully connected out = self.dense_layers(out) return out idx_to_classes = {0: 'Apple___Apple_scab', 1: 'Apple___Black_rot', 2: 'Apple___Cedar_apple_rust', 3: 'Apple___healthy', 4: 'Background_without_leaves', 5: 'Blueberry___healthy', 6: 'Cherry___Powdery_mildew', 7: 'Cherry___healthy', 8: 'Corn___Cercospora_leaf_spot Gray_leaf_spot', 9: 'Corn___Common_rust', 10: 'Corn___Northern_Leaf_Blight', 11: 'Corn___healthy', 12: 'Grape___Black_rot', 13: 'Grape___Esca_(Black_Measles)', 14: 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 15: 'Grape___healthy', 16: 'Orange___Haunglongbing_(Citrus_greening)', 17: 'Peach___Bacterial_spot', 18: 'Peach___healthy', 19: 'Pepper,_bell___Bacterial_spot', 20: 'Pepper,_bell___healthy', 21: 'Potato___Early_blight', 22: 'Potato___Late_blight', 23: 'Potato___healthy', 24: 'Raspberry___healthy', 25: 'Soybean___healthy', 26: 'Squash___Powdery_mildew', 27: 'Strawberry___Leaf_scorch', 28: 'Strawberry___healthy', 29: 'Tomato___Bacterial_spot', 30: 'Tomato___Early_blight', 31: 'Tomato___Late_blight', 32: 'Tomato___Leaf_Mold', 33: 'Tomato___Septoria_leaf_spot', 34: 'Tomato___Spider_mites Two-spotted_spider_mite', 35: 'Tomato___Target_Spot', 36: 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 37: 'Tomato___Tomato_mosaic_virus', 38: 'Tomato___healthy'}