Spaces:
Running
Running
File size: 1,507 Bytes
17c015f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torchvision import transforms
class simpleCNN(nn.Module):
def __init__(self, num_classes=3):
super(simpleCNN, self).__init__()
self.name = "simpleCNN"
self.conv1 = nn.Conv2d(3, 5, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(5, 10, 5)
self.fc1 = nn.Linear(10 * 5 * 5, 32)
self.fc2 = nn.Linear(32, num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 10 * 5 * 5)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
net = simpleCNN(num_classes=3)
net.load_state_dict(torch.load("./ckpt.pth", map_location=torch.device("cpu")))
net.eval()
class_labels = ["other", "car", "truck"]
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
@torch.no_grad()
def predict(img):
global net
img = Image.fromarray(img.astype("uint8"), "RGB")
img = transform(img).unsqueeze(0)
pred = net(img).detach().numpy()[0]
pred = np.exp(pred) / np.sum(np.exp(pred))
return {class_labels[i]: float(pred[i]) for i in range(len(class_labels))}
iface = gr.Interface(fn=predict, inputs="image", outputs="label")
iface.launch()
|