dnth's picture
Update app.py
cfd2a9d verified
import time
from urllib.request import urlopen
import gradio as gr
import numpy as np
import onnxruntime as ort
import torch
from PIL import Image
from imagenet_classes import IMAGENET2012_CLASSES
def read_image(image: Image.Image):
image = image.convert("RGB")
img_numpy = np.array(image).astype(np.float32)
img_numpy = img_numpy.transpose(2, 0, 1)
img_numpy = np.expand_dims(img_numpy, axis=0)
return img_numpy
providers = ["CPUExecutionProvider"]
session = ort.InferenceSession("merged_model_compose.onnx", providers=providers)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
def predict(img):
output = session.run([output_name], {input_name: read_image(img)})
output = torch.from_numpy(output[0])
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
im_classes = list(IMAGENET2012_CLASSES.values())
class_names = [im_classes[i] for i in top5_class_indices[0]]
results = {
name: float(prob) for name, prob in zip(class_names, top5_probabilities[0])
}
return results
# Add an example image
example_image = "beignets-task-guide.png"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="Image Classification with ONNX using EVA02 model",
description="Blog post: https://dicksonneoh.com/portfolio/supercharge_your_pytorch_image_models/",
examples=[example_image], # Add the example image to the interface
)
if __name__ == "__main__":
iface.launch()