Spaces:
Runtime error
Runtime error
File size: 1,225 Bytes
1d6ca53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import numpy as np
import json
import onnxruntime as rt
model_path = 'models/model.onnx'
idx_to_class = 'models/idx_to_class.json'
normalise_means = [0.4914, 0.4822, 0.4465]
normalise_stds = [0.2023, 0.1994, 0.2010]
def normalise_image(image):
image = image.copy()
for i in range(3):
image[:, i, :, :] = (image[:, i, :, :] - normalise_means[i]) / normalise_stds[i]
return image
def load_class_names():
with open(idx_to_class, 'r') as f:
class_names = json.load(f)
return class_names
def predict(inp_image):
class_names = load_class_names()
image = inp_image
image = image.transpose((2, 0, 1))
image = image / 255.0
image = np.expand_dims(image, axis=0)
image = normalise_image(image)
image = image.astype(np.float32)
sess = rt.InferenceSession(model_path)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
output = sess.run([output_name], {input_name: image})[0]
prob = np.exp(output) / np.sum(np.exp(output), axis=1, keepdims=True)
top5 = np.argsort(prob[0])[-5:][::-1]
class_probs = {class_names[str(i)]: float(prob[0][i]) for i in top5}
print(class_probs)
return class_probs |