|
import timm |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.core.mixins import HyperparametersMixin |
|
|
|
|
|
class SyntheticModel(pl.LightningModule, HyperparametersMixin): |
|
def __init__(self): |
|
super().__init__() |
|
self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', |
|
pretrained=False, |
|
num_classes=0) |
|
|
|
self.clf = nn.Sequential( |
|
nn.Linear(1536, 128), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(128, 2)) |
|
|
|
def forward(self, image): |
|
image_features = self.model(image) |
|
return self.clf(image_features) |