OpenPhenom / vit.py
Kian Kenyon-Dean
Reformat and add comments (#9)
560d738 unverified
raw
history blame
9.88 kB
# © Recursion Pharmaceuticals 2024
import timm.models.vision_transformer as vit
import torch
def generate_2d_sincos_pos_embeddings(
embedding_dim: int,
length: int,
scale: float = 10000.0,
use_class_token: bool = True,
num_modality: int = 1,
) -> torch.nn.Parameter:
"""
Generate 2Dimensional sin/cosine positional embeddings
Parameters
----------
embedding_dim : int
embedding dimension used in vit
length : int
number of tokens along height or width of image after patching (assuming square)
scale : float
scale for sin/cos functions
use_class_token : bool
True - add zero vector to be added to class_token, False - no vector added
num_modality: number of modalities. If 0, a single modality is assumed.
Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
Returns
-------
positional_encoding : torch.Tensor
positional encoding to add to vit patch encodings
[num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
(w/ or w/o cls_token)
"""
linear_positions = torch.arange(length, dtype=torch.float32)
height_mesh, width_mesh = torch.meshgrid(
linear_positions, linear_positions, indexing="ij"
)
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
positional_weights = (
torch.arange(positional_dim, dtype=torch.float32) / positional_dim
)
positional_weights = 1.0 / (scale**positional_weights)
height_weights = torch.outer(height_mesh.flatten(), positional_weights)
width_weights = torch.outer(width_mesh.flatten(), positional_weights)
positional_encoding = torch.cat(
[
torch.sin(height_weights),
torch.cos(height_weights),
torch.sin(width_weights),
torch.cos(width_weights),
],
dim=1,
)[None, :, :]
# repeat positional encoding for multiple channel modalities
positional_encoding = positional_encoding.repeat(1, num_modality, 1)
if use_class_token:
class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
return positional_encoding
class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
def __init__(
self,
img_size: int,
patch_size: int,
embed_dim: int,
bias: bool = True,
) -> None:
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=1, # in_chans is used by self.proj, which we override anyway
embed_dim=embed_dim,
norm_layer=None,
flatten=False,
bias=bias,
)
# channel-agnostic MAE has a single projection for all chans
self.proj = torch.nn.Conv2d(
1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
in_chans = x.shape[1]
x = torch.stack(
[self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
) # single project for all chans
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
return x
class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
# rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
# TODO: upgrade timm to get access to register tokens
# if self.vit_backbone.reg_token is not None:
# to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
# MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
# this supports having CA-MAEs actually be channel-agnostic at inference time
if self.no_embed_class:
x = x + self.pos_embed[:, : x.shape[1]]
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + self.pos_embed[:, : x.shape[1]]
return self.pos_drop(x) # type: ignore[no-any-return]
def channel_agnostic_vit(
vit_backbone: vit.VisionTransformer, max_in_chans: int
) -> vit.VisionTransformer:
# replace patch embedding with channel-agnostic version
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
img_size=vit_backbone.patch_embed.img_size[0],
patch_size=vit_backbone.patch_embed.patch_size[0],
embed_dim=vit_backbone.embed_dim,
)
# replace positional embedding with channel-agnostic version
vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
embedding_dim=vit_backbone.embed_dim,
length=vit_backbone.patch_embed.grid_size[0],
use_class_token=vit_backbone.cls_token is not None,
num_modality=max_in_chans,
)
# change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
vit_backbone.__class__ = ChannelAgnosticViT
return vit_backbone
def sincos_positional_encoding_vit(
vit_backbone: vit.VisionTransformer, scale: float = 10000.0
) -> vit.VisionTransformer:
"""Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
Parameters
----------
vit_backbone : timm.models.vision_transformer.VisionTransformer
the constructed vision transformer from timm
scale : float (default 10000.0)
hyperparameter for sincos positional embeddings, recommend keeping at 10,000
Returns
-------
timm.models.vision_transformer.VisionTransformer
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
"""
# length: number of tokens along height or width of image after patching (assuming square)
length = (
vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
)
pos_embeddings = generate_2d_sincos_pos_embeddings(
vit_backbone.embed_dim,
length=length,
scale=scale,
use_class_token=vit_backbone.cls_token is not None,
)
# note, if the model had weight_init == 'skip', this might get overwritten
vit_backbone.pos_embed = pos_embeddings
return vit_backbone
def vit_small_patch16_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_small_patch16_224(**default_kwargs)
def vit_small_patch32_512(**kwargs):
default_kwargs = dict(
img_size=512,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_small_patch32_384(**default_kwargs)
def vit_base_patch8_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_base_patch8_224(**default_kwargs)
def vit_base_patch16_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_base_patch16_224(**default_kwargs)
def vit_base_patch32_512(**kwargs):
default_kwargs = dict(
img_size=512,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_base_patch32_384(**default_kwargs)
def vit_large_patch8_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
patch_size=8,
embed_dim=1024,
depth=24,
num_heads=16,
drop_path_rate=0.3,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.VisionTransformer(**default_kwargs)
def vit_large_patch16_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.3,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_large_patch16_384(**default_kwargs)