gauthamk commited on
Commit
1d6ca53
·
1 Parent(s): bd591b6

add: gradio interface

Browse files
Files changed (3) hide show
  1. app.py +13 -0
  2. functions.py +47 -0
  3. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from functions import *
4
+
5
+ examples_dir = 'examples'
6
+ title = "Birds Classification - ResNet34 PyTorch"
7
+ examples = [os.path.join(examples_dir, i) for i in os.listdir('examples')]
8
+
9
+ interface = gr.Interface(fn=predict, inputs=gr.Image(type= 'numpy', shape=(64, 64)).style(height= 256),
10
+ outputs= gr.Label(num_top_classes= 5), cache_examples= False,
11
+ examples= examples, title= title)
12
+
13
+ interface.launch()
functions.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ import onnxruntime as rt
4
+
5
+ model_path = 'models/model.onnx'
6
+ idx_to_class = 'models/idx_to_class.json'
7
+
8
+ normalise_means = [0.4914, 0.4822, 0.4465]
9
+ normalise_stds = [0.2023, 0.1994, 0.2010]
10
+
11
+ def normalise_image(image):
12
+ image = image.copy()
13
+ for i in range(3):
14
+ image[:, i, :, :] = (image[:, i, :, :] - normalise_means[i]) / normalise_stds[i]
15
+ return image
16
+
17
+ def load_class_names():
18
+ with open(idx_to_class, 'r') as f:
19
+ class_names = json.load(f)
20
+ return class_names
21
+
22
+ def predict(inp_image):
23
+
24
+ class_names = load_class_names()
25
+
26
+ image = inp_image
27
+ image = image.transpose((2, 0, 1))
28
+
29
+ image = image / 255.0
30
+ image = np.expand_dims(image, axis=0)
31
+ image = normalise_image(image)
32
+ image = image.astype(np.float32)
33
+
34
+ sess = rt.InferenceSession(model_path)
35
+
36
+ input_name = sess.get_inputs()[0].name
37
+ output_name = sess.get_outputs()[0].name
38
+
39
+ output = sess.run([output_name], {input_name: image})[0]
40
+ prob = np.exp(output) / np.sum(np.exp(output), axis=1, keepdims=True)
41
+
42
+ top5 = np.argsort(prob[0])[-5:][::-1]
43
+
44
+ class_probs = {class_names[str(i)]: float(prob[0][i]) for i in top5}
45
+ print(class_probs)
46
+
47
+ return class_probs
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.16.1
2
+ json5==0.9.10
3
+ matplotlib==3.6.2
4
+ matplotlib-inline==0.1.6
5
+ numpy==1.23.5
6
+ onnx==1.13.0
7
+ onnxruntime==1.13.1
8
+ pandas==1.5.2
9
+ Pillow==9.3.0
10
+ torch==1.13.1
11
+ torchvision==0.14.1
12
+ tqdm==4.64.1