|
import torch.nn as nn |
|
import torch |
|
|
|
class LeNNon(nn.Module): |
|
def __init__(self): |
|
""" Define a CNN architecture used for image classification. |
|
This class defines the LeNNon architecture as a PyTorch module |
|
""" |
|
super(LeNNon, self).__init__() |
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) |
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) |
|
self.fc1 = nn.Linear(32 * 25 * 25, 128) |
|
self.fc2 = nn.Linear(128, 10) |
|
|
|
def forward(self, x): |
|
""" |
|
Perform a forward pass through the LeNNon architecture. |
|
|
|
This method applies the convolutional layers, max pooling layers, |
|
and fully connected layers to the input tensor x. |
|
|
|
Parameters: |
|
----------- |
|
x (torch.Tensor): The input tensor. |
|
|
|
Returns: |
|
-------- |
|
torch.Tensor: The output tensor. |
|
""" |
|
x = self.pool(torch.relu(self.conv1(x))) |
|
x = self.pool(torch.relu(self.conv2(x))) |
|
x = x.view(-1, 32 * 25 * 25) |
|
x = torch.relu(self.fc1(x)) |
|
x = self.fc2(x) |
|
return x |