import torch import torch.nn as nn from typing import Tuple, Union from monai.networks.blocks.dynunet_block import UnetOutBlock from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock def build_sam_decoder_vit_h(): return _build_sam_decoder( encoder_embed_dim=1280, encoder_num_heads=16, ) def build_sam_decoder_vit_l(): return _build_sam_decoder( encoder_embed_dim=1024, encoder_num_heads=16, ) def build_sam_decoder_vit_b(): return _build_sam_decoder( encoder_embed_dim=768, encoder_num_heads=12, ) sam_decoder_reg = { "default": build_sam_decoder_vit_h, "vit_h": build_sam_decoder_vit_h, "vit_l": build_sam_decoder_vit_l, "vit_b": build_sam_decoder_vit_b, } def _build_sam_decoder( encoder_embed_dim, encoder_num_heads, ): image_size = 1024 vit_patch_size = 16 return ImageDecoderViT( hidden_size=encoder_embed_dim, img_size=image_size, num_heads=encoder_num_heads, patch_size=vit_patch_size, ) class ImageDecoderViT(nn.Module): def __init__( self, in_channels: int = 3, feature_size: int = 64, hidden_size: int = 1280, conv_block: bool = True, res_block: bool = True, norm_name: Union[Tuple, str] = "instance", dropout_rate: float = 0.0, spatial_dims: int = 2, img_size: int = 1024, patch_size: int = 16, out_channels: int = 1, num_heads: int = 12, ) -> None: super().__init__() if not (0 <= dropout_rate <= 1): raise AssertionError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise AssertionError("hidden size should be divisible by num_heads.") self.patch_size = patch_size self.feat_size = ( img_size // self.patch_size, img_size // self.patch_size ) self.hidden_size = hidden_size self.classification = False self.encoder_low_res_mask = nn.Sequential( UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=res_block, ), UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=feature_size, out_channels=feature_size * 4, kernel_size=3, stride=1, norm_name=norm_name, res_block=res_block, ), ) self.decoder_fuse = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, stride=1, norm_name=norm_name, res_block=res_block, ) self.encoder1 = UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name=norm_name, res_block=res_block, ) self.encoder2 = UnetrPrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 2, num_layer=2, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name=norm_name, conv_block=conv_block, res_block=res_block, ) self.encoder3 = UnetrPrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 4, num_layer=1, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name=norm_name, conv_block=conv_block, res_block=res_block, ) self.encoder4 = UnetrPrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 8, num_layer=0, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name=norm_name, conv_block=conv_block, res_block=res_block, ) self.decoder5 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=hidden_size, out_channels=feature_size * 8, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder4 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=feature_size * 8, out_channels=feature_size * 4, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder3 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=feature_size * 4, out_channels=feature_size * 2, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.decoder2 = UnetrUpBlock( spatial_dims=spatial_dims, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, upsample_kernel_size=2, norm_name=norm_name, res_block=res_block, ) self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) self.proj_view_shape = list(self.feat_size) + [self.hidden_size] def proj_feat(self, x): new_view = [x.size(0)] + self.proj_view_shape x = x.view(new_view) x = x.permute(self.proj_axes).contiguous() return x def forward(self, x_img,hidden_states_out, low_res_mask): enc1 = self.encoder1(x_img) x2 = hidden_states_out[0] enc2 = self.encoder2(self.proj_feat(x2)) x3 = hidden_states_out[1] enc3 = self.encoder3(self.proj_feat(x3)) x4 = hidden_states_out[2] enc4 = self.encoder4(self.proj_feat(x4)) dec4 = self.proj_feat(hidden_states_out[3]) dec3 = self.decoder5(dec4, enc4) dec2 = self.decoder4(dec3, enc3) if low_res_mask != None: enc_mask = self.encoder_low_res_mask(low_res_mask) fused_dec2 = torch.cat([dec2, enc_mask], dim=1) fused_dec2 = self.decoder_fuse(fused_dec2) dec1 = self.decoder3(fused_dec2, enc2) else: dec1 = self.decoder3(dec2, enc2) out = self.decoder2(dec1, enc1) return self.out(out)