Spaces:
Running
Running
import time | |
from urllib.request import urlopen | |
import gradio as gr | |
import numpy as np | |
import onnxruntime as ort | |
import torch | |
from PIL import Image | |
from imagenet_classes import IMAGENET2012_CLASSES | |
def read_image(image: Image.Image): | |
image = image.convert("RGB") | |
img_numpy = np.array(image).astype(np.float32) | |
img_numpy = img_numpy.transpose(2, 0, 1) | |
img_numpy = np.expand_dims(img_numpy, axis=0) | |
return img_numpy | |
providers = ["CPUExecutionProvider"] | |
session = ort.InferenceSession("merged_model_compose.onnx", providers=providers) | |
input_name = session.get_inputs()[0].name | |
output_name = session.get_outputs()[0].name | |
def predict(img): | |
output = session.run([output_name], {input_name: read_image(img)}) | |
output = torch.from_numpy(output[0]) | |
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5) | |
im_classes = list(IMAGENET2012_CLASSES.values()) | |
class_names = [im_classes[i] for i in top5_class_indices[0]] | |
results = { | |
name: float(prob) for name, prob in zip(class_names, top5_probabilities[0]) | |
} | |
return results | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=5), | |
title="Image Classification with ONNX TensorRT", | |
description="Upload an image to classify it using the ONNX TensorRT model.", | |
) | |
if __name__ == "__main__": | |
iface.launch() | |