import torch import gradio as gr from torchvision import transforms from PIL import ImageOps import os from dotenv import load_dotenv from torch import nn import torch.nn.functional as F class Lenet(nn.Module): def __init__(self, args=None): super().__init__() self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28 self.pool1 = nn.MaxPool2d(2) # -> 6 channels, 14x14 self.conv2 = nn.Conv2d(6, 16, 5) #-> 16 images, 10x10 self.pool2 = nn.MaxPool2d(2) # -> 16 channels, 5x5 self.conv3 = nn.Conv2d(16, 120, 5) #-> 16 images, 1x1 self.fc1 = nn.Linear(120, 84) self.fc2 = nn.Linear(84, 10) def __call__(self, x): xx = F.relu(self.conv1(x)) xx = F.relu(self.pool1(xx)) xx = F.relu(self.conv2(xx)) xx = F.relu(self.pool2(xx)) xx = F.relu(self.conv3(xx)) xx = xx.flatten(1) xx = F.relu(self.fc1(xx)) return self.fc2(xx) load_dotenv() hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging") def load_model(): model = Lenet() model.load_state_dict(torch.load('model.pt')) model.eval() return model model = load_model() convert_tensor = transforms.ToTensor() def predict(img): img = ImageOps.grayscale(img).resize((28,28)) image_tensor = convert_tensor(img).view(1, 1, 28, 28) logits = model(image_tensor) pred = torch.argmax(logits, dim=1) return pred.tolist()[0] title = "Handwritten digit recognition" description = '