ghlee94 commited on
Commit
bd25e7c
1 Parent(s): 2f2e97c

Implement cellseg prediction

Browse files
Files changed (1) hide show
  1. app.py +1248 -4
app.py CHANGED
@@ -1,7 +1,1251 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from torch.nn import (
4
+ Module,
5
+ Conv2d,
6
+ BatchNorm2d,
7
+ Identity,
8
+ UpsamplingBilinear2d,
9
+ Mish,
10
+ ReLU,
11
+ Sequential,
12
+ )
13
+ from torch.nn.functional import interpolate, grid_sample, pad
14
+ import numpy as np
15
+ from copy import deepcopy
16
+ import os, argparse, math
17
+ import tifffile as tif
18
+ from typing import Tuple, List, Mapping
19
 
20
+ from monai.utils import (
21
+ BlendMode,
22
+ PytorchPadMode,
23
+ convert_data_type,
24
+ ensure_tuple,
25
+ fall_back_tuple,
26
+ look_up_option,
27
+ convert_to_dst_type,
28
+ )
29
+ from monai.utils.misc import ensure_tuple_size, ensure_tuple_rep, issequenceiterable
30
+ from monai.networks.layers.convutils import gaussian_1d
31
+ from monai.networks.layers.simplelayers import separable_filtering
32
 
