yuxin commited on
Commit
da9b19f
1 Parent(s): 420b0c2

mk quick start

Browse files
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SegVolModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model_segvol_single.SegVolConfig",
7
+ "AutoModel": "model_segvol_single.SegVolModel"
8
+ },
9
+ "model_type": "segvol",
10
+ "patch_size": [
11
+ 4,
12
+ 16,
13
+ 16
14
+ ],
15
+ "spatial_size": [
16
+ 32,
17
+ 256,
18
+ 256
19
+ ],
20
+ "test_mode": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.18.0"
23
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model_segvol_single.py ADDED
@@ -0,0 +1,1951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ import numpy as np
3
+ import monai.transforms as transforms
4
+ import nibabel as nib
5
+ from scipy import sparse
6
+ import ast
7
+
8
+ class SegVolConfig(PretrainedConfig):
9
+ model_type = "segvol"
10
+
11
+ def __init__(
12
+ self,
13
+ test_mode=True,
14
+ **kwargs,
15
+ ):
16
+ self.spatial_size = [32, 256, 256]
17
+ self.patch_size = [4, 16, 16]
18
+ self.test_mode = test_mode
19
+ super().__init__(**kwargs)
20
+
21
+ class SegVolModel(PreTrainedModel):
22
+ config_class = SegVolConfig
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ sam_model = _build_sam(
27
+ image_encoder_type='vit',
28
+ embed_dim = 768,
29
+ patch_size=self.config.patch_size,
30
+ checkpoint=None,
31
+ image_size=self.config.spatial_size,
32
+ )
33
+ self.model = SegVol(
34
+ image_encoder=sam_model.image_encoder,
35
+ mask_decoder=sam_model.mask_decoder,
36
+ prompt_encoder=sam_model.prompt_encoder,
37
+ roi_size=self.config.spatial_size,
38
+ patch_size=self.config.patch_size,
39
+ # clip_model=self.config.clip_model,
40
+ test_mode=self.config.test_mode,
41
+ )
42
+
43
+ self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
44
+
45
+ def forward_test(self,
46
+ image,
47
+ zoomed_image=None,
48
+ text_prompt=None,
49
+ bbox_prompt_group=None,
50
+ point_prompt_group=None,
51
+ use_zoom=True,):
52
+ device = image.device
53
+ assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
54
+ assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None), 'Drive SegVol using at least one type of prompt'
55
+ bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
56
+ if bbox_prompt_group is not None:
57
+ bbox_prompt, bbox_prompt_map = bbox_prompt_group
58
+ if point_prompt_group is not None:
59
+ point_prompt, point_prompt_map = point_prompt_group
60
+ volume_shape = image[0][0].shape
61
+
62
+ with torch.no_grad():
63
+ logits_global_single = self.model(zoomed_image,
64
+ text=text_prompt,
65
+ boxes=bbox_prompt,
66
+ points=point_prompt)
67
+ logits_global_single = F.interpolate(
68
+ logits_global_single.cpu(),
69
+ size=volume_shape, mode='nearest')
70
+ if not use_zoom:
71
+ return logits_global_single
72
+
73
+ if point_prompt_map is not None:
74
+ binary_points = F.interpolate(
75
+ point_prompt_map.float(),
76
+ size=volume_shape, mode='nearest')
77
+ if bbox_prompt_map is not None:
78
+ binary_cube = F.interpolate(
79
+ bbox_prompt_map.float(),
80
+ size=volume_shape, mode='nearest')
81
+
82
+ min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(self.config.spatial_size, logits_global_single[0][0])
83
+ if min_d is None:
84
+ print('Fail to detect foreground!')
85
+ return logits_global_single
86
+
87
+ # Crop roi
88
+ image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
89
+ global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
90
+
91
+ assert not (bbox_prompt is not None and point_prompt is not None), 'Do not use point prompt and box prompt at the same time.'
92
+ prompt_reflection = None
93
+ if bbox_prompt is not None:
94
+ binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
95
+ prompt_reflection = (
96
+ binary_cube_cropped,
97
+ global_preds
98
+ )
99
+ if point_prompt is not None:
100
+ binary_points_cropped = binary_points[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
101
+ prompt_reflection = (
102
+ binary_points_cropped,
103
+ global_preds
104
+ )
105
+
106
+ ## inference
107
+ with torch.no_grad():
108
+ logits_single_cropped = sliding_window_inference(
109
+ image_single_cropped.to(device), prompt_reflection,
110
+ self.config.spatial_size, 1, self.model, 0.5,
111
+ text=text_prompt,
112
+ use_box=bbox_prompt is not None,
113
+ use_point=point_prompt is not None,
114
+ )
115
+ logits_single_cropped = logits_single_cropped.cpu().squeeze()
116
+ logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
117
+ return logits_global_single
118
+
119
+ def forward_train(self, image, train_organs, train_labels):
120
+ loss = self.model(image, text=None, boxes=None, points=None,
121
+ train_organs=train_organs,
122
+ train_labels=train_labels)
123
+ return loss
124
+
125
+ def forward(self, **kwargs):
126
+ if self.config.test_mode:
127
+ return self.forward_test(kwargs['image'],
128
+ kwargs['zoomed_image'],
129
+ kwargs['text_prompt'],
130
+ kwargs['bbox_prompt_group'],
131
+ kwargs['point_prompt_group'],
132
+ kwargs['use_zoom'])
133
+ else:
134
+ return self.forward_train(kwargs['image'],
135
+ kwargs['train_organs'],
136
+ kwargs['train_labels'])
137
+
138
+ # processor
139
+ class SegVolProcessor():
140
+ def __init__(self, spatial_size) -> None:
141
+ self.img_loader = transforms.LoadImage()
142
+ self.transform4test = transforms.Compose(
143
+ [
144
+ DimTranspose(keys=["image", "label"]),
145
+ MinMaxNormalization(),
146
+ transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
147
+ transforms.ToTensord(keys=["image", "label"]),
148
+ ]
149
+ )
150
+ self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
151
+ self.transform4train = transforms.Compose(
152
+ [
153
+ # transforms.AddChanneld(keys=["image"]),
154
+ DimTranspose(keys=["image", "label"]),
155
+ MinMaxNormalization(),
156
+ transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
157
+ transforms.SpatialPadd(keys=["image", "label"], spatial_size=spatial_size, mode='constant'),
158
+ transforms.OneOf(transforms=[
159
+ transforms.Resized(keys=["image", "label"],spatial_size=spatial_size),
160
+ transforms.RandCropByPosNegLabeld(
161
+ keys=["image", "label"],
162
+ label_key="label",
163
+ spatial_size=spatial_size,
164
+ pos=2,
165
+ neg=1,
166
+ num_samples=1,
167
+ image_key="image",
168
+ image_threshold=0,
169
+ ),
170
+ ],
171
+ weights=[1, 3]
172
+ ),
173
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
174
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
175
+ transforms.RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
176
+ transforms.RandScaleIntensityd(keys="image", factors=0.2, prob=0.2),
177
+ transforms.RandShiftIntensityd(keys="image", offsets=0.2, prob=0.2),
178
+ transforms.ToTensord(keys=["image", "label"]),
179
+ ]
180
+ )
181
+
182
+ # ct_path is path for a ct scan file with nii.gz format
183
+ # gt_path is path for a ground truth file with nii.gz format
184
+ def preprocess_ct_gt(self, ct_path, gt_path, category):
185
+ item = {}
186
+ # generate ct_voxel_ndarray
187
+ ct_voxel_ndarray, _ = self.img_loader(ct_path)
188
+ ct_voxel_ndarray = np.array(ct_voxel_ndarray).squeeze()
189
+ ct_shape = ct_voxel_ndarray.shape
190
+ ct_voxel_ndarray = np.expand_dims(ct_voxel_ndarray, axis=0)
191
+ ct_voxel_ndarray = self.ForegroundNorm(ct_voxel_ndarray)
192
+ item['image'] = ct_voxel_ndarray
193
+
194
+ # generate gt_voxel_ndarray
195
+ gt_voxel_ndarray, _ = self.img_loader(gt_path)
196
+ gt_voxel_ndarray = np.array(gt_voxel_ndarray)
197
+ present_categories = np.unique(gt_voxel_ndarray)
198
+ gt_masks = []
199
+ for cls_idx in range(len(category)):
200
+ # ignore background
201
+ cls = cls_idx + 1
202
+ if cls not in present_categories:
203
+ gt_voxel_ndarray_category = np.zeros(ct_shape)
204
+ gt_masks.append(gt_voxel_ndarray_category)
205
+ else:
206
+ gt_voxel_ndarray_category = gt_voxel_ndarray.copy()
207
+ gt_voxel_ndarray_category[gt_voxel_ndarray != cls] = 0
208
+ gt_voxel_ndarray_category[gt_voxel_ndarray == cls] = 1
209
+ gt_masks.append(gt_voxel_ndarray_category)
210
+ gt_voxel_ndarray = np.stack(gt_masks, axis=0)
211
+ assert gt_voxel_ndarray.shape[0] == len(category) and gt_voxel_ndarray.shape[1:] == ct_voxel_ndarray.shape[1:]
212
+ item['label'] = gt_voxel_ndarray.astype(np.int32)
213
+
214
+ # transform
215
+ return item['image'], item['label']
216
+
217
+ def load_uniseg_case(self, ct_npy_path, gt_npy_path):
218
+ img_array = np.load(ct_npy_path)
219
+ allmatrix_sp= sparse.load_npz(gt_npy_path)
220
+ if 'mask_' in gt_npy_path:
221
+ gt_shape = ast.literal_eval(gt_npy_path.split('_')[-1].replace('.npz', ''))
222
+ else:
223
+ gt_shape = ast.literal_eval(gt_npy_path.split('.')[-2])
224
+ gt_array=allmatrix_sp.toarray().reshape(gt_shape)
225
+ return img_array, gt_array
226
+
227
+ def ForegroundNorm(self, ct_narray):
228
+ ct_voxel_ndarray = ct_narray.copy()
229
+ ct_voxel_ndarray = ct_voxel_ndarray.flatten()
230
+ thred = np.mean(ct_voxel_ndarray)
231
+ voxel_filtered = ct_voxel_ndarray[(ct_voxel_ndarray > thred)]
232
+ upper_bound = np.percentile(voxel_filtered, 99.95)
233
+ lower_bound = np.percentile(voxel_filtered, 00.05)
234
+ mean = np.mean(voxel_filtered)
235
+ std = np.std(voxel_filtered)
236
+ ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
237
+ ct_narray = (ct_narray - mean) / max(std, 1e-8)
238
+ return ct_narray
239
+
240
+ def zoom_transform(self, ct_npy, gt_npy):
241
+ item = {
242
+ 'image': ct_npy,
243
+ 'label': gt_npy
244
+ }
245
+ item = self.transform4test(item)
246
+ item_zoom_out = self.zoom_out_transform(item)
247
+ item['zoom_out_image'] = item_zoom_out['image']
248
+ item['zoom_out_label'] = item_zoom_out['label']
249
+ return item
250
+
251
+ def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0, device='cpu'):
252
+ point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
253
+ points_single = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
254
+ binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
255
+ return points_single, binary_points_resize
256
+
257
+ def bbox_prompt_b(self, label_single_resize, device='cpu'):
258
+ box_single = generate_box(label_single_resize).unsqueeze(0).float().to(device)
259
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
260
+ return box_single, binary_cube_resize
261
+
262
+ def dice_score(self, preds, labels, device='cpu'):
263
+ assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
264
+ predict = preds.view(1, -1).to(device)
265
+ target = labels.view(1, -1).to(device)
266
+
267
+ predict = torch.sigmoid(predict)
268
+ predict = torch.where(predict > 0.5, 1., 0.)
269
+
270
+ tp = torch.sum(torch.mul(predict, target))
271
+ den = torch.sum(predict) + torch.sum(target) + 1
272
+ dice = 2 * tp / den
273
+ return dice
274
+
275
+ def save_preds(self, ct_path, save_path, logits_mask, start_coord, end_coord):
276
+ ct = nib.load(ct_path)
277
+ logits_mask = logits_mask.transpose(-1, -3)
278
+ start_coord[-1], start_coord[-3] = start_coord[-3], start_coord[-1]
279
+ end_coord[-1], end_coord[-3] = end_coord[-3], end_coord[-1]
280
+ preds_save = torch.zeros(ct.shape)
281
+ preds_save[start_coord[0]:end_coord[0],
282
+ start_coord[1]:end_coord[1],
283
+ start_coord[2]:end_coord[2]] = torch.sigmoid(logits_mask)
284
+ preds_save = torch.where(preds_save > 0.5, 1., 0.).numpy()
285
+ preds_nii = nib.Nifti1Image(preds_save, affine=ct.affine, header=ct.header)
286
+ nib.save(preds_nii, save_path)
287
+
288
+ def train_transform(self, ct_npy, gt_npy):
289
+ item = {
290
+ 'image': ct_npy,
291
+ 'label': gt_npy
292
+ }
293
+ item = self.transform4train(item)
294
+ if type(item) is list:
295
+ assert len(item) == 1
296
+ item = item[0]
297
+ return item
298
+
299
+ class MinMaxNormalization(transforms.Transform):
300
+ def __call__(self, data):
301
+ d = dict(data)
302
+ k = "image"
303
+ d[k] = d[k] - d[k].min()
304
+ d[k] = d[k] / np.clip(d[k].max(), a_min=1e-8, a_max=None)
305
+ return d
306
+
307
+ class DimTranspose(transforms.Transform):
308
+ def __init__(self, keys):
309
+ self.keys = keys
310
+
311
+ def __call__(self, data):
312
+ d = dict(data)
313
+ for key in self.keys:
314
+ d[key] = np.swapaxes(d[key], -1, -3)
315
+ return d
316
+
317
+ # prompts
318
+ def generate_box(pred_pre, bbox_shift=None):
319
+ meaning_post_label = pred_pre # [h, w, d]
320
+ ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
321
+ if all(tensor.nelement() == 0 for tensor in ones_idx):
322
+ bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
323
+ return bboxes
324
+ min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
325
+ max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
326
+
327
+
328
+ if bbox_shift is None:
329
+ corner_min = []
330
+ corner_max = []
331
+ shape = meaning_post_label.shape
332
+ for coor in min_coords:
333
+ coor_ = max(0, coor)
334
+ corner_min.append(coor_)
335
+ for idx, coor in enumerate(max_coords):
336
+ coor_ = min(shape[idx], coor)
337
+ corner_max.append(coor_)
338
+ corner_min = torch.tensor(corner_min)
339
+ corner_max = torch.tensor(corner_max)
340
+ return torch.cat((corner_min, corner_max), dim=0)
341
+ else:
342
+ # add perturbation to bounding box coordinates
343
+ corner_min = []
344
+ corner_max = []
345
+ shape = meaning_post_label.shape
346
+ for coor in min_coords:
347
+ coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
348
+ corner_min.append(coor_)
349
+ for idx, coor in enumerate(max_coords):
350
+ coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
351
+ corner_max.append(coor_)
352
+ corner_min = torch.tensor(corner_min)
353
+ corner_max = torch.tensor(corner_max)
354
+ return torch.cat((corner_min, corner_max), dim=0)
355
+
356
+
357
+ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
358
+ spacial_dim = 3
359
+ points = torch.zeros((0, 3))
360
+ labels = torch.zeros((0))
361
+ pos_thred = 0.9
362
+ neg_thred = 0.1
363
+
364
+ # get pos/net indices
365
+ positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
366
+ negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
367
+
368
+ ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
369
+ if all(tmp.nelement() == 0 for tmp in ones_idx):
370
+ # all neg
371
+ num_positive_extra = 0
372
+ selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
373
+ points = torch.cat((points, selected_positive_point), dim=0)
374
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
375
+ else:
376
+ # random select a pos point
377
+ random_idx = torch.randint(len(positive_indices[0]), (1,))
378
+ selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
379
+ points = torch.cat((points, selected_positive_point), dim=0)
380
+ labels = torch.cat((labels, torch.ones((1))))
381
+
382
+ if num_positive_extra > 0:
383
+ pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
384
+ extra_positive_points = []
385
+ for pos_idx in pos_idx_list:
386
+ extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
387
+ extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
388
+ points = torch.cat((points, extra_positive_points), dim=0)
389
+ labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
390
+
391
+ if num_negative_extra > 0:
392
+ neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
393
+ extra_negative_points = []
394
+ for neg_idx in neg_idx_list:
395
+ extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
396
+ extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
397
+ points = torch.cat((points, extra_negative_points), dim=0)
398
+ labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
399
+
400
+ if fix_extra_point_num is None:
401
+ left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
402
+ else:
403
+ left_point_num = fix_extra_point_num + 1 - labels.shape[0]
404
+
405
+ for _ in range(left_point_num):
406
+ ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
407
+ points = torch.cat((points, ignore_point), dim=0)
408
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
409
+
410
+ return points, labels
411
+
412
+ # SegVol
413
+ import torch
414
+ import torch.nn as nn
415
+ import torch.nn.functional as F
416
+ import numpy as np
417
+ from transformers import CLIPTextModel, CLIPTextConfig
418
+ import random
419
+
420
+ #%% set up model
421
+ class SegVol(nn.Module):
422
+ def __init__(self,
423
+ image_encoder,
424
+ mask_decoder,
425
+ prompt_encoder,
426
+ roi_size,
427
+ patch_size,
428
+ # clip_model,
429
+ test_mode=False,
430
+ ):
431
+ super().__init__()
432
+ self.image_encoder = image_encoder
433
+ self.mask_decoder = mask_decoder
434
+ self.prompt_encoder = prompt_encoder
435
+ self.text_encoder = TextEncoder()
436
+ self.feat_shape = np.array(roi_size)/np.array(patch_size)
437
+ self.test_mode = test_mode
438
+ self.dice_loss = BinaryDiceLoss()
439
+ self.bce_loss = BCELoss()
440
+ self.decoder_iter = 6
441
+
442
+ def forward(self, image, text=None, boxes=None, points=None, **kwargs):
443
+ bs = image.shape[0]
444
+ img_shape = (image.shape[2], image.shape[3], image.shape[4])
445
+ image_embedding, _ = self.image_encoder(image)
446
+ image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
447
+ int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
448
+ # test mode
449
+ if self.test_mode:
450
+ return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
451
+
452
+ # train mode
453
+ ## sl
454
+ sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
455
+ ## ssl
456
+ # ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
457
+ return sl_loss
458
+
459
+ def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
460
+ device = image_embedding.device
461
+ with torch.no_grad():
462
+ if boxes is not None:
463
+ if len(boxes.shape) == 2:
464
+ boxes = boxes[:, None, :] # (B, 1, 6)
465
+ if text is not None:
466
+ text_embedding = self.text_encoder(text, device) # (B, 768)
467
+ else:
468
+ text_embedding = None
469
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
470
+ points=points,
471
+ boxes=boxes,
472
+ masks=None,
473
+ text_embedding=text_embedding,
474
+ )
475
+
476
+ dense_pe = self.prompt_encoder.get_dense_pe()
477
+ low_res_masks, _ = self.mask_decoder(
478
+ image_embeddings=image_embedding,
479
+ text_embedding = text_embedding,
480
+ image_pe=dense_pe,
481
+ sparse_prompt_embeddings=sparse_embeddings,
482
+ dense_prompt_embeddings=dense_embeddings,
483
+ multimask_output=False,
484
+ )
485
+ logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
486
+ return logits
487
+
488
+ def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
489
+ device = image_embedding.device
490
+ iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels, device)
491
+ # select prompt
492
+ prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
493
+ [None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
494
+ [iter_bboxes, iter_points, None]]
495
+ sl_loss = 0
496
+ for prompt in prompt_options:
497
+ bboxes, points, organs = prompt
498
+ logits = self.forward_decoder(image_embedding, img_shape, text=organs, boxes=bboxes, points=points)
499
+ # cal loss
500
+ sl_loss_dice = self.dice_loss.forward(logits.squeeze().float(), train_labels.squeeze().float())
501
+ sl_loss_bce = self.bce_loss.forward(logits.squeeze().float(), train_labels.squeeze().float())
502
+ sl_loss += sl_loss_dice + sl_loss_bce
503
+ return sl_loss
504
+
505
+ # def unsupervised_forward(self, image, image_embedding, pseudo_seg_cleaned, img_shape):
506
+ # sll_loss = 0
507
+ # for iter in range(self.decoder_iter):
508
+ # if iter % 2 == 0:
509
+ # pseudo_labels, pseudo_points_prompt = self.build_pseudo_point_prompt_label(image.shape, pseudo_seg_cleaned)
510
+ # logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=None, points=pseudo_points_prompt)
511
+ # else:
512
+ # pseudo_labels, pseudo_bboxes_prompt = self.build_pseudo_box_prompt_label(image.shape, pseudo_seg_cleaned)
513
+ # logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=pseudo_bboxes_prompt, points=None)
514
+ # # cal loss
515
+ # sll_loss_dice = self.dice_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
516
+ # sll_loss_bce = self.bce_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
517
+ # sll_loss += sll_loss_dice + sll_loss_bce
518
+ # return sll_loss
519
+
520
+ def build_prompt_label(self, bs, training_organs, train_labels, device):
521
+ # generate prompt & label
522
+ iter_organs = []
523
+ iter_bboxes = []
524
+ iter_points_ax = []
525
+ iter_point_labels = []
526
+ for sample_idx in range(bs):
527
+ # organ prompt
528
+ iter_organs.append(training_organs)
529
+ # box prompt
530
+ box = generate_box(train_labels[sample_idx], bbox_shift=10)
531
+ iter_bboxes.append(box)
532
+ # point prompt
533
+ num_positive_extra_max, num_negative_extra_max = 10, 10
534
+ num_positive_extra = random.randint(0, num_positive_extra_max)
535
+ num_negative_extra = random.randint(0, num_negative_extra_max)
536
+ point, point_label = select_points(
537
+ train_labels[sample_idx],
538
+ num_positive_extra=num_positive_extra,
539
+ num_negative_extra=num_negative_extra,
540
+ fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
541
+ iter_points_ax.append(point)
542
+ iter_point_labels.append(point_label)
543
+ # batched prompt
544
+ iter_points_ax = torch.stack(iter_points_ax, dim=0).to(device)
545
+ iter_point_labels = torch.stack(iter_point_labels, dim=0).to(device)
546
+ iter_points = (iter_points_ax, iter_point_labels)
547
+ iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(device)
548
+ return iter_points, iter_bboxes, iter_organs
549
+
550
+ # def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
551
+ # pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
552
+ # # generate points
553
+ # points = []
554
+ # point_labels = []
555
+ # for batch_idx in range(input_shape[0]):
556
+ # # generate pseudo label
557
+ # unique_ids = torch.unique(seg_labels[batch_idx])
558
+ # unique_ids = unique_ids[unique_ids != -1]
559
+ # region_id = random.choice(unique_ids).item()
560
+ # pseudo_labels[batch_idx][seg_labels[batch_idx]==region_id] = 1
561
+ # # generate point prompt
562
+ # num_positive_extra_max, num_negative_extra_max = 10, 10
563
+ # num_positive_extra = random.randint(4, num_positive_extra_max)
564
+ # num_negative_extra = random.randint(0, num_negative_extra_max)
565
+ # assert len(pseudo_labels[batch_idx][0].shape) == 3
566
+ # point, point_label = select_points(
567
+ # pseudo_labels[batch_idx][0],
568
+ # num_positive_extra=num_positive_extra,
569
+ # num_negative_extra=num_negative_extra,
570
+ # fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
571
+ # points.append(point)
572
+ # point_labels.append(point_label)
573
+ # points = torch.stack(points, dim=0).to(self.custom_device)
574
+ # point_labels = torch.stack(point_labels, dim=0).to(self.custom_device)
575
+ # pseudo_points_prompt = (points, point_labels)
576
+ # return pseudo_labels, pseudo_points_prompt
577
+
578
+ # def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
579
+ # pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
580
+ # iter_bboxes = []
581
+ # # generate boxes
582
+ # for batch_idx in range(input_shape[0]):
583
+ # # generate ori pseudo label
584
+ # unique_ids = torch.unique(seg_labels_cleaned[batch_idx])
585
+ # unique_ids = unique_ids[unique_ids != -1]
586
+ # region_id = random.choice(unique_ids).item()
587
+ # pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==region_id] = 1
588
+ # # generate box prompt
589
+ # box = generate_box(pseudo_labels[batch_idx][0])
590
+ # iter_bboxes.append(box)
591
+ # # refine pseudo label
592
+ # x_min, y_min, z_min, x_max, y_max, z_max = box
593
+ # binary_cube = torch.zeros_like(pseudo_labels[batch_idx][0]).int()
594
+ # binary_cube[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] = 1
595
+ # # cal iou
596
+ # mask_label = seg_labels_cleaned[batch_idx][0]
597
+ # assert binary_cube.shape == mask_label.shape, str(binary_cube.shape) + ' ' + str(mask_label.shape)
598
+ # mask_values_in_binary_cube = mask_label[binary_cube == 1]
599
+ # unique_mask_values = torch.unique(mask_values_in_binary_cube)
600
+ # # print('unique_mask_values ', unique_mask_values)
601
+ # for value in unique_mask_values:
602
+ # if value == -1: continue
603
+ # mask_area = (mask_label == value)
604
+ # intersection = (binary_cube & mask_area)
605
+ # iou = intersection.float().sum() / mask_area.float().sum()
606
+ # if iou > 0.90:
607
+ # # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
608
+ # pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
609
+
610
+ # bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
611
+ # return pseudo_labels, bboxes
612
+
613
+ class TextEncoder(nn.Module):
614
+ def __init__(self):
615
+ super().__init__()
616
+ config = CLIPTextConfig()
617
+ self.clip_text_model = CLIPTextModel(config)
618
+ self.tokenizer = None
619
+ self.dim_align = nn.Linear(512, 768)
620
+ # freeze text encoder
621
+ for param in self.clip_text_model.parameters():
622
+ param.requires_grad = False
623
+
624
+ def organ2tokens(self, organ_names, device):
625
+ text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
626
+ tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
627
+ for key in tokens.keys():
628
+ tokens[key] = tokens[key].to(device)
629
+ return tokens
630
+
631
+ def forward(self, text, device):
632
+ if text is None:
633
+ return None
634
+ if type(text) is str:
635
+ # text is supposed to be list
636
+ text = [text]
637
+ tokens = self.organ2tokens(text, device)
638
+ clip_outputs = self.clip_text_model(**tokens)
639
+ text_embedding = clip_outputs.pooler_output
640
+ text_embedding = self.dim_align(text_embedding)
641
+ return text_embedding
642
+
643
+ # loss
644
+ import torch
645
+ import torch.nn as nn
646
+
647
+ class BinaryDiceLoss(nn.Module):
648
+ def __init__(self, smooth=1, p=2, reduction='mean'):
649
+ super(BinaryDiceLoss, self).__init__()
650
+ self.smooth = smooth
651
+ self.p = p
652
+ self.reduction = reduction
653
+
654
+ def forward(self, predict, target):
655
+ predict = torch.sigmoid(predict)
656
+ target_ = target.clone()
657
+ target_[target == -1] = 0
658
+ assert predict.shape[0] == target.shape[0], "predict & target batch size don't match\n" + str(predict.shape) + '\n' + str(target.shape[0])
659
+ predict = predict.contiguous().view(predict.shape[0], -1)
660
+ target_ = target_.contiguous().view(target_.shape[0], -1)
661
+
662
+ num = torch.sum(torch.mul(predict, target_), dim=1)
663
+ den = torch.sum(predict, dim=1) + torch.sum(target_, dim=1) + self.smooth
664
+
665
+ dice_score = 2*num / den
666
+ dice_loss = 1 - dice_score
667
+
668
+ # dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]
669
+ dice_loss_avg = dice_loss.sum() / dice_loss.shape[0]
670
+
671
+ return dice_loss_avg
672
+
673
+ class BCELoss(nn.Module):
674
+ def __init__(self):
675
+ super(BCELoss, self).__init__()
676
+ self.criterion = nn.BCEWithLogitsLoss()
677
+
678
+ def forward(self, predict, target):
679
+ assert predict.shape == target.shape, 'predict & target shape do not match\n' + str(predict.shape) + '\n' + str(target.shape)
680
+ target_ = target.clone()
681
+ target_[target == -1] = 0
682
+
683
+ ce_loss = self.criterion(predict, target_)
684
+
685
+ return ce_loss
686
+
687
+ # monai inference
688
+
689
+ # Copyright (c) MONAI Consortium
690
+ # Licensed under the Apache License, Version 2.0 (the "License");
691
+ # you may not use this file except in compliance with the License.
692
+ # You may obtain a copy of the License at
693
+ # http://www.apache.org/licenses/LICENSE-2.0
694
+ # Unless required by applicable law or agreed to in writing, software
695
+ # distributed under the License is distributed on an "AS IS" BASIS,
696
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
697
+ # See the License for the specific language governing permissions and
698
+ # limitations under the License.
699
+
700
+ import warnings
701
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
702
+
703
+ import torch
704
+ import torch.nn.functional as F
705
+ import random
706
+
707
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
708
+ from monai.transforms import Resize
709
+ from monai.utils import (
710
+ BlendMode,
711
+ PytorchPadMode,
712
+ convert_data_type,
713
+ ensure_tuple,
714
+ fall_back_tuple,
715
+ look_up_option,
716
+ optional_import,
717
+ )
718
+
719
+ tqdm, _ = optional_import("tqdm", name="tqdm")
720
+
721
+ __all__ = ["sliding_window_inference"]
722
+
723
+ def logits2roi_coor(spatial_size, logits_global_single):
724
+ # crop predict
725
+ pred_global_single = torch.sigmoid(logits_global_single) > 0.5
726
+ ## get all pos idx
727
+ nonzero_indices = torch.nonzero(pred_global_single)
728
+ if nonzero_indices.shape[0] == 0:
729
+ return None, None, None, None, None, None
730
+ ## get boundary
731
+ min_d, max_d = nonzero_indices[:, 0].min(), nonzero_indices[:, 0].max()
732
+ min_h, max_h = nonzero_indices[:, 1].min(), nonzero_indices[:, 1].max()
733
+ min_w, max_w = nonzero_indices[:, 2].min(), nonzero_indices[:, 2].max()
734
+ ## padding
735
+ crop_d, crop_h, crop_w = max_d - min_d + 1, max_h - min_h + 1, max_w - min_w + 1,
736
+ window_d, window_h, window_w = spatial_size
737
+ padding_d, padding_h, padding_w = max(0, window_d-crop_d), max(0, window_h-crop_h), max(0, window_w-crop_w)
738
+ global_d, global_h, global_w = logits_global_single.shape
739
+ min_d = max(0, min_d - int(padding_d)//2)
740
+ min_h = max(0, min_h - int(padding_h)//2)
741
+ min_w = max(0, min_w - int(padding_w)//2)
742
+ max_d = min(global_d, max_d + int(padding_d)//2)
743
+ max_h = min(global_h, max_h + int(padding_h)//2)
744
+ max_w = min(global_w, max_w + int(padding_w)//2)
745
+ return min_d, min_h, min_w, max_d, max_h, max_w
746
+
747
+ def build_binary_cube(bbox, binary_cube_shape):
748
+ min_coord = bbox[0][:3].int().tolist()
749
+ max_coord = bbox[0][3:].int().tolist()
750
+ binary_cube = torch.zeros(binary_cube_shape)
751
+ binary_cube[min_coord[0]:max_coord[0]+1, min_coord[1]:max_coord[1]+1, min_coord[2]:max_coord[2]+1] = 1
752
+ return binary_cube
753
+
754
+ def build_binary_points(points, labels, shape):
755
+ binary_points = torch.zeros(shape, dtype=torch.int16)
756
+ binary_points[points[labels == 1, 0].long(), points[labels == 1, 1].long(), points[labels == 1, 2].long()] = 1
757
+ return binary_points
758
+
759
+ def sliding_window_inference(
760
+ inputs: torch.Tensor,
761
+ prompt_reflection: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
762
+ roi_size: Union[Sequence[int], int],
763
+ sw_batch_size: int,
764
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
765
+ overlap: float = 0.25,
766
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
767
+ sigma_scale: Union[Sequence[float], float] = 0.125,
768
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
769
+ cval: float = 0.0,
770
+ sw_device: Union[torch.device, str, None] = None,
771
+ device: Union[torch.device, str, None] = None,
772
+ progress: bool = False,
773
+ roi_weight_map: Union[torch.Tensor, None] = None,
774
+ *args: Any,
775
+ **kwargs: Any,
776
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
777
+ """
778
+ Sliding window inference on `inputs` with `predictor`.
779
+
780
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
781
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
782
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
783
+ could be ([128,64,256], [64,32,128]).
784
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
785
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
786
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
787
+
788
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
789
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
790
+
791
+ Args:
792
+ inputs: input image to be processed (assuming NCHW[D])
793
+ roi_size: the spatial window size for inferences.
794
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
795
+ if the components of the `roi_size` are non-positive values, the transform will use the
796
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
797
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
798
+ sw_batch_size: the batch size to run window slices.
799
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
800
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
801
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
802
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
803
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
804
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
805
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
806
+ to ensure the scaled output ROI sizes are still integers.
807
+ If the `predictor`'s input and output spatial sizes are different,
808
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
809
+ overlap: Amount of overlap between scans.
810
+ mode: {``"constant"``, ``"gaussian"``}
811
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
812
+
813
+ - ``"constant``": gives equal weight to all predictions.
814
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
815
+
816
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
817
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
818
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
819
+ spatial dimensions.
820
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
821
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
822
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
823
+ cval: fill value for 'constant' padding mode. Default: 0
824
+ sw_device: device for the window data.
825
+ By default the device (and accordingly the memory) of the `inputs` is used.
826
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
827
+ device: device for the stitched output prediction.
828
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
829
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
830
+ `inputs` and `roi_size`. Output is on the `device`.
831
+ progress: whether to print a `tqdm` progress bar.
832
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
833
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
834
+ args: optional args to be passed to ``predictor``.
835
+ kwargs: optional keyword args to be passed to ``predictor``.
836
+
837
+ Note:
838
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
839
+
840
+ """
841
+ print('sliding window inference for ROI')
842
+ text = kwargs['text']
843
+ use_box = kwargs['use_box']
844
+ use_point = kwargs['use_point']
845
+ assert not (use_box and use_point)
846
+ compute_dtype = inputs.dtype
847
+ num_spatial_dims = len(inputs.shape) - 2
848
+ if overlap < 0 or overlap >= 1:
849
+ raise ValueError("overlap must be >= 0 and < 1.")
850
+
851
+ # determine image spatial size and batch size
852
+ # Note: all input images must have the same image size and batch size
853
+ batch_size, _, *image_size_ = inputs.shape
854
+
855
+ if device is None:
856
+ device = inputs.device
857
+ if sw_device is None:
858
+ sw_device = inputs.device
859
+
860
+ roi_size = fall_back_tuple(roi_size, image_size_)
861
+ # in case that image size is smaller than roi size
862
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
863
+ pad_size = []
864
+ for k in range(len(inputs.shape) - 1, 1, -1):
865
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
866
+ half = diff // 2
867
+ pad_size.extend([half, diff - half])
868
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
869
+ #############
870
+ if use_point or use_box:
871
+ binary_prompt_map, global_preds = prompt_reflection
872
+ global_preds = F.pad(global_preds, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
873
+ #############
874
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
875
+
876
+ # Store all slices in list
877
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
878
+ num_win = len(slices) # number of windows per image
879
+ total_slices = num_win * batch_size # total number of windows
880
+
881
+ # Create window-level importance map
882
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
883
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
884
+ importance_map = roi_weight_map
885
+ else:
886
+ try:
887
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
888
+ except BaseException as e:
889
+ raise RuntimeError(
890
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
891
+ ) from e
892
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
893
+ # handle non-positive weights
894
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
895
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
896
+
897
+ # Perform predictions
898
+ dict_key, output_image_list, count_map_list = None, [], []
899
+ _initialized_ss = -1
900
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
901
+
902
+ # for each patch
903
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
904
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
905
+ unravel_slice = [
906
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
907
+ for idx in slice_range
908
+ ]
909
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
910
+ #############
911
+
912
+ boxes = None
913
+ points = None
914
+ if use_point:
915
+ window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
916
+ point, point_label = select_points(window_binary_prompt_map.squeeze())
917
+ points = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
918
+ pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
919
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
920
+ if use_box:
921
+ if num_win == 1:
922
+ window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
923
+ boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().to(device)
924
+ else:
925
+ pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
926
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
927
+ seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
928
+ #############
929
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
930
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
931
+ if isinstance(seg_prob_out, torch.Tensor):
932
+ seg_prob_tuple = (seg_prob_out,)
933
+ elif isinstance(seg_prob_out, Mapping):
934
+ if dict_key is None:
935
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
936
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
937
+ is_tensor_output = False
938
+ else:
939
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
940
+ is_tensor_output = False
941
+
942
+ # for each output in multi-output list
943
+ for ss, seg_prob in enumerate(seg_prob_tuple):
944
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
945
+
946
+ # compute zoom scale: out_roi_size/in_roi_size
947
+ zoom_scale = []
948
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
949
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
950
+ ):
951
+ _scale = out_w_i / float(in_w_i)
952
+ if not (img_s_i * _scale).is_integer():
953
+ warnings.warn(
954
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
955
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
956
+ )
957
+ zoom_scale.append(_scale)
958
+
959
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
960
+ # construct multi-resolution outputs
961
+ output_classes = seg_prob.shape[1]
962
+ output_shape = [batch_size, output_classes] + [
963
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
964
+ ]
965
+ # allocate memory to store the full output and the count for overlapping parts
966
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
967
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
968
+ _initialized_ss += 1
969
+
970
+ # resizing the importance_map
971
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
972
+
973
+ # store the result in the proper location of the full output. Apply weights from importance map.
974
+ for idx, original_idx in zip(slice_range, unravel_slice):
975
+ # zoom roi
976
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
977
+ for axis in range(2, len(original_idx_zoom)):
978
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
979
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
980
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
981
+ warnings.warn(
982
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
983
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
984
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
985
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
986
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
987
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
988
+ )
989
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
990
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
991
+ # store results and weights
992
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
993
+ count_map_list[ss][original_idx_zoom] += (
994
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
995
+ )
996
+
997
+ # account for any overlapping sections
998
+ for ss in range(len(output_image_list)):
999
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
1000
+
1001
+ # remove padding if image_size smaller than roi_size
1002
+ for ss, output_i in enumerate(output_image_list):
1003
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
1004
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
1005
+
1006
+ zoom_scale = [
1007
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
1008
+ ]
1009
+
1010
+ final_slicing: List[slice] = []
1011
+ for sp in range(num_spatial_dims):
1012
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
1013
+ slice_dim = slice(
1014
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
1015
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
1016
+ )
1017
+ final_slicing.insert(0, slice_dim)
1018
+ while len(final_slicing) < len(output_i.shape):
1019
+ final_slicing.insert(0, slice(None))
1020
+ output_image_list[ss] = output_i[final_slicing]
1021
+
1022
+ if dict_key is not None: # if output of predictor is a dict
1023
+ final_output = dict(zip(dict_key, output_image_list))
1024
+ else:
1025
+ final_output = tuple(output_image_list) # type: ignore
1026
+ return final_output[0] if is_tensor_output else final_output # type: ignore
1027
+
1028
+
1029
+ def _get_scan_interval(
1030
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
1031
+ ) -> Tuple[int, ...]:
1032
+ """
1033
+ Compute scan interval according to the image size, roi size and overlap.
1034
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
1035
+ use 1 instead to make sure sliding window works.
1036
+
1037
+ """
1038
+ if len(image_size) != num_spatial_dims:
1039
+ raise ValueError("image coord different from spatial dims.")
1040
+ if len(roi_size) != num_spatial_dims:
1041
+ raise ValueError("roi coord different from spatial dims.")
1042
+
1043
+ scan_interval = []
1044
+ for i in range(num_spatial_dims):
1045
+ if roi_size[i] == image_size[i]:
1046
+ scan_interval.append(int(roi_size[i]))
1047
+ else:
1048
+ interval = int(roi_size[i] * (1 - overlap))
1049
+ scan_interval.append(interval if interval > 0 else 1)
1050
+ return tuple(scan_interval)
1051
+
1052
+ # build 3D SAM
1053
+ import torch
1054
+ import numpy as np
1055
+ from monai.networks.nets import ViT
1056
+
1057
+ def _build_sam(
1058
+ image_encoder_type,
1059
+ embed_dim,
1060
+ patch_size,
1061
+ checkpoint,
1062
+ image_size,
1063
+ ):
1064
+ mlp_dim = 3072
1065
+ num_layers = 12
1066
+ num_heads = 12
1067
+ pos_embed = 'perceptron'
1068
+ dropout_rate = 0.0
1069
+
1070
+ image_encoder=ViT(
1071
+ in_channels=1,
1072
+ img_size=image_size,
1073
+ patch_size=patch_size,
1074
+ hidden_size=embed_dim,
1075
+ mlp_dim=mlp_dim,
1076
+ num_layers=num_layers,
1077
+ num_heads=num_heads,
1078
+ pos_embed=pos_embed,
1079
+ classification=False,
1080
+ dropout_rate=dropout_rate,
1081
+ )
1082
+ image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
1083
+
1084
+ if checkpoint is not None:
1085
+ with open(checkpoint, "rb") as f:
1086
+ state_dict = torch.load(f, map_location='cpu')['state_dict']
1087
+ encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
1088
+ image_encoder.load_state_dict(encoder_dict)
1089
+ print(f'===> image_encoder.load_param: {checkpoint}')
1090
+ sam = Sam(
1091
+ image_encoder=image_encoder,
1092
+ prompt_encoder=PromptEncoder(
1093
+ embed_dim=embed_dim,
1094
+ image_embedding_size=image_embedding_size,
1095
+ input_image_size=image_size,
1096
+ mask_in_chans=16,
1097
+ ),
1098
+ mask_decoder=MaskDecoder(
1099
+ image_encoder_type=image_encoder_type,
1100
+ num_multimask_outputs=3,
1101
+ transformer=TwoWayTransformer(
1102
+ depth=2,
1103
+ embedding_dim=embed_dim,
1104
+ mlp_dim=2048,
1105
+ num_heads=8,
1106
+ ),
1107
+ transformer_dim=embed_dim,
1108
+ iou_head_depth=3,
1109
+ iou_head_hidden_dim=256,
1110
+ image_size=np.array(image_size),
1111
+ patch_size=np.array(patch_size),
1112
+ ),
1113
+ pixel_mean=[123.675, 116.28, 103.53],
1114
+ pixel_std=[58.395, 57.12, 57.375],
1115
+ )
1116
+ sam.eval()
1117
+ return sam
1118
+
1119
+ # mask decoder
1120
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1121
+ # All rights reserved.
1122
+
1123
+ # This source code is licensed under the license found in the
1124
+ # LICENSE file in the root directory of this source tree.
1125
+
1126
+ import torch
1127
+ from torch import nn
1128
+ from torch.nn import functional as F
1129
+
1130
+ from typing import List, Tuple, Type, Optional
1131
+
1132
+ class MaskDecoder(nn.Module):
1133
+ def __init__(
1134
+ self,
1135
+ *,
1136
+ image_encoder_type: str,
1137
+ transformer_dim: int,
1138
+ transformer: nn.Module,
1139
+ num_multimask_outputs: int = 3,
1140
+ activation: Type[nn.Module] = nn.GELU,
1141
+ iou_head_depth: int = 3,
1142
+ iou_head_hidden_dim: int = 256,
1143
+ image_size,
1144
+ patch_size,
1145
+ ) -> None:
1146
+ """
1147
+ Predicts masks given an image and prompt embeddings, using a
1148
+ transformer architecture.
1149
+
1150
+ Arguments:
1151
+ transformer_dim (int): the channel dimension of the transformer
1152
+ transformer (nn.Module): the transformer used to predict masks
1153
+ num_multimask_outputs (int): the number of masks to predict
1154
+ when disambiguating masks
1155
+ activation (nn.Module): the type of activation to use when
1156
+ upscaling masks
1157
+ iou_head_depth (int): the depth of the MLP used to predict
1158
+ mask quality
1159
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
1160
+ used to predict mask quality
1161
+ """
1162
+ super().__init__()
1163
+ self.transformer_dim = transformer_dim
1164
+ self.transformer = transformer
1165
+
1166
+ self.num_multimask_outputs = num_multimask_outputs
1167
+
1168
+ self.iou_token = nn.Embedding(1, transformer_dim)
1169
+ self.num_mask_tokens = num_multimask_outputs + 1
1170
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
1171
+
1172
+ if image_encoder_type == 'swin_vit':
1173
+ self.feat_shape = image_size/patch_size
1174
+ self.output_upscaling = nn.Sequential(
1175
+ nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
1176
+ nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # swin
1177
+ activation(),
1178
+ nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), # swin
1179
+ # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1), # vit
1180
+ activation(),
1181
+ )
1182
+ else:
1183
+ self.feat_shape = image_size/patch_size * 2
1184
+ self.output_upscaling = nn.Sequential(
1185
+ nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
1186
+ nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # vit
1187
+ activation(),
1188
+ nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
1189
+ # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1),
1190
+ activation(),
1191
+ )
1192
+ self.output_hypernetworks_mlps = nn.ModuleList(
1193
+ [
1194
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
1195
+ for i in range(self.num_mask_tokens)
1196
+ ]
1197
+ )
1198
+
1199
+ self.iou_prediction_head = MLP(
1200
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
1201
+ )
1202
+
1203
+ self.txt_align_upscaled_embedding = nn.Linear(768, 96)
1204
+
1205
+ def forward(
1206
+ self,
1207
+ image_embeddings: torch.Tensor,
1208
+ text_embedding: Optional[torch.Tensor],
1209
+ image_pe: torch.Tensor,
1210
+ sparse_prompt_embeddings: torch.Tensor,
1211
+ dense_prompt_embeddings: torch.Tensor,
1212
+ multimask_output: bool,
1213
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1214
+ """
1215
+ Predict masks given image and prompt embeddings.
1216
+
1217
+ Returns:
1218
+ torch.Tensor: batched predicted masks
1219
+ """
1220
+ # print('--------------decoder here--------------')
1221
+ masks, iou_pred = self.predict_masks(
1222
+ image_embeddings=image_embeddings,
1223
+ text_embedding=text_embedding,
1224
+ image_pe=image_pe,
1225
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
1226
+ dense_prompt_embeddings=dense_prompt_embeddings,
1227
+ )
1228
+
1229
+ # Select the correct mask or masks for output
1230
+ if multimask_output:
1231
+ mask_slice = slice(1, None)
1232
+ else:
1233
+ mask_slice = slice(0, 1)
1234
+ masks = masks[:, mask_slice, :, :, :]
1235
+ iou_pred = iou_pred[:, mask_slice]
1236
+
1237
+ # Prepare output
1238
+ return masks, iou_pred
1239
+
1240
+ def predict_masks(
1241
+ self,
1242
+ image_embeddings: torch.Tensor,
1243
+ text_embedding: torch.Tensor,
1244
+ image_pe: torch.Tensor,
1245
+ sparse_prompt_embeddings: torch.Tensor,
1246
+ dense_prompt_embeddings: torch.Tensor,
1247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1248
+ """Predicts masks. See 'forward' for more details."""
1249
+ # Concatenate output tokens
1250
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
1251
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
1252
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
1253
+ # Expand per-image data in batch direction to be per-mask
1254
+ if image_embeddings.shape[0] != tokens.shape[0]:
1255
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
1256
+ else:
1257
+ src = image_embeddings
1258
+ src = src + dense_prompt_embeddings
1259
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
1260
+ b, c, h, w, d = src.shape
1261
+
1262
+ # Run the transformer
1263
+ hs, src = self.transformer(src, pos_src, tokens)
1264
+ iou_token_out = hs[:, 0, :]
1265
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
1266
+
1267
+ # Upscale mask embeddings and predict masks using the mask tokens
1268
+ src = src.transpose(1, 2).view(b, c, h, w, d)
1269
+ upscaled_embedding = self.output_upscaling(src)
1270
+ hyper_in_list: List[torch.Tensor] = []
1271
+ for i in range(self.num_mask_tokens):
1272
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
1273
+ hyper_in = torch.stack(hyper_in_list, dim=1)
1274
+ b, c, h, w, d = upscaled_embedding.shape
1275
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d)
1276
+
1277
+ if text_embedding is not None:
1278
+ text_embedding_down = self.txt_align_upscaled_embedding(text_embedding).unsqueeze(dim=1)
1279
+ upscaled_embedding = upscaled_embedding.view(b, c, h * w * d)
1280
+ sim = (text_embedding_down @ upscaled_embedding).view(b, -1, h, w, d)
1281
+ sim = sim.repeat(1, masks.shape[1], 1, 1, 1)
1282
+ masks = masks + sim
1283
+ iou_pred = self.iou_prediction_head(iou_token_out)
1284
+
1285
+ return masks, iou_pred
1286
+
1287
+ # Lightly adapted from
1288
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
1289
+ class MLP(nn.Module):
1290
+ def __init__(
1291
+ self,
1292
+ input_dim: int,
1293
+ hidden_dim: int,
1294
+ output_dim: int,
1295
+ num_layers: int,
1296
+ sigmoid_output: bool = False,
1297
+ ) -> None:
1298
+ super().__init__()
1299
+ self.num_layers = num_layers
1300
+ h = [hidden_dim] * (num_layers - 1)
1301
+ self.layers = nn.ModuleList(
1302
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
1303
+ )
1304
+ self.sigmoid_output = sigmoid_output
1305
+
1306
+ def forward(self, x):
1307
+ for i, layer in enumerate(self.layers):
1308
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1309
+ if self.sigmoid_output:
1310
+ x = F.sigmoid(x)
1311
+ return x
1312
+
1313
+ # prompt encoder
1314
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1315
+ # All rights reserved.
1316
+
1317
+ # This source code is licensed under the license found in the
1318
+ # LICENSE file in the root directory of this source tree.
1319
+
1320
+ import numpy as np
1321
+ import torch
1322
+ from torch import nn
1323
+
1324
+ from typing import Any, Optional, Tuple, Type
1325
+
1326
+ class PromptEncoder(nn.Module):
1327
+ def __init__(
1328
+ self,
1329
+ embed_dim: int,
1330
+ image_embedding_size: Tuple[int, int, int],
1331
+ input_image_size: Tuple[int, int, int],
1332
+ mask_in_chans: int,
1333
+ activation: Type[nn.Module] = nn.GELU,
1334
+ ) -> None:
1335
+ """
1336
+ Encodes prompts for input to SAM's mask decoder.
1337
+
1338
+ Arguments:
1339
+ embed_dim (int): The prompts' embedding dimension
1340
+ image_embedding_size (tuple(int, int)): The spatial size of the
1341
+ image embedding, as (H, W).
1342
+ input_image_size (int): The padded size of the image as input
1343
+ to the image encoder, as (H, W).
1344
+ mask_in_chans (int): The number of hidden channels used for
1345
+ encoding input masks.
1346
+ activation (nn.Module): The activation to use when encoding
1347
+ input masks.
1348
+ """
1349
+ super().__init__()
1350
+ self.embed_dim = embed_dim
1351
+ self.input_image_size = input_image_size
1352
+ self.image_embedding_size = image_embedding_size
1353
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
1354
+
1355
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
1356
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
1357
+ self.point_embeddings = nn.ModuleList(point_embeddings)
1358
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
1359
+
1360
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
1361
+ self.mask_downscaling = nn.Sequential(
1362
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
1363
+ LayerNorm2d(mask_in_chans // 4),
1364
+ activation(),
1365
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
1366
+ LayerNorm2d(mask_in_chans),
1367
+ activation(),
1368
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
1369
+ )
1370
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
1371
+
1372
+ def get_dense_pe(self) -> torch.Tensor:
1373
+ """
1374
+ Returns the positional encoding used to encode point prompts,
1375
+ applied to a dense set of points the shape of the image encoding.
1376
+
1377
+ Returns:
1378
+ torch.Tensor: Positional encoding with shape
1379
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
1380
+ """
1381
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
1382
+
1383
+ def _embed_points(
1384
+ self,
1385
+ points: torch.Tensor,
1386
+ labels: torch.Tensor,
1387
+ pad: bool,
1388
+ ) -> torch.Tensor:
1389
+ """Embeds point prompts."""
1390
+ points = points + 0.5 # Shift to center of pixel
1391
+ if pad:
1392
+ padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
1393
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
1394
+ points = torch.cat([points, padding_point], dim=1)
1395
+ labels = torch.cat([labels, padding_label], dim=1)
1396
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
1397
+ point_embedding[labels == -1] = 0.0
1398
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
1399
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
1400
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
1401
+ return point_embedding
1402
+
1403
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
1404
+ """Embeds box prompts."""
1405
+ boxes = boxes + 0.5 # Shift to center of pixel
1406
+ coords = boxes.reshape(-1, 2, 3)
1407
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
1408
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
1409
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
1410
+ return corner_embedding
1411
+
1412
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
1413
+ """Embeds mask inputs."""
1414
+ mask_embedding = self.mask_downscaling(masks)
1415
+ return mask_embedding
1416
+
1417
+ def _get_batch_size(
1418
+ self,
1419
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
1420
+ boxes: Optional[torch.Tensor],
1421
+ masks: Optional[torch.Tensor],
1422
+ text_embedding: Optional[torch.Tensor],
1423
+ ) -> int:
1424
+ """
1425
+ Gets the batch size of the output given the batch size of the input prompts.
1426
+ """
1427
+ if points is not None:
1428
+ return points[0].shape[0]
1429
+ elif boxes is not None:
1430
+ return boxes.shape[0]
1431
+ elif masks is not None:
1432
+ return masks.shape[0]
1433
+ elif text_embedding is not None:
1434
+ return text_embedding.shape[0]
1435
+ else:
1436
+ return 1
1437
+
1438
+ def _get_device(self) -> torch.device:
1439
+ return self.point_embeddings[0].weight.device
1440
+
1441
+ def forward(
1442
+ self,
1443
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
1444
+ boxes: Optional[torch.Tensor],
1445
+ masks: Optional[torch.Tensor],
1446
+ text_embedding: Optional[torch.Tensor],
1447
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1448
+
1449
+ bs = self._get_batch_size(points, boxes, masks, text_embedding)
1450
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
1451
+
1452
+ if points is not None:
1453
+ coords, labels = points
1454
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
1455
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
1456
+
1457
+ if boxes is not None:
1458
+ box_embeddings = self._embed_boxes(boxes)
1459
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
1460
+
1461
+ if text_embedding is not None:
1462
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
1463
+
1464
+ if masks is not None:
1465
+ dense_embeddings = self._embed_masks(masks)
1466
+ else:
1467
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
1468
+ bs, -1, int(self.image_embedding_size[0]), int(self.image_embedding_size[1]), int(self.image_embedding_size[2])
1469
+ )
1470
+
1471
+ return sparse_embeddings, dense_embeddings
1472
+
1473
+
1474
+ class PositionEmbeddingRandom(nn.Module):
1475
+ """
1476
+ Positional encoding using random spatial frequencies.
1477
+ """
1478
+
1479
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
1480
+ super().__init__()
1481
+ if scale is None or scale <= 0.0:
1482
+ scale = 1.0
1483
+ self.register_buffer(
1484
+ "positional_encoding_gaussian_matrix",
1485
+ scale * torch.randn((3, num_pos_feats)),
1486
+ )
1487
+
1488
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
1489
+ """Positionally encode points that are normalized to [0,1]."""
1490
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1491
+ coords = 2 * coords - 1
1492
+ coords = coords @ self.positional_encoding_gaussian_matrix
1493
+ coords = 2 * np.pi * coords
1494
+ # outputs d_1 x ... x d_n x C shape
1495
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
1496
+
1497
+ def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
1498
+ """Generate positional encoding for a grid of the specified size."""
1499
+ h, w, d = size
1500
+ device: Any = self.positional_encoding_gaussian_matrix.device
1501
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
1502
+ y_embed = grid.cumsum(dim=0) - 0.5
1503
+ x_embed = grid.cumsum(dim=1) - 0.5
1504
+ z_embed = grid.cumsum(dim=2) - 0.5
1505
+ y_embed = y_embed / h
1506
+ x_embed = x_embed / w
1507
+ z_embed = z_embed / d
1508
+
1509
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
1510
+ return pe.permute(3, 0, 1, 2) # C x H x W x D
1511
+
1512
+ def forward_with_coords(
1513
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
1514
+ ) -> torch.Tensor:
1515
+ """Positionally encode points that are not normalized to [0,1]."""
1516
+ coords = coords_input.clone()
1517
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
1518
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
1519
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
1520
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
1521
+
1522
+ # two way transformer
1523
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1524
+ # All rights reserved.
1525
+
1526
+ # This source code is licensed under the license found in the
1527
+ # LICENSE file in the root directory of this source tree.
1528
+
1529
+ import torch
1530
+ from torch import Tensor, nn
1531
+
1532
+ import math
1533
+ from typing import Tuple, Type
1534
+
1535
+ class TwoWayTransformer(nn.Module):
1536
+ def __init__(
1537
+ self,
1538
+ depth: int,
1539
+ embedding_dim: int,
1540
+ num_heads: int,
1541
+ mlp_dim: int,
1542
+ activation: Type[nn.Module] = nn.ReLU,
1543
+ attention_downsample_rate: int = 2,
1544
+ ) -> None:
1545
+ """
1546
+ A transformer decoder that attends to an input image using
1547
+ queries whose positional embedding is supplied.
1548
+
1549
+ Args:
1550
+ depth (int): number of layers in the transformer
1551
+ embedding_dim (int): the channel dimension for the input embeddings
1552
+ num_heads (int): the number of heads for multihead attention. Must
1553
+ divide embedding_dim
1554
+ mlp_dim (int): the channel dimension internal to the MLP block
1555
+ activation (nn.Module): the activation to use in the MLP block
1556
+ """
1557
+ super().__init__()
1558
+ self.depth = depth
1559
+ self.embedding_dim = embedding_dim
1560
+ self.num_heads = num_heads
1561
+ self.mlp_dim = mlp_dim
1562
+ self.layers = nn.ModuleList()
1563
+
1564
+ for i in range(depth):
1565
+ self.layers.append(
1566
+ TwoWayAttentionBlock(
1567
+ embedding_dim=embedding_dim,
1568
+ num_heads=num_heads,
1569
+ mlp_dim=mlp_dim,
1570
+ activation=activation,
1571
+ attention_downsample_rate=attention_downsample_rate,
1572
+ skip_first_layer_pe=(i == 0),
1573
+ )
1574
+ )
1575
+
1576
+ self.final_attn_token_to_image = Attention(
1577
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1578
+ )
1579
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
1580
+
1581
+ def forward(
1582
+ self,
1583
+ image_embedding: Tensor,
1584
+ image_pe: Tensor,
1585
+ point_embedding: Tensor,
1586
+ ) -> Tuple[Tensor, Tensor]:
1587
+ """
1588
+ Args:
1589
+ image_embedding (torch.Tensor): image to attend to. Should be shape
1590
+ B x embedding_dim x h x w for any h and w.
1591
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
1592
+ have the same shape as image_embedding.
1593
+ point_embedding (torch.Tensor): the embedding to add to the query points.
1594
+ Must have shape B x N_points x embedding_dim for any N_points.
1595
+
1596
+ Returns:
1597
+ torch.Tensor: the processed point_embedding
1598
+ torch.Tensor: the processed image_embedding
1599
+ """
1600
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
1601
+ bs, c, h, w, d = image_embedding.shape
1602
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
1603
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
1604
+
1605
+ # Prepare queries
1606
+ queries = point_embedding
1607
+ keys = image_embedding
1608
+
1609
+ # Apply transformer blocks and final layernorm
1610
+ for layer in self.layers:
1611
+ queries, keys = layer(
1612
+ queries=queries,
1613
+ keys=keys,
1614
+ query_pe=point_embedding,
1615
+ key_pe=image_pe,
1616
+ )
1617
+
1618
+ # Apply the final attention layer from the points to the image
1619
+ q = queries + point_embedding
1620
+ k = keys + image_pe
1621
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
1622
+ queries = queries + attn_out
1623
+ queries = self.norm_final_attn(queries)
1624
+
1625
+ return queries, keys
1626
+
1627
+
1628
+ class TwoWayAttentionBlock(nn.Module):
1629
+ def __init__(
1630
+ self,
1631
+ embedding_dim: int,
1632
+ num_heads: int,
1633
+ mlp_dim: int = 2048,
1634
+ activation: Type[nn.Module] = nn.ReLU,
1635
+ attention_downsample_rate: int = 2,
1636
+ skip_first_layer_pe: bool = False,
1637
+ ) -> None:
1638
+ """
1639
+ A transformer block with four layers: (1) self-attention of sparse
1640
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
1641
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
1642
+ inputs.
1643
+
1644
+ Arguments:
1645
+ embedding_dim (int): the channel dimension of the embeddings
1646
+ num_heads (int): the number of heads in the attention layers
1647
+ mlp_dim (int): the hidden dimension of the mlp block
1648
+ activation (nn.Module): the activation of the mlp block
1649
+ skip_first_layer_pe (bool): skip the PE on the first layer
1650
+ """
1651
+ super().__init__()
1652
+ self.self_attn = Attention(embedding_dim, num_heads)
1653
+ self.norm1 = nn.LayerNorm(embedding_dim)
1654
+
1655
+ self.cross_attn_token_to_image = Attention(
1656
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1657
+ )
1658
+ self.norm2 = nn.LayerNorm(embedding_dim)
1659
+
1660
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
1661
+ self.norm3 = nn.LayerNorm(embedding_dim)
1662
+
1663
+ self.norm4 = nn.LayerNorm(embedding_dim)
1664
+ self.cross_attn_image_to_token = Attention(
1665
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1666
+ )
1667
+
1668
+ self.skip_first_layer_pe = skip_first_layer_pe
1669
+
1670
+ def forward(
1671
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
1672
+ ) -> Tuple[Tensor, Tensor]:
1673
+ # Self attention block
1674
+ if self.skip_first_layer_pe:
1675
+ queries = self.self_attn(q=queries, k=queries, v=queries)
1676
+ else:
1677
+ q = queries + query_pe
1678
+ attn_out = self.self_attn(q=q, k=q, v=queries)
1679
+ queries = queries + attn_out
1680
+ queries = self.norm1(queries)
1681
+
1682
+ # Cross attention block, tokens attending to image embedding
1683
+ q = queries + query_pe
1684
+ k = keys + key_pe
1685
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
1686
+ queries = queries + attn_out
1687
+ queries = self.norm2(queries)
1688
+
1689
+ # MLP block
1690
+ mlp_out = self.mlp(queries)
1691
+ queries = queries + mlp_out
1692
+ queries = self.norm3(queries)
1693
+
1694
+ # Cross attention block, image embedding attending to tokens
1695
+ q = queries + query_pe
1696
+ k = keys + key_pe
1697
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
1698
+ keys = keys + attn_out
1699
+ keys = self.norm4(keys)
1700
+
1701
+ return queries, keys
1702
+
1703
+
1704
+ class Attention(nn.Module):
1705
+ """
1706
+ An attention layer that allows for downscaling the size of the embedding
1707
+ after projection to queries, keys, and values.
1708
+ """
1709
+
1710
+ def __init__(
1711
+ self,
1712
+ embedding_dim: int,
1713
+ num_heads: int,
1714
+ downsample_rate: int = 1,
1715
+ ) -> None:
1716
+ super().__init__()
1717
+ self.embedding_dim = embedding_dim
1718
+ self.internal_dim = embedding_dim // downsample_rate
1719
+ self.num_heads = num_heads
1720
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
1721
+
1722
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
1723
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
1724
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
1725
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
1726
+
1727
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
1728
+ b, n, c = x.shape
1729
+ x = x.reshape(b, n, num_heads, c // num_heads)
1730
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
1731
+
1732
+ def _recombine_heads(self, x: Tensor) -> Tensor:
1733
+ b, n_heads, n_tokens, c_per_head = x.shape
1734
+ x = x.transpose(1, 2)
1735
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
1736
+
1737
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
1738
+ # Input projections
1739
+ q = self.q_proj(q)
1740
+ k = self.k_proj(k)
1741
+ v = self.v_proj(v)
1742
+
1743
+ # Separate into heads
1744
+ q = self._separate_heads(q, self.num_heads)
1745
+ k = self._separate_heads(k, self.num_heads)
1746
+ v = self._separate_heads(v, self.num_heads)
1747
+
1748
+ # Attention
1749
+ _, _, _, c_per_head = q.shape
1750
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
1751
+ attn = attn / math.sqrt(c_per_head)
1752
+ attn = torch.softmax(attn, dim=-1)
1753
+
1754
+ # Get output
1755
+ out = attn @ v
1756
+ out = self._recombine_heads(out)
1757
+ out = self.out_proj(out)
1758
+
1759
+ return out
1760
+
1761
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
1762
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
1763
+ class LayerNorm2d(nn.Module):
1764
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
1765
+ super().__init__()
1766
+ self.weight = nn.Parameter(torch.ones(num_channels))
1767
+ self.bias = nn.Parameter(torch.zeros(num_channels))
1768
+ self.eps = eps
1769
+
1770
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1771
+ u = x.mean(1, keepdim=True)
1772
+ s = (x - u).pow(2).mean(1, keepdim=True)
1773
+ x = (x - u) / torch.sqrt(s + self.eps)
1774
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
1775
+ return x
1776
+
1777
+ class MLPBlock(nn.Module):
1778
+ def __init__(
1779
+ self,
1780
+ embedding_dim: int,
1781
+ mlp_dim: int,
1782
+ act: Type[nn.Module] = nn.GELU,
1783
+ ) -> None:
1784
+ super().__init__()
1785
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
1786
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
1787
+ self.act = act()
1788
+
1789
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1790
+ return self.lin2(self.act(self.lin1(x)))
1791
+
1792
+
1793
+ # sam
1794
+ class Sam(nn.Module):
1795
+ mask_threshold: float = 0.0
1796
+ image_format: str = "RGB"
1797
+
1798
+ def __init__(
1799
+ self,
1800
+ image_encoder,
1801
+ prompt_encoder,
1802
+ mask_decoder,
1803
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
1804
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
1805
+ ) -> None:
1806
+ """
1807
+ SAM predicts object masks from an image and input prompts.
1808
+
1809
+ Arguments:
1810
+ image_encoder (ImageEncoderViT): The backbone used to encode the
1811
+ image into image embeddings that allow for efficient mask prediction.
1812
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
1813
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
1814
+ and encoded prompts.
1815
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
1816
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
1817
+ """
1818
+ super().__init__()
1819
+ self.image_encoder = image_encoder
1820
+ self.prompt_encoder = prompt_encoder
1821
+ self.mask_decoder = mask_decoder
1822
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
1823
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
1824
+
1825
+ @property
1826
+ def device(self) -> Any:
1827
+ return self.pixel_mean.device
1828
+
1829
+ @torch.no_grad()
1830
+ def forward(
1831
+ self,
1832
+ batched_input: List[Dict[str, Any]],
1833
+ multimask_output: bool,
1834
+ ) -> List[Dict[str, torch.Tensor]]:
1835
+ """
1836
+ Predicts masks end-to-end from provided images and prompts.
1837
+ If prompts are not known in advance, using SamPredictor is
1838
+ recommended over calling the model directly.
1839
+
1840
+ Arguments:
1841
+ batched_input (list(dict)): A list over input images, each a
1842
+ dictionary with the following keys. A prompt key can be
1843
+ excluded if it is not present.
1844
+ 'image': The image as a torch tensor in 3xHxW format,
1845
+ already transformed for input to the model.
1846
+ 'original_size': (tuple(int, int)) The original size of
1847
+ the image before transformation, as (H, W).
1848
+ 'point_coords': (torch.Tensor) Batched point prompts for
1849
+ this image, with shape BxNx2. Already transformed to the
1850
+ input frame of the model.
1851
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
1852
+ with shape BxN.
1853
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
1854
+ Already transformed to the input frame of the model.
1855
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
1856
+ in the form Bx1xHxW.
1857
+ multimask_output (bool): Whether the model should predict multiple
1858
+ disambiguating masks, or return a single mask.
1859
+
1860
+ Returns:
1861
+ (list(dict)): A list over input images, where each element is
1862
+ as dictionary with the following keys.
1863
+ 'masks': (torch.Tensor) Batched binary mask predictions,
1864
+ with shape BxCxHxW, where B is the number of input prompts,
1865
+ C is determined by multimask_output, and (H, W) is the
1866
+ original size of the image.
1867
+ 'iou_predictions': (torch.Tensor) The model's predictions
1868
+ of mask quality, in shape BxC.
1869
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
1870
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
1871
+ to subsequent iterations of prediction.
1872
+ """
1873
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
1874
+ image_embeddings = self.image_encoder(input_images)
1875
+
1876
+ outputs = []
1877
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
1878
+ if "point_coords" in image_record:
1879
+ points = (image_record["point_coords"], image_record["point_labels"])
1880
+ else:
1881
+ points = None
1882
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1883
+ points=points,
1884
+ boxes=image_record.get("boxes", None),
1885
+ masks=image_record.get("mask_inputs", None),
1886
+ )
1887
+ low_res_masks, iou_predictions = self.mask_decoder(
1888
+ image_embeddings=curr_embedding.unsqueeze(0),
1889
+ image_pe=self.prompt_encoder.get_dense_pe(),
1890
+ sparse_prompt_embeddings=sparse_embeddings,
1891
+ dense_prompt_embeddings=dense_embeddings,
1892
+ multimask_output=multimask_output,
1893
+ )
1894
+ masks = self.postprocess_masks(
1895
+ low_res_masks,
1896
+ input_size=image_record["image"].shape[-2:],
1897
+ original_size=image_record["original_size"],
1898
+ )
1899
+ masks = masks > self.mask_threshold
1900
+ outputs.append(
1901
+ {
1902
+ "masks": masks,
1903
+ "iou_predictions": iou_predictions,
1904
+ "low_res_logits": low_res_masks,
1905
+ }
1906
+ )
1907
+ return outputs
1908
+
1909
+ def postprocess_masks(
1910
+ self,
1911
+ masks: torch.Tensor,
1912
+ input_size: Tuple[int, ...],
1913
+ original_size: Tuple[int, ...],
1914
+ ) -> torch.Tensor:
1915
+ """
1916
+ Remove padding and upscale masks to the original image size.
1917
+
1918
+ Arguments:
1919
+ masks (torch.Tensor): Batched masks from the mask_decoder,
1920
+ in BxCxHxW format.
1921
+ input_size (tuple(int, int)): The size of the image input to the
1922
+ model, in (H, W) format. Used to remove padding.
1923
+ original_size (tuple(int, int)): The original size of the image
1924
+ before resizing for input to the model, in (H, W) format.
1925
+
1926
+ Returns:
1927
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
1928
+ is given by original_size.
1929
+ """
1930
+ masks = F.interpolate(
1931
+ masks,
1932
+ (self.image_encoder.img_size, self.image_encoder.img_size),
1933
+ mode="bilinear",
1934
+ align_corners=False,
1935
+ )
1936
+ masks = masks[..., : input_size[0], : input_size[1]]
1937
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
1938
+ return masks
1939
+
1940
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
1941
+ """Normalize pixel values and pad to a square input."""
1942
+ # Normalize colors
1943
+ # TODO
1944
+ x = (x - self.pixel_mean) / self.pixel_std
1945
+
1946
+ # Pad
1947
+ h, w = x.shape[-2:]
1948
+ padh = self.image_encoder.img_size - h
1949
+ padw = self.image_encoder.img_size - w
1950
+ x = F.pad(x, (0, padw, 0, padh))
1951
+ return x
SegVol_v1.pth → pytorch_model.bin RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b751dc95f1a0c0c6086c1e6fa7f8a17bbb87635e5226e15f5d156fbd364dbb85
3
- size 1660308695
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:500f2758a8f989339b2b2baf09a819169bc87549795193d3cfe505726ac0b399
3
+ size 723726667
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip", "special_tokens_map_file": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip/special_tokens_map.json", "tokenizer_class": "CLIPTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff