PPC-SAM / ppc_decoder.py
forSubAnony's picture
v1
57abc33
import torch
import torch.nn as nn
from typing import Tuple, Union
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
def build_sam_decoder_vit_h():
return _build_sam_decoder(
encoder_embed_dim=1280,
encoder_num_heads=16,
)
def build_sam_decoder_vit_l():
return _build_sam_decoder(
encoder_embed_dim=1024,
encoder_num_heads=16,
)
def build_sam_decoder_vit_b():
return _build_sam_decoder(
encoder_embed_dim=768,
encoder_num_heads=12,
)
sam_decoder_reg = {
"default": build_sam_decoder_vit_h,
"vit_h": build_sam_decoder_vit_h,
"vit_l": build_sam_decoder_vit_l,
"vit_b": build_sam_decoder_vit_b,
}
def _build_sam_decoder(
encoder_embed_dim,
encoder_num_heads,
):
image_size = 1024
vit_patch_size = 16
return ImageDecoderViT(
hidden_size=encoder_embed_dim,
img_size=image_size,
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
)
class ImageDecoderViT(nn.Module):
def __init__(
self,
in_channels: int = 3,
feature_size: int = 64,
hidden_size: int = 1280,
conv_block: bool = True,
res_block: bool = True,
norm_name: Union[Tuple, str] = "instance",
dropout_rate: float = 0.0,
spatial_dims: int = 2,
img_size: int = 1024,
patch_size: int = 16,
out_channels: int = 1,
num_heads: int = 12,
) -> None:
super().__init__()
if not (0 <= dropout_rate <= 1):
raise AssertionError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
raise AssertionError("hidden size should be divisible by num_heads.")
self.patch_size = patch_size
self.feat_size = (
img_size // self.patch_size,
img_size // self.patch_size
)
self.hidden_size = hidden_size
self.classification = False
self.encoder_low_res_mask = nn.Sequential(
UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=out_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
),
UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=feature_size,
out_channels=feature_size * 4,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
),
)
self.decoder_fuse = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=res_block,
)
self.encoder2 = UnetrPrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 2,
num_layer=2,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.encoder3 = UnetrPrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 4,
num_layer=1,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.encoder4 = UnetrPrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 8,
num_layer=0,
kernel_size=3,
stride=1,
upsample_kernel_size=2,
norm_name=norm_name,
conv_block=conv_block,
res_block=res_block,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=hidden_size,
out_channels=feature_size * 8,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=res_block,
)
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
self.proj_view_shape = list(self.feat_size) + [self.hidden_size]
def proj_feat(self, x):
new_view = [x.size(0)] + self.proj_view_shape
x = x.view(new_view)
x = x.permute(self.proj_axes).contiguous()
return x
def forward(self, x_img,hidden_states_out, low_res_mask):
enc1 = self.encoder1(x_img)
x2 = hidden_states_out[0]
enc2 = self.encoder2(self.proj_feat(x2))
x3 = hidden_states_out[1]
enc3 = self.encoder3(self.proj_feat(x3))
x4 = hidden_states_out[2]
enc4 = self.encoder4(self.proj_feat(x4))
dec4 = self.proj_feat(hidden_states_out[3])
dec3 = self.decoder5(dec4, enc4)
dec2 = self.decoder4(dec3, enc3)
if low_res_mask != None:
enc_mask = self.encoder_low_res_mask(low_res_mask)
fused_dec2 = torch.cat([dec2, enc_mask], dim=1)
fused_dec2 = self.decoder_fuse(fused_dec2)
dec1 = self.decoder3(fused_dec2, enc2)
else:
dec1 = self.decoder3(dec2, enc2)
out = self.decoder2(dec1, enc1)
return self.out(out)