makiisthebes's picture
Upload 9 files
61f0100 verified
raw
history blame
2.47 kB
import idx2numpy, torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from PIL import Image
class ApplyEnhancementFilter:
def __init__(self, out_channels, kernel_size, stride=1, padding=0, bias=False):
"""
Initialize the convolution parameters.
"""
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.bias = bias
# Define the convolutional layer (not trained here)
self.conv = nn.Conv2d(in_channels=1, # Adjust this based on your image channels (1 for grayscale, 3 for RGB)
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias)
# Example: Manually defining a simple edge-detection kernel
# For a real use-case, the kernel weights would be learned or defined according to the filter you need.
edge_detection_kernel = torch.tensor([[0, -1., 0.],
[-1., 5., -1.],
[0., -1., 0.]]).unsqueeze(0).unsqueeze(0)
self.conv.weight = nn.Parameter(edge_detection_kernel.float())
def __call__(self, img):
"""
Apply the convolution transformation.
"""
# Convert PIL image to tensor
img_tensor = transforms.functional.to_tensor(img).unsqueeze(0) # Add batch dimension
# Apply convolution
conv_img = self.conv(img_tensor)
# Remove batch dimension and convert back to PIL image for further transformations or visualization
conv_img_pil = transforms.functional.to_pil_image(conv_img.squeeze(0))
return conv_img_pil
if __name__ == "__main__":
# It is important to normalise the dataset, so no specific input effects the model more than other based purely on input values.
# As values can range from 0-255, this can cause problems, so z-score will be used via Transforms.
# First we need the mean and standard deviation of train dataset.
train_images = idx2numpy.convert_from_file("mnist_dataset/train-images.idx3-ubyte")
# Convert the training images to a PyTorch tensor and scale values to [0, 1]
train_images_tensor = torch.tensor(train_images, dtype=torch.float32) / 255.0
train_mean = train_images_tensor.mean()
train_std = train_images_tensor.std()
print(f"Mean: {train_mean}, Std: {train_std}")
# Mean: 0.13066047430038452, Std: 0.30810782313346863