rman-rahimi-29 commited on
Commit
936c1ea
·
1 Parent(s): 8aeac52

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_effnetb2_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+
10
+ # setting up class names
11
+ with open("class_names.txt", "r") as f:
12
+ class_names = [food.strip() for food in f.readlines()]
13
+
14
+ ### model and transforms prepration ###
15
+ effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101,
16
+ seed=29)
17
+ # loading the saved weights
18
+ effnetb2.load_state_dict(
19
+ torch.load(
20
+ f="pretrained_effnetb2_feature_extractor_food101_20_percent.pth",
21
+ map_location=torch.device("cpu") # loading the model to cpu
22
+ )
23
+ )
24
+
25
+ ### predict function ###
26
+ def predict(img) -> Tuple[Dict, float]:
27
+ # start a timer
28
+ start_time = timer()
29
+
30
+ # transforming the input image
31
+ img = effnetb2_transforms(img).unsqueeze(0)
32
+
33
+ # putting the model into eval mode & making prediction
34
+ effnetb2.eval()
35
+ with torch.inference_mode():
36
+ # passing transformed img through the model and turn pred logits into probs
37
+ pred_probs = torch.softmax(effnetb2(img), dim=1)
38
+
39
+ # creating a prediction label & prediction probability dictionary
40
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
41
+
42
+ # calculate pred time
43
+ end_time = timer()
44
+ pred_time = round(end_time-start_time, 4)
45
+
46
+ # return pred dict and pred time
47
+ return pred_labels_and_probs, pred_time
48
+
49
+ ### gradio app ###
50
+ # creating title, description and article
51
+ title = "FoodVision Big"
52
+ description = "An EfficientNetB2 feature extractor computer vision model to classify images in 101 different classes!"
53
+
54
+ # creating an example list
55
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
56
+
57
+ # creating the gradio demo
58
+ demo = gr.Interface(fn=predict, # maps inputs to outputs
59
+ inputs=gr.Image(type="pil"),
60
+ outputs=[gr.Label(num_top_classes=5, label="Predictions"),
61
+ gr.Number(label="Prediction Time (s)")],
62
+ examples=example_list,
63
+ title=title,
64
+ description=description)
65
+
66
+ # launching the demo
67
+ demo.launch()