|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from models import Net,NetConv |
|
|
|
net = torch.load('mnist.pth') |
|
net.eval() |
|
|
|
def predict(img): |
|
arr = np.array(img) / 255 |
|
arr = np.expand_dims(arr, axis=0) |
|
arr = torch.from_numpy(arr).float() |
|
output = net(arr) |
|
topk_values, topk_indices = torch.topk(output, 2) |
|
return [str(k) for k in topk_indices[0].tolist()] |
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# MNIST + Gradio End to End") |
|
gr.HTML("Shows end to end MNIST training with Gradio interface") |
|
with gr.Row(): |
|
with gr.Column(): |
|
sp = gr.Sketchpad(shape=(28, 28)) |
|
with gr.Row(): |
|
with gr.Column(): |
|
pred_button = gr.Button("Predict") |
|
with gr.Column(): |
|
clear = gr.Button("Clear") |
|
with gr.Column(): |
|
label1 = gr.Label(label='1st Pred') |
|
label2 = gr.Label(label='2nd Pred') |
|
|
|
pred_button.click(predict, inputs=sp, outputs=[label1,label2]) |
|
clear.click(lambda: None, None, sp, queue=False) |
|
iface.launch() |