# © Recursion Pharmaceuticals 2024 | |
from typing import Dict | |
import timm.models.vision_transformer as vit | |
import torch | |
def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]: | |
"""This returns the prepped imagenet encoders from timm, not bad for microscopy data.""" | |
vit_backbones = [ | |
_make_vit(vit.vit_small_patch16_384), | |
_make_vit(vit.vit_base_patch16_384), | |
_make_vit(vit.vit_base_patch8_224), | |
_make_vit(vit.vit_large_patch16_384), | |
] | |
model_names = [ | |
"vit_small_patch16_384", | |
"vit_base_patch16_384", | |
"vit_base_patch8_224", | |
"vit_large_patch16_384", | |
] | |
imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones)) | |
return {name: model for name, model in zip(model_names, imagenet_encoders)} | |
def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule: | |
dummy_input = torch.testing.make_tensor( | |
(2, 6, 256, 256), | |
low=0, | |
high=255, | |
dtype=torch.uint8, | |
device=torch.device("cpu"), | |
) | |
encoder = torch.nn.Sequential( | |
Normalizer(), | |
torch.nn.LazyInstanceNorm2d( | |
affine=False, track_running_stats=False | |
), # this module performs self-standardization, very important | |
vit_backbone, | |
).to(device="cpu") | |
_ = encoder(dummy_input) # get those lazy modules built | |
return torch.jit.freeze(torch.jit.script(encoder.eval())) | |
def _make_vit(constructor): | |
return constructor( | |
pretrained=True, # download imagenet weights | |
img_size=256, # 256x256 crops | |
in_chans=6, # we expect 6-channel microscopy images | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
global_pool="avg", # minimal perf diff btwn "cls" and "avg" | |
) | |
class Normalizer(torch.nn.Module): | |
def forward(self, pixels: torch.Tensor) -> torch.Tensor: | |
pixels = pixels.float() | |
pixels /= 255.0 | |
return pixels | |