Spaces:
Running
Running
import argparse | |
import os | |
import torch | |
import torch.nn.functional as F | |
import json | |
import monai.transforms as transforms | |
from model.segment_anything_volumetric import sam_model_registry | |
from model.network.model import SegVol | |
from model.data_process.demo_data_process import process_ct_gt | |
from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor | |
from model.utils.visualize import draw_result | |
import streamlit as st | |
def set_parse(): | |
# %% set up parser | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--test_mode", default=True, type=bool) | |
parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth') | |
parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap") | |
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple) | |
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple) | |
parser.add_argument('-work_dir', type=str, default='./work_dir') | |
### demo | |
parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip') | |
args = parser.parse_args() | |
return args | |
def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt): | |
image_single_resize = image_resize | |
image_single = image[0,0] | |
ori_shape = image_single.shape | |
resize_shape = image_single_resize.shape[2:] | |
# generate prompts | |
text_single = None if text_prompt is None else [text_prompt] | |
points_single = None | |
box_single = None | |
if args.use_point_prompt: | |
point, point_label = point_prompt | |
points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float()) | |
binary_points_resize = build_binary_points(point, point_label, resize_shape) | |
if args.use_box_prompt: | |
box_single = box_prompt.unsqueeze(0).float() | |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape) | |
#################### | |
# zoom-out inference: | |
print('--- zoom out inference ---') | |
print(text_single) | |
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]') | |
with torch.no_grad(): | |
logits_global_single = segvol_model(image_single_resize, | |
text=text_single, | |
boxes=box_single, | |
points=points_single) | |
# resize back global logits | |
logits_global_single = F.interpolate( | |
logits_global_single.cpu(), | |
size=ori_shape, mode='nearest')[0][0] | |
# build prompt reflection for zoom-in | |
if args.use_point_prompt: | |
binary_points = F.interpolate( | |
binary_points_resize.unsqueeze(0).unsqueeze(0).float(), | |
size=ori_shape, mode='nearest')[0][0] | |
if args.use_box_prompt: | |
binary_cube = F.interpolate( | |
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(), | |
size=ori_shape, mode='nearest')[0][0] | |
# draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single) | |
if not args.use_zoom_in: | |
return logits_global_single | |
#################### | |
# zoom-in inference: | |
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single) | |
if min_d is None: | |
print('Fail to detect foreground!') | |
return logits_global_single | |
# Crop roi | |
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0) | |
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long() | |
assert not (args.use_box_prompt and args.use_point_prompt) | |
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0) | |
prompt_reflection = None | |
if args.use_box_prompt: | |
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] | |
prompt_reflection = ( | |
binary_cube_cropped.unsqueeze(0).unsqueeze(0), | |
global_preds.unsqueeze(0).unsqueeze(0) | |
) | |
if args.use_point_prompt: | |
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] | |
prompt_reflection = ( | |
binary_points_cropped.unsqueeze(0).unsqueeze(0), | |
global_preds.unsqueeze(0).unsqueeze(0) | |
) | |
## inference | |
with torch.no_grad(): | |
logits_single_cropped = sliding_window_inference( | |
image_single_cropped, prompt_reflection, | |
args.spatial_size, 1, segvol_model, args.infer_overlap, | |
text=text_single, | |
use_box=args.use_box_prompt, | |
use_point=args.use_point_prompt, | |
logits_global_single=logits_global_single, | |
) | |
logits_single_cropped = logits_single_cropped.cpu().squeeze() | |
if logits_single_cropped.shape != logits_global_single.shape: | |
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped | |
return logits_global_single | |
def build_model(): | |
# build model | |
st.write('building model') | |
clip_ckpt = 'model/config/clip' | |
resume = 'SegVol_v1.pth' | |
sam_model = sam_model_registry['vit']() | |
segvol_model = SegVol( | |
image_encoder=sam_model.image_encoder, | |
mask_decoder=sam_model.mask_decoder, | |
prompt_encoder=sam_model.prompt_encoder, | |
clip_ckpt=clip_ckpt, | |
roi_size=(32,256,256), | |
patch_size=(4,16,16), | |
test_mode=True, | |
) | |
segvol_model = torch.nn.DataParallel(segvol_model) | |
segvol_model.eval() | |
# load param | |
if os.path.isfile(resume): | |
## Map model to be loaded to specified single GPU | |
loc = 'cpu' | |
checkpoint = torch.load(resume, map_location=loc) | |
segvol_model.load_state_dict(checkpoint['model'], strict=True) | |
print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch'])) | |
print('model build done!') | |
return segvol_model | |
def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt): | |
# seg config | |
args = set_parse() | |
args.use_zoom_in = True | |
args.use_text_prompt = text_prompt is not None | |
args.use_box_prompt = _box_prompt is not None | |
args.use_point_prompt = _point_prompt is not None | |
segvol_model = build_model() | |
# run inference | |
logits = zoom_in_zoom_out( | |
args, segvol_model, | |
_image.unsqueeze(0), _image_zoom_out.unsqueeze(0), | |
text_prompt, _point_prompt, _box_prompt) | |
print(logits.shape) | |
resize_transform = transforms.Compose([ | |
transforms.AddChannel(), | |
transforms.Resize((325,325,325), mode='trilinear') | |
] | |
) | |
logits = resize_transform(logits)[0] | |
print(logits.shape) | |
return (torch.sigmoid(logits) > 0.5).int().numpy() | |