BAAI
/

README.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6565b54a9bf6665f10f75441/no60wyvKDTD-WV3pCt2P5.jpeg)
3
+
4
+ Language: [EN / ZH]
5
+
6
+ The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts point, box, and text prompts while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
7
+
8
+ SegVol是用于体积医学图像分割的通用交互式模型,可以使用点,框和文本作为prompt驱动模型,输出分割结果。
9
+
10
+ 通过在90k个无标签CT和6k个有标签CT上进行训练,该基础模型支持对200多个解剖类别进行分割。
11
+
12
+ [**Paper**](https://arxiv.org/abs/2311.13385), [**Code**](https://github.com/BAAI-DCAI/SegVol) 和 [**Demo**](https://huggingface.co/spaces/BAAI/SegVol) 已发布。
13
+
14
+ **Keywords**: 3D medical SAM, volumetric image segmentation
15
+
16
+ ## Quicktart
17
+
18
+ ### Requirements
19
+ ```bash
20
+ conda create -n segvol_transformers python=3.8
21
+ conda activate segvol_transformers
22
+ ```
23
+ [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) or higher version is required. Please also install the following support packages:
24
+
25
+ 需要 [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) 或更高版本。另外请安装如下支持包:
26
+
27
+ ```bash
28
+ pip install 'monai[all]==0.9.0'
29
+ pip install einops==0.6.1
30
+ pip install transformers==4.18.0
31
+ pip install matplotlib
32
+ ```
33
+
34
+ ### Test script
35
+
36
+ ```python
37
+ from transformers import AutoModel, AutoTokenizer
38
+ import torch
39
+ import os
40
+
41
+ # get device
42
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43
+
44
+ # load model
45
+ clip_tokenizer = AutoTokenizer.from_pretrained("BAAI/SegVol")
46
+ model = AutoModel.from_pretrained("BAAI/SegVol", trust_remote_code=True, test_mode=True)
47
+ model.model.text_encoder.tokenizer = clip_tokenizer
48
+ model.eval()
49
+ model.to(device)
50
+ print('model load done')
51
+
52
+ # set case path
53
+ ct_path = 'path/to/Case_image_00001_0000.nii.gz'
54
+ gt_path = 'path/to/Case_label_00001.nii.gz'
55
+
56
+ # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
57
+ categories = ["liver", "kidney", "spleen", "pancreas"]
58
+
59
+ # generate npy data format
60
+ ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
61
+ # IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files
62
+
63
+ # go through zoom_transform to generate zoomout & zoomin views
64
+ data_item = model.processor.zoom_transform(ct_npy, gt_npy)
65
+
66
+ # add batch dim manually
67
+ data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
68
+ data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)
69
+
70
+ # take liver as the example
71
+ cls_idx = 0
72
+
73
+ # text prompt
74
+ text_prompt = [categories[cls_idx]]
75
+
76
+ # point prompt
77
+ point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim
78
+
79
+ # bbox prompt
80
+ bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim
81
+
82
+ print('prompt done')
83
+
84
+ # segvol test forward
85
+ # use_zoom: use zoom-out-zoom-in
86
+ # point_prompt_group: use point prompt
87
+ # bbox_prompt_group: use bbox prompt
88
+ # text_prompt: use text prompt
89
+ logits_mask = model.forward_test(image=data_item['image'],
90
+ zoomed_image=data_item['zoom_out_image'],
91
+ # point_prompt_group=[point_prompt, point_prompt_map],
92
+ bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
93
+ text_prompt=text_prompt,
94
+ use_zoom=True
95
+ )
96
+
97
+ # cal dice score
98
+ dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx], device)
99
+ print(dice)
100
+
101
+ # save prediction as nii.gz file
102
+ save_path='./Case_preds_00001.nii.gz'
103
+ model.processor.save_preds(ct_path, save_path, logits_mask[0][0],
104
+ start_coord=data_item['foreground_start_coord'],
105
+ end_coord=data_item['foreground_end_coord'])
106
+ print('done')
107
+ ```
108
+
109
+ ### Training script
110
+
111
+ ```python
112
+ from transformers import AutoModel, AutoTokenizer
113
+ import torch
114
+ import os
115
+
116
+ # get device
117
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
118
+
119
+ # load model
120
+ clip_tokenizer = AutoTokenizer.from_pretrained("BAAI/SegVol")
121
+ model = AutoModel.from_pretrained("BAAI/SegVol", trust_remote_code=True, test_mode=False)
122
+ model.model.text_encoder.tokenizer = clip_tokenizer
123
+ model.train()
124
+ model.to(device)
125
+ print('model load done')
126
+
127
+ # set case path
128
+ ct_path = 'path/to/Case_image_00001_0000.nii.gz'
129
+ gt_path = 'path/to/Case_label_00001.nii.gz'
130
+
131
+ # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
132
+ categories = ["liver", "kidney", "spleen", "pancreas"]
133
+
134
+ # generate npy data format
135
+ ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
136
+ # IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files
137
+
138
+ # go through train transform
139
+ data_item = model.processor.train_transform(ct_npy, gt_npy)
140
+
141
+ # training example
142
+ # add batch dim manually
143
+ image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device) # add batch dim
144
+
145
+ loss_step_avg = 0
146
+ for cls_idx in range(len(categories)):
147
+ # optimizer.zero_grad()
148
+ organs_cls = categories[cls_idx]
149
+ labels_cls = gt3D[:, cls_idx]
150
+ loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
151
+ loss_step_avg += loss.item()
152
+ loss.backward()
153
+ # optimizer.step()
154
+
155
+ loss_step_avg /= len(categories)
156
+ print(f'AVG loss {loss_step_avg}')
157
+
158
+ # save ckpt
159
+ model.save_pretrained('./ckpt')
160
+ ```
161
+
162
+ ### Start with M3D-Seg dataset
163
+
164
+ We have released 25 open source datasets(M3D-Seg) for training SegVol, and these preprocessed data have been uploaded to [ModelScope](https://www.modelscope.cn/datasets/GoodBaiBai88/M3D-Seg/summary) and [HuggingFace](https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg).
165
+ You can use the following script to easily load cases and insert them into Test script and Training script.
166
+
167
+ 我们已经发布了用于训练SegVol的25个开源数据集(M3D-Seg),并将预处理后的数据上传到了[ModelScope](https://www.modelscope.cn/datasets/GoodBaiBai88/M3D-Seg/summary)和[HuggingFace](https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg)。
168
+ 您可以使用下面的script方便地载入,并插入到Test script和Training script中。
169
+
170
+ ```python
171
+ import json, os
172
+ M3D_Seg_path = 'path/to/M3D-Seg'
173
+
174
+ # select a dataset
175
+ dataset_code = '0000'
176
+
177
+ # load json dict
178
+ json_path = os.path.join(M3D_Seg_path, dataset_code, dataset_code + '.json')
179
+ with open(json_path, 'r') as f:
180
+ dataset_dict = json.load(f)
181
+
182
+ # get a case
183
+ ct_path = os.path.join(M3D_Seg_path, dataset_dict['train'][0]['image'])
184
+ gt_path = os.path.join(M3D_Seg_path, dataset_dict['train'][0]['label'])
185
+
186
+ # get categories
187
+ categories_dict = dataset_dict['labels']
188
+ categories = [x for _, x in categories_dict.items() if x != "background"]
189
+
190
+ # load npy data format
191
+ ct_npy, gt_npy = model.processor.load_uniseg_case(ct_path, gt_path)
192
+ ```
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SegVolModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model_segvol_single.SegVolConfig",
7
+ "AutoModel": "model_segvol_single.SegVolModel"
8
+ },
9
+ "model_type": "segvol",
10
+ "patch_size": [
11
+ 4,
12
+ 16,
13
+ 16
14
+ ],
15
+ "spatial_size": [
16
+ 32,
17
+ 256,
18
+ 256
19
+ ],
20
+ "test_mode": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.18.0"
23
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model_segvol_single.py ADDED
@@ -0,0 +1,1951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import numpy as np
3
+ import monai.transforms as transforms
4
+ import nibabel as nib
5
+ from scipy import sparse
6
+ import ast
7
+
8
+ class SegVolConfig(PretrainedConfig):
9
+ model_type = "segvol"
10
+
11
+ def __init__(
12
+ self,
13
+ test_mode=True,
14
+ **kwargs,
15
+ ):
16
+ self.spatial_size = [32, 256, 256]
17
+ self.patch_size = [4, 16, 16]
18
+ self.test_mode = test_mode
19
+ super().__init__(**kwargs)
20
+
21
+ class SegVolModel(PreTrainedModel):
22
+ config_class = SegVolConfig
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ sam_model = _build_sam(
27
+ image_encoder_type='vit',
28
+ embed_dim = 768,
29
+ patch_size=self.config.patch_size,
30
+ checkpoint=None,
31
+ image_size=self.config.spatial_size,
32
+ )
33
+ self.model = SegVol(
34
+ image_encoder=sam_model.image_encoder,
35
+ mask_decoder=sam_model.mask_decoder,
36
+ prompt_encoder=sam_model.prompt_encoder,
37
+ roi_size=self.config.spatial_size,
38
+ patch_size=self.config.patch_size,
39
+ # clip_model=self.config.clip_model,
40
+ test_mode=self.config.test_mode,
41
+ )
42
+
43
+ self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
44
+
45
+ def forward_test(self,
46
+ image,
47
+ zoomed_image=None,
48
+ text_prompt=None,
49
+ bbox_prompt_group=None,
50
+ point_prompt_group=None,
51
+ use_zoom=True,):
52
+ device = image.device
53
+ assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
54
+ assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
55
+ bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
56
+ if bbox_prompt_group is not None:
57
+ bbox_prompt, bbox_prompt_map = bbox_prompt_group
58
+ if point_prompt_group is not None:
59
+ point_prompt, point_prompt_map = point_prompt_group
60
+ volume_shape = image[0][0].shape
61
+
62
+ with torch.no_grad():
63
+ logits_global_single = self.model(zoomed_image,
64
+ text=text_prompt,
65
+ boxes=bbox_prompt,
66
+ points=point_prompt)
67
+ logits_global_single = F.interpolate(
68
+ logits_global_single.cpu(),
69
+ size=volume_shape, mode='nearest')
70
+ if not use_zoom:
71
+ return logits_global_single
72
+
73
+ if point_prompt_map is not None:
74
+ binary_points = F.interpolate(
75
+ point_prompt_map.float(),
76
+ size=volume_shape, mode='nearest')
77
+ if bbox_prompt_map is not None:
78
+ binary_cube = F.interpolate(
79
+ bbox_prompt_map.float(),
80
+ size=volume_shape, mode='nearest')
81
+
82
+ min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(self.config.spatial_size, logits_global_single[0][0])
83
+ if min_d is None:
84
+ print('Fail to detect foreground!')
85
+ return logits_global_single
86
+
87
+ # Crop roi
88
+ image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
89
+ global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
90
+
91
+ assert not (bbox_prompt is not None and point_prompt is not None), 'Do not use point prompt and box prompt at the same time.'
92
+ prompt_reflection = None
93
+ if bbox_prompt is not None:
94
+ binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
95
+ prompt_reflection = (
96
+ binary_cube_cropped,
97
+ global_preds
98
+ )
99
+ if point_prompt is not None:
100
+ binary_points_cropped = binary_points[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
101
+ prompt_reflection = (
102
+ binary_points_cropped,
103
+ global_preds
104
+ )
105
+
106
+ ## inference
107
+ with torch.no_grad():
108
+ logits_single_cropped = sliding_window_inference(
109
+ image_single_cropped.to(device), prompt_reflection,
110
+ self.config.spatial_size, 1, self.model, 0.5,
111
+ text=text_prompt,
112
+ use_box=bbox_prompt is not None,
113
+ use_point=point_prompt is not None,
114
+ )
115
+ logits_single_cropped = logits_single_cropped.cpu().squeeze()
116
+ logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
117
+ return logits_global_single
118
+
119
+ def forward_train(self, image, train_organs, train_labels):
120
+ loss = self.model(image, text=None, boxes=None, points=None,
121
+ train_organs=train_organs,
122
+ train_labels=train_labels)
123
+ return loss
124
+
125
+ def forward(self, **kwargs):
126
+ if self.config.test_mode:
127
+ return self.forward_test(kwargs['image'],
128
+ kwargs['zoomed_image'],
129
+ kwargs['text_prompt'],
130
+ kwargs['bbox_prompt_group'],
131
+ kwargs['point_prompt_group'],
132
+ kwargs['use_zoom'])
133
+ else:
134
+ return self.forward_train(kwargs['image'],
135
+ kwargs['train_organs'],
136
+ kwargs['train_labels'])
137
+
138
+ # processor
139
+ class SegVolProcessor():
140
+ def __init__(self, spatial_size) -> None:
141
+ self.img_loader = transforms.LoadImage()
142
+ self.transform4test = transforms.Compose(
143
+ [
144
+ DimTranspose(keys=["image", "label"]),
145
+ MinMaxNormalization(),
146
+ transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
147
+ transforms.ToTensord(keys=["image", "label"]),
148
+ ]
149
+ )
150
+ self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
151
+ self.transform4train = transforms.Compose(
152
+ [
153
+ # transforms.AddChanneld(keys=["image"]),
154
+ DimTranspose(keys=["image", "label"]),
155
+ MinMaxNormalization(),
156
+ transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
157
+ transforms.SpatialPadd(keys=["image", "label"], spatial_size=spatial_size, mode='constant'),
158
+ transforms.OneOf(transforms=[
159
+ transforms.Resized(keys=["image", "label"],spatial_size=spatial_size),
160
+ transforms.RandCropByPosNegLabeld(
161
+ keys=["image", "label"],
162
+ label_key="label",
163
+ spatial_size=spatial_size,
164
+ pos=2,
165
+ neg=1,
166
+ num_samples=1,
167
+ image_key="image",
168
+ image_threshold=0,
169
+ ),
170
+ ],
171
+ weights=[1, 3]
172
+ ),
173
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
174
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
175
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
176
+ transforms.RandScaleIntensityd(keys="image", factors=0.2, prob=0.2),
177
+ transforms.RandShiftIntensityd(keys="image", offsets=0.2, prob=0.2),
178
+ transforms.ToTensord(keys=["image", "label"]),
179
+ ]
180
+ )
181
+
182
+ # ct_path is path for a ct scan file with nii.gz format
183
+ # gt_path is path for a ground truth file with nii.gz format
184
+ def preprocess_ct_gt(self, ct_path, gt_path, category):
185
+ item = {}
186
+ # generate ct_voxel_ndarray
187
+ ct_voxel_ndarray, _ = self.img_loader(ct_path)
188
+ ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
189
+ ct_shape = ct_voxel_ndarray.shape
190
+ ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
191
+ ct_voxel_ndarray = self.ForegroundNorm(ct_voxel_ndarray)
192
+ item['image'] = ct_voxel_ndarray
193
+
194
+ # generate gt_voxel_ndarray
195
+ gt_voxel_ndarray, _ = self.img_loader(gt_path)
196
+ gt_voxel_ndarray = np.array(gt_voxel_ndarray)
197
+ present_categories = np.unique(gt_voxel_ndarray)
198
+ gt_masks = []
199
+ for cls_idx in range(len(category)):
200
+ # ignore background
201
+ cls = cls_idx + 1
202
+ if cls not in present_categories:
203
+ gt_voxel_ndarray_category = np.zeros(ct_shape)
204
+ gt_masks.append(gt_voxel_ndarray_category)
205
+ else:
206
+ gt_voxel_ndarray_category = gt_voxel_ndarray.copy()
207
+ gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0
208
+ gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1
209
+ gt_masks.append(gt_voxel_ndarray_category)
210
+ gt_voxel_ndarray = np.stack(gt_masks, axis=0)
211
+ assert gt_voxel_ndarray.shape[0] == len(category) and gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:]
212
+ item['label'] = gt_voxel_ndarray.astype(np.int32)
213
+
214
+ # transform
215
+ return item['image'], item['label']
216
+
217
+ def load_uniseg_case(self, ct_npy_path, gt_npy_path):
218
+ img_array = np.load(ct_npy_path)
219
+ allmatrix_sp= sparse.load_npz(gt_npy_path)
220
+ if 'mask_' in gt_npy_path:
221
+ gt_shape = ast.literal_eval(gt_npy_path.split('_')[-1].replace('.npz', ''))
222
+ else:
223
+ gt_shape = ast.literal_eval(gt_npy_path.split('.')[-2])
224
+ gt_array=allmatrix_sp.toarray().reshape(gt_shape)
225
+ return img_array, gt_array
226
+
227
+ def ForegroundNorm(self, ct_narray):
228
+ ct_voxel_ndarray = ct_narray.copy()
229
+ ct_voxel_ndarray = ct_voxel_ndarray.flatten()
230
+ thred = np.mean(ct_voxel_ndarray)
231
+ voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
232
+ upper_bound = np.percentile(voxel_filtered, 99.95)
233
+ lower_bound = np.percentile(voxel_filtered, 00.05)
234
+ mean = np.mean(voxel_filtered)
235
+ std = np.std(voxel_filtered)
236
+ ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
237
+ ct_narray = (ct_narray - mean) / max(std, 1e-8)
238
+ return ct_narray
239
+
240
+ def zoom_transform(self, ct_npy, gt_npy):
241
+ item = {
242
+ 'image': ct_npy,
243
+ 'label': gt_npy
244
+ }
245
+ item = self.transform4test(item)
246
+ item_zoom_out = self.zoom_out_transform(item)
247
+ item['zoom_out_image'] = item_zoom_out['image']
248
+ item['zoom_out_label'] = item_zoom_out['label']
249
+ return item
250
+
251
+ def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0, device='cpu'):
252
+ point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
253
+ points_single = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
254
+ binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
255
+ return points_single, binary_points_resize
256
+
257
+ def bbox_prompt_b(self, label_single_resize, device='cpu'):
258
+ box_single = generate_box(label_single_resize).unsqueeze(0).float().to(device)
259
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
260
+ return box_single, binary_cube_resize
261
+
262
+ def dice_score(self, preds, labels, device='cpu'):
263
+ assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
264
+ predict = preds.view(1, -1).to(device)
265
+ target = labels.view(1, -1).to(device)
266
+
267
+ predict = torch.sigmoid(predict)
268
+ predict = torch.where(predict > 0.5, 1., 0.)
269
+
270
+ tp = torch.sum(torch.mul(predict, target))
271
+ den = torch.sum(predict) + torch.sum(target) + 1
272
+ dice = 2 * tp / den
273
+ return dice
274
+
275
+ def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
276
+ ct = nib.load(ct_path)
277
+ logits_mask = logits_mask.transpose(-1, -3)
278
+ start_coord[-1], start_coord[-3] = start_coord[-3], start_coord[-1]
279
+ end_coord[-1], end_coord[-3] = end_coord[-3], end_coord[-1]
280
+ preds_save = torch.zeros(ct.shape)
281
+ preds_save[start_coord[0]:end_coord[0],
282
+ start_coord[1]:end_coord[1],
283
+ start_coord[2]:end_coord[2]] = torch.sigmoid(logits_mask)
284
+ preds_save = torch.where(preds_save > 0.5, 1., 0.).numpy()
285
+ preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header)
286
+ nib.save(preds_nii, save_path)
287
+
288
+ def train_transform(self, ct_npy, gt_npy):
289
+ item = {
290
+ 'image': ct_npy,
291
+ 'label': gt_npy
292
+ }
293
+ item = self.transform4train(item)
294
+ if type(item) is list:
295
+ assert len(item) == 1
296
+ item = item[0]
297
+ return item
298
+
299
+ class MinMaxNormalization(transforms.Transform):
300
+ def __call__(self, data):
301
+ d = dict(data)
302
+ k = "image"
303
+ d[k] = d[k] - d[k].min()
304
+ d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
305
+ return d
306
+
307
+ class DimTranspose(transforms.Transform):
308
+ def __init__(self, keys):
309
+ self.keys = keys
310
+
311
+ def __call__(self, data):
312
+ d = dict(data)
313
+ for key in self.keys:
314
+ d[key] = np.swapaxes(d[key], -1, -3)
315
+ return d
316
+
317
+ # prompts
318
+ def generate_box(pred_pre, bbox_shift=None):
319
+ meaning_post_label = pred_pre # [h, w, d]
320
+ ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
321
+ if all(tensor.nelement() == 0 for tensor in ones_idx):
322
+ bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
323
+ return bboxes
324
+ min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
325
+ max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
326
+
327
+
328
+ if bbox_shift is None:
329
+ corner_min = []
330
+ corner_max = []
331
+ shape = meaning_post_label.shape
332
+ for coor in min_coords:
333
+ coor_ = max(0, coor)
334
+ corner_min.append(coor_)
335
+ for idx, coor in enumerate(max_coords):
336
+ coor_ = min(shape[idx], coor)
337
+ corner_max.append(coor_)
338
+ corner_min = torch.tensor(corner_min)
339
+ corner_max = torch.tensor(corner_max)
340
+ return torch.cat((corner_min, corner_max), dim=0)
341
+ else:
342
+ # add perturbation to bounding box coordinates
343
+ corner_min = []
344
+ corner_max = []
345
+ shape = meaning_post_label.shape
346
+ for coor in min_coords:
347
+ coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
348
+ corner_min.append(coor_)
349
+ for idx, coor in enumerate(max_coords):
350
+ coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
351
+ corner_max.append(coor_)
352
+ corner_min = torch.tensor(corner_min)
353
+ corner_max = torch.tensor(corner_max)
354
+ return torch.cat((corner_min, corner_max), dim=0)
355
+
356
+
357
+ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
358
+ spacial_dim = 3
359
+ points = torch.zeros((0, 3))
360
+ labels = torch.zeros((0))
361
+ pos_thred = 0.9
362
+ neg_thred = 0.1
363
+
364
+ # get pos/net indices
365
+ positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
366
+ negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
367
+
368
+ ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
369
+ if all(tmp.nelement() == 0 for tmp in ones_idx):
370
+ # all neg
371
+ num_positive_extra = 0
372
+ selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
373
+ points = torch.cat((points, selected_positive_point), dim=0)
374
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
375
+ else:
376
+ # random select a pos point
377
+ random_idx = torch.randint(len(positive_indices[0]), (1,))
378
+ selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
379
+ points = torch.cat((points, selected_positive_point), dim=0)
380
+ labels = torch.cat((labels, torch.ones((1))))
381
+
382
+ if num_positive_extra > 0:
383
+ pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
384
+ extra_positive_points = []
385
+ for pos_idx in pos_idx_list:
386
+ extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
387
+ extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
388
+ points = torch.cat((points, extra_positive_points), dim=0)
389
+ labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
390
+
391
+ if num_negative_extra > 0:
392
+ neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
393
+ extra_negative_points = []
394
+ for neg_idx in neg_idx_list:
395
+ extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
396
+ extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
397
+ points = torch.cat((points, extra_negative_points), dim=0)
398
+ labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
399
+
400
+ if fix_extra_point_num is None:
401
+ left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
402
+ else:
403
+ left_point_num = fix_extra_point_num + 1 - labels.shape[0]
404
+
405
+ for _ in range(left_point_num):
406
+ ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
407
+ points = torch.cat((points, ignore_point), dim=0)
408
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
409
+
410
+ return points, labels
411
+
412
+ # SegVol
413
+ import torch
414
+ import torch.nn as nn
415
+ import torch.nn.functional as F
416
+ import numpy as np
417
+ from transformers import CLIPTextModel, CLIPTextConfig
418
+ import random
419
+
420
+ #%% set up model
421
+ class SegVol(nn.Module):
422
+ def __init__(self,
423
+ image_encoder,
424
+ mask_decoder,
425
+ prompt_encoder,
426
+ roi_size,
427
+ patch_size,
428
+ # clip_model,
429
+ test_mode=False,
430
+ ):
431
+ super().__init__()
432
+ self.image_encoder = image_encoder
433
+ self.mask_decoder = mask_decoder
434
+ self.prompt_encoder = prompt_encoder
435
+ self.text_encoder = TextEncoder()
436
+ self.feat_shape = np.array(roi_size)/np.array(patch_size)
437
+ self.test_mode = test_mode
438
+ self.dice_loss = BinaryDiceLoss()
439
+ self.bce_loss = BCELoss()
440
+ self.decoder_iter = 6
441
+
442
+ def forward(self, image, text=None, boxes=None, points=None, **kwargs):
443
+ bs = image.shape[0]
444
+ img_shape = (image.shape[2], image.shape[3], image.shape[4])
445
+ image_embedding, _ = self.image_encoder(image)
446
+ image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
447
+ int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
448
+ # test mode
449
+ if self.test_mode:
450
+ return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
451
+
452
+ # train mode
453
+ ## sl
454
+ sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
455
+ ## ssl
456
+ # ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
457
+ return sl_loss
458
+
459
+ def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
460
+ device = image_embedding.device
461
+ with torch.no_grad():
462
+ if boxes is not None:
463
+ if len(boxes.shape) == 2:
464
+ boxes = boxes[:, None, :] # (B, 1, 6)
465
+ if text is not None:
466
+ text_embedding = self.text_encoder(text, device) # (B, 768)
467
+ else:
468
+ text_embedding = None
469
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
470
+ points=points,
471
+ boxes=boxes,
472
+ masks=None,
473
+ text_embedding=text_embedding,
474
+ )
475
+
476
+ dense_pe = self.prompt_encoder.get_dense_pe()
477
+ low_res_masks, _ = self.mask_decoder(
478
+ image_embeddings=image_embedding,
479
+ text_embedding = text_embedding,
480
+ image_pe=dense_pe,
481
+ sparse_prompt_embeddings=sparse_embeddings,
482
+ dense_prompt_embeddings=dense_embeddings,
483
+ multimask_output=False,
484
+ )
485
+ logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
486
+ return logits
487
+
488
+ def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
489
+ device = image_embedding.device
490
+ iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels, device)
491
+ # select prompt
492
+ prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
493
+ [None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
494
+ [iter_bboxes, iter_points, None]]
495
+ sl_loss = 0
496
+ for prompt in prompt_options:
497
+ bboxes, points, organs = prompt
498
+ logits = self.forward_decoder(image_embedding, img_shape, text=organs, boxes=bboxes, points=points)
499
+ # cal loss
500
+ sl_loss_dice = self.dice_loss.forward(logits.squeeze().float(), train_labels.squeeze().float())
501
+ sl_loss_bce = self.bce_loss.forward(logits.squeeze().float(), train_labels.squeeze().float())
502
+ sl_loss += sl_loss_dice + sl_loss_bce
503
+ return sl_loss
504
+
505
+ # def unsupervised_forward(self, image, image_embedding, pseudo_seg_cleaned, img_shape):
506
+ # sll_loss = 0
507
+ # for iter in range(self.decoder_iter):
508
+ # if iter % 2 == 0:
509
+ # pseudo_labels, pseudo_points_prompt = self.build_pseudo_point_prompt_label(image.shape, pseudo_seg_cleaned)
510
+ # logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=None, points=pseudo_points_prompt)
511
+ # else:
512
+ # pseudo_labels, pseudo_bboxes_prompt = self.build_pseudo_box_prompt_label(image.shape, pseudo_seg_cleaned)
513
+ # logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=pseudo_bboxes_prompt, points=None)
514
+ # # cal loss
515
+ # sll_loss_dice = self.dice_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
516
+ # sll_loss_bce = self.bce_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
517
+ # sll_loss += sll_loss_dice + sll_loss_bce
518
+ # return sll_loss
519
+
520
+ def build_prompt_label(self, bs, training_organs, train_labels, device):
521
+ # generate prompt & label
522
+ iter_organs = []
523
+ iter_bboxes = []
524
+ iter_points_ax = []
525
+ iter_point_labels = []
526
+ for sample_idx in range(bs):
527
+ # organ prompt
528
+ iter_organs.append(training_organs)
529
+ # box prompt
530
+ box = generate_box(train_labels[sample_idx], bbox_shift=10)
531
+ iter_bboxes.append(box)
532
+ # point prompt
533
+ num_positive_extra_max, num_negative_extra_max = 10, 10
534
+ num_positive_extra = random.randint(0, num_positive_extra_max)
535
+ num_negative_extra = random.randint(0, num_negative_extra_max)
536
+ point, point_label = select_points(
537
+ train_labels[sample_idx],
538
+ num_positive_extra=num_positive_extra,
539
+ num_negative_extra=num_negative_extra,
540
+ fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
541
+ iter_points_ax.append(point)
542
+ iter_point_labels.append(point_label)
543
+ # batched prompt
544
+ iter_points_ax = torch.stack(iter_points_ax, dim=0).to(device)
545
+ iter_point_labels = torch.stack(iter_point_labels, dim=0).to(device)
546
+ iter_points = (iter_points_ax, iter_point_labels)
547
+ iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(device)
548
+ return iter_points, iter_bboxes, iter_organs
549
+
550
+ # def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
551
+ # pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
552
+ # # generate points
553
+ # points = []
554
+ # point_labels = []
555
+ # for batch_idx in range(input_shape[0]):
556
+ # # generate pseudo label
557
+ # unique_ids = torch.unique(seg_labels[batch_idx])
558
+ # unique_ids = unique_ids[unique_ids != -1]
559
+ # region_id = random.choice(unique_ids).item()
560
+ # pseudo_labels[batch_idx][seg_labels[batch_idx]==region_id] = 1
561
+ # # generate point prompt
562
+ # num_positive_extra_max, num_negative_extra_max = 10, 10
563
+ # num_positive_extra = random.randint(4, num_positive_extra_max)
564
+ # num_negative_extra = random.randint(0, num_negative_extra_max)
565
+ # assert len(pseudo_labels[batch_idx][0].shape) == 3
566
+ # point, point_label = select_points(
567
+ # pseudo_labels[batch_idx][0],
568
+ # num_positive_extra=num_positive_extra,
569
+ # num_negative_extra=num_negative_extra,
570
+ # fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
571
+ # points.append(point)
572
+ # point_labels.append(point_label)
573
+ # points = torch.stack(points, dim=0).to(self.custom_device)
574
+ # point_labels = torch.stack(point_labels, dim=0).to(self.custom_device)
575
+ # pseudo_points_prompt = (points, point_labels)
576
+ # return pseudo_labels, pseudo_points_prompt
577
+
578
+ # def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
579
+ # pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
580
+ # iter_bboxes = []
581
+ # # generate boxes
582
+ # for batch_idx in range(input_shape[0]):
583
+ # # generate ori pseudo label
584
+ # unique_ids = torch.unique(seg_labels_cleaned[batch_idx])
585
+ # unique_ids = unique_ids[unique_ids != -1]
586
+ # region_id = random.choice(unique_ids).item()
587
+ # pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==region_id] = 1
588
+ # # generate box prompt
589
+ # box = generate_box(pseudo_labels[batch_idx][0])
590
+ # iter_bboxes.append(box)
591
+ # # refine pseudo label
592
+ # x_min, y_min, z_min, x_max, y_max, z_max = box
593
+ # binary_cube = torch.zeros_like(pseudo_labels[batch_idx][0]).int()
594
+ # binary_cube[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] = 1
595
+ # # cal iou
596
+ # mask_label = seg_labels_cleaned[batch_idx][0]
597
+ # assert binary_cube.shape == mask_label.shape, str(binary_cube.shape) + ' ' + str(mask_label.shape)
598
+ # mask_values_in_binary_cube = mask_label[binary_cube == 1]
599
+ # unique_mask_values = torch.unique(mask_values_in_binary_cube)
600
+ # # print('unique_mask_values ', unique_mask_values)
601
+ # for value in unique_mask_values:
602
+ # if value == -1: continue
603
+ # mask_area = (mask_label == value)
604
+ # intersection = (binary_cube & mask_area)
605
+ # iou = intersection.float().sum() / mask_area.float().sum()
606
+ # if iou > 0.90:
607
+ # # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
608
+ # pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
609
+
610
+ # bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
611
+ # return pseudo_labels, bboxes
612
+
613
+ class TextEncoder(nn.Module):
614
+ def __init__(self):
615
+ super().__init__()
616
+ config = CLIPTextConfig()
617
+ self.clip_text_model = CLIPTextModel(config)
618
+ self.tokenizer = None
619
+ self.dim_align = nn.Linear(512, 768)
620
+ # freeze text encoder
621
+ for param in self.clip_text_model.parameters():
622
+ param.requires_grad = False
623
+
624
+ def organ2tokens(self, organ_names, device):
625
+ text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
626
+ tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
627
+ for key in tokens.keys():
628
+ tokens[key] = tokens[key].to(device)
629
+ return tokens
630
+
631
+ def forward(self, text, device):
632
+ if text is None:
633
+ return None
634
+ if type(text) is str:
635
+ # text is supposed to be list
636
+ text = [text]
637
+ tokens = self.organ2tokens(text, device)
638
+ clip_outputs = self.clip_text_model(**tokens)
639
+ text_embedding = clip_outputs.pooler_output
640
+ text_embedding = self.dim_align(text_embedding)
641
+ return text_embedding
642
+
643
+ # loss
644
+ import torch
645
+ import torch.nn as nn
646
+
647
+ class BinaryDiceLoss(nn.Module):
648
+ def __init__(self, smooth=1, p=2, reduction='mean'):
649
+ super(BinaryDiceLoss, self).__init__()
650
+ self.smooth = smooth
651
+ self.p = p
652
+ self.reduction = reduction
653
+
654
+ def forward(self, predict, target):
655
+ predict = torch.sigmoid(predict)
656
+ target_ = target.clone()
657
+ target_[target == -1] = 0
658
+ assert predict.shape[0] == target.shape[0], "predict & target batch size don't match\n" + str(predict.shape) + '\n' + str(target.shape[0])
659
+ predict = predict.contiguous().view(predict.shape[0], -1)
660
+ target_ = target_.contiguous().view(target_.shape[0], -1)
661
+
662
+ num = torch.sum(torch.mul(predict, target_), dim=1)
663
+ den = torch.sum(predict, dim=1) + torch.sum(target_, dim=1) + self.smooth
664
+
665
+ dice_score = 2*num / den
666
+ dice_loss = 1 - dice_score
667
+
668
+ # dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]
669
+ dice_loss_avg = dice_loss.sum() / dice_loss.shape[0]
670
+
671
+ return dice_loss_avg
672
+
673
+ class BCELoss(nn.Module):
674
+ def __init__(self):
675
+ super(BCELoss, self).__init__()
676
+ self.criterion = nn.BCEWithLogitsLoss()
677
+
678
+ def forward(self, predict, target):
679
+ assert predict.shape == target.shape, 'predict & target shape do not match\n' + str(predict.shape) + '\n' + str(target.shape)
680
+ target_ = target.clone()
681
+ target_[target == -1] = 0
682
+
683
+ ce_loss = self.criterion(predict, target_)
684
+
685
+ return ce_loss
686
+
687
+ # monai inference
688
+
689
+ # Copyright (c) MONAI Consortium
690
+ # Licensed under the Apache License, Version 2.0 (the "License");
691
+ # you may not use this file except in compliance with the License.
692
+ # You may obtain a copy of the License at
693
+ # http://www.apache.org/licenses/LICENSE-2.0
694
+ # Unless required by applicable law or agreed to in writing, software
695
+ # distributed under the License is distributed on an "AS IS" BASIS,
696
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
697
+ # See the License for the specific language governing permissions and
698
+ # limitations under the License.
699
+
700
+ import warnings
701
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
702
+
703
+ import torch
704
+ import torch.nn.functional as F
705
+ import random
706
+
707
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
708
+ from monai.transforms import Resize
709
+ from monai.utils import (
710
+ BlendMode,
711
+ PytorchPadMode,
712
+ convert_data_type,
713
+ ensure_tuple,
714
+ fall_back_tuple,
715
+ look_up_option,
716
+ optional_import,
717
+ )
718
+
719
+ tqdm, _ = optional_import("tqdm", name="tqdm")
720
+
721
+ __all__ = ["sliding_window_inference"]
722
+
723
+ def logits2roi_coor(spatial_size, logits_global_single):
724
+ # crop predict
725
+ pred_global_single = torch.sigmoid(logits_global_single) > 0.5
726
+ ## get all pos idx
727
+ nonzero_indices = torch.nonzero(pred_global_single)
728
+ if nonzero_indices.shape[0] == 0:
729
+ return None, None, None, None, None, None
730
+ ## get boundary
731
+ min_d, max_d = nonzero_indices[:, 0].min(), nonzero_indices[:, 0].max()
732
+ min_h, max_h = nonzero_indices[:, 1].min(), nonzero_indices[:, 1].max()
733
+ min_w, max_w = nonzero_indices[:, 2].min(), nonzero_indices[:, 2].max()
734
+ ## padding
735
+ crop_d, crop_h, crop_w = max_d - min_d + 1, max_h - min_h + 1, max_w - min_w + 1,
736
+ window_d, window_h, window_w = spatial_size
737
+ padding_d, padding_h, padding_w = max(0, window_d-crop_d), max(0, window_h-crop_h), max(0, window_w-crop_w)
738
+ global_d, global_h, global_w = logits_global_single.shape
739
+ min_d = max(0, min_d - int(padding_d)//2)
740
+ min_h = max(0, min_h - int(padding_h)//2)
741
+ min_w = max(0, min_w - int(padding_w)//2)
742
+ max_d = min(global_d, max_d + int(padding_d)//2)
743
+ max_h = min(global_h, max_h + int(padding_h)//2)
744
+ max_w = min(global_w, max_w + int(padding_w)//2)
745
+ return min_d, min_h, min_w, max_d, max_h, max_w
746
+
747
+ def build_binary_cube(bbox, binary_cube_shape):
748
+ min_coord = bbox[0][:3].int().tolist()
749
+ max_coord = bbox[0][3:].int().tolist()
750
+ binary_cube = torch.zeros(binary_cube_shape)
751
+ binary_cube[min_coord[0]:max_coord[0]+1, min_coord[1]:max_coord[1]+1, min_coord[2]:max_coord[2]+1] = 1
752
+ return binary_cube
753
+
754
+ def build_binary_points(points, labels, shape):
755
+ binary_points = torch.zeros(shape, dtype=torch.int16)
756
+ binary_points[points[labels == 1, 0].long(), points[labels == 1, 1].long(), points[labels == 1, 2].long()] = 1
757
+ return binary_points
758
+
759
+ def sliding_window_inference(
760
+ inputs: torch.Tensor,
761
+ prompt_reflection: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
762
+ roi_size: Union[Sequence[int], int],
763
+ sw_batch_size: int,
764
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
765
+ overlap: float = 0.25,
766
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
767
+ sigma_scale: Union[Sequence[float], float] = 0.125,
768
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
769
+ cval: float = 0.0,
770
+ sw_device: Union[torch.device, str, None] = None,
771
+ device: Union[torch.device, str, None] = None,
772
+ progress: bool = False,
773
+ roi_weight_map: Union[torch.Tensor, None] = None,
774
+ *args: Any,
775
+ **kwargs: Any,
776
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
777
+ """
778
+ Sliding window inference on `inputs` with `predictor`.
779
+
780
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
781
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
782
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
783
+ could be ([128,64,256], [64,32,128]).
784
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
785
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
786
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
787
+
788
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
789
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
790
+
791
+ Args:
792
+ inputs: input image to be processed (assuming NCHW[D])
793
+ roi_size: the spatial window size for inferences.
794
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
795
+ if the components of the `roi_size` are non-positive values, the transform will use the
796
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
797
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
798
+ sw_batch_size: the batch size to run window slices.
799
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
800
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
801
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
802
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
803
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
804
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
805
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
806
+ to ensure the scaled output ROI sizes are still integers.
807
+ If the `predictor`'s input and output spatial sizes are different,
808
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
809
+ overlap: Amount of overlap between scans.
810
+ mode: {``"constant"``, ``"gaussian"``}
811
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
812
+
813
+ - ``"constant``": gives equal weight to all predictions.
814
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
815
+
816
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
817
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
818
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
819
+ spatial dimensions.
820
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
821
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
822
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
823
+ cval: fill value for 'constant' padding mode. Default: 0
824
+ sw_device: device for the window data.
825
+ By default the device (and accordingly the memory) of the `inputs` is used.
826
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
827
+ device: device for the stitched output prediction.
828
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
829
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
830
+ `inputs` and `roi_size`. Output is on the `device`.
831
+ progress: whether to print a `tqdm` progress bar.
832
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
833
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
834
+ args: optional args to be passed to ``predictor``.
835
+ kwargs: optional keyword args to be passed to ``predictor``.
836
+
837
+ Note:
838
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
839
+
840
+ """
841
+ print('sliding window inference for ROI')
842
+ text = kwargs['text']
843
+ use_box = kwargs['use_box']
844
+ use_point = kwargs['use_point']
845
+ assert not (use_box and use_point)
846
+ compute_dtype = inputs.dtype
847
+ num_spatial_dims = len(inputs.shape) - 2
848
+ if overlap < 0 or overlap >= 1:
849
+ raise ValueError("overlap must be >= 0 and < 1.")
850
+
851
+ # determine image spatial size and batch size
852
+ # Note: all input images must have the same image size and batch size
853
+ batch_size, _, *image_size_ = inputs.shape
854
+
855
+ if device is None:
856
+ device = inputs.device
857
+ if sw_device is None:
858
+ sw_device = inputs.device
859
+
860
+ roi_size = fall_back_tuple(roi_size, image_size_)
861
+ # in case that image size is smaller than roi size
862
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
863
+ pad_size = []
864
+ for k in range(len(inputs.shape) - 1, 1, -1):
865
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
866
+ half = diff // 2
867
+ pad_size.extend([half, diff - half])
868
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
869
+ #############
870
+ if use_point or use_box:
871
+ binary_prompt_map, global_preds = prompt_reflection
872
+ global_preds = F.pad(global_preds, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
873
+ #############
874
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
875
+
876
+ # Store all slices in list
877
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
878
+ num_win = len(slices) # number of windows per image
879
+ total_slices = num_win * batch_size # total number of windows
880
+
881
+ # Create window-level importance map
882
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
883
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
884
+ importance_map = roi_weight_map
885
+ else:
886
+ try:
887
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
888
+ except BaseException as e:
889
+ raise RuntimeError(
890
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
891
+ ) from e
892
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
893
+ # handle non-positive weights
894
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
895
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
896
+
897
+ # Perform predictions
898
+ dict_key, output_image_list, count_map_list = None, [], []
899
+ _initialized_ss = -1
900
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
901
+
902
+ # for each patch
903
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
904
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
905
+ unravel_slice = [
906
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
907
+ for idx in slice_range
908
+ ]
909
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
910
+ #############
911
+
912
+ boxes = None
913
+ points = None
914
+ if use_point:
915
+ window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
916
+ point, point_label = select_points(window_binary_prompt_map.squeeze())
917
+ points = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
918
+ pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
919
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
920
+ if use_box:
921
+ if num_win == 1:
922
+ window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
923
+ boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().to(device)
924
+ else:
925
+ pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
926
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
927
+ seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
928
+ #############
929
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
930
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
931
+ if isinstance(seg_prob_out, torch.Tensor):
932
+ seg_prob_tuple = (seg_prob_out,)
933
+ elif isinstance(seg_prob_out, Mapping):
934
+ if dict_key is None:
935
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
936
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
937
+ is_tensor_output = False
938
+ else:
939
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
940
+ is_tensor_output = False
941
+
942
+ # for each output in multi-output list
943
+ for ss, seg_prob in enumerate(seg_prob_tuple):
944
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
945
+
946
+ # compute zoom scale: out_roi_size/in_roi_size
947
+ zoom_scale = []
948
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
949
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
950
+ ):
951
+ _scale = out_w_i / float(in_w_i)
952
+ if not (img_s_i * _scale).is_integer():
953
+ warnings.warn(
954
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
955
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
956
+ )
957
+ zoom_scale.append(_scale)
958
+
959
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
960
+ # construct multi-resolution outputs
961
+ output_classes = seg_prob.shape[1]
962
+ output_shape = [batch_size, output_classes] + [
963
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
964
+ ]
965
+ # allocate memory to store the full output and the count for overlapping parts
966
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
967
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
968
+ _initialized_ss += 1
969
+
970
+ # resizing the importance_map
971
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
972
+
973
+ # store the result in the proper location of the full output. Apply weights from importance map.
974
+ for idx, original_idx in zip(slice_range, unravel_slice):
975
+ # zoom roi
976
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
977
+ for axis in range(2, len(original_idx_zoom)):
978
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
979
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
980
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
981
+ warnings.warn(
982
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
983
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
984
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
985
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
986
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
987
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
988
+ )
989
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
990
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
991
+ # store results and weights
992
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
993
+ count_map_list[ss][original_idx_zoom] += (
994
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
995
+ )
996
+
997
+ # account for any overlapping sections
998
+ for ss in range(len(output_image_list)):
999
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
1000
+
1001
+ # remove padding if image_size smaller than roi_size
1002
+ for ss, output_i in enumerate(output_image_list):
1003
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
1004
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
1005
+
1006
+ zoom_scale = [
1007
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
1008
+ ]
1009
+
1010
+ final_slicing: List[slice] = []
1011
+ for sp in range(num_spatial_dims):
1012
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
1013
+ slice_dim = slice(
1014
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
1015
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
1016
+ )
1017
+ final_slicing.insert(0, slice_dim)
1018
+ while len(final_slicing) < len(output_i.shape):
1019
+ final_slicing.insert(0, slice(None))
1020
+ output_image_list[ss] = output_i[final_slicing]
1021
+
1022
+ if dict_key is not None: # if output of predictor is a dict
1023
+ final_output = dict(zip(dict_key, output_image_list))
1024
+ else:
1025
+ final_output = tuple(output_image_list) # type: ignore
1026
+ return final_output[0] if is_tensor_output else final_output # type: ignore
1027
+
1028
+
1029
+ def _get_scan_interval(
1030
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
1031
+ ) -> Tuple[int, ...]:
1032
+ """
1033
+ Compute scan interval according to the image size, roi size and overlap.
1034
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
1035
+ use 1 instead to make sure sliding window works.
1036
+
1037
+ """
1038
+ if len(image_size) != num_spatial_dims:
1039
+ raise ValueError("image coord different from spatial dims.")
1040
+ if len(roi_size) != num_spatial_dims:
1041
+ raise ValueError("roi coord different from spatial dims.")
1042
+
1043
+ scan_interval = []
1044
+ for i in range(num_spatial_dims):
1045
+ if roi_size[i] == image_size[i]:
1046
+ scan_interval.append(int(roi_size[i]))
1047
+ else:
1048
+ interval = int(roi_size[i] * (1 - overlap))
1049
+ scan_interval.append(interval if interval > 0 else 1)
1050
+ return tuple(scan_interval)
1051
+
1052
+ # build 3D SAM
1053
+ import torch
1054
+ import numpy as np
1055
+ from monai.networks.nets import ViT
1056
+
1057
+ def _build_sam(
1058
+ image_encoder_type,
1059
+ embed_dim,
1060
+ patch_size,
1061
+ checkpoint,
1062
+ image_size,
1063
+ ):
1064
+ mlp_dim = 3072
1065
+ num_layers = 12
1066
+ num_heads = 12
1067
+ pos_embed = 'perceptron'
1068
+ dropout_rate = 0.0
1069
+
1070
+ image_encoder=ViT(
1071
+ in_channels=1,
1072
+ img_size=image_size,
1073
+ patch_size=patch_size,
1074
+ hidden_size=embed_dim,
1075
+ mlp_dim=mlp_dim,
1076
+ num_layers=num_layers,
1077
+ num_heads=num_heads,
1078
+ pos_embed=pos_embed,
1079
+ classification=False,
1080
+ dropout_rate=dropout_rate,
1081
+ )
1082
+ image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
1083
+
1084
+ if checkpoint is not None:
1085
+ with open(checkpoint, "rb") as f:
1086
+ state_dict = torch.load(f, map_location='cpu')['state_dict']
1087
+ encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
1088
+ image_encoder.load_state_dict(encoder_dict)
1089
+ print(f'===> image_encoder.load_param: {checkpoint}')
1090
+ sam = Sam(
1091
+ image_encoder=image_encoder,
1092
+ prompt_encoder=PromptEncoder(
1093
+ embed_dim=embed_dim,
1094
+ image_embedding_size=image_embedding_size,
1095
+ input_image_size=image_size,
1096
+ mask_in_chans=16,
1097
+ ),
1098
+ mask_decoder=MaskDecoder(
1099
+ image_encoder_type=image_encoder_type,
1100
+ num_multimask_outputs=3,
1101
+ transformer=TwoWayTransformer(
1102
+ depth=2,
1103
+ embedding_dim=embed_dim,
1104
+ mlp_dim=2048,
1105
+ num_heads=8,
1106
+ ),
1107
+ transformer_dim=embed_dim,
1108
+ iou_head_depth=3,
1109
+ iou_head_hidden_dim=256,
1110
+ image_size=np.array(image_size),
1111
+ patch_size=np.array(patch_size),
1112
+ ),
1113
+ pixel_mean=[123.675, 116.28, 103.53],
1114
+ pixel_std=[58.395, 57.12, 57.375],
1115
+ )
1116
+ sam.eval()
1117
+ return sam
1118
+
1119
+ # mask decoder
1120
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1121
+ # All rights reserved.
1122
+
1123
+ # This source code is licensed under the license found in the
1124
+ # LICENSE file in the root directory of this source tree.
1125
+
1126
+ import torch
1127
+ from torch import nn
1128
+ from torch.nn import functional as F
1129
+
1130
+ from typing import List, Tuple, Type, Optional
1131
+
1132
+ class MaskDecoder(nn.Module):
1133
+ def __init__(
1134
+ self,
1135
+ *,
1136
+ image_encoder_type: str,
1137
+ transformer_dim: int,
1138
+ transformer: nn.Module,
1139
+ num_multimask_outputs: int = 3,
1140
+ activation: Type[nn.Module] = nn.GELU,
1141
+ iou_head_depth: int = 3,
1142
+ iou_head_hidden_dim: int = 256,
1143
+ image_size,
1144
+ patch_size,
1145
+ ) -> None:
1146
+ """
1147
+ Predicts masks given an image and prompt embeddings, using a
1148
+ transformer architecture.
1149
+
1150
+ Arguments:
1151
+ transformer_dim (int): the channel dimension of the transformer
1152
+ transformer (nn.Module): the transformer used to predict masks
1153
+ num_multimask_outputs (int): the number of masks to predict
1154
+ when disambiguating masks
1155
+ activation (nn.Module): the type of activation to use when
1156
+ upscaling masks
1157
+ iou_head_depth (int): the depth of the MLP used to predict
1158
+ mask quality
1159
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
1160
+ used to predict mask quality
1161
+ """
1162
+ super().__init__()
1163
+ self.transformer_dim = transformer_dim
1164
+ self.transformer = transformer
1165
+
1166
+ self.num_multimask_outputs = num_multimask_outputs
1167
+
1168
+ self.iou_token = nn.Embedding(1, transformer_dim)
1169
+ self.num_mask_tokens = num_multimask_outputs + 1
1170
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
1171
+
1172
+ if image_encoder_type == 'swin_vit':
1173
+ self.feat_shape = image_size/patch_size
1174
+ self.output_upscaling = nn.Sequential(
1175
+ nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
1176
+ nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # swin
1177
+ activation(),
1178
+ nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), # swin
1179
+ # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1), # vit
1180
+ activation(),
1181
+ )
1182
+ else:
1183
+ self.feat_shape = image_size/patch_size * 2
1184
+ self.output_upscaling = nn.Sequential(
1185
+ nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
1186
+ nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # vit
1187
+ activation(),
1188
+ nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
1189
+ # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1),
1190
+ activation(),
1191
+ )
1192
+ self.output_hypernetworks_mlps = nn.ModuleList(
1193
+ [
1194
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
1195
+ for i in range(self.num_mask_tokens)
1196
+ ]
1197
+ )
1198
+
1199
+ self.iou_prediction_head = MLP(
1200
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
1201
+ )
1202
+
1203
+ self.txt_align_upscaled_embedding = nn.Linear(768, 96)
1204
+
1205
+ def forward(
1206
+ self,
1207
+ image_embeddings: torch.Tensor,
1208
+ text_embedding: Optional[torch.Tensor],
1209
+ image_pe: torch.Tensor,
1210
+ sparse_prompt_embeddings: torch.Tensor,
1211
+ dense_prompt_embeddings: torch.Tensor,
1212
+ multimask_output: bool,
1213
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1214
+ """
1215
+ Predict masks given image and prompt embeddings.
1216
+
1217
+ Returns:
1218
+ torch.Tensor: batched predicted masks
1219
+ """
1220
+ # print('--------------decoder here--------------')
1221
+ masks, iou_pred = self.predict_masks(
1222
+ image_embeddings=image_embeddings,
1223
+ text_embedding=text_embedding,
1224
+ image_pe=image_pe,
1225
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
1226
+ dense_prompt_embeddings=dense_prompt_embeddings,
1227
+ )
1228
+
1229
+ # Select the correct mask or masks for output
1230
+ if multimask_output:
1231
+ mask_slice = slice(1, None)
1232
+ else:
1233
+ mask_slice = slice(0, 1)
1234
+ masks = masks[:, mask_slice, :, :, :]
1235
+ iou_pred = iou_pred[:, mask_slice]
1236
+
1237
+ # Prepare output
1238
+ return masks, iou_pred
1239
+
1240
+ def predict_masks(
1241
+ self,
1242
+ image_embeddings: torch.Tensor,
1243
+ text_embedding: torch.Tensor,
1244
+ image_pe: torch.Tensor,
1245
+ sparse_prompt_embeddings: torch.Tensor,
1246
+ dense_prompt_embeddings: torch.Tensor,
1247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1248
+ """Predicts masks. See 'forward' for more details."""
1249
+ # Concatenate output tokens
1250
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1251
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
1252
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
1253
+ # Expand per-image data in batch direction to be per-mask
1254
+ if image_embeddings.shape[0] != tokens.shape[0]:
1255
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
1256
+ else:
1257
+ src = image_embeddings
1258
+ src = src + dense_prompt_embeddings
1259
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
1260
+ b, c, h, w, d = src.shape
1261
+
1262
+ # Run the transformer
1263
+ hs, src = self.transformer(src, pos_src, tokens)
1264
+ iou_token_out = hs[:, 0, :]
1265
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
1266
+
1267
+ # Upscale mask embeddings and predict masks using the mask tokens
1268
+ src = src.transpose(1, 2).view(b, c, h, w, d)
1269
+ upscaled_embedding = self.output_upscaling(src)
1270
+ hyper_in_list: List[torch.Tensor] = []
1271
+ for i in range(self.num_mask_tokens):
1272
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
1273
+ hyper_in = torch.stack(hyper_in_list, dim=1)
1274
+ b, c, h, w, d = upscaled_embedding.shape
1275
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d)
1276
+
1277
+ if text_embedding is not None:
1278
+ text_embedding_down = self.txt_align_upscaled_embedding(text_embedding).unsqueeze(dim=1)
1279
+ upscaled_embedding = upscaled_embedding.view(b, c, h * w * d)
1280
+ sim = (text_embedding_down @ upscaled_embedding).view(b, -1, h, w, d)
1281
+ sim = sim.repeat(1, masks.shape[1], 1, 1, 1)
1282
+ masks = masks + sim
1283
+ iou_pred = self.iou_prediction_head(iou_token_out)
1284
+
1285
+ return masks, iou_pred
1286
+
1287
+ # Lightly adapted from
1288
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
1289
+ class MLP(nn.Module):
1290
+ def __init__(
1291
+ self,
1292
+ input_dim: int,
1293
+ hidden_dim: int,
1294
+ output_dim: int,
1295
+ num_layers: int,
1296
+ sigmoid_output: bool = False,
1297
+ ) -> None:
1298
+ super().__init__()
1299
+ self.num_layers = num_layers
1300
+ h = [hidden_dim] * (num_layers - 1)
1301
+ self.layers = nn.ModuleList(
1302
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
1303
+ )
1304
+ self.sigmoid_output = sigmoid_output
1305
+
1306
+ def forward(self, x):
1307
+ for i, layer in enumerate(self.layers):
1308
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1309
+ if self.sigmoid_output:
1310
+ x = F.sigmoid(x)
1311
+ return x
1312
+
1313
+ # prompt encoder
1314
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1315
+ # All rights reserved.
1316
+
1317
+ # This source code is licensed under the license found in the
1318
+ # LICENSE file in the root directory of this source tree.
1319
+
1320
+ import numpy as np
1321
+ import torch
1322
+ from torch import nn
1323
+
1324
+ from typing import Any, Optional, Tuple, Type
1325
+
1326
+ class PromptEncoder(nn.Module):
1327
+ def __init__(
1328
+ self,
1329
+ embed_dim: int,
1330
+ image_embedding_size: Tuple[int, int, int],
1331
+ input_image_size: Tuple[int, int, int],
1332
+ mask_in_chans: int,
1333
+ activation: Type[nn.Module] = nn.GELU,
1334
+ ) -> None:
1335
+ """
1336
+ Encodes prompts for input to SAM's mask decoder.
1337
+
1338
+ Arguments:
1339
+ embed_dim (int): The prompts' embedding dimension
1340
+ image_embedding_size (tuple(int, int)): The spatial size of the
1341
+ image embedding, as (H, W).
1342
+ input_image_size (int): The padded size of the image as input
1343
+ to the image encoder, as (H, W).
1344
+ mask_in_chans (int): The number of hidden channels used for
1345
+ encoding input masks.
1346
+ activation (nn.Module): The activation to use when encoding
1347
+ input masks.
1348
+ """
1349
+ super().__init__()
1350
+ self.embed_dim = embed_dim
1351
+ self.input_image_size = input_image_size
1352
+ self.image_embedding_size = image_embedding_size
1353
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
1354
+
1355
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
1356
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
1357
+ self.point_embeddings = nn.ModuleList(point_embeddings)
1358
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
1359
+
1360
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
1361
+ self.mask_downscaling = nn.Sequential(
1362
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
1363
+ LayerNorm2d(mask_in_chans // 4),
1364
+ activation(),
1365
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
1366
+ LayerNorm2d(mask_in_chans),
1367
+ activation(),
1368
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
1369
+ )
1370
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
1371
+
1372
+ def get_dense_pe(self) -> torch.Tensor:
1373
+ """
1374
+ Returns the positional encoding used to encode point prompts,
1375
+ applied to a dense set of points the shape of the image encoding.
1376
+
1377
+ Returns:
1378
+ torch.Tensor: Positional encoding with shape
1379
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
1380
+ """
1381
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
1382
+
1383
+ def _embed_points(
1384
+ self,
1385
+ points: torch.Tensor,
1386
+ labels: torch.Tensor,
1387
+ pad: bool,
1388
+ ) -> torch.Tensor:
1389
+ """Embeds point prompts."""
1390
+ points = points + 0.5 # Shift to center of pixel
1391
+ if pad:
1392
+ padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
1393
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
1394
+ points = torch.cat([points, padding_point], dim=1)
1395
+ labels = torch.cat([labels, padding_label], dim=1)
1396
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
1397
+ point_embedding[labels == -1] = 0.0
1398
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
1399
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
1400
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
1401
+ return point_embedding
1402
+
1403
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
1404
+ """Embeds box prompts."""
1405
+ boxes = boxes + 0.5 # Shift to center of pixel
1406
+ coords = boxes.reshape(-1, 2, 3)
1407
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
1408
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
1409
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
1410
+ return corner_embedding
1411
+
1412
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
1413
+ """Embeds mask inputs."""
1414
+ mask_embedding = self.mask_downscaling(masks)
1415
+ return mask_embedding
1416
+
1417
+ def _get_batch_size(
1418
+ self,
1419
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
1420
+ boxes: Optional[torch.Tensor],
1421
+ masks: Optional[torch.Tensor],
1422
+ text_embedding: Optional[torch.Tensor],
1423
+ ) -> int:
1424
+ """
1425
+ Gets the batch size of the output given the batch size of the input prompts.
1426
+ """
1427
+ if points is not None:
1428
+ return points[0].shape[0]
1429
+ elif boxes is not None:
1430
+ return boxes.shape[0]
1431
+ elif masks is not None:
1432
+ return masks.shape[0]
1433
+ elif text_embedding is not None:
1434
+ return text_embedding.shape[0]
1435
+ else:
1436
+ return 1
1437
+
1438
+ def _get_device(self) -> torch.device:
1439
+ return self.point_embeddings[0].weight.device
1440
+
1441
+ def forward(
1442
+ self,
1443
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
1444
+ boxes: Optional[torch.Tensor],
1445
+ masks: Optional[torch.Tensor],
1446
+ text_embedding: Optional[torch.Tensor],
1447
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1448
+
1449
+ bs = self._get_batch_size(points, boxes, masks, text_embedding)
1450
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
1451
+
1452
+ if points is not None:
1453
+ coords, labels = points
1454
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
1455
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
1456
+
1457
+ if boxes is not None:
1458
+ box_embeddings = self._embed_boxes(boxes)
1459
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
1460
+
1461
+ if text_embedding is not None:
1462
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
1463
+
1464
+ if masks is not None:
1465
+ dense_embeddings = self._embed_masks(masks)
1466
+ else:
1467
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
1468
+ bs, -1, int(self.image_embedding_size[0]), int(self.image_embedding_size[1]), int(self.image_embedding_size[2])
1469
+ )
1470
+
1471
+ return sparse_embeddings, dense_embeddings
1472
+
1473
+
1474
+ class PositionEmbeddingRandom(nn.Module):
1475
+ """
1476
+ Positional encoding using random spatial frequencies.
1477
+ """
1478
+
1479
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
1480
+ super().__init__()
1481
+ if scale is None or scale <= 0.0:
1482
+ scale = 1.0
1483
+ self.register_buffer(
1484
+ "positional_encoding_gaussian_matrix",
1485
+ scale * torch.randn((3, num_pos_feats)),
1486
+ )
1487
+
1488
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
1489
+ """Positionally encode points that are normalized to [0,1]."""
1490
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1491
+ coords = 2 * coords - 1
1492
+ coords = coords @ self.positional_encoding_gaussian_matrix
1493
+ coords = 2 * np.pi * coords
1494
+ # outputs d_1 x ... x d_n x C shape
1495
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
1496
+
1497
+ def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
1498
+ """Generate positional encoding for a grid of the specified size."""
1499
+ h, w, d = size
1500
+ device: Any = self.positional_encoding_gaussian_matrix.device
1501
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
1502
+ y_embed = grid.cumsum(dim=0) - 0.5
1503
+ x_embed = grid.cumsum(dim=1) - 0.5
1504
+ z_embed = grid.cumsum(dim=2) - 0.5
1505
+ y_embed = y_embed / h
1506
+ x_embed = x_embed / w
1507
+ z_embed = z_embed / d
1508
+
1509
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
1510
+ return pe.permute(3, 0, 1, 2) # C x H x W x D
1511
+
1512
+ def forward_with_coords(
1513
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
1514
+ ) -> torch.Tensor:
1515
+ """Positionally encode points that are not normalized to [0,1]."""
1516
+ coords = coords_input.clone()
1517
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
1518
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
1519
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
1520
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
1521
+
1522
+ # two way transformer
1523
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1524
+ # All rights reserved.
1525
+
1526
+ # This source code is licensed under the license found in the
1527
+ # LICENSE file in the root directory of this source tree.
1528
+
1529
+ import torch
1530
+ from torch import Tensor, nn
1531
+
1532
+ import math
1533
+ from typing import Tuple, Type
1534
+
1535
+ class TwoWayTransformer(nn.Module):
1536
+ def __init__(
1537
+ self,
1538
+ depth: int,
1539
+ embedding_dim: int,
1540
+ num_heads: int,
1541
+ mlp_dim: int,
1542
+ activation: Type[nn.Module] = nn.ReLU,
1543
+ attention_downsample_rate: int = 2,
1544
+ ) -> None:
1545
+ """
1546
+ A transformer decoder that attends to an input image using
1547
+ queries whose positional embedding is supplied.
1548
+
1549
+ Args:
1550
+ depth (int): number of layers in the transformer
1551
+ embedding_dim (int): the channel dimension for the input embeddings
1552
+ num_heads (int): the number of heads for multihead attention. Must
1553
+ divide embedding_dim
1554
+ mlp_dim (int): the channel dimension internal to the MLP block
1555
+ activation (nn.Module): the activation to use in the MLP block
1556
+ """
1557
+ super().__init__()
1558
+ self.depth = depth
1559
+ self.embedding_dim = embedding_dim
1560
+ self.num_heads = num_heads
1561
+ self.mlp_dim = mlp_dim
1562
+ self.layers = nn.ModuleList()
1563
+
1564
+ for i in range(depth):
1565
+ self.layers.append(
1566
+ TwoWayAttentionBlock(
1567
+ embedding_dim=embedding_dim,
1568
+ num_heads=num_heads,
1569
+ mlp_dim=mlp_dim,
1570
+ activation=activation,
1571
+ attention_downsample_rate=attention_downsample_rate,
1572
+ skip_first_layer_pe=(i == 0),
1573
+ )
1574
+ )
1575
+
1576
+ self.final_attn_token_to_image = Attention(
1577
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1578
+ )
1579
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
1580
+
1581
+ def forward(
1582
+ self,
1583
+ image_embedding: Tensor,
1584
+ image_pe: Tensor,
1585
+ point_embedding: Tensor,
1586
+ ) -> Tuple[Tensor, Tensor]:
1587
+ """
1588
+ Args:
1589
+ image_embedding (torch.Tensor): image to attend to. Should be shape
1590
+ B x embedding_dim x h x w for any h and w.
1591
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
1592
+ have the same shape as image_embedding.
1593
+ point_embedding (torch.Tensor): the embedding to add to the query points.
1594
+ Must have shape B x N_points x embedding_dim for any N_points.
1595
+
1596
+ Returns:
1597
+ torch.Tensor: the processed point_embedding
1598
+ torch.Tensor: the processed image_embedding
1599
+ """
1600
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
1601
+ bs, c, h, w, d = image_embedding.shape
1602
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
1603
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
1604
+
1605
+ # Prepare queries
1606
+ queries = point_embedding
1607
+ keys = image_embedding
1608
+
1609
+ # Apply transformer blocks and final layernorm
1610
+ for layer in self.layers:
1611
+ queries, keys = layer(
1612
+ queries=queries,
1613
+ keys=keys,
1614
+ query_pe=point_embedding,
1615
+ key_pe=image_pe,
1616
+ )
1617
+
1618
+ # Apply the final attention layer from the points to the image
1619
+ q = queries + point_embedding
1620
+ k = keys + image_pe
1621
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
1622
+ queries = queries + attn_out
1623
+ queries = self.norm_final_attn(queries)
1624
+
1625
+ return queries, keys
1626
+
1627
+
1628
+ class TwoWayAttentionBlock(nn.Module):
1629
+ def __init__(
1630
+ self,
1631
+ embedding_dim: int,
1632
+ num_heads: int,
1633
+ mlp_dim: int = 2048,
1634
+ activation: Type[nn.Module] = nn.ReLU,
1635
+ attention_downsample_rate: int = 2,
1636
+ skip_first_layer_pe: bool = False,
1637
+ ) -> None:
1638
+ """
1639
+ A transformer block with four layers: (1) self-attention of sparse
1640
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
1641
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
1642
+ inputs.
1643
+
1644
+ Arguments:
1645
+ embedding_dim (int): the channel dimension of the embeddings
1646
+ num_heads (int): the number of heads in the attention layers
1647
+ mlp_dim (int): the hidden dimension of the mlp block
1648
+ activation (nn.Module): the activation of the mlp block
1649
+ skip_first_layer_pe (bool): skip the PE on the first layer
1650
+ """
1651
+ super().__init__()
1652
+ self.self_attn = Attention(embedding_dim, num_heads)
1653
+ self.norm1 = nn.LayerNorm(embedding_dim)
1654
+
1655
+ self.cross_attn_token_to_image = Attention(
1656
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1657
+ )
1658
+ self.norm2 = nn.LayerNorm(embedding_dim)
1659
+
1660
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
1661
+ self.norm3 = nn.LayerNorm(embedding_dim)
1662
+
1663
+ self.norm4 = nn.LayerNorm(embedding_dim)
1664
+ self.cross_attn_image_to_token = Attention(
1665
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1666
+ )
1667
+
1668
+ self.skip_first_layer_pe = skip_first_layer_pe
1669
+
1670
+ def forward(
1671
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
1672
+ ) -> Tuple[Tensor, Tensor]:
1673
+ # Self attention block
1674
+ if self.skip_first_layer_pe:
1675
+ queries = self.self_attn(q=queries, k=queries, v=queries)
1676
+ else:
1677
+ q = queries + query_pe
1678
+ attn_out = self.self_attn(q=q, k=q, v=queries)
1679
+ queries = queries + attn_out
1680
+ queries = self.norm1(queries)
1681
+
1682
+ # Cross attention block, tokens attending to image embedding
1683
+ q = queries + query_pe
1684
+ k = keys + key_pe
1685
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
1686
+ queries = queries + attn_out
1687
+ queries = self.norm2(queries)
1688
+
1689
+ # MLP block
1690
+ mlp_out = self.mlp(queries)
1691
+ queries = queries + mlp_out
1692
+ queries = self.norm3(queries)
1693
+
1694
+ # Cross attention block, image embedding attending to tokens
1695
+ q = queries + query_pe
1696
+ k = keys + key_pe
1697
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
1698
+ keys = keys + attn_out
1699
+ keys = self.norm4(keys)
1700
+
1701
+ return queries, keys
1702
+
1703
+
1704
+ class Attention(nn.Module):
1705
+ """
1706
+ An attention layer that allows for downscaling the size of the embedding
1707
+ after projection to queries, keys, and values.
1708
+ """
1709
+
1710
+ def __init__(
1711
+ self,
1712
+ embedding_dim: int,
1713
+ num_heads: int,
1714
+ downsample_rate: int = 1,
1715
+ ) -> None:
1716
+ super().__init__()
1717
+ self.embedding_dim = embedding_dim
1718
+ self.internal_dim = embedding_dim // downsample_rate
1719
+ self.num_heads = num_heads
1720
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
1721
+
1722
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
1723
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
1724
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
1725
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
1726
+
1727
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
1728
+ b, n, c = x.shape
1729
+ x = x.reshape(b, n, num_heads, c // num_heads)
1730
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
1731
+
1732
+ def _recombine_heads(self, x: Tensor) -> Tensor:
1733
+ b, n_heads, n_tokens, c_per_head = x.shape
1734
+ x = x.transpose(1, 2)
1735
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
1736
+
1737
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
1738
+ # Input projections
1739
+ q = self.q_proj(q)
1740
+ k = self.k_proj(k)
1741
+ v = self.v_proj(v)
1742
+
1743
+ # Separate into heads
1744
+ q = self._separate_heads(q, self.num_heads)
1745
+ k = self._separate_heads(k, self.num_heads)
1746
+ v = self._separate_heads(v, self.num_heads)
1747
+
1748
+ # Attention
1749
+ _, _, _, c_per_head = q.shape
1750
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
1751
+ attn = attn / math.sqrt(c_per_head)
1752
+ attn = torch.softmax(attn, dim=-1)
1753
+
1754
+ # Get output
1755
+ out = attn @ v
1756
+ out = self._recombine_heads(out)
1757
+ out = self.out_proj(out)
1758
+
1759
+ return out
1760
+
1761
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
1762
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
1763
+ class LayerNorm2d(nn.Module):
1764
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
1765
+ super().__init__()
1766
+ self.weight = nn.Parameter(torch.ones(num_channels))
1767
+ self.bias = nn.Parameter(torch.zeros(num_channels))
1768
+ self.eps = eps
1769
+
1770
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1771
+ u = x.mean(1, keepdim=True)
1772
+ s = (x - u).pow(2).mean(1, keepdim=True)
1773
+ x = (x - u) / torch.sqrt(s + self.eps)
1774
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
1775
+ return x
1776
+
1777
+ class MLPBlock(nn.Module):
1778
+ def __init__(
1779
+ self,
1780
+ embedding_dim: int,
1781
+ mlp_dim: int,
1782
+ act: Type[nn.Module] = nn.GELU,
1783
+ ) -> None:
1784
+ super().__init__()
1785
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
1786
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
1787
+ self.act = act()
1788
+
1789
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1790
+ return self.lin2(self.act(self.lin1(x)))
1791
+
1792
+
1793
+ # sam
1794
+ class Sam(nn.Module):
1795
+ mask_threshold: float = 0.0
1796
+ image_format: str = "RGB"
1797
+
1798
+ def __init__(
1799
+ self,
1800
+ image_encoder,
1801
+ prompt_encoder,
1802
+ mask_decoder,
1803
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
1804
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
1805
+ ) -> None:
1806
+ """
1807
+ SAM predicts object masks from an image and input prompts.
1808
+
1809
+ Arguments:
1810
+ image_encoder (ImageEncoderViT): The backbone used to encode the
1811
+ image into image embeddings that allow for efficient mask prediction.
1812
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
1813
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
1814
+ and encoded prompts.
1815
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
1816
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
1817
+ """
1818
+ super().__init__()
1819
+ self.image_encoder = image_encoder
1820
+ self.prompt_encoder = prompt_encoder
1821
+ self.mask_decoder = mask_decoder
1822
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
1823
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
1824
+
1825
+ @property
1826
+ def device(self) -> Any:
1827
+ return self.pixel_mean.device
1828
+
1829
+ @torch.no_grad()
1830
+ def forward(
1831
+ self,
1832
+ batched_input: List[Dict[str, Any]],
1833
+ multimask_output: bool,
1834
+ ) -> List[Dict[str, torch.Tensor]]:
1835
+ """
1836
+ Predicts masks end-to-end from provided images and prompts.
1837
+ If prompts are not known in advance, using SamPredictor is
1838
+ recommended over calling the model directly.
1839
+
1840
+ Arguments:
1841
+ batched_input (list(dict)): A list over input images, each a
1842
+ dictionary with the following keys. A prompt key can be
1843
+ excluded if it is not present.
1844
+ 'image': The image as a torch tensor in 3xHxW format,
1845
+ already transformed for input to the model.
1846
+ 'original_size': (tuple(int, int)) The original size of
1847
+ the image before transformation, as (H, W).
1848
+ 'point_coords': (torch.Tensor) Batched point prompts for
1849
+ this image, with shape BxNx2. Already transformed to the
1850
+ input frame of the model.
1851
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
1852
+ with shape BxN.
1853
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
1854
+ Already transformed to the input frame of the model.
1855
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
1856
+ in the form Bx1xHxW.
1857
+ multimask_output (bool): Whether the model should predict multiple
1858
+ disambiguating masks, or return a single mask.
1859
+
1860
+ Returns:
1861
+ (list(dict)): A list over input images, where each element is
1862
+ as dictionary with the following keys.
1863
+ 'masks': (torch.Tensor) Batched binary mask predictions,
1864
+ with shape BxCxHxW, where B is the number of input prompts,
1865
+ C is determined by multimask_output, and (H, W) is the
1866
+ original size of the image.
1867
+ 'iou_predictions': (torch.Tensor) The model's predictions
1868
+ of mask quality, in shape BxC.
1869
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
1870
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
1871
+ to subsequent iterations of prediction.
1872
+ """
1873
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
1874
+ image_embeddings = self.image_encoder(input_images)
1875
+
1876
+ outputs = []
1877
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
1878
+ if "point_coords" in image_record:
1879
+ points = (image_record["point_coords"], image_record["point_labels"])
1880
+ else:
1881
+ points = None
1882
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1883
+ points=points,
1884
+ boxes=image_record.get("boxes", None),
1885
+ masks=image_record.get("mask_inputs", None),
1886
+ )
1887
+ low_res_masks, iou_predictions = self.mask_decoder(
1888
+ image_embeddings=curr_embedding.unsqueeze(0),
1889
+ image_pe=self.prompt_encoder.get_dense_pe(),
1890
+ sparse_prompt_embeddings=sparse_embeddings,
1891
+ dense_prompt_embeddings=dense_embeddings,
1892
+ multimask_output=multimask_output,
1893
+ )
1894
+ masks = self.postprocess_masks(
1895
+ low_res_masks,
1896
+ input_size=image_record["image"].shape[-2:],
1897
+ original_size=image_record["original_size"],
1898
+ )
1899
+ masks = masks > self.mask_threshold
1900
+ outputs.append(
1901
+ {
1902
+ "masks": masks,
1903
+ "iou_predictions": iou_predictions,
1904
+ "low_res_logits": low_res_masks,
1905
+ }
1906
+ )
1907
+ return outputs
1908
+
1909
+ def postprocess_masks(
1910
+ self,
1911
+ masks: torch.Tensor,
1912
+ input_size: Tuple[int, ...],
1913
+ original_size: Tuple[int, ...],
1914
+ ) -> torch.Tensor:
1915
+ """
1916
+ Remove padding and upscale masks to the original image size.
1917
+
1918
+ Arguments:
1919
+ masks (torch.Tensor): Batched masks from the mask_decoder,
1920
+ in BxCxHxW format.
1921
+ input_size (tuple(int, int)): The size of the image input to the
1922
+ model, in (H, W) format. Used to remove padding.
1923
+ original_size (tuple(int, int)): The original size of the image
1924
+ before resizing for input to the model, in (H, W) format.
1925
+
1926
+ Returns:
1927
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
1928
+ is given by original_size.
1929
+ """
1930
+ masks = F.interpolate(
1931
+ masks,
1932
+ (self.image_encoder.img_size, self.image_encoder.img_size),
1933
+ mode="bilinear",
1934
+ align_corners=False,
1935
+ )
1936
+ masks = masks[..., : input_size[0], : input_size[1]]
1937
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
1938
+ return masks
1939
+
1940
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
1941
+ """Normalize pixel values and pad to a square input."""
1942
+ # Normalize colors
1943
+ # TODO
1944
+ x = (x - self.pixel_mean) / self.pixel_std
1945
+
1946
+ # Pad
1947
+ h, w = x.shape[-2:]
1948
+ padh = self.image_encoder.img_size - h
1949
+ padw = self.image_encoder.img_size - w
1950
+ x = F.pad(x, (0, padw, 0, padh))
1951
+ return x
SegVol_v1.pth → pytorch_model.bin RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b751dc95f1a0c0c6086c1e6fa7f8a17bbb87635e5226e15f5d156fbd364dbb85
3
- size 1660308695
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:500f2758a8f989339b2b2baf09a819169bc87549795193d3cfe505726ac0b399
3
+ size 723726667
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip", "special_tokens_map_file": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip/special_tokens_map.json", "tokenizer_class": "CLIPTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff