Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from PIL import Image | |
from model import get_model, apply_weights, copy_weight | |
from vocab import vocab | |
from transforms import resized_crop_pad, gpu_crop | |
from torchvision.transforms import Normalize, ToTensor | |
model = get_model() | |
state = torch.load("./vit_saved.pth", map_location="cpu") | |
apply_weights(model, state, copy_weight) | |
to_tensor = ToTensor() | |
norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
def decode_pred(pred: torch.Tensor) -> str: | |
indices = pred > 0.95 | |
if indices.any(): | |
# return first match | |
return vocab[indices.nonzero()[0]] | |
else: | |
return "I don't know what this is, ¡páharo!" | |
def classify_image(inp): | |
inp = Image.fromarray(inp) | |
transformed_input = resized_crop_pad(inp, (460, 460)) | |
transformed_input = to_tensor(transformed_input).unsqueeze(0) | |
transformed_input = gpu_crop(transformed_input, (224, 224)) | |
transformed_input = norm(transformed_input) | |
model.eval() | |
with torch.no_grad(): | |
pred = model(transformed_input) | |
return decode_pred(torch.sigmoid(pred).squeeze(dim=0)) | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.inputs.Image(), | |
outputs="text", | |
title="Birds Classifier without Fastai", | |
description="A birds classifier over 200 species trained with Fastai" | |
" and deployed with plain pytorch in Gradio.", | |
).launch() | |