oliverlevn commited on
Commit
ade0106
·
1 Parent(s): 01ecd1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py CHANGED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import torch
3
+ import torchvision.transforms as T
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import os
7
+ import random
8
+ import PIL.Image as Image
9
+ import time
10
+ from model import create_fasterrcnn_model
11
+
12
+ categories = [
13
+ {
14
+ "id": 0,
15
+ "name": "creatures",
16
+ "supercategory": "none"
17
+ },
18
+ {
19
+ "id": 1,
20
+ "name": "fish",
21
+ "supercategory": "creatures"
22
+ },
23
+ {
24
+ "id": 2,
25
+ "name": "jellyfish",
26
+ "supercategory": "creatures"
27
+ },
28
+ {
29
+ "id": 3,
30
+ "name": "penguin",
31
+ "supercategory": "creatures"
32
+ },
33
+ {
34
+ "id": 4,
35
+ "name": "puffin",
36
+ "supercategory": "creatures"
37
+ },
38
+ {
39
+ "id": 5,
40
+ "name": "shark",
41
+ "supercategory": "creatures"
42
+ },
43
+ {
44
+ "id": 6,
45
+ "name": "starfish",
46
+ "supercategory": "creatures"
47
+ },
48
+ {
49
+ "id": 7,
50
+ "name": "stingray",
51
+ "supercategory": "creatures"
52
+ }
53
+ ]
54
+
55
+
56
+
57
+ # 1, Create title, description and article strings
58
+ title = "Ocean creatures detection Faster-R-CNN"
59
+ description = "A Faster-RCNN-ResNet-50 backbone feature extractor computer vision model to classify images of fish, penguin, shark, etc"
60
+
61
+ faster_rcnn = create_fasterrcnn_model(
62
+ num_classes=8, # len(class_names) would also work
63
+ )
64
+
65
+ # Load saved weights
66
+ faster_rcnn.load_state_dict(
67
+ torch.load(
68
+ f="./third_train.pth",
69
+ map_location=torch.device("cpu"), # load to CPU
70
+ )
71
+ )
72
+ import random
73
+ # Create predict function
74
+ def predict(img):
75
+ """Transforms and performs a prediction on img and returns prediction and time taken.
76
+ """
77
+ # Start the timer
78
+ start_time = time.time()
79
+ device = 'cpu'
80
+ transform = T.Compose([T.ToPILImage(), T.Resize(size = (768,1024)),T.ToTensor()])
81
+ image_tensor = transform(img).to(device)
82
+ image_tensor = image_tensor.unsqueeze(0)
83
+ faster_rcnn.eval()
84
+ with torch.no_grad():
85
+ predictions = faster_rcnn(image_tensor)
86
+ pred_boxes = predictions[0]['boxes'].cpu().numpy()
87
+ pred_scores = predictions[0]['scores'].cpu().numpy()
88
+ pred_labels = predictions[0]['labels'].cpu().numpy()
89
+ label_names = [categories[label]['name'] for label in pred_labels]
90
+ fig, ax = plt.subplots(1)
91
+ ax.imshow(img)
92
+ for box, score, label_name in zip(pred_boxes, pred_scores, label_names):
93
+ if score > 0.5:
94
+ x1, y1, x2, y2 = box
95
+ w, h = x2 - x1, y2 - y1
96
+ rect = plt.Rectangle((x1, y1), w, h, fill=False, edgecolor='red', linewidth=2)
97
+ ax.add_patch(rect)
98
+ ax.text(x1, y1, f'{label_name}: {score:.2f}', fontsize=5, color='white', bbox=dict(facecolor='red', alpha=0.2))
99
+ # save the figure to an image file
100
+ random_name = str(random.randint(0,99))
101
+ img_path = f"./{random_name}.png"
102
+ fig.savefig(img_path)
103
+ # convert the figure to an image
104
+ fig.canvas.draw()
105
+
106
+ # Calculate the prediction time
107
+ pred_time = round(time.time() - start_time, 5)
108
+
109
+ # return the predicted label, the path to the saved image, and the prediction time
110
+ return img_path, str(pred_time)
111
+
112
+
113
+ ### 4. Gradio app ###
114
+
115
+ # Get a list of all image file paths in the folder
116
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
117
+
118
+
119
+ # Create the Gradio demo
120
+ demo = gradio.Interface(fn=predict, # mapping function from input to output
121
+ inputs=gradio.Image(type= "numpy"), # what are the inputs?
122
+ outputs=[gradio.outputs.Image(type= "filepath", label="Image with Bounding Boxes"),
123
+ gradio.outputs.Label(type="auto", label="Prediction Time", formatter=format_time_output)], # our fn has two outputs
124
+ # Create examples list from "examples/" directory
125
+ examples=example_list,
126
+ title=title,
127
+ description=description)
128
+
129
+ # Launch the demo!
130
+ demo.launch(debug =True)