33
+ from segmentation_models_pytorch import MAnet
34
+
35
+ from skimage.io import imread as io_imread
36
+ from skimage.util.dtype import dtype_range
37
+ from skimage._shared.utils import _supported_float_type
38
+ from scipy.ndimage import find_objects, binary_fill_holes
39
+
40
+
41
+ ########################### Data Loading Modules #########################################################
42
+ DTYPE_RANGE = dtype_range.copy()
43
+ DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items())
44
+ DTYPE_RANGE.update(
45
+ {
46
+ "uint10": (0, 2 ** 10 - 1),
47
+ "uint12": (0, 2 ** 12 - 1),
48
+ "uint14": (0, 2 ** 14 - 1),
49
+ "bool": dtype_range[bool],
50
+ "float": dtype_range[np.float64],
51
+ }
52
+ )
53
+
54
+
55
+ def _output_dtype(dtype_or_range, image_dtype):
56
+ if type(dtype_or_range) in [list, tuple, np.ndarray]:
57
+ # pair of values: always return float.
58
+ return _supported_float_type(image_dtype)
59
+ if type(dtype_or_range) == type:
60
+ # already a type: return it
61
+ return dtype_or_range
62
+ if dtype_or_range in DTYPE_RANGE:
63
+ # string key in DTYPE_RANGE dictionary
64
+ try:
65
+ # if it's a canonical numpy dtype, convert
66
+ return np.dtype(dtype_or_range).type
67
+ except TypeError: # uint10, uint12, uint14
68
+ # otherwise, return uint16
69
+ return np.uint16
70
+ else:
71
+ raise ValueError(
72
+ "Incorrect value for out_range, should be a valid image data "
73
+ f"type or a pair of values, got {dtype_or_range}."
74
+ )
75
+
76
+
77
+ def intensity_range(image, range_values="image", clip_negative=False):
78
+ if range_values == "dtype":
79
+ range_values = image.dtype.type
80
+
81
+ if range_values == "image":
82
+ i_min = np.min(image)
83
+ i_max = np.max(image)
84
+ elif range_values in DTYPE_RANGE:
85
+ i_min, i_max = DTYPE_RANGE[range_values]
86
+ if clip_negative:
87
+ i_min = 0
88
+ else:
89
+ i_min, i_max = range_values
90
+ return i_min, i_max
91
+
92
+
93
+ def rescale_intensity(image, in_range="image", out_range="dtype"):
94
+ out_dtype = _output_dtype(out_range, image.dtype)
95
+
96
+ imin, imax = map(float, intensity_range(image, in_range))
97
+ omin, omax = map(
98
+ float, intensity_range(image, out_range, clip_negative=(imin >= 0))
99
+ )
100
+ image = np.clip(image, imin, imax)
101
+
102
+ if imin != imax:
103
+ image = (image - imin) / (imax - imin)
104
+ return np.asarray(image * (omax - omin) + omin, dtype=out_dtype)
105
+ else:
106
+ return np.clip(image, omin, omax).astype(out_dtype)
107
+
108
+
109
+ def _normalize(img):
110
+ non_zero_vals = img[np.nonzero(img)]
111
+ percentiles = np.percentile(non_zero_vals, [0, 99.5])
112
+ img_norm = rescale_intensity(
113
+ img, in_range=(percentiles[0], percentiles[1]), out_range="uint8"
114
+ )
115
+
116
+ return img_norm.astype(np.uint8)
117
+
118
+
119
+ def pred_transforms(filename):
120
+ # LoadImage
121
+ img = (
122
+ tif.imread(filename)
123
+ if filename.endswith(".tif") or filename.endswith(".tiff")
124
+ else io_imread(filename)
125
+ )
126
+
127
+ if len(img.shape) == 2:
128
+ img = np.repeat(np.expand_dims(img, axis=-1), 3, axis=-1)
129
+ elif len(img.shape) == 3 and img.shape[-1] > 3:
130
+ img = img[:, :, :3]
131
+
132
+ img = img.astype(np.float32)
133
+ img = _normalize(img)
134
+ img = np.moveaxis(img, -1, 0)
135
+ img = (img - img.min()) / (img.max() - img.min())
136
+
137
+ return torch.FloatTensor(img).unsqueeze(0)
138
+
139
+
140
+ ################################################################################
141
+
142
+ ########################### MODEL Architecture #################################
143
+ class SegformerGH(MAnet):
144
+ def __init__(
145
+ self,
146
+ encoder_name: str = "mit_b5",
147
+ encoder_weights="imagenet",
148
+ decoder_channels=(256, 128, 64, 32, 32),
149
+ decoder_pab_channels=256,
150
+ in_channels: int = 3,
151
+ classes: int = 3,
152
+ ):
153
+ super(SegformerGH, self).__init__(
154
+ encoder_name=encoder_name,
155
+ encoder_weights=encoder_weights,
156
+ decoder_channels=decoder_channels,
157
+ decoder_pab_channels=decoder_pab_channels,
158
+ in_channels=in_channels,
159
+ classes=classes,
160
+ )
161
+
162
+ convert_relu_to_mish(self.encoder)
163
+ convert_relu_to_mish(self.decoder)
164
+
165
+ self.cellprob_head = DeepSegmantationHead(
166
+ in_channels=decoder_channels[-1], out_channels=1, kernel_size=3,
167
+ )
168
+ self.gradflow_head = DeepSegmantationHead(
169
+ in_channels=decoder_channels[-1], out_channels=2, kernel_size=3,
170
+ )
171
+
172
+ def forward(self, x):
173
+ """Sequentially pass `x` trough model`s encoder, decoder and heads"""
174
+ self.check_input_shape(x)
175
+
176
+ features = self.encoder(x)
177
+ decoder_output = self.decoder(*features)
178
+
179
+ gradflow_mask = self.gradflow_head(decoder_output)
180
+ cellprob_mask = self.cellprob_head(decoder_output)
181
+
182
+ masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
183
+
184
+ return masks
185
+
186
+
187
+ class DeepSegmantationHead(Sequential):
188
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
189
+ conv2d_1 = Conv2d(
190
+ in_channels,
191
+ in_channels // 2,
192
+ kernel_size=kernel_size,
193
+ padding=kernel_size // 2,
194
+ )
195
+ bn = BatchNorm2d(in_channels // 2)
196
+ conv2d_2 = Conv2d(
197
+ in_channels // 2,
198
+ out_channels,
199
+ kernel_size=kernel_size,
200
+ padding=kernel_size // 2,
201
+ )
202
+ mish = Mish(inplace=True)
203
+
204
+ upsampling = (
205
+ UpsamplingBilinear2d(scale_factor=upsampling)
206
+ if upsampling > 1
207
+ else Identity()
208
+ )
209
+ activation = Identity()
210
+ super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation)
211
+
212
+
213
+ def convert_relu_to_mish(model):
214
+ for child_name, child in model.named_children():
215
+ if isinstance(child, ReLU):
216
+ setattr(model, child_name, Mish(inplace=True))
217
+ else:
218
+ convert_relu_to_mish(child)
219
+
220
+
221
+ #####################################################################################
222
+
223
+ ########################### Sliding Window Inference #################################
224
+ class GaussianFilter(Module):
225
+ def __init__(
226
+ self, spatial_dims, sigma, truncated=4.0, approx="erf", requires_grad=False,
227
+ ) -> None:
228
+ if issequenceiterable(sigma):
229
+ if len(sigma) != spatial_dims: # type: ignore
230
+ raise ValueError
231
+ else:
232
+ sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore
233
+ super().__init__()
234
+ self.sigma = [
235
+ torch.nn.Parameter(
236
+ torch.as_tensor(
237
+ s,
238
+ dtype=torch.float,
239
+ device=s.device if isinstance(s, torch.Tensor) else None,
240
+ ),
241
+ requires_grad=requires_grad,
242
+ )
243
+ for s in sigma # type: ignore
244
+ ]
245
+ self.truncated = truncated
246
+ self.approx = approx
247
+ for idx, param in enumerate(self.sigma):
248
+ self.register_parameter(f"kernel_sigma_{idx}", param)
249
+
250
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
251
+ _kernel = [
252
+ gaussian_1d(s, truncated=self.truncated, approx=self.approx)
253
+ for s in self.sigma
254
+ ]
255
+ return separable_filtering(x=x, kernels=_kernel)
256
+
257
+
258
+ def compute_importance_map(
259
+ patch_size, mode=BlendMode.CONSTANT, sigma_scale=0.125, device="cpu"
260
+ ):
261
+ mode = look_up_option(mode, BlendMode)
262
+ device = torch.device(device)
263
+
264
+ center_coords = [i // 2 for i in patch_size]
265
+ sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
266
+ sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]
267
+
268
+ importance_map = torch.zeros(patch_size, device=device)
269
+ importance_map[tuple(center_coords)] = 1
270
+ pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(
271
+ device=device, dtype=torch.float
272
+ )
273
+ importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
274
+ importance_map = importance_map.squeeze(0).squeeze(0)
275
+ importance_map = importance_map / torch.max(importance_map)
276
+ importance_map = importance_map.float()
277
+
278
+ return importance_map
279
+
280
+
281
+ def first(iterable, default=None):
282
+ for i in iterable:
283
+ return i
284
+
285
+ return default
286
+
287
+
288
+ def dense_patch_slices(image_size, patch_size, scan_interval):
289
+ num_spatial_dims = len(image_size)
290
+ patch_size = get_valid_patch_size(image_size, patch_size)
291
+ scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)
292
+
293
+ scan_num = []
294
+ for i in range(num_spatial_dims):
295
+ if scan_interval[i] == 0:
296
+ scan_num.append(1)
297
+ else:
298
+ num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
299
+ scan_dim = first(
300
+ d
301
+ for d in range(num)
302
+ if d * scan_interval[i] + patch_size[i] >= image_size[i]
303
+ )
304
+ scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
305
+
306
+ starts = []
307
+ for dim in range(num_spatial_dims):
308
+ dim_starts = []
309
+ for idx in range(scan_num[dim]):
310
+ start_idx = idx * scan_interval[dim]
311
+ start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
312
+ dim_starts.append(start_idx)
313
+ starts.append(dim_starts)
314
+ out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
315
+ return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
316
+
317
+
318
+ def get_valid_patch_size(image_size, patch_size):
319
+ ndim = len(image_size)
320
+ patch_size_ = ensure_tuple_size(patch_size, ndim)
321
+
322
+ # ensure patch size dimensions are not larger than image dimension, if a dimension is None or 0 use whole dimension
323
+ return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_))
324
+
325
+
326
+ class Resize:
327
+ def __init__(self, spatial_size):
328
+ self.size_mode = "all"
329
+ self.spatial_size = spatial_size
330
+
331
+ def __call__(self, img):
332
+ input_ndim = img.ndim - 1 # spatial ndim
333
+ output_ndim = len(ensure_tuple(self.spatial_size))
334
+
335
+ if output_ndim > input_ndim:
336
+ input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
337
+ img = img.reshape(input_shape)
338
+
339
+ spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:])
340
+
341
+ if (
342
+ tuple(img.shape[1:]) == spatial_size_
343
+ ): # spatial shape is already the desired
344
+ return img
345
+
346
+ img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
347
+
348
+ resized = interpolate(
349
+ input=img_.unsqueeze(0), size=spatial_size_, mode="nearest",
350
+ )
351
+ out, *_ = convert_to_dst_type(resized.squeeze(0), img)
352
+ return out
353
+
354
+
355
+ def sliding_window_inference(
356
+ inputs,
357
+ roi_size,
358
+ sw_batch_size,
359
+ predictor,
360
+ overlap,
361
+ mode=BlendMode.CONSTANT,
362
+ sigma_scale=0.125,
363
+ padding_mode=PytorchPadMode.CONSTANT,
364
+ cval=0.0,
365
+ sw_device=None,
366
+ device=None,
367
+ roi_weight_map=None,
368
+ ):
369
+ compute_dtype = inputs.dtype
370
+ num_spatial_dims = len(inputs.shape) - 2
371
+ batch_size, _, *image_size_ = inputs.shape
372
+
373
+ roi_size = fall_back_tuple(roi_size, image_size_)
374
+ # in case that image size is smaller than roi size
375
+ image_size = tuple(
376
+ max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)
377
+ )
378
+ pad_size = []
379
+
380
+ for k in range(len(inputs.shape) - 1, 1, -1):
381
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
382
+ half = diff // 2
383
+ pad_size.extend([half, diff - half])
384
+
385
+ inputs = pad(
386
+ inputs,
387
+ pad=pad_size,
388
+ mode=look_up_option(padding_mode, PytorchPadMode).value,
389
+ value=cval,
390
+ )
391
+
392
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
393
+
394
+ # Store all slices in list
395
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
396
+ num_win = len(slices) # number of windows per image
397
+ total_slices = num_win * batch_size # total number of windows
398
+
399
+ # Create window-level importance map
400
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
401
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
402
+ importance_map = roi_weight_map
403
+ else:
404
+ importance_map = compute_importance_map(
405
+ valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device
406
+ )
407
+
408
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
409
+ # handle non-positive weights
410
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
411
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(
412
+ compute_dtype
413
+ )
414
+
415
+ # Perform predictions
416
+ dict_key, output_image_list, count_map_list = None, [], []
417
+ _initialized_ss = -1
418
+ is_tensor_output = (
419
+ True # whether the predictor's output is a tensor (instead of dict/tuple)
420
+ )
421
+
422
+ # for each patch
423
+ for slice_g in range(0, total_slices, sw_batch_size):
424
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
425
+ unravel_slice = [
426
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)]
427
+ + list(slices[idx % num_win])
428
+ for idx in slice_range
429
+ ]
430
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(
431
+ sw_device
432
+ )
433
+ seg_prob_out = predictor(window_data) # batched patch segmentation
434
+
435
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
436
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
437
+ if isinstance(seg_prob_out, torch.Tensor):
438
+ seg_prob_tuple = (seg_prob_out,)
439
+ elif isinstance(seg_prob_out, Mapping):
440
+ if dict_key is None:
441
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
442
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
443
+ is_tensor_output = False
444
+ else:
445
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
446
+ is_tensor_output = False
447
+
448
+ # for each output in multi-output list
449
+ for ss, seg_prob in enumerate(seg_prob_tuple):
450
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
451
+
452
+ # compute zoom scale: out_roi_size/in_roi_size
453
+ zoom_scale = []
454
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
455
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
456
+ ):
457
+ _scale = out_w_i / float(in_w_i)
458
+
459
+ zoom_scale.append(_scale)
460
+
461
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
462
+ # construct multi-resolution outputs
463
+ output_classes = seg_prob.shape[1]
464
+ output_shape = [batch_size, output_classes] + [
465
+ int(image_size_d * zoom_scale_d)
466
+ for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
467
+ ]
468
+ # allocate memory to store the full output and the count for overlapping parts
469
+ output_image_list.append(
470
+ torch.zeros(output_shape, dtype=compute_dtype, device=device)
471
+ )
472
+ count_map_list.append(
473
+ torch.zeros(
474
+ [1, 1] + output_shape[2:], dtype=compute_dtype, device=device
475
+ )
476
+ )
477
+ _initialized_ss += 1
478
+
479
+ # resizing the importance_map
480
+ resizer = Resize(spatial_size=seg_prob.shape[2:])
481
+
482
+ # store the result in the proper location of the full output. Apply weights from importance map.
483
+ for idx, original_idx in zip(slice_range, unravel_slice):
484
+ # zoom roi
485
+ original_idx_zoom = list(
486
+ original_idx
487
+ ) # 4D for 2D image, 5D for 3D image
488
+ for axis in range(2, len(original_idx_zoom)):
489
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
490
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
491
+
492
+ original_idx_zoom[axis] = slice(
493
+ int(zoomed_start), int(zoomed_end), None
494
+ )
495
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(
496
+ compute_dtype
497
+ )
498
+ # store results and weights
499
+ output_image_list[ss][original_idx_zoom] += (
500
+ importance_map_zoom * seg_prob[idx - slice_g]
501
+ )
502
+ count_map_list[ss][original_idx_zoom] += (
503
+ importance_map_zoom.unsqueeze(0)
504
+ .unsqueeze(0)
505
+ .expand(count_map_list[ss][original_idx_zoom].shape)
506
+ )
507
+
508
+ # account for any overlapping sections
509
+ for ss in range(len(output_image_list)):
510
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(
511
+ compute_dtype
512
+ )
513
+
514
+ # remove padding if image_size smaller than roi_size
515
+ for ss, output_i in enumerate(output_image_list):
516
+ zoom_scale = [
517
+ seg_prob_map_shape_d / roi_size_d
518
+ for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
519
+ ]
520
+
521
+ final_slicing: List[slice] = []
522
+ for sp in range(num_spatial_dims):
523
+ slice_dim = slice(
524
+ pad_size[sp * 2],
525
+ image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2],
526
+ )
527
+ slice_dim = slice(
528
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
529
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
530
+ )
531
+ final_slicing.insert(0, slice_dim)
532
+ while len(final_slicing) < len(output_i.shape):
533
+ final_slicing.insert(0, slice(None))
534
+ output_image_list[ss] = output_i[final_slicing]
535
+
536
+ if dict_key is not None: # if output of predictor is a dict
537
+ final_output = dict(zip(dict_key, output_image_list))
538
+ else:
539
+ final_output = tuple(output_image_list) # type: ignore
540
+
541
+ return final_output[0] if is_tensor_output else final_output # type: ignore
542
+
543
+
544
+ def _get_scan_interval(
545
+ image_size, roi_size, num_spatial_dims: int, overlap: float
546
+ ) -> Tuple[int, ...]:
547
+ scan_interval = []
548
+
549
+ for i in range(num_spatial_dims):
550
+ if roi_size[i] == image_size[i]:
551
+ scan_interval.append(int(roi_size[i]))
552
+ else:
553
+ interval = int(roi_size[i] * (1 - overlap))
554
+ scan_interval.append(interval if interval > 0 else 1)
555
+
556
+ return tuple(scan_interval)
557
+
558
+
559
+ #####################################################################################
560
+
561
+ ########################### Main Inference Functions #################################
562
+ def post_process(pred_mask, device):
563
+ dP, cellprob = pred_mask[:2], 1 / (1 + np.exp(-pred_mask[-1]))
564
+ H, W = pred_mask.shape[-2], pred_mask.shape[-1]
565
+
566
+ if np.prod(H * W) < (5000 * 5000):
567
+ pred_mask = compute_masks(
568
+ dP,
569
+ cellprob,
570
+ use_gpu=True,
571
+ flow_threshold=0.4,
572
+ device=device,
573
+ cellprob_threshold=0.4,
574
+ )[0]
575
+
576
+ else:
577
+ print("\n[Whole Slide] Grid Prediction starting...")
578
+ roi_size = 2000
579
+
580
+ # Get patch grid by roi_size
581
+ if H % roi_size != 0:
582
+ n_H = H // roi_size + 1
583
+ new_H = roi_size * n_H
584
+ else:
585
+ n_H = H // roi_size
586
+ new_H = H
587
+
588
+ if W % roi_size != 0:
589
+ n_W = W // roi_size + 1
590
+ new_W = roi_size * n_W
591
+ else:
592
+ n_W = W // roi_size
593
+ new_W = W
594
+
595
+ # Allocate values on the grid
596
+ pred_pad = np.zeros((new_H, new_W), dtype=np.uint32)
597
+ dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32)
598
+ cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32)
599
+
600
+ dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob
601
+
602
+ for i in range(n_H):
603
+ for j in range(n_W):
604
+ print("Pred on Grid (%d, %d) processing..." % (i, j))
605
+ dP_roi = dP_pad[
606
+ :,
607
+ roi_size * i : roi_size * (i + 1),
608
+ roi_size * j : roi_size * (j + 1),
609
+ ]
610
+ cellprob_roi = cellprob_pad[
611
+ roi_size * i : roi_size * (i + 1),
612
+ roi_size * j : roi_size * (j + 1),
613
+ ]
614
+
615
+ pred_mask = compute_masks(
616
+ dP_roi,
617
+ cellprob_roi,
618
+ use_gpu=True,
619
+ flow_threshold=0.4,
620
+ device=device,
621
+ cellprob_threshold=0.4,
622
+ )[0]
623
+
624
+ pred_pad[
625
+ roi_size * i : roi_size * (i + 1),
626
+ roi_size * j : roi_size * (j + 1),
627
+ ] = pred_mask
628
+
629
+ pred_mask = pred_pad[:H, :W]
630
+
631
+ cell_idx, cell_sizes = np.unique(pred_mask, return_counts=True)
632
+ cell_idx, cell_sizes = cell_idx[1:], cell_sizes[1:]
633
+ cell_drop = np.where(cell_sizes < np.mean(cell_sizes) - 2.7 * np.std(cell_sizes))
634
+
635
+ for drop_cell in cell_idx[cell_drop]:
636
+ pred_mask[pred_mask == drop_cell] = 0
637
+
638
+ return pred_mask
639
+
640
+
641
+ def hflip(x):
642
+ """flip batch of images horizontally"""
643
+ return x.flip(3)
644
+
645
+
646
+ def vflip(x):
647
+ """flip batch of images vertically"""
648
+ return x.flip(2)
649
+
650
+
651
+ class DualTransform:
652
+ identity_param = None
653
+
654
+ def __init__(
655
+ self, name: str, params,
656
+ ):
657
+ self.params = params
658
+ self.pname = name
659
+
660
+ def apply_aug_image(self, image, *args, **params):
661
+ raise NotImplementedError
662
+
663
+ def apply_deaug_mask(self, mask, *args, **params):
664
+ raise NotImplementedError
665
+
666
+
667
+ class HorizontalFlip(DualTransform):
668
+ """Flip images horizontally (left->right)"""
669
+
670
+ identity_param = False
671
+
672
+ def __init__(self):
673
+ super().__init__("apply", [False, True])
674
+
675
+ def apply_aug_image(self, image, apply=False, **kwargs):
676
+ if apply:
677
+ image = hflip(image)
678
+ return image
679
+
680
+ def apply_deaug_mask(self, mask, apply=False, **kwargs):
681
+ if apply:
682
+ mask = hflip(mask)
683
+ return mask
684
+
685
+
686
+ class VerticalFlip(DualTransform):
687
+ """Flip images vertically (up->down)"""
688
+
689
+ identity_param = False
690
+
691
+ def __init__(self):
692
+ super().__init__("apply", [False, True])
693
+
694
+ def apply_aug_image(self, image, apply=False, **kwargs):
695
+ if apply:
696
+ image = vflip(image)
697
+ return image
698
+
699
+ def apply_deaug_mask(self, mask, apply=False, **kwargs):
700
+ if apply:
701
+ mask = vflip(mask)
702
+ return mask
703
+
704
+
705
+ #################### GradFlow Modules ##################################################
706
+ from scipy.ndimage.filters import maximum_filter1d
707
+ import scipy.ndimage
708
+ import fastremap
709
+ from skimage import morphology
710
+
711
+ from scipy.ndimage import mean
712
+
713
+ torch_GPU = torch.device("cuda")
714
+ torch_CPU = torch.device("cpu")
715
+
716
+
717
+ def _extend_centers_gpu(
718
+ neighbors, centers, isneighbor, Ly, Lx, n_iter=200, device=torch.device("cuda")
719
+ ):
720
+ if device is not None:
721
+ device = device
722
+ nimg = neighbors.shape[0] // 9
723
+ pt = torch.from_numpy(neighbors).to(device)
724
+
725
+ T = torch.zeros((nimg, Ly, Lx), dtype=torch.double, device=device)
726
+ meds = torch.from_numpy(centers.astype(int)).to(device).long()
727
+ isneigh = torch.from_numpy(isneighbor).to(device)
728
+ for i in range(n_iter):
729
+ T[:, meds[:, 0], meds[:, 1]] += 1
730
+ Tneigh = T[:, pt[:, :, 0], pt[:, :, 1]]
731
+ Tneigh *= isneigh
732
+ T[:, pt[0, :, 0], pt[0, :, 1]] = Tneigh.mean(axis=1)
733
+ del meds, isneigh, Tneigh
734
+ T = torch.log(1.0 + T)
735
+ # gradient positions
736
+ grads = T[:, pt[[2, 1, 4, 3], :, 0], pt[[2, 1, 4, 3], :, 1]]
737
+ del pt
738
+ dy = grads[:, 0] - grads[:, 1]
739
+ dx = grads[:, 2] - grads[:, 3]
740
+ del grads
741
+ mu_torch = np.stack((dy.cpu().squeeze(), dx.cpu().squeeze()), axis=-2)
742
+ return mu_torch
743
+
744
+
745
+ def diameters(masks):
746
+ _, counts = np.unique(np.int32(masks), return_counts=True)
747
+ counts = counts[1:]
748
+ md = np.median(counts ** 0.5)
749
+ if np.isnan(md):
750
+ md = 0
751
+ md /= (np.pi ** 0.5) / 2
752
+ return md, counts ** 0.5
753
+
754
+
755
+ def masks_to_flows_gpu(masks, device=None):
756
+ if device is None:
757
+ device = torch.device("cuda")
758
+
759
+ Ly0, Lx0 = masks.shape
760
+ Ly, Lx = Ly0 + 2, Lx0 + 2
761
+
762
+ masks_padded = np.zeros((Ly, Lx), np.int64)
763
+ masks_padded[1:-1, 1:-1] = masks
764
+
765
+ # get mask pixel neighbors
766
+ y, x = np.nonzero(masks_padded)
767
+ neighborsY = np.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), axis=0)
768
+ neighborsX = np.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), axis=0)
769
+ neighbors = np.stack((neighborsY, neighborsX), axis=-1)
770
+
771
+ # get mask centers
772
+ slices = scipy.ndimage.find_objects(masks)
773
+
774
+ centers = np.zeros((masks.max(), 2), "int")
775
+ for i, si in enumerate(slices):
776
+ if si is not None:
777
+ sr, sc = si
778
+
779
+ ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1
780
+ yi, xi = np.nonzero(masks[sr, sc] == (i + 1))
781
+ yi = yi.astype(np.int32) + 1 # add padding
782
+ xi = xi.astype(np.int32) + 1 # add padding
783
+ ymed = np.median(yi)
784
+ xmed = np.median(xi)
785
+ imin = np.argmin((xi - xmed) ** 2 + (yi - ymed) ** 2)
786
+ xmed = xi[imin]
787
+ ymed = yi[imin]
788
+ centers[i, 0] = ymed + sr.start
789
+ centers[i, 1] = xmed + sc.start
790
+
791
+ # get neighbor validator (not all neighbors are in same mask)
792
+ neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1]]
793
+ isneighbor = neighbor_masks == neighbor_masks[0]
794
+ ext = np.array(
795
+ [[sr.stop - sr.start + 1, sc.stop - sc.start + 1] for sr, sc in slices]
796
+ )
797
+ n_iter = 2 * (ext.sum(axis=1)).max()
798
+ # run diffusion
799
+ mu = _extend_centers_gpu(
800
+ neighbors, centers, isneighbor, Ly, Lx, n_iter=n_iter, device=device
801
+ )
802
+
803
+ # normalize
804
+ mu /= 1e-20 + (mu ** 2).sum(axis=0) ** 0.5
805
+
806
+ # put into original image
807
+ mu0 = np.zeros((2, Ly0, Lx0))
808
+ mu0[:, y - 1, x - 1] = mu
809
+ mu_c = np.zeros_like(mu0)
810
+ return mu0, mu_c
811
+
812
+
813
+ def masks_to_flows(masks, use_gpu=False, device=None):
814
+ if masks.max() == 0 or (masks != 0).sum() == 1:
815
+ # dynamics_logger.warning('empty masks!')
816
+ return np.zeros((2, *masks.shape), "float32")
817
+
818
+ if use_gpu:
819
+ if use_gpu and device is None:
820
+ device = torch_GPU
821
+ elif device is None:
822
+ device = torch_CPU
823
+ masks_to_flows_device = masks_to_flows_gpu
824
+
825
+ if masks.ndim == 3:
826
+ Lz, Ly, Lx = masks.shape
827
+ mu = np.zeros((3, Lz, Ly, Lx), np.float32)
828
+ for z in range(Lz):
829
+ mu0 = masks_to_flows_device(masks[z], device=device)[0]
830
+ mu[[1, 2], z] += mu0
831
+ for y in range(Ly):
832
+ mu0 = masks_to_flows_device(masks[:, y], device=device)[0]
833
+ mu[[0, 2], :, y] += mu0
834
+ for x in range(Lx):
835
+ mu0 = masks_to_flows_device(masks[:, :, x], device=device)[0]
836
+ mu[[0, 1], :, :, x] += mu0
837
+ return mu
838
+ elif masks.ndim == 2:
839
+ mu, mu_c = masks_to_flows_device(masks, device=device)
840
+ return mu
841
+
842
+ else:
843
+ raise ValueError("masks_to_flows only takes 2D or 3D arrays")
844
+
845
+
846
+ def steps2D_interp(p, dP, niter, use_gpu=False, device=None):
847
+ shape = dP.shape[1:]
848
+ if use_gpu:
849
+ if device is None:
850
+ device = torch_GPU
851
+ shape = (
852
+ np.array(shape)[[1, 0]].astype("float") - 1
853
+ ) # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1
854
+ pt = (
855
+ torch.from_numpy(p[[1, 0]].T).float().to(device).unsqueeze(0).unsqueeze(0)
856
+ ) # p is n_points by 2, so pt is [1 1 2 n_points]
857
+ im = (
858
+ torch.from_numpy(dP[[1, 0]]).float().to(device).unsqueeze(0)
859
+ ) # covert flow numpy array to tensor on GPU, add dimension
860
+ # normalize pt between 0 and 1, normalize the flow
861
+ for k in range(2):
862
+ im[:, k, :, :] *= 2.0 / shape[k]
863
+ pt[:, :, :, k] /= shape[k]
864
+
865
+ # normalize to between -1 and 1
866
+ pt = pt * 2 - 1
867
+
868
+ # here is where the stepping happens
869
+ for t in range(niter):
870
+ # align_corners default is False, just added to suppress warning
871
+ dPt = grid_sample(im, pt, align_corners=False)
872
+
873
+ for k in range(2): # clamp the final pixel locations
874
+ pt[:, :, :, k] = torch.clamp(
875
+ pt[:, :, :, k] + dPt[:, k, :, :], -1.0, 1.0
876
+ )
877
+
878
+ # undo the normalization from before, reverse order of operations
879
+ pt = (pt + 1) * 0.5
880
+ for k in range(2):
881
+ pt[:, :, :, k] *= shape[k]
882
+
883
+ p = pt[:, :, :, [1, 0]].cpu().numpy().squeeze().T
884
+ return p
885
+
886
+ else:
887
+ assert print("ho")
888
+
889
+
890
+ def follow_flows(dP, mask=None, niter=200, interp=True, use_gpu=True, device=None):
891
+ shape = np.array(dP.shape[1:]).astype(np.int32)
892
+ niter = np.uint32(niter)
893
+
894
+ p = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing="ij")
895
+ p = np.array(p).astype(np.float32)
896
+
897
+ inds = np.array(np.nonzero(np.abs(dP[0]) > 1e-3)).astype(np.int32).T
898
+
899
+ if inds.ndim < 2 or inds.shape[0] < 5:
900
+ return p, None
901
+
902
+ if not interp:
903
+ assert print("woo")
904
+
905
+ else:
906
+ p_interp = steps2D_interp(
907
+ p[:, inds[:, 0], inds[:, 1]], dP, niter, use_gpu=use_gpu, device=device
908
+ )
909
+ p[:, inds[:, 0], inds[:, 1]] = p_interp
910
+
911
+ return p, inds
912
+
913
+
914
+ def flow_error(maski, dP_net, use_gpu=False, device=None):
915
+ if dP_net.shape[1:] != maski.shape:
916
+ print("ERROR: net flow is not same size as predicted masks")
917
+ return
918
+
919
+ # flows predicted from estimated masks
920
+ dP_masks = masks_to_flows(maski, use_gpu=use_gpu, device=device)
921
+ # difference between predicted flows vs mask flows
922
+ flow_errors = np.zeros(maski.max())
923
+ for i in range(dP_masks.shape[0]):
924
+ flow_errors += mean(
925
+ (dP_masks[i] - dP_net[i] / 5.0) ** 2,
926
+ maski,
927
+ index=np.arange(1, maski.max() + 1),
928
+ )
929
+
930
+ return flow_errors, dP_masks
931
+
932
+
933
+ def remove_bad_flow_masks(masks, flows, threshold=0.4, use_gpu=False, device=None):
934
+ merrors, _ = flow_error(masks, flows, use_gpu, device)
935
+ badi = 1 + (merrors > threshold).nonzero()[0]
936
+ masks[np.isin(masks, badi)] = 0
937
+ return masks
938
+
939
+
940
+ def get_masks(p, iscell=None, rpad=20):
941
+ pflows = []
942
+ edges = []
943
+ shape0 = p.shape[1:]
944
+ dims = len(p)
945
+
946
+ for i in range(dims):
947
+ pflows.append(p[i].flatten().astype("int32"))
948
+ edges.append(np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1))
949
+
950
+ h, _ = np.histogramdd(tuple(pflows), bins=edges)
951
+ hmax = h.copy()
952
+ for i in range(dims):
953
+ hmax = maximum_filter1d(hmax, 5, axis=i)
954
+
955
+ seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10))
956
+ Nmax = h[seeds]
957
+ isort = np.argsort(Nmax)[::-1]
958
+ for s in seeds:
959
+ s = s[isort]
960
+
961
+ pix = list(np.array(seeds).T)
962
+
963
+ shape = h.shape
964
+ if dims == 3:
965
+ expand = np.nonzero(np.ones((3, 3, 3)))
966
+ else:
967
+ expand = np.nonzero(np.ones((3, 3)))
968
+ for e in expand:
969
+ e = np.expand_dims(e, 1)
970
+
971
+ for iter in range(5):
972
+ for k in range(len(pix)):
973
+ if iter == 0:
974
+ pix[k] = list(pix[k])
975
+ newpix = []
976
+ iin = []
977
+ for i, e in enumerate(expand):
978
+ epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
979
+ epix = epix.flatten()
980
+ iin.append(np.logical_and(epix >= 0, epix < shape[i]))
981
+ newpix.append(epix)
982
+ iin = np.all(tuple(iin), axis=0)
983
+ for p in newpix:
984
+ p = p[iin]
985
+ newpix = tuple(newpix)
986
+ igood = h[newpix] > 2
987
+ for i in range(dims):
988
+ pix[k][i] = newpix[i][igood]
989
+ if iter == 4:
990
+ pix[k] = tuple(pix[k])
991
+
992
+ M = np.zeros(h.shape, np.uint32)
993
+ for k in range(len(pix)):
994
+ M[pix[k]] = 1 + k
995
+
996
+ for i in range(dims):
997
+ pflows[i] = pflows[i] + rpad
998
+ M0 = M[tuple(pflows)]
999
+
1000
+ # remove big masks
1001
+ uniq, counts = fastremap.unique(M0, return_counts=True)
1002
+ big = np.prod(shape0) * 0.9
1003
+ bigc = uniq[counts > big]
1004
+ if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
1005
+ M0 = fastremap.mask(M0, bigc)
1006
+ fastremap.renumber(M0, in_place=True) # convenient to guarantee non-skipped labels
1007
+ M0 = np.reshape(M0, shape0)
1008
+ return M0
1009
+
1010
+ def fill_holes_and_remove_small_masks(masks, min_size=15):
1011
+ """ fill holes in masks (2D/3D) and discard masks smaller than min_size (2D)
1012
+
1013
+ fill holes in each mask using scipy.ndimage.morphology.binary_fill_holes
1014
+ (might have issues at borders between cells, todo: check and fix)
1015
+
1016
+ Parameters
1017
+ ----------------
1018
+ masks: int, 2D or 3D array
1019
+ labelled masks, 0=NO masks; 1,2,...=mask labels,
1020
+ size [Ly x Lx] or [Lz x Ly x Lx]
1021
+ min_size: int (optional, default 15)
1022
+ minimum number of pixels per mask, can turn off with -1
1023
+ Returns
1024
+ ---------------
1025
+ masks: int, 2D or 3D array
1026
+ masks with holes filled and masks smaller than min_size removed,
1027
+ 0=NO masks; 1,2,...=mask labels,
1028
+ size [Ly x Lx] or [Lz x Ly x Lx]
1029
+
1030
+ """
1031
+
1032
+ slices = find_objects(masks)
1033
+ j = 0
1034
+ for i,slc in enumerate(slices):
1035
+ if slc is not None:
1036
+ msk = masks[slc] == (i+1)
1037
+ npix = msk.sum()
1038
+ if min_size > 0 and npix < min_size:
1039
+ masks[slc][msk] = 0
1040
+ elif npix > 0:
1041
+ if msk.ndim==3:
1042
+ for k in range(msk.shape[0]):
1043
+ msk[k] = binary_fill_holes(msk[k])
1044
+ else:
1045
+ msk = binary_fill_holes(msk)
1046
+ masks[slc][msk] = (j+1)
1047
+ j+=1
1048
+ return masks
1049
+
1050
+ def compute_masks(
1051
+ dP,
1052
+ cellprob,
1053
+ p=None,
1054
+ niter=200,
1055
+ cellprob_threshold=0.4,
1056
+ flow_threshold=0.4,
1057
+ interp=True,
1058
+ resize=None,
1059
+ use_gpu=False,
1060
+ device=None,
1061
+ ):
1062
+ """compute masks using dynamics from dP, cellprob, and boundary"""
1063
+
1064
+ cp_mask = cellprob > cellprob_threshold
1065
+ cp_mask = morphology.remove_small_holes(cp_mask, area_threshold=16)
1066
+ cp_mask = morphology.remove_small_objects(cp_mask, min_size=16)
1067
+
1068
+ if np.any(cp_mask): # mask at this point is a cell cluster binary map, not labels
1069
+ # follow flows
1070
+ if p is None:
1071
+ p, inds = follow_flows(
1072
+ dP * cp_mask / 5.0,
1073
+ niter=niter,
1074
+ interp=interp,
1075
+ use_gpu=use_gpu,
1076
+ device=device,
1077
+ )
1078
+ if inds is None:
1079
+ shape = resize if resize is not None else cellprob.shape
1080
+ mask = np.zeros(shape, np.uint16)
1081
+ p = np.zeros((len(shape), *shape), np.uint16)
1082
+ return mask, p
1083
+
1084
+ # calculate masks
1085
+ mask = get_masks(p, iscell=cp_mask)
1086
+
1087
+ # flow thresholding factored out of get_masks
1088
+ shape0 = p.shape[1:]
1089
+ if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
1090
+ # make sure labels are unique at output of get_masks
1091
+ mask = remove_bad_flow_masks(
1092
+ mask, dP, threshold=flow_threshold, use_gpu=use_gpu, device=device
1093
+ )
1094
+
1095
+ mask = fill_holes_and_remove_small_masks(mask, min_size=15)
1096
+
1097
+ else: # nothing to compute, just make it compatible
1098
+ shape = resize if resize is not None else cellprob.shape
1099
+ mask = np.zeros(shape, np.uint16)
1100
+ p = np.zeros((len(shape), *shape), np.uint16)
1101
+ return mask, p
1102
+
1103
+ return mask, p
1104
+
1105
+ def predict(img):
1106
+ # Dataset parameters
1107
+ ### for huggingface space
1108
+ device = "cpu"
1109
+ model_path = "./main_model.pt"
1110
+ model_path2 = "./sub_model.pth"
1111
+ ###
1112
+ model = torch.load(model_path, map_location=device)
1113
+ model.eval()
1114
+ hflip_tta = HorizontalFlip()
1115
+ vflip_tta = VerticalFlip()
1116
+
1117
+ img_name = img.name
1118
+ # if img_name.endswith('.tif') or img_name.endswith('.tiff'):
1119
+ # img_data = tif.imread(img_name)
1120
+ # else:
1121
+ # img_data = io.imread(img_name)
1122
+
1123
+ img_data = pred_transforms(img_name)
1124
+ img_data = img_data.to(device)
1125
+ img_size = img_data.shape[-1] * img_data.shape[-2]
1126
+
1127
+ if img_size < 1150000 and 900000 < img_size:
1128
+ overlap = 0.5
1129
+ else:
1130
+ overlap = 0.6
1131
+
1132
+ with torch.no_grad():
1133
+ img0 = img_data
1134
+ outputs0 = sliding_window_inference(
1135
+ img0,
1136
+ 512,
1137
+ 4,
1138
+ model,
1139
+ padding_mode="reflect",
1140
+ mode="gaussian",
1141
+ overlap=overlap,
1142
+ device="cpu",
1143
+ )
1144
+ outputs0 = outputs0.cpu().squeeze()
1145
+
1146
+ if img_size < 2000 * 2000:
1147
+
1148
+ model.load_state_dict(torch.load(model_path2, map_location=device))
1149
+ model.eval()
1150
+
1151
+ img2 = hflip_tta.apply_aug_image(img_data, apply=True)
1152
+ outputs2 = sliding_window_inference(
1153
+ img2,
1154
+ 512,
1155
+ 4,
1156
+ model,
1157
+ padding_mode="reflect",
1158
+ mode="gauusian",
1159
+ overlap=overlap,
1160
+ device="cpu",
1161
+ )
1162
+ outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True)
1163
+ outputs2 = outputs2.cpu().squeeze()
1164
+
1165
+ outputs = torch.zeros_like(outputs0)
1166
+ outputs[0] = (outputs0[0] + outputs2[0]) / 2
1167
+ outputs[1] = (outputs0[1] - outputs2[1]) / 2
1168
+ outputs[2] = (outputs0[2] + outputs2[2]) / 2
1169
+
1170
+ elif img_size < 5000*5000:
1171
+ # Hflip TTA
1172
+ img2 = hflip_tta.apply_aug_image(img_data, apply=True)
1173
+ outputs2 = sliding_window_inference(
1174
+ img2,
1175
+ 512,
1176
+ 4,
1177
+ model,
1178
+ padding_mode="reflect",
1179
+ mode="gaussian",
1180
+ overlap=overlap,
1181
+ device="cpu",
1182
+ )
1183
+ outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True)
1184
+ outputs2 = outputs2.cpu().squeeze()
1185
+ img2 = img2.cpu()
1186
+
1187
+ ##################
1188
+ # #
1189
+ # ensemble #
1190
+ # #
1191
+ ##################
1192
+
1193
+ model.load_state_dict(torch.load(model_path2, map_location=device))
1194
+ model.eval()
1195
+
1196
+ img1 = img_data
1197
+ outputs1 = sliding_window_inference(
1198
+ img1,
1199
+ 512,
1200
+ 4,
1201
+ model,
1202
+ padding_mode="reflect",
1203
+ mode="gaussian",
1204
+ overlap=overlap,
1205
+ device="cpu",
1206
+ )
1207
+ outputs1 = outputs1.cpu().squeeze()
1208
+
1209
+ # Vflip TTA
1210
+ img3 = vflip_tta.apply_aug_image(img_data, apply=True)
1211
+ outputs3 = sliding_window_inference(
1212
+ img3,
1213
+ 512,
1214
+ 4,
1215
+ model,
1216
+ padding_mode="reflect",
1217
+ mode="gaussian",
1218
+ overlap=overlap,
1219
+ device="cpu",
1220
+ )
1221
+ outputs3 = vflip_tta.apply_deaug_mask(outputs3, apply=True)
1222
+ outputs3 = outputs3.cpu().squeeze()
1223
+ img3 = img3.cpu()
1224
+
1225
+ # Merge Results
1226
+ outputs = torch.zeros_like(outputs0)
1227
+ outputs[0] = (outputs0[0] + outputs1[0] + outputs2[0] - outputs3[0]) / 4
1228
+ outputs[1] = (outputs0[1] + outputs1[1] - outputs2[1] + outputs3[1]) / 4
1229
+ outputs[2] = (outputs0[2] + outputs1[2] + outputs2[2] + outputs3[2]) / 4
1230
+ else:
1231
+ outputs = outputs0
1232
+
1233
+ pred_mask = post_process(outputs.squeeze(0).cpu().numpy(), device)
1234
+
1235
+ file_path = os.path.join(
1236
+ os.getcwd(), img_name.split(".")[0] + "_label.tiff"
1237
+ )
1238
+
1239
+ tif.imwrite(file_path, pred_mask, compression="zlib")
1240
+ # return img_data, seg_rgb, join(os.getcwd(), 'segmentation.tiff')
1241
+ return img_data, pred_mask, file_path
1242
+
1243
+ demo = gr.Interface(
1244
+ predict,
1245
+ # inputs=[gr.Image()],
1246
+ # inputs="file",
1247
+ inputs=[gr.File(label="input image")],
1248
+ outputs=[gr.Image(label="image"), gr.Image(label="segmentation"), gr.File(label="download segmentation")],
1249
+ title="NeurIPS Cellseg MEDIAR",
1250
+ )
1251
+ demo.launch()