Spaces:
Runtime error
Runtime error
File size: 2,621 Bytes
9fbf078 e5bb367 9fbf078 e5bb367 9fbf078 e5bb367 9fbf078 e5bb367 9fbf078 e5bb367 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import timm
import torch.nn as nn
import albumentations as A
import torch
import cv2
class CustomNormalization(A.ImageOnlyTransform):
def _norm(self, img):
return img / 255.
def apply(self, img, **params):
return self._norm(img)
def transform_image(image, size):
transforms = [
A.Resize(size, size,
interpolation=cv2.INTER_NEAREST),
CustomNormalization(p=1),
]
augs = A.Compose(transforms)
transformed = augs(image=image)
return transformed['image']
class CustomEfficientNet(nn.Module):
"""
This class defines a custom EfficientNet network.
Parameters
----------
target_size : int
Number of units for the output layer.
pretrained : bool
Determine if pretrained weights are used.
Attributes
----------
model : nn.Module
EfficientNet model.
"""
def __init__(self, model_name : str = 'efficientnet_b0',
target_size : int = 4, pretrained : bool = True):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained)
# Modify the classifier layer
in_features = self.model.classifier.in_features
self.model.classifier = nn.Sequential(
#nn.Dropout(0.5),
nn.Linear(in_features, 256),
nn.ReLU(),
#nn.Dropout(0.5),
nn.Linear(256, target_size)
)
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
class CustomViT(nn.Module):
"""
This class defines a custom ViT network.
Parameters
----------
target_size : int
Number of units for the output layer.
pretrained : bool
Determine if pretrained weights are used.
Attributes
----------
model : nn.Module
CustomViT model.
"""
def __init__(self, model_name : str = 'vit_base_patch16_224',
target_size : int = 4, pretrained : bool = True):
super().__init__()
self.model = timm.create_model(model_name,
pretrained=pretrained,
num_classes=target_size)
in_features = self.model.head.in_features
self.model.head = nn.Sequential(
#nn.Dropout(0.5),
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, target_size)
)
def forward(self, x : torch.Tensor) -> torch.Tensor:
x = self.model(x)
return x
|