File size: 1,372 Bytes
c80d125 2b25469 c726874 2b25469 c80d125 c726874 c80d125 c912b6b c726874 c80d125 c726874 c912b6b c726874 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from PIL import Image
import gradio as gr
import torch
import torchvision.transforms as transforms
from model import *
title = "Garment Classifier"
description = "Trained on the Fashion MNIST dataset (28x28 pixels). The model expects images containing only one garment article as in the examples."
inputs = gr.components.Image()
outputs = gr.components.Label()
examples = "examples"
model = torch.load("model/fashion.mnist.base.pt", map_location=torch.device("cpu"))
# Images need to be transformed to the `Fashion MNIST` dataset format
# see https://arxiv.org/abs/1708.07747
transform = transforms.Compose(
[
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)), # Normalization
transforms.Lambda(lambda x: 1.0 - x), # Invert colors
transforms.Lambda(lambda x: x[0]),
transforms.Lambda(lambda x: x.unsqueeze(0)),
]
)
def predict(img):
img = transform(Image.fromarray(img))
predictions = model.predictions(img)
return predictions
with gr.Blocks() as demo:
with gr.Tab("Garment Prediction"):
gr.Interface(
fn=predict,
inputs=inputs,
outputs=outputs,
examples=examples,
description=description,
).queue(default_concurrency_limit=5)
demo.launch()
|