Waste-Detector / classifier.py
Hector Lopez
feature: Use ViT as classifier
e5bb367
import timm
import torch.nn as nn
import albumentations as A
import torch
import cv2
class CustomNormalization(A.ImageOnlyTransform):
def _norm(self, img):
return img / 255.
def apply(self, img, **params):
return self._norm(img)
def transform_image(image, size):
transforms = [
A.Resize(size, size,
interpolation=cv2.INTER_NEAREST),
CustomNormalization(p=1),
]
augs = A.Compose(transforms)
transformed = augs(image=image)
return transformed['image']
class CustomEfficientNet(nn.Module):
"""
This class defines a custom EfficientNet network.
Parameters
----------
target_size : int
Number of units for the output layer.
pretrained : bool
Determine if pretrained weights are used.
Attributes
----------
model : nn.Module
EfficientNet model.
"""
def __init__(self, model_name : str = 'efficientnet_b0',
target_size : int = 4, pretrained : bool = True):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained)
# Modify the classifier layer
in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
#nn.Dropout(0.5),
nn.Linear(in_features, 256),
nn.ReLU(),
#nn.Dropout(0.5),
nn.Linear(256, target_size)
)
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
class CustomViT(nn.Module):
"""
This class defines a custom ViT network.
Parameters
----------
target_size : int
Number of units for the output layer.
pretrained : bool
Determine if pretrained weights are used.
Attributes
----------
model : nn.Module
CustomViT model.
"""
def __init__(self, model_name : str = 'vit_base_patch16_224',
target_size : int = 4, pretrained : bool = True):
super().__init__()
self.model = timm.create_model(model_name,
pretrained=pretrained,
num_classes=target_size)
in_features = self.model.head.in_features
self.model.head = nn.Sequential(
#nn.Dropout(0.5),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, target_size)
)
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x