Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from typing import Tuple, Union | |
from monai.networks.blocks.dynunet_block import UnetOutBlock | |
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock | |
def build_sam_decoder_vit_h(): | |
return _build_sam_decoder( | |
encoder_embed_dim=1280, | |
encoder_num_heads=16, | |
) | |
def build_sam_decoder_vit_l(): | |
return _build_sam_decoder( | |
encoder_embed_dim=1024, | |
encoder_num_heads=16, | |
) | |
def build_sam_decoder_vit_b(): | |
return _build_sam_decoder( | |
encoder_embed_dim=768, | |
encoder_num_heads=12, | |
) | |
sam_decoder_reg = { | |
"default": build_sam_decoder_vit_h, | |
"vit_h": build_sam_decoder_vit_h, | |
"vit_l": build_sam_decoder_vit_l, | |
"vit_b": build_sam_decoder_vit_b, | |
} | |
def _build_sam_decoder( | |
encoder_embed_dim, | |
encoder_num_heads, | |
): | |
image_size = 1024 | |
vit_patch_size = 16 | |
return ImageDecoderViT( | |
hidden_size=encoder_embed_dim, | |
img_size=image_size, | |
num_heads=encoder_num_heads, | |
patch_size=vit_patch_size, | |
) | |
class ImageDecoderViT(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 3, | |
feature_size: int = 64, | |
hidden_size: int = 1280, | |
conv_block: bool = True, | |
res_block: bool = True, | |
norm_name: Union[Tuple, str] = "instance", | |
dropout_rate: float = 0.0, | |
spatial_dims: int = 2, | |
img_size: int = 1024, | |
patch_size: int = 16, | |
out_channels: int = 1, | |
num_heads: int = 12, | |
) -> None: | |
super().__init__() | |
if not (0 <= dropout_rate <= 1): | |
raise AssertionError("dropout_rate should be between 0 and 1.") | |
if hidden_size % num_heads != 0: | |
raise AssertionError("hidden size should be divisible by num_heads.") | |
self.patch_size = patch_size | |
self.feat_size = ( | |
img_size // self.patch_size, | |
img_size // self.patch_size | |
) | |
self.hidden_size = hidden_size | |
self.classification = False | |
self.encoder_low_res_mask = nn.Sequential( | |
UnetrBasicBlock( | |
spatial_dims=spatial_dims, | |
in_channels=out_channels, | |
out_channels=feature_size, | |
kernel_size=3, | |
stride=1, | |
norm_name=norm_name, | |
res_block=res_block, | |
), | |
UnetrBasicBlock( | |
spatial_dims=spatial_dims, | |
in_channels=feature_size, | |
out_channels=feature_size * 4, | |
kernel_size=3, | |
stride=1, | |
norm_name=norm_name, | |
res_block=res_block, | |
), | |
) | |
self.decoder_fuse = UnetrBasicBlock( | |
spatial_dims=spatial_dims, | |
in_channels=feature_size * 8, | |
out_channels=feature_size * 4, | |
kernel_size=3, | |
stride=1, | |
norm_name=norm_name, | |
res_block=res_block, | |
) | |
self.encoder1 = UnetrBasicBlock( | |
spatial_dims=spatial_dims, | |
in_channels=in_channels, | |
out_channels=feature_size, | |
kernel_size=3, | |
stride=1, | |
norm_name=norm_name, | |
res_block=res_block, | |
) | |
self.encoder2 = UnetrPrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=hidden_size, | |
out_channels=feature_size * 2, | |
num_layer=2, | |
kernel_size=3, | |
stride=1, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
conv_block=conv_block, | |
res_block=res_block, | |
) | |
self.encoder3 = UnetrPrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=hidden_size, | |
out_channels=feature_size * 4, | |
num_layer=1, | |
kernel_size=3, | |
stride=1, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
conv_block=conv_block, | |
res_block=res_block, | |
) | |
self.encoder4 = UnetrPrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=hidden_size, | |
out_channels=feature_size * 8, | |
num_layer=0, | |
kernel_size=3, | |
stride=1, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
conv_block=conv_block, | |
res_block=res_block, | |
) | |
self.decoder5 = UnetrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=hidden_size, | |
out_channels=feature_size * 8, | |
kernel_size=3, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
res_block=res_block, | |
) | |
self.decoder4 = UnetrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=feature_size * 8, | |
out_channels=feature_size * 4, | |
kernel_size=3, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
res_block=res_block, | |
) | |
self.decoder3 = UnetrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=feature_size * 4, | |
out_channels=feature_size * 2, | |
kernel_size=3, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
res_block=res_block, | |
) | |
self.decoder2 = UnetrUpBlock( | |
spatial_dims=spatial_dims, | |
in_channels=feature_size * 2, | |
out_channels=feature_size, | |
kernel_size=3, | |
upsample_kernel_size=2, | |
norm_name=norm_name, | |
res_block=res_block, | |
) | |
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) | |
self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) | |
self.proj_view_shape = list(self.feat_size) + [self.hidden_size] | |
def proj_feat(self, x): | |
new_view = [x.size(0)] + self.proj_view_shape | |
x = x.view(new_view) | |
x = x.permute(self.proj_axes).contiguous() | |
return x | |
def forward(self, x_img,hidden_states_out, low_res_mask): | |
enc1 = self.encoder1(x_img) | |
x2 = hidden_states_out[0] | |
enc2 = self.encoder2(self.proj_feat(x2)) | |
x3 = hidden_states_out[1] | |
enc3 = self.encoder3(self.proj_feat(x3)) | |
x4 = hidden_states_out[2] | |
enc4 = self.encoder4(self.proj_feat(x4)) | |
dec4 = self.proj_feat(hidden_states_out[3]) | |
dec3 = self.decoder5(dec4, enc4) | |
dec2 = self.decoder4(dec3, enc3) | |
if low_res_mask != None: | |
enc_mask = self.encoder_low_res_mask(low_res_mask) | |
fused_dec2 = torch.cat([dec2, enc_mask], dim=1) | |
fused_dec2 = self.decoder_fuse(fused_dec2) | |
dec1 = self.decoder3(fused_dec2, enc2) | |
else: | |
dec1 = self.decoder3(dec2, enc2) | |
out = self.decoder2(dec1, enc1) | |
return self.out(out) | |