|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
class Base(nn.Module): |
|
def training_step(self, batch): |
|
images, labels = batch |
|
out = self(images) |
|
loss = F.cross_entropy(out, labels) |
|
return loss |
|
|
|
def validation_step(self, batch): |
|
images, labels = batch |
|
out = self(images) |
|
loss = F.cross_entropy(out, labels) |
|
acc = accuracy(out, labels) |
|
return {'val_loss': loss.detach(), 'val_acc': acc} |
|
|
|
def validation_epoch_end(self, outputs): |
|
batch_losses = [x['val_loss'] for x in outputs] |
|
epoch_loss = torch.stack(batch_losses).mean() |
|
batch_accs = [x['val_acc'] for x in outputs] |
|
epoch_acc = torch.stack(batch_accs).mean() |
|
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} |
|
|
|
def epoch_end(self, epoch, result): |
|
print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format( |
|
epoch, result['train_loss'], result['val_loss'], result['val_acc'])) |
|
|
|
|
|
|
|
def accuracy(outputs, labels): |
|
_, preds = torch.max(outputs, dim=1) |
|
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) |
|
|
|
|
|
class PotatoDiseaseDetectionModel(Base): |
|
def __init__(self, in_channels=3, num_classes=3): |
|
super(PotatoDiseaseDetectionModel, self).__init__() |
|
|
|
|
|
self.network = nn.Sequential( |
|
nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
|
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
|
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(inplace=True), |
|
nn.MaxPool2d(kernel_size=2, stride=2), |
|
|
|
nn.Flatten() |
|
) |
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(in_features=256*28*28, out_features=128), |
|
nn.BatchNorm1d(128), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(0.5), |
|
nn.Linear(in_features=128, out_features=num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
|
|
x = self.network(x) |
|
|
|
|
|
x = self.classifier(x) |
|
|
|
return x |
|
|
|
|
|
model = PotatoDiseaseDetectionModel(num_classes=3) |