Spaces:
Runtime error
Runtime error
Hector Lopez
commited on
Commit
•
e5bb367
1
Parent(s):
8123f86
feature: Use ViT as classifier
Browse files- app.py +4 -4
- classifier.py +56 -4
- model.py +10 -4
app.py
CHANGED
@@ -5,14 +5,14 @@ import cv2
|
|
5 |
import PIL
|
6 |
import torch
|
7 |
|
8 |
-
from classifier import CustomEfficientNet
|
9 |
from model import get_model, predict, prepare_prediction, predict_class
|
10 |
|
11 |
print('Creating the model')
|
12 |
-
model = get_model('
|
13 |
print('Loading the classifier')
|
14 |
-
classifier =
|
15 |
-
classifier.load_state_dict(torch.load('
|
16 |
|
17 |
def plot_img_no_mask(image, boxes, labels):
|
18 |
colors = {
|
|
|
5 |
import PIL
|
6 |
import torch
|
7 |
|
8 |
+
from classifier import CustomEfficientNet, CustomViT
|
9 |
from model import get_model, predict, prepare_prediction, predict_class
|
10 |
|
11 |
print('Creating the model')
|
12 |
+
model = get_model('efficientDet_icevision.ckpt')
|
13 |
print('Loading the classifier')
|
14 |
+
classifier = CustomViT(target_size=7, pretrained=False)
|
15 |
+
classifier.load_state_dict(torch.load('class_ViT_taco_7_class.pth', map_location='cpu'))
|
16 |
|
17 |
def plot_img_no_mask(image, boxes, labels):
|
18 |
colors = {
|
classifier.py
CHANGED
@@ -1,12 +1,27 @@
|
|
1 |
import timm
|
2 |
import torch.nn as nn
|
3 |
-
|
4 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
|
9 |
-
return
|
10 |
|
11 |
class CustomEfficientNet(nn.Module):
|
12 |
"""
|
@@ -43,3 +58,40 @@ class CustomEfficientNet(nn.Module):
|
|
43 |
x = self.model(x)
|
44 |
|
45 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import timm
|
2 |
import torch.nn as nn
|
3 |
+
import albumentations as A
|
4 |
import torch
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
class CustomNormalization(A.ImageOnlyTransform):
|
8 |
+
def _norm(self, img):
|
9 |
+
return img / 255.
|
10 |
+
|
11 |
+
def apply(self, img, **params):
|
12 |
+
return self._norm(img)
|
13 |
+
|
14 |
+
def transform_image(image, size):
|
15 |
+
transforms = [
|
16 |
+
A.Resize(size, size,
|
17 |
+
interpolation=cv2.INTER_NEAREST),
|
18 |
+
CustomNormalization(p=1),
|
19 |
+
]
|
20 |
|
21 |
+
augs = A.Compose(transforms)
|
22 |
+
transformed = augs(image=image)
|
23 |
|
24 |
+
return transformed['image']
|
25 |
|
26 |
class CustomEfficientNet(nn.Module):
|
27 |
"""
|
|
|
58 |
x = self.model(x)
|
59 |
|
60 |
return x
|
61 |
+
|
62 |
+
class CustomViT(nn.Module):
|
63 |
+
"""
|
64 |
+
This class defines a custom ViT network.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
target_size : int
|
69 |
+
Number of units for the output layer.
|
70 |
+
pretrained : bool
|
71 |
+
Determine if pretrained weights are used.
|
72 |
+
|
73 |
+
Attributes
|
74 |
+
----------
|
75 |
+
model : nn.Module
|
76 |
+
CustomViT model.
|
77 |
+
"""
|
78 |
+
def __init__(self, model_name : str = 'vit_base_patch16_224',
|
79 |
+
target_size : int = 4, pretrained : bool = True):
|
80 |
+
super().__init__()
|
81 |
+
self.model = timm.create_model(model_name,
|
82 |
+
pretrained=pretrained,
|
83 |
+
num_classes=target_size)
|
84 |
+
|
85 |
+
in_features = self.model.head.in_features
|
86 |
+
self.model.head = nn.Sequential(
|
87 |
+
#nn.Dropout(0.5),
|
88 |
+
nn.Linear(in_features, 256),
|
89 |
+
nn.ReLU(),
|
90 |
+
nn.Dropout(0.5),
|
91 |
+
nn.Linear(256, target_size)
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, x : torch.Tensor) -> torch.Tensor:
|
95 |
+
x = self.model(x)
|
96 |
+
|
97 |
+
return x
|
model.py
CHANGED
@@ -6,6 +6,8 @@ import torch
|
|
6 |
import numpy as np
|
7 |
import torchvision
|
8 |
|
|
|
|
|
9 |
import icevision.models.ross.efficientdet
|
10 |
|
11 |
MODEL_TYPE = icevision.models.ross.efficientdet
|
@@ -81,10 +83,14 @@ def predict_class(model, image, bboxes):
|
|
81 |
img = image.copy()
|
82 |
bbox = np.array(bbox).astype(int)
|
83 |
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
84 |
-
cropped_img = np.array(cropped_img)
|
85 |
-
cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
preds.append(y_preds.softmax(1).detach().numpy())
|
89 |
|
90 |
preds = np.concatenate(preds).argmax(1)
|
|
|
6 |
import numpy as np
|
7 |
import torchvision
|
8 |
|
9 |
+
from classifier import transform_image
|
10 |
+
|
11 |
import icevision.models.ross.efficientdet
|
12 |
|
13 |
MODEL_TYPE = icevision.models.ross.efficientdet
|
|
|
83 |
img = image.copy()
|
84 |
bbox = np.array(bbox).astype(int)
|
85 |
cropped_img = PIL.Image.fromarray(img).crop(bbox)
|
86 |
+
cropped_img = np.array(cropped_img)
|
87 |
+
#cropped_img = torch.as_tensor(cropped_img, dtype=torch.float).unsqueeze(0)
|
88 |
+
|
89 |
+
tran_image = transform_image(cropped_img, 224)
|
90 |
+
tran_image = tran_image.transpose(2, 0, 1)
|
91 |
+
tran_image = torch.as_tensor(tran_image, dtype=torch.float).unsqueeze(0)
|
92 |
+
print(tran_image.shape)
|
93 |
+
y_preds = model(tran_image)
|
94 |
preds.append(y_preds.softmax(1).detach().numpy())
|
95 |
|
96 |
preds = np.concatenate(preds).argmax(1)
|