Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig | |
#%% set up model | |
class SegVol(nn.Module): | |
def __init__(self, | |
image_encoder, | |
mask_decoder, | |
prompt_encoder, | |
clip_ckpt, | |
roi_size, | |
patch_size, | |
test_mode=False, | |
): | |
super().__init__() | |
self.image_encoder = image_encoder | |
self.mask_decoder = mask_decoder | |
self.prompt_encoder = prompt_encoder | |
self.text_encoder = TextEncoder(clip_ckpt) | |
self.feat_shape = np.array(roi_size)/np.array(patch_size) | |
self.test_mode = test_mode | |
def forward(self, image, text=None, boxes=None, points=None, **kwargs): | |
bs = image.shape[0] | |
img_shape = (image.shape[2], image.shape[3], image.shape[4]) | |
image_embedding, _ = self.image_encoder(image) | |
image_embedding = image_embedding.transpose(1, 2).view(bs, -1, | |
int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2])) | |
# test mode | |
if self.test_mode: | |
return self.forward_decoder(image_embedding, img_shape, text, boxes, points) | |
# train mode | |
# future release | |
def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None): | |
with torch.no_grad(): | |
if boxes is not None: | |
if len(boxes.shape) == 2: | |
boxes = boxes[:, None, :] # (B, 1, 6) | |
if text is not None: | |
text_embedding = self.text_encoder(text) # (B, 768) | |
else: | |
text_embedding = None | |
sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
points=points, | |
boxes=boxes, | |
masks=None, | |
text_embedding=text_embedding, | |
) | |
dense_pe = self.prompt_encoder.get_dense_pe() | |
low_res_masks, _ = self.mask_decoder( | |
image_embeddings=image_embedding, | |
text_embedding = text_embedding, | |
image_pe=dense_pe, | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
) | |
logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False) | |
return logits | |
class TextEncoder(nn.Module): | |
def __init__(self, clip_ckpt): | |
super().__init__() | |
config = CLIPTextConfig() | |
self.clip_text_model = CLIPTextModel(config) | |
self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt) | |
self.dim_align = nn.Linear(512, 768) | |
# freeze text encoder | |
for param in self.clip_text_model.parameters(): | |
param.requires_grad = False | |
def organ2tokens(self, organ_names): | |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names] | |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt") | |
return tokens | |
def forward(self, text): | |
if text is None: | |
return None | |
if type(text) is str: | |
text = [text] | |
tokens = self.organ2tokens(text) | |
clip_outputs = self.clip_text_model(**tokens) | |
text_embedding = clip_outputs.pooler_output | |
text_embedding = self.dim_align(text_embedding) | |
return text_embedding | |