Spaces:
Running
Running
File size: 7,202 Bytes
9850295 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import argparse
import os
import torch
import torch.nn.functional as F
import json
import monai.transforms as transforms
from model.segment_anything_volumetric import sam_model_registry
from model.network.model import SegVol
from model.data_process.demo_data_process import process_ct_gt
from model.utils.monai_inferers_utils import sliding_window_inference, generate_box, select_points, build_binary_cube, build_binary_points, logits2roi_coor
from model.utils.visualize import draw_result
import streamlit as st
def set_parse():
# %% set up parser
parser = argparse.ArgumentParser()
parser.add_argument("--test_mode", default=True, type=bool)
parser.add_argument("--resume", type = str, default = 'SegVol_v1.pth')
parser.add_argument("-infer_overlap", default=0.0, type=float, help="sliding window inference overlap")
parser.add_argument("-spatial_size", default=(32, 256, 256), type=tuple)
parser.add_argument("-patch_size", default=(4, 16, 16), type=tuple)
parser.add_argument('-work_dir', type=str, default='./work_dir')
### demo
parser.add_argument("--clip_ckpt", type = str, default = 'model/config/clip')
args = parser.parse_args()
return args
def zoom_in_zoom_out(args, segvol_model, image, image_resize, text_prompt, point_prompt, box_prompt):
image_single_resize = image_resize
image_single = image[0,0]
ori_shape = image_single.shape
resize_shape = image_single_resize.shape[2:]
# generate prompts
text_single = None if text_prompt is None else [text_prompt]
points_single = None
box_single = None
if args.use_point_prompt:
point, point_label = point_prompt
points_single = (point.unsqueeze(0).float(), point_label.unsqueeze(0).float())
binary_points_resize = build_binary_points(point, point_label, resize_shape)
if args.use_box_prompt:
box_single = box_prompt.unsqueeze(0).float()
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=resize_shape)
####################
# zoom-out inference:
print('--- zoom out inference ---')
print(text_single)
print(f'use text-prompt [{text_single!=None}], use box-prompt [{box_single!=None}], use point-prompt [{points_single!=None}]')
with torch.no_grad():
logits_global_single = segvol_model(image_single_resize,
text=text_single,
boxes=box_single,
points=points_single)
# resize back global logits
logits_global_single = F.interpolate(
logits_global_single.cpu(),
size=ori_shape, mode='nearest')[0][0]
# build prompt reflection for zoom-in
if args.use_point_prompt:
binary_points = F.interpolate(
binary_points_resize.unsqueeze(0).unsqueeze(0).float(),
size=ori_shape, mode='nearest')[0][0]
if args.use_box_prompt:
binary_cube = F.interpolate(
binary_cube_resize.unsqueeze(0).unsqueeze(0).float(),
size=ori_shape, mode='nearest')[0][0]
# draw_result('unknow', image_single_resize, None, point_prompt, logits_global_single, logits_global_single)
if not args.use_zoom_in:
return logits_global_single
####################
# zoom-in inference:
min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(args.spatial_size, logits_global_single)
if min_d is None:
print('Fail to detect foreground!')
return logits_global_single
# Crop roi
image_single_cropped = image_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
global_preds = (torch.sigmoid(logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
assert not (args.use_box_prompt and args.use_point_prompt)
# label_single_cropped = label_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1].unsqueeze(0).unsqueeze(0)
prompt_reflection = None
if args.use_box_prompt:
binary_cube_cropped = binary_cube[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
prompt_reflection = (
binary_cube_cropped.unsqueeze(0).unsqueeze(0),
global_preds.unsqueeze(0).unsqueeze(0)
)
if args.use_point_prompt:
binary_points_cropped = binary_points[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
prompt_reflection = (
binary_points_cropped.unsqueeze(0).unsqueeze(0),
global_preds.unsqueeze(0).unsqueeze(0)
)
## inference
with torch.no_grad():
logits_single_cropped = sliding_window_inference(
image_single_cropped, prompt_reflection,
args.spatial_size, 1, segvol_model, args.infer_overlap,
text=text_single,
use_box=args.use_box_prompt,
use_point=args.use_point_prompt,
logits_global_single=logits_global_single,
)
logits_single_cropped = logits_single_cropped.cpu().squeeze()
if logits_single_cropped.shape != logits_global_single.shape:
logits_global_single[min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
return logits_global_single
@st.cache_resource
def build_model():
# build model
st.write('building model')
clip_ckpt = 'model/config/clip'
resume = 'SegVol_v1.pth'
sam_model = sam_model_registry['vit']()
segvol_model = SegVol(
image_encoder=sam_model.image_encoder,
mask_decoder=sam_model.mask_decoder,
prompt_encoder=sam_model.prompt_encoder,
clip_ckpt=clip_ckpt,
roi_size=(32,256,256),
patch_size=(4,16,16),
test_mode=True,
)
segvol_model = torch.nn.DataParallel(segvol_model)
segvol_model.eval()
# load param
if os.path.isfile(resume):
## Map model to be loaded to specified single GPU
loc = 'cpu'
checkpoint = torch.load(resume, map_location=loc)
segvol_model.load_state_dict(checkpoint['model'], strict=True)
print("loaded checkpoint '{}' (epoch {})".format(resume, checkpoint['epoch']))
print('model build done!')
return segvol_model
@st.cache_data
def inference_case(_image, _image_zoom_out, _point_prompt, text_prompt, _box_prompt):
# seg config
args = set_parse()
args.use_zoom_in = True
args.use_text_prompt = text_prompt is not None
args.use_box_prompt = _box_prompt is not None
args.use_point_prompt = _point_prompt is not None
segvol_model = build_model()
# run inference
logits = zoom_in_zoom_out(
args, segvol_model,
_image.unsqueeze(0), _image_zoom_out.unsqueeze(0),
text_prompt, _point_prompt, _box_prompt)
print(logits.shape)
resize_transform = transforms.Compose([
transforms.AddChannel(),
transforms.Resize((325,325,325), mode='trilinear')
]
)
logits = resize_transform(logits)[0]
print(logits.shape)
return (torch.sigmoid(logits) > 0.5).int().numpy()
|