import idx2numpy, torch import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms, datasets from PIL import Image class ApplyEnhancementFilter: def __init__(self, out_channels, kernel_size, stride=1, padding=0, bias=False): """ Initialize the convolution parameters. """ self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.bias = bias # Define the convolutional layer (not trained here) self.conv = nn.Conv2d(in_channels=1, # Adjust this based on your image channels (1 for grayscale, 3 for RGB) out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) # Example: Manually defining a simple edge-detection kernel # For a real use-case, the kernel weights would be learned or defined according to the filter you need. edge_detection_kernel = torch.tensor([[0, -1., 0.], [-1., 5., -1.], [0., -1., 0.]]).unsqueeze(0).unsqueeze(0) self.conv.weight = nn.Parameter(edge_detection_kernel.float()) def __call__(self, img): """ Apply the convolution transformation. """ # Convert PIL image to tensor img_tensor = transforms.functional.to_tensor(img).unsqueeze(0) # Add batch dimension # Apply convolution conv_img = self.conv(img_tensor) # Remove batch dimension and convert back to PIL image for further transformations or visualization conv_img_pil = transforms.functional.to_pil_image(conv_img.squeeze(0)) return conv_img_pil if __name__ == "__main__": # It is important to normalise the dataset, so no specific input effects the model more than other based purely on input values. # As values can range from 0-255, this can cause problems, so z-score will be used via Transforms. # First we need the mean and standard deviation of train dataset. train_images = idx2numpy.convert_from_file("mnist_dataset/train-images.idx3-ubyte") # Convert the training images to a PyTorch tensor and scale values to [0, 1] train_images_tensor = torch.tensor(train_images, dtype=torch.float32) / 255.0 train_mean = train_images_tensor.mean() train_std = train_images_tensor.std() print(f"Mean: {train_mean}, Std: {train_std}") # Mean: 0.13066047430038452, Std: 0.30810782313346863