|
import idx2numpy, torch |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import transforms, datasets |
|
from PIL import Image |
|
|
|
|
|
class ApplyEnhancementFilter: |
|
def __init__(self, out_channels, kernel_size, stride=1, padding=0, bias=False): |
|
""" |
|
Initialize the convolution parameters. |
|
""" |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.bias = bias |
|
|
|
self.conv = nn.Conv2d(in_channels=1, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias) |
|
|
|
|
|
|
|
edge_detection_kernel = torch.tensor([[0, -1., 0.], |
|
[-1., 5., -1.], |
|
[0., -1., 0.]]).unsqueeze(0).unsqueeze(0) |
|
self.conv.weight = nn.Parameter(edge_detection_kernel.float()) |
|
|
|
def __call__(self, img): |
|
""" |
|
Apply the convolution transformation. |
|
""" |
|
|
|
img_tensor = transforms.functional.to_tensor(img).unsqueeze(0) |
|
|
|
conv_img = self.conv(img_tensor) |
|
|
|
conv_img_pil = transforms.functional.to_pil_image(conv_img.squeeze(0)) |
|
return conv_img_pil |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
train_images = idx2numpy.convert_from_file("mnist_dataset/train-images.idx3-ubyte") |
|
|
|
|
|
train_images_tensor = torch.tensor(train_images, dtype=torch.float32) / 255.0 |
|
|
|
train_mean = train_images_tensor.mean() |
|
train_std = train_images_tensor.std() |
|
|
|
print(f"Mean: {train_mean}, Std: {train_std}") |
|
|