medical
AleksanderObuchowski's picture
Add files using upload-large-folder tool
5ceacbc verified
raw
history blame
26.7 kB
import logging
import os
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from collections import OrderedDict
from einops import rearrange
from timm.models.layers import DropPath, trunc_normal_
# helper methods
from .registry import register_image_encoder
import mup.init
from mup import MuReadout, set_base_shapes
logger = logging.getLogger(__name__)
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class PreNorm(nn.Module):
def __init__(self, norm, fn, drop_path=None):
super().__init__()
self.norm = norm
self.fn = fn
self.drop_path = drop_path
def forward(self, x, *args, **kwargs):
shortcut = x
if self.norm != None:
x, size = self.fn(self.norm(x), *args, **kwargs)
else:
x, size = self.fn(x, *args, **kwargs)
if self.drop_path:
x = self.drop_path(x)
x = shortcut + x
return x, size
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.net = nn.Sequential(OrderedDict([
("fc1", nn.Linear(in_features, hidden_features)),
("act", act_layer()),
("fc2", nn.Linear(hidden_features, out_features))
]))
def forward(self, x, size):
return self.net(x), size
class DepthWiseConv2d(nn.Module):
def __init__(
self,
dim_in,
kernel_size,
padding,
stride,
bias=True,
):
super().__init__()
self.dw = nn.Conv2d(
dim_in, dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias=bias
)
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == H * W
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
size = (x.size(-2), x.size(-1))
x = x.flatten(2).transpose(1, 2)
return x, size
class ConvEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(
self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None,
pre_norm=True
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding
)
dim_norm = in_chans if pre_norm else embed_dim
self.norm = norm_layer(dim_norm) if norm_layer else None
self.pre_norm = pre_norm
def forward(self, x, size):
H, W = size
if len(x.size()) == 3:
if self.norm and self.pre_norm:
x = self.norm(x)
x = rearrange(
x, 'b (h w) c -> b c h w',
h=H, w=W
)
x = self.proj(x)
_, _, H, W = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
if self.norm and not self.pre_norm:
x = self.norm(x)
return x, (H, W)
class ChannelAttention(nn.Module):
def __init__(self, dim, base_dim, groups=8, base_groups=8, qkv_bias=True, dynamic_scale=True, standparam=True):
super().__init__()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.dynamic_scale = dynamic_scale
self.dim = dim
self.groups = groups
self.group_dim = dim // groups
self.base_dim = base_dim
self.base_groups = base_groups
self.base_group_dim = base_dim // base_groups
self.group_wm = self.group_dim / self.base_group_dim # Width multiplier for each group.
self.standparam = standparam
def forward(self, x, size):
B, N, C = x.shape
assert C == self.dim
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, groups, N, group_dim].
scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
# Change the scaling factor.
# Ref: examples/Transformer/model.py in muP.
# Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/18.
if self.standparam:
scale = N ** -0.5 if self.dynamic_scale else self.dim ** -0.5
else:
assert self.dynamic_scale # Currently only support dynamic scale.
scale = N ** -0.5
q = q * scale
attention = q.transpose(-1, -2) @ k
attention = attention.softmax(dim=-1)
if not self.standparam:
# Follow https://github.com/microsoft/mup/issues/18.
attention = attention / self.group_wm
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, size
class ChannelBlock(nn.Module):
def __init__(self, dim, base_dim, groups, base_groups, mlp_ratio=4., qkv_bias=True,
drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
conv_at_attn=True, conv_at_ffn=True, dynamic_scale=True, standparam=True):
super().__init__()
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
self.channel_attn = PreNorm(
norm_layer(dim),
ChannelAttention(dim, base_dim, groups=groups, base_groups=base_groups, qkv_bias=qkv_bias,
dynamic_scale=dynamic_scale, standparam=standparam),
drop_path
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
drop_path
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.channel_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
def window_partition(x, window_size: int):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
B = windows.shape[0] // (H * W // window_size // window_size)
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=True, standparam=True):
super().__init__()
self.window_size = window_size
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.base_dim = base_dim
self.base_num_heads = base_num_heads
base_head_dim = base_dim // base_num_heads
# Change the scaling factor.
# Ref: examples/Transformer/model.py in muP.
# Note: We consider backward compatiblity and follow https://github.com/microsoft/mup/issues/17.
if standparam:
scale = float(head_dim) ** -0.5
else:
# TODO: Here we ensure backward compatibility, which may not be optimal.
# We may add an argument called backward_comp. If it is set as False, we use
# float(head_dim) ** -1 * math.sqrt(attn_mult)
# as in the Transformer example in muP.
base_scale = float(base_head_dim) ** -0.5 # The same as scaling in standard parametrization.
head_wm = head_dim / base_head_dim # Width multiplier for each head.
scale = base_scale / head_wm
# scale_1 = (float(base_head_dim) ** 0.5) * (float(head_dim) ** -1) # Equivalent implementation as shown in the muP paper.
# assert np.isclose(scale, scale_1)
self.scale = scale
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, size):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x = window_partition(x, self.window_size)
x = x.view(-1, self.window_size * self.window_size, C)
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
# merge windows
x = x.view(
-1, self.window_size, self.window_size, C
)
x = window_reverse(x, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
return x, size
class SpatialBlock(nn.Module):
def __init__(self, dim, base_dim, num_heads, base_num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU,
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True, standparam=True):
super().__init__()
drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
self.window_attn = PreNorm(
norm_layer(dim),
WindowAttention(dim, base_dim, num_heads, base_num_heads, window_size, qkv_bias=qkv_bias,
standparam=standparam),
drop_path
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer),
drop_path
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.window_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
class DaViT(nn.Module):
""" DaViT: Dual-Attention Transformer
Args:
img_size (int | tuple(int)): Input image size. Default: 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of spatial and channel blocks in different stages. Default: (1, 1, 3, 1)
patch_size (tuple(int)): Patch sizes in different stages. Default: (7, 2, 2, 2)
patch_stride (tuple(int)): Patch strides in different stages. Default: (4, 2, 2, 2)
patch_padding (tuple(int)): Patch padding sizes in different stages. Default: (3, 0, 0, 0)
patch_prenorm (tuple(bool)): Use pre-normalization or not in different stages. Default: (False, False, False, False)
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
base_embed_dims (tuple(int)): Patch embedding dimension (base case for muP). Default: (64, 128, 192, 256)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
base_num_heads (tuple(int)): Number of attention heads in different layers (base case for muP). Default: (4, 8, 12, 16)
num_groups (tuple(int)): Number of groups in channel attention in different layers. Default: (3, 6, 12, 24)
base_num_groups (tuple(int)): Number of groups in channel attention in different layers (base case for muP). Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
enable_checkpoint (bool): If True, enabling checkpoint. Default: False
conv_at_attn (bool): If True, add convolution layer before attention. Default: True
conv_at_ffn (bool): If True, add convolution layer before ffn. Default: True
dynamic_scale (bool): If True, scale of channel attention is respect to the number of tokens. Default: True
standparam (bool): Use standard parametrization or mu-parametrization. Default: True (i.e., use standard paramerization)
"""
def __init__(
self,
img_size=224,
in_chans=3,
num_classes=1000,
depths=(1, 1, 3, 1),
patch_size=(7, 2, 2, 2),
patch_stride=(4, 2, 2, 2),
patch_padding=(3, 0, 0, 0),
patch_prenorm=(False, False, False, False),
embed_dims=(64, 128, 192, 256),
base_embed_dims=(64, 128, 192, 256),
num_heads=(3, 6, 12, 24),
base_num_heads=(3, 6, 12, 24),
num_groups=(3, 6, 12, 24),
base_num_groups=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
enable_checkpoint=False,
conv_at_attn=True,
conv_at_ffn=True,
dynamic_scale=True,
standparam=True
):
super().__init__()
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_groups = num_groups
self.num_stages = len(self.embed_dims)
self.enable_checkpoint = enable_checkpoint
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
num_stages = len(embed_dims)
self.img_size = img_size
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths) * 2)]
depth_offset = 0
convs = []
blocks = []
for i in range(num_stages):
conv_embed = ConvEmbed(
patch_size=patch_size[i],
stride=patch_stride[i],
padding=patch_padding[i],
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
norm_layer=norm_layer,
pre_norm=patch_prenorm[i]
)
convs.append(conv_embed)
logger.info(f'=> Depth offset in stage {i}: {depth_offset}')
block = MySequential(
*[
MySequential(OrderedDict([
(
'spatial_block', SpatialBlock(
embed_dims[i],
base_embed_dims[i],
num_heads[i],
base_num_heads[i],
window_size,
drop_path_rate=dpr[depth_offset + j * 2],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
standparam=standparam
)
),
(
'channel_block', ChannelBlock(
embed_dims[i],
base_embed_dims[i],
num_groups[i],
base_num_groups[i],
drop_path_rate=dpr[depth_offset + j * 2 + 1],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
dynamic_scale=dynamic_scale,
standparam=standparam
)
)
])) for j in range(depths[i])
]
)
blocks.append(block)
depth_offset += depths[i] * 2
self.convs = nn.ModuleList(convs)
self.blocks = nn.ModuleList(blocks)
self.norms = norm_layer(self.embed_dims[-1])
self.avgpool = nn.AdaptiveAvgPool1d(1)
if standparam:
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head = MuReadout(self.embed_dims[-1], num_classes,
readout_zero_init=True) # Follow examples/ResNet/resnet.py in muP.
if torch.cuda.is_available():
self.device = torch.device(type="cuda", index=0)
else:
self.device = torch.device(type="cpu")
def custom_init_weights(self, use_original_init=True):
self.use_original_init = use_original_init
logger.info('Custom init: {}'.format('original init' if self.use_original_init else 'muP init'))
self.apply(self._custom_init_weights)
@property
def dim_out(self):
return self.embed_dims[-1]
def _custom_init_weights(self, m):
# Customized initialization for weights.
if self.use_original_init:
# Original initialization.
# Note: This is not SP init. We do not implement SP init here.
custom_trunc_normal_ = trunc_normal_
custom_normal_ = nn.init.normal_
else:
# muP.
custom_trunc_normal_ = mup.init.trunc_normal_
custom_normal_ = mup.init.normal_
# These initializations will overwrite the existing inializations from the modules and adjusted by set_base_shapes().
if isinstance(m, MuReadout):
pass # Note: MuReadout is already zero initialized due to readout_zero_init=True.
elif isinstance(m, nn.Linear):
custom_trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
custom_normal_(m.weight, std=0.02)
for name, _ in m.named_parameters():
if name in ['bias']:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm): # Follow P24 Layernorm Weights and Biases.
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d): # Follow P24 Layernorm Weights and Biases.
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
def _try_remap_keys(self, pretrained_dict):
remap_keys = {
"conv_embeds": "convs",
"main_blocks": "blocks",
"0.cpe.0.proj": "spatial_block.conv1.fn.dw",
"0.attn": "spatial_block.window_attn.fn",
"0.cpe.1.proj": "spatial_block.conv2.fn.dw",
"0.mlp": "spatial_block.ffn.fn.net",
"1.cpe.0.proj": "channel_block.conv1.fn.dw",
"1.attn": "channel_block.channel_attn.fn",
"1.cpe.1.proj": "channel_block.conv2.fn.dw",
"1.mlp": "channel_block.ffn.fn.net",
"0.norm1": "spatial_block.window_attn.norm",
"0.norm2": "spatial_block.ffn.norm",
"1.norm1": "channel_block.channel_attn.norm",
"1.norm2": "channel_block.ffn.norm"
}
full_key_mappings = {}
for k in pretrained_dict.keys():
old_k = k
for remap_key in remap_keys.keys():
if remap_key in k:
logger.info(f'=> Repace {remap_key} with {remap_keys[remap_key]}')
k = k.replace(remap_key, remap_keys[remap_key])
full_key_mappings[old_k] = k
return full_key_mappings
def from_state_dict(self, pretrained_dict, pretrained_layers=[], verbose=True):
model_dict = self.state_dict()
stripped_key = lambda x: x[14:] if x.startswith('image_encoder.') else x
full_key_mappings = self._try_remap_keys(pretrained_dict)
pretrained_dict = {
stripped_key(full_key_mappings[k]): v.to(self.device) for k, v in pretrained_dict.items()
if stripped_key(full_key_mappings[k]) in model_dict.keys()
}
need_init_state_dict = {}
for k, v in pretrained_dict.items():
need_init = (
k.split('.')[0] in pretrained_layers
or pretrained_layers[0] == '*'
)
if need_init:
if verbose:
logger.info(f'=> init {k} from pretrained state dict')
need_init_state_dict[k] = v.to(self.device)
self.load_state_dict(need_init_state_dict, strict=False)
def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
logger.info(f'=> loading pretrained model {pretrained}')
pretrained_dict = torch.load(pretrained, map_location='cpu')
self.from_state_dict(pretrained_dict, pretrained_layers, verbose)
def forward_features(self, x):
input_size = (x.size(2), x.size(3))
for conv, block in zip(self.convs, self.blocks):
x, input_size = conv(x, input_size)
if self.enable_checkpoint:
x, input_size = checkpoint.checkpoint(block, x, input_size)
else:
x, input_size = block(x, input_size)
x = self.avgpool(x.transpose(1, 2))
x = torch.flatten(x, 1)
x = self.norms(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def create_encoder(config_encoder):
spec = config_encoder['SPEC']
standparam = spec.get('STANDPARAM', True)
if standparam:
# Dummy values for muP parameters.
base_embed_dims = spec['DIM_EMBED']
base_num_heads = spec['NUM_HEADS']
base_num_groups = spec['NUM_GROUPS']
else:
base_embed_dims = spec['BASE_DIM_EMBED']
base_num_heads = spec['BASE_NUM_HEADS']
base_num_groups = spec['BASE_NUM_GROUPS']
davit = DaViT(
num_classes=config_encoder['NUM_CLASSES'],
depths=spec['DEPTHS'],
embed_dims=spec['DIM_EMBED'],
base_embed_dims=base_embed_dims,
num_heads=spec['NUM_HEADS'],
base_num_heads=base_num_heads,
num_groups=spec['NUM_GROUPS'],
base_num_groups=base_num_groups,
patch_size=spec['PATCH_SIZE'],
patch_stride=spec['PATCH_STRIDE'],
patch_padding=spec['PATCH_PADDING'],
patch_prenorm=spec['PATCH_PRENORM'],
drop_path_rate=spec['DROP_PATH_RATE'],
img_size=config_encoder['IMAGE_SIZE'],
window_size=spec.get('WINDOW_SIZE', 7),
enable_checkpoint=spec.get('ENABLE_CHECKPOINT', False),
conv_at_attn=spec.get('CONV_AT_ATTN', True),
conv_at_ffn=spec.get('CONV_AT_FFN', True),
dynamic_scale=spec.get('DYNAMIC_SCALE', True),
standparam=standparam,
)
return davit
def create_mup_encoder(config_encoder):
def gen_config(config, wm):
new_config = copy.deepcopy(config)
for name in ['DIM_EMBED', 'NUM_HEADS', 'NUM_GROUPS']:
base_name = 'BASE_' + name
new_values = [round(base_value * wm) for base_value in
config['SPEC'][base_name]] # New value = base value * width multiplier.
logger.info(f'config["SPEC"]["{name}"]: {new_config["SPEC"][name]} -> {new_values}')
new_config['SPEC'][name] = new_values
return new_config
logger.info('muP: Create models and set base shapes')
logger.info('=> Create model')
model = create_encoder(config_encoder)
logger.info('=> Create base model')
base_config = gen_config(config_encoder, wm=1.0)
base_model = create_encoder(base_config)
logger.info('=> Create delta model')
delta_config = gen_config(config_encoder, wm=2.0)
delta_model = create_encoder(delta_config)
logger.info('=> Set base shapes in model for training')
set_base_shapes(model, base=base_model, delta=delta_model)
return model
@register_image_encoder
def image_encoder(config_encoder, verbose, **kwargs):
spec = config_encoder['SPEC']
standparam = spec.get('STANDPARAM', True)
if standparam:
logger.info('Create model with standard parameterization')
model = create_encoder(config_encoder)
model.custom_init_weights(use_original_init=True)
else:
logger.info('Create model with mu parameterization')
model = create_mup_encoder(config_encoder)
model.custom_init_weights(use_original_init=False)
logger.info('Load model from pretrained checkpoint')
if config_encoder['LOAD_PRETRAINED']:
model.from_pretrained(
config_encoder['PRETRAINED'],
config_encoder['PRETRAINED_LAYERS'],
verbose
)
return model