import timm | |
import torch | |
from torch import nn | |
class Model200M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = timm.create_model('convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384', | |
pretrained=False, | |
num_classes=0) | |
self.clf = nn.Sequential( | |
nn.Linear(1536, 128), | |
nn.ReLU(inplace=True), | |
nn.Linear(128, 2)) | |
def forward(self, image): | |
image_features = self.model(image) | |
return self.clf(image_features) |