Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from torch import optim | |
from torch.nn import Module | |
from torchvision import models, transforms | |
from torchvision.datasets import ImageFolder | |
from PIL import Image | |
import numpy as np | |
import onnxruntime | |
import gradio as gr | |
import json | |
def get_image(x): | |
return x.split(', ')[0] | |
def to_numpy(tensor): | |
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() | |
# Transform image to ToTensor | |
def transform_image(myarray): | |
transform = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
]) | |
image = Image.fromarray(np.uint8(myarray)).convert('RGB') | |
image = transform(image).unsqueeze(0) | |
return image | |
f = open('imagenet_label.json',) | |
label_map=json.load(f) | |
f.close() | |
# Load list of images for similarity | |
sub_test_list = open('img_list.txt', 'r') | |
sub_test_list = [i.strip() for i in sub_test_list] | |
# Load images embedding for similarity | |
embeddings = torch.load('embeddings.pt') | |
# Configure | |
options = onnxruntime.SessionOptions() | |
options.intra_op_num_threads = 8 | |
options.inter_op_num_threads = 8 | |
# Load model | |
PATH = 'model_onnx.onnx' | |
ort_session = onnxruntime.InferenceSession(PATH, sess_options=options) | |
input_name = ort_session.get_inputs()[0].name | |
# predict multi-level classification | |
def get_classification(img): | |
image_tensor = transform_image(img) | |
ort_inputs = {input_name: to_numpy(image_tensor)} | |
x = ort_session.run(None, ort_inputs) | |
predictions = torch.topk(torch.from_numpy(x[0]), k=5).indices.squeeze(0).tolist() | |
result = {} | |
for i in predictions: | |
label = label_map[str(i)] | |
prob = x[0][0, i].item() | |
result[label] = prob | |
return result | |
iface = gr.Interface( | |
get_classification, | |
gr.inputs.Image(shape=(200, 200)), | |
outputs="label", | |
title = 'Image Classification', | |
) | |
iface.launch() | |