Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
·
3fa54be
1
Parent(s):
bc3d4e9
Multiple refactors
Browse files
app.py
CHANGED
@@ -3,8 +3,7 @@ import PIL
|
|
3 |
import torch
|
4 |
|
5 |
from utils import plot_img_no_mask, get_models
|
6 |
-
from
|
7 |
-
from model import get_model, predict, prepare_prediction, predict_class
|
8 |
|
9 |
DET_CKPT = 'efficientDet_icevision.ckpt'
|
10 |
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
|
|
3 |
import torch
|
4 |
|
5 |
from utils import plot_img_no_mask, get_models
|
6 |
+
from model import predict, prepare_prediction, predict_class
|
|
|
7 |
|
8 |
DET_CKPT = 'efficientDet_icevision.ckpt'
|
9 |
CLASS_CKPT = 'class_ViT_taco_7_class.pth'
|
model.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
from io import BytesIO
|
2 |
-
from typing import Union
|
3 |
from icevision import *
|
4 |
from icevision.models.checkpoint import model_from_checkpoint
|
5 |
from classifier import transform_image
|
6 |
from icevision.models import ross
|
7 |
|
8 |
-
import collections
|
9 |
import PIL
|
10 |
import torch
|
11 |
import numpy as np
|
@@ -13,44 +12,34 @@ import torchvision
|
|
13 |
|
14 |
MODEL_TYPE = ross.efficientdet
|
15 |
|
16 |
-
def
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
fixed_state_dict = collections.OrderedDict()
|
33 |
-
|
34 |
-
for k, v in ckpt['state_dict'].items():
|
35 |
-
new_k = k[6:]
|
36 |
-
fixed_state_dict[new_k] = v
|
37 |
-
|
38 |
-
return fixed_state_dict
|
39 |
-
|
40 |
-
def predict(model : object, image : Union[str, BytesIO], detection_threshold : float):
|
41 |
img = PIL.Image.open(image)
|
42 |
-
|
43 |
-
|
44 |
-
img = PIL.Image.fromarray(img)
|
45 |
class_map = ClassMap(classes=['Waste'])
|
46 |
transforms = tfms.A.Adapter([
|
47 |
*tfms.A.resize_and_pad(512),
|
48 |
tfms.A.Normalize()
|
49 |
])
|
50 |
-
|
|
|
51 |
pred_dict = MODEL_TYPE.end2end_detect(img,
|
52 |
transforms,
|
53 |
-
|
54 |
class_map=class_map,
|
55 |
detection_threshold=detection_threshold,
|
56 |
return_as_pil_img=False,
|
@@ -61,32 +50,67 @@ def predict(model : object, image : Union[str, BytesIO], detection_threshold : f
|
|
61 |
|
62 |
return pred_dict
|
63 |
|
64 |
-
def prepare_prediction(pred_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
|
66 |
boxes = torch.stack(boxes)
|
67 |
|
|
|
68 |
scores = torch.as_tensor(pred_dict['detection']['scores'])
|
69 |
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
|
|
|
70 |
image = np.array(pred_dict['img'])
|
71 |
|
72 |
-
|
|
|
|
|
73 |
boxes = boxes[fixed_boxes, :]
|
74 |
|
75 |
return boxes, image
|
76 |
|
77 |
-
def predict_class(classifier, image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
preds = []
|
79 |
|
80 |
for bbox in bboxes:
|
81 |
img = image.copy()
|
82 |
bbox = np.array(bbox).astype(int)
|
|
|
|
|
83 |
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
84 |
cropped_img = np.array(cropped_img)
|
85 |
|
|
|
86 |
tran_image = transform_image(cropped_img, 224)
|
|
|
87 |
tran_image = tran_image.transpose(2, 0, 1)
|
88 |
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
|
89 |
-
|
|
|
90 |
y_preds = classifier(tran_image)
|
91 |
preds.append(y_preds.softmax(1).detach().numpy())
|
92 |
|
|
|
1 |
from io import BytesIO
|
2 |
+
from typing import Dict, Tuple, Union
|
3 |
from icevision import *
|
4 |
from icevision.models.checkpoint import model_from_checkpoint
|
5 |
from classifier import transform_image
|
6 |
from icevision.models import ross
|
7 |
|
|
|
8 |
import PIL
|
9 |
import torch
|
10 |
import numpy as np
|
|
|
12 |
|
13 |
MODEL_TYPE = ross.efficientdet
|
14 |
|
15 |
+
def predict(det_model : torch.nn.Module, image : Union[str, BytesIO],
|
16 |
+
detection_threshold : float) -> Dict:
|
17 |
+
"""
|
18 |
+
Make a prediction with the detection model.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
det_model (torch.nn.Module): Detection model
|
22 |
+
image (Union[str, BytesIO]): Image filepath if the image is one of
|
23 |
+
the example images and BytesIO if the image is a custom image
|
24 |
+
uploaded by the user.
|
25 |
+
detection_threshold (float): Detection threshold
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Dict: Prediction dictionary.
|
29 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
img = PIL.Image.open(image)
|
31 |
+
|
32 |
+
# Class map and transforms
|
|
|
33 |
class_map = ClassMap(classes=['Waste'])
|
34 |
transforms = tfms.A.Adapter([
|
35 |
*tfms.A.resize_and_pad(512),
|
36 |
tfms.A.Normalize()
|
37 |
])
|
38 |
+
|
39 |
+
# Single prediction
|
40 |
pred_dict = MODEL_TYPE.end2end_detect(img,
|
41 |
transforms,
|
42 |
+
det_model,
|
43 |
class_map=class_map,
|
44 |
detection_threshold=detection_threshold,
|
45 |
return_as_pil_img=False,
|
|
|
50 |
|
51 |
return pred_dict
|
52 |
|
53 |
+
def prepare_prediction(pred_dict : Dict,
|
54 |
+
nms_threshold : str) -> Tuple[torch.Tensor, np.ndarray]:
|
55 |
+
"""
|
56 |
+
Get the predictions in a right format.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
pred_dict (Dict): Prediction dictionary.
|
60 |
+
nms_threshold (float): Threshold for the NMS postprocess.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Tuple: Tuple containing the following:
|
64 |
+
- (torch.Tensor): Bounding boxes
|
65 |
+
- (np.ndarray): Image data
|
66 |
+
"""
|
67 |
+
# Convert each box to a tensor and stack them into an unique tensor
|
68 |
boxes = [box.to_tensor() for box in pred_dict['detection']['bboxes']]
|
69 |
boxes = torch.stack(boxes)
|
70 |
|
71 |
+
# Get the scores and labels as tensor
|
72 |
scores = torch.as_tensor(pred_dict['detection']['scores'])
|
73 |
labels = torch.as_tensor(pred_dict['detection']['label_ids'])
|
74 |
+
|
75 |
image = np.array(pred_dict['img'])
|
76 |
|
77 |
+
# Apply NMS to postprocess the bounding boxes
|
78 |
+
fixed_boxes = torchvision.ops.batched_nms(boxes, scores,
|
79 |
+
labels,nms_threshold)
|
80 |
boxes = boxes[fixed_boxes, :]
|
81 |
|
82 |
return boxes, image
|
83 |
|
84 |
+
def predict_class(classifier : torch.nn.Module, image : np.ndarray,
|
85 |
+
bboxes : torch.Tensor) -> np.ndarray:
|
86 |
+
"""
|
87 |
+
Predict the class of each detected object.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
classifier (torch.nn.Module): Classifier model.
|
91 |
+
image (np.ndarray): Image data.
|
92 |
+
bboxes (torch.Tensor): Bounding boxes.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
np.ndarray: Array containing the predicted class for each object.
|
96 |
+
"""
|
97 |
preds = []
|
98 |
|
99 |
for bbox in bboxes:
|
100 |
img = image.copy()
|
101 |
bbox = np.array(bbox).astype(int)
|
102 |
+
|
103 |
+
# Get the bounding box content
|
104 |
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
105 |
cropped_img = np.array(cropped_img)
|
106 |
|
107 |
+
# Apply transformations to the cropped image
|
108 |
tran_image = transform_image(cropped_img, 224)
|
109 |
+
# Channels first
|
110 |
tran_image = tran_image.transpose(2, 0, 1)
|
111 |
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
|
112 |
+
|
113 |
+
# Make prediction
|
114 |
y_preds = classifier(tran_image)
|
115 |
preds.append(y_preds.softmax(1).detach().numpy())
|
116 |
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
icevision[all]
|
2 |
matplotlib
|
3 |
effdet
|
|
|
4 |
Pillow==8.4.0
|
|
|
1 |
icevision[all]
|
2 |
matplotlib
|
3 |
effdet
|
4 |
+
mmcv-full
|
5 |
Pillow==8.4.0
|
utils.py
CHANGED
@@ -4,8 +4,8 @@ import numpy as np
|
|
4 |
import cv2
|
5 |
import torch
|
6 |
|
|
|
7 |
from classifier import CustomViT
|
8 |
-
from model import get_model
|
9 |
|
10 |
def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels):
|
11 |
colors = {
|
@@ -67,7 +67,15 @@ def get_models(
|
|
67 |
- (torch.nn.Module): Classifier model
|
68 |
"""
|
69 |
print('Loading the detection model')
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
det_model.eval()
|
72 |
|
73 |
print('Loading the classifier model')
|
|
|
4 |
import cv2
|
5 |
import torch
|
6 |
|
7 |
+
from icevision.models.checkpoint import model_from_checkpoint
|
8 |
from classifier import CustomViT
|
|
|
9 |
|
10 |
def plot_img_no_mask(image : np.ndarray, boxes : torch.Tensor, labels):
|
11 |
colors = {
|
|
|
67 |
- (torch.nn.Module): Classifier model
|
68 |
"""
|
69 |
print('Loading the detection model')
|
70 |
+
checkpoint_and_model = model_from_checkpoint(
|
71 |
+
detection_ckpt,
|
72 |
+
model_name='ross.efficientdet',
|
73 |
+
backbone_name='d0',
|
74 |
+
img_size=512,
|
75 |
+
classes=['Waste'],
|
76 |
+
revise_keys=[(r'^model\.', '')])
|
77 |
+
|
78 |
+
det_model = checkpoint_and_model['model']
|
79 |
det_model.eval()
|
80 |
|
81 |
print('Loading the classifier model')
|