Spaces:
Running
Running
File size: 2,009 Bytes
8d70b1b 9139749 8d70b1b 9139749 8d70b1b 9139749 40ca095 ac8d134 9139749 8d70b1b ac8d134 8d70b1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
import torch.nn.functional as F
from torch import optim
from torch.nn import Module
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import numpy as np
import onnxruntime
import gradio as gr
import json
def get_image(x):
return x.split(', ')[0]
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# Transform image to ToTensor
def transform_image(myarray):
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image = Image.fromarray(np.uint8(myarray)).convert('RGB')
image = transform(image).unsqueeze(0)
return image
f = open('imagenet_label.json',)
label_map=json.load(f)
f.close()
# Load list of images for similarity
sub_test_list = open('img_list.txt', 'r')
sub_test_list = [i.strip() for i in sub_test_list]
# Load images embedding for similarity
embeddings = torch.load('embeddings.pt')
# Configure
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
# Load model
PATH = 'model_onnx.onnx'
ort_session = onnxruntime.InferenceSession(PATH, sess_options=options)
input_name = ort_session.get_inputs()[0].name
# predict multi-level classification
def get_classification(img):
image_tensor = transform_image(img)
ort_inputs = {input_name: to_numpy(image_tensor)}
x = ort_session.run(None, ort_inputs)
predictions = torch.topk(torch.from_numpy(x[0]), k=5).indices.squeeze(0).tolist()
result = {}
for i in predictions:
label = label_map[str(i)]
prob = x[0][0, i].item()
result[label] = prob
return result
iface = gr.Interface(
get_classification,
gr.inputs.Image(shape=(200, 200)),
outputs="label",
title = 'Image Classification',
)
iface.launch()
|