import timm import torch.nn as nn import albumentations as A import torch import cv2 class CustomNormalization(A.ImageOnlyTransform): def _norm(self, img): return img / 255. def apply(self, img, **params): return self._norm(img) def transform_image(image, size): transforms = [ A.Resize(size, size, interpolation=cv2.INTER_NEAREST), CustomNormalization(p=1), ] augs = A.Compose(transforms) transformed = augs(image=image) return transformed['image'] class CustomEfficientNet(nn.Module): """ This class defines a custom EfficientNet network. Parameters ---------- target_size : int Number of units for the output layer. pretrained : bool Determine if pretrained weights are used. Attributes ---------- model : nn.Module EfficientNet model. """ def __init__(self, model_name : str = 'efficientnet_b0', target_size : int = 4, pretrained : bool = True): super().__init__() self.model = timm.create_model(model_name, pretrained=pretrained) # Modify the classifier layer in_features = self.model.classifier.in_features self.model.classifier = nn.Sequential( #nn.Dropout(0.5), nn.Linear(in_features, 256), nn.ReLU(), #nn.Dropout(0.5), nn.Linear(256, target_size) ) def forward(self, x : torch.Tensor) -> torch.Tensor: x = self.model(x) return x class CustomViT(nn.Module): """ This class defines a custom ViT network. Parameters ---------- target_size : int Number of units for the output layer. pretrained : bool Determine if pretrained weights are used. Attributes ---------- model : nn.Module CustomViT model. """ def __init__(self, model_name : str = 'vit_base_patch16_224', target_size : int = 4, pretrained : bool = True): super().__init__() self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=target_size) in_features = self.model.head.in_features self.model.head = nn.Sequential( #nn.Dropout(0.5), nn.Linear(in_features, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, target_size) ) def forward(self, x : torch.Tensor) -> torch.Tensor: x = self.model(x) return x