dnth's picture
Rename 10_gradio_app.py to app.py
833c9a4 verified
raw
history blame
1.4 kB
import time
from urllib.request import urlopen
import gradio as gr
import numpy as np
import onnxruntime as ort
import torch
from PIL import Image
from imagenet_classes import IMAGENET2012_CLASSES
def read_image(image: Image.Image):
image = image.convert("RGB")
img_numpy = np.array(image).astype(np.float32)
img_numpy = img_numpy.transpose(2, 0, 1)
img_numpy = np.expand_dims(img_numpy, axis=0)
return img_numpy
providers = ["CPUExecutionProvider"]
session = ort.InferenceSession("merged_model_compose.onnx", providers=providers)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
def predict(img):
output = session.run([output_name], {input_name: read_image(img)})
output = torch.from_numpy(output[0])
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
im_classes = list(IMAGENET2012_CLASSES.values())
class_names = [im_classes[i] for i in top5_class_indices[0]]
results = {
name: float(prob) for name, prob in zip(class_names, top5_probabilities[0])
}
return results
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
title="Image Classification with ONNX TensorRT",
description="Upload an image to classify it using the ONNX TensorRT model.",
)
if __name__ == "__main__":
iface.launch()