import torch.nn as nn from collections import OrderedDict class LeNet5(nn.Module): """ Input - 1x32x32 C1 - 6@28x28 (5x5 kernel) tanh S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling C3 - 16@10x10 (5x5 kernel, complicated shit) tanh S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling C5 - 120@1x1 (5x5 kernel) F6 - 84 tanh F7 - 10 (Output) """ def __init__(self): super(LeNet5, self).__init__() self.convnet = nn.Sequential(OrderedDict([ ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))), ('tanh1', nn.Tanh()), ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)), ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))), ('tanh3', nn.Tanh()), ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)), ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))), ('tanh5', nn.Tanh()) ])) self.fc = nn.Sequential(OrderedDict([ ('f6', nn.Linear(120, 84)), ('tanh6', nn.Tanh()), ('f7', nn.Linear(84, 10)), ('sig7', nn.LogSoftmax(dim=-1)) ])) def forward(self, img): output = self.convnet(img) output = output.view(img.size(0), -1) output = self.fc(output) return output def extract_features(self, img): output = self.convnet(img.float()) output = output.view(img.size(0), -1) output = self.fc[1](self.fc[0](output)) return output