MEDIAR / predict.py
ghlee94's picture
Init
2a13495
raw
history blame
41.8 kB
import torch
from torch.nn import (
Module,
Conv2d,
BatchNorm2d,
Identity,
UpsamplingBilinear2d,
Mish,
ReLU,
Sequential,
)
from torch.nn.functional import interpolate, grid_sample, pad
import numpy as np
from copy import deepcopy
import os, argparse, math
import tifffile as tif
from typing import Tuple, List, Mapping
from monai.utils import (
BlendMode,
PytorchPadMode,
convert_data_type,
ensure_tuple,
fall_back_tuple,
look_up_option,
convert_to_dst_type,
)
from monai.utils.misc import ensure_tuple_size, ensure_tuple_rep, issequenceiterable
from monai.networks.layers.convutils import gaussian_1d
from monai.networks.layers.simplelayers import separable_filtering
from segmentation_models_pytorch import MAnet
from skimage.io import imread as io_imread
from skimage.util.dtype import dtype_range
from skimage._shared.utils import _supported_float_type
from scipy.ndimage import find_objects, binary_fill_holes
########################### Data Loading Modules #########################################################
DTYPE_RANGE = dtype_range.copy()
DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items())
DTYPE_RANGE.update(
{
"uint10": (0, 2 ** 10 - 1),
"uint12": (0, 2 ** 12 - 1),
"uint14": (0, 2 ** 14 - 1),
"bool": dtype_range[bool],
"float": dtype_range[np.float64],
}
)
def _output_dtype(dtype_or_range, image_dtype):
if type(dtype_or_range) in [list, tuple, np.ndarray]:
# pair of values: always return float.
return _supported_float_type(image_dtype)
if type(dtype_or_range) == type:
# already a type: return it
return dtype_or_range
if dtype_or_range in DTYPE_RANGE:
# string key in DTYPE_RANGE dictionary
try:
# if it's a canonical numpy dtype, convert
return np.dtype(dtype_or_range).type
except TypeError: # uint10, uint12, uint14
# otherwise, return uint16
return np.uint16
else:
raise ValueError(
"Incorrect value for out_range, should be a valid image data "
f"type or a pair of values, got {dtype_or_range}."
)
def intensity_range(image, range_values="image", clip_negative=False):
if range_values == "dtype":
range_values = image.dtype.type
if range_values == "image":
i_min = np.min(image)
i_max = np.max(image)
elif range_values in DTYPE_RANGE:
i_min, i_max = DTYPE_RANGE[range_values]
if clip_negative:
i_min = 0
else:
i_min, i_max = range_values
return i_min, i_max
def rescale_intensity(image, in_range="image", out_range="dtype"):
out_dtype = _output_dtype(out_range, image.dtype)
imin, imax = map(float, intensity_range(image, in_range))
omin, omax = map(
float, intensity_range(image, out_range, clip_negative=(imin >= 0))
)
image = np.clip(image, imin, imax)
if imin != imax:
image = (image - imin) / (imax - imin)
return np.asarray(image * (omax - omin) + omin, dtype=out_dtype)
else:
return np.clip(image, omin, omax).astype(out_dtype)
def _normalize(img):
non_zero_vals = img[np.nonzero(img)]
percentiles = np.percentile(non_zero_vals, [0, 99.5])
img_norm = rescale_intensity(
img, in_range=(percentiles[0], percentiles[1]), out_range="uint8"
)
return img_norm.astype(np.uint8)
def pred_transforms(filename):
# LoadImage
img = (
tif.imread(filename)
if filename.endswith(".tif") or filename.endswith(".tiff")
else io_imread(filename)
)
if len(img.shape) == 2:
img = np.repeat(np.expand_dims(img, axis=-1), 3, axis=-1)
elif len(img.shape) == 3 and img.shape[-1] > 3:
img = img[:, :, :3]
img = img.astype(np.float32)
img = _normalize(img)
img = np.moveaxis(img, -1, 0)
img = (img - img.min()) / (img.max() - img.min())
return torch.FloatTensor(img).unsqueeze(0)
################################################################################
########################### MODEL Architecture #################################
class SegformerGH(MAnet):
def __init__(
self,
encoder_name: str = "mit_b5",
encoder_weights="imagenet",
decoder_channels=(256, 128, 64, 32, 32),
decoder_pab_channels=256,
in_channels: int = 3,
classes: int = 3,
):
super(SegformerGH, self).__init__(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
decoder_channels=decoder_channels,
decoder_pab_channels=decoder_pab_channels,
in_channels=in_channels,
classes=classes,
)
convert_relu_to_mish(self.encoder)
convert_relu_to_mish(self.decoder)
self.cellprob_head = DeepSegmantationHead(
in_channels=decoder_channels[-1], out_channels=1, kernel_size=3,
)
self.gradflow_head = DeepSegmantationHead(
in_channels=decoder_channels[-1], out_channels=2, kernel_size=3,
)
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
self.check_input_shape(x)
features = self.encoder(x)
decoder_output = self.decoder(*features)
gradflow_mask = self.gradflow_head(decoder_output)
cellprob_mask = self.cellprob_head(decoder_output)
masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
return masks
class DeepSegmantationHead(Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d_1 = Conv2d(
in_channels,
in_channels // 2,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
bn = BatchNorm2d(in_channels // 2)
conv2d_2 = Conv2d(
in_channels // 2,
out_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
)
mish = Mish(inplace=True)
upsampling = (
UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else Identity()
)
activation = Identity()
super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation)
def convert_relu_to_mish(model):
for child_name, child in model.named_children():
if isinstance(child, ReLU):
setattr(model, child_name, Mish(inplace=True))
else:
convert_relu_to_mish(child)
#####################################################################################
########################### Sliding Window Inference #################################
class GaussianFilter(Module):
def __init__(
self, spatial_dims, sigma, truncated=4.0, approx="erf", requires_grad=False,
) -> None:
if issequenceiterable(sigma):
if len(sigma) != spatial_dims: # type: ignore
raise ValueError
else:
sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore
super().__init__()
self.sigma = [
torch.nn.Parameter(
torch.as_tensor(
s,
dtype=torch.float,
device=s.device if isinstance(s, torch.Tensor) else None,
),
requires_grad=requires_grad,
)
for s in sigma # type: ignore
]
self.truncated = truncated
self.approx = approx
for idx, param in enumerate(self.sigma):
self.register_parameter(f"kernel_sigma_{idx}", param)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_kernel = [
gaussian_1d(s, truncated=self.truncated, approx=self.approx)
for s in self.sigma
]
return separable_filtering(x=x, kernels=_kernel)
def compute_importance_map(
patch_size, mode=BlendMode.CONSTANT, sigma_scale=0.125, device="cpu"
):
mode = look_up_option(mode, BlendMode)
device = torch.device(device)
center_coords = [i // 2 for i in patch_size]
sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]
importance_map = torch.zeros(patch_size, device=device)
importance_map[tuple(center_coords)] = 1
pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(
device=device, dtype=torch.float
)
importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
importance_map = importance_map.squeeze(0).squeeze(0)
importance_map = importance_map / torch.max(importance_map)
importance_map = importance_map.float()
return importance_map
def first(iterable, default=None):
for i in iterable:
return i
return default
def dense_patch_slices(image_size, patch_size, scan_interval):
num_spatial_dims = len(image_size)
patch_size = get_valid_patch_size(image_size, patch_size)
scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)
scan_num = []
for i in range(num_spatial_dims):
if scan_interval[i] == 0:
scan_num.append(1)
else:
num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
scan_dim = first(
d
for d in range(num)
if d * scan_interval[i] + patch_size[i] >= image_size[i]
)
scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
starts = []
for dim in range(num_spatial_dims):
dim_starts = []
for idx in range(scan_num[dim]):
start_idx = idx * scan_interval[dim]
start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
dim_starts.append(start_idx)
starts.append(dim_starts)
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
def get_valid_patch_size(image_size, patch_size):
ndim = len(image_size)
patch_size_ = ensure_tuple_size(patch_size, ndim)
# ensure patch size dimensions are not larger than image dimension, if a dimension is None or 0 use whole dimension
return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_))
class Resize:
def __init__(self, spatial_size):
self.size_mode = "all"
self.spatial_size = spatial_size
def __call__(self, img):
input_ndim = img.ndim - 1 # spatial ndim
output_ndim = len(ensure_tuple(self.spatial_size))
if output_ndim > input_ndim:
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
img = img.reshape(input_shape)
spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:])
if (
tuple(img.shape[1:]) == spatial_size_
): # spatial shape is already the desired
return img
img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
resized = interpolate(
input=img_.unsqueeze(0), size=spatial_size_, mode="nearest",
)
out, *_ = convert_to_dst_type(resized.squeeze(0), img)
return out
def sliding_window_inference(
inputs,
roi_size,
sw_batch_size,
predictor,
overlap,
mode=BlendMode.CONSTANT,
sigma_scale=0.125,
padding_mode=PytorchPadMode.CONSTANT,
cval=0.0,
sw_device=None,
device=None,
roi_weight_map=None,
):
compute_dtype = inputs.dtype
num_spatial_dims = len(inputs.shape) - 2
batch_size, _, *image_size_ = inputs.shape
roi_size = fall_back_tuple(roi_size, image_size_)
# in case that image size is smaller than roi size
image_size = tuple(
max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)
)
pad_size = []
for k in range(len(inputs.shape) - 1, 1, -1):
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
inputs = pad(
inputs,
pad=pad_size,
mode=look_up_option(padding_mode, PytorchPadMode).value,
value=cval,
)
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
# Store all slices in list
slices = dense_patch_slices(image_size, roi_size, scan_interval)
num_win = len(slices) # number of windows per image
total_slices = num_win * batch_size # total number of windows
# Create window-level importance map
valid_patch_size = get_valid_patch_size(image_size, roi_size)
if valid_patch_size == roi_size and (roi_weight_map is not None):
importance_map = roi_weight_map
else:
importance_map = compute_importance_map(
valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device
)
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
# handle non-positive weights
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(
compute_dtype
)
# Perform predictions
dict_key, output_image_list, count_map_list = None, [], []
_initialized_ss = -1
is_tensor_output = (
True # whether the predictor's output is a tensor (instead of dict/tuple)
)
# for each patch
for slice_g in range(0, total_slices, sw_batch_size):
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
unravel_slice = [
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)]
+ list(slices[idx % num_win])
for idx in slice_range
]
window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(
sw_device
)
seg_prob_out = predictor(window_data) # batched patch segmentation
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
seg_prob_tuple: Tuple[torch.Tensor, ...]
if isinstance(seg_prob_out, torch.Tensor):
seg_prob_tuple = (seg_prob_out,)
elif isinstance(seg_prob_out, Mapping):
if dict_key is None:
dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
is_tensor_output = False
else:
seg_prob_tuple = ensure_tuple(seg_prob_out)
is_tensor_output = False
# for each output in multi-output list
for ss, seg_prob in enumerate(seg_prob_tuple):
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
# compute zoom scale: out_roi_size/in_roi_size
zoom_scale = []
for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
):
_scale = out_w_i / float(in_w_i)
zoom_scale.append(_scale)
if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
# construct multi-resolution outputs
output_classes = seg_prob.shape[1]
output_shape = [batch_size, output_classes] + [
int(image_size_d * zoom_scale_d)
for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
]
# allocate memory to store the full output and the count for overlapping parts
output_image_list.append(
torch.zeros(output_shape, dtype=compute_dtype, device=device)
)
count_map_list.append(
torch.zeros(
[1, 1] + output_shape[2:], dtype=compute_dtype, device=device
)
)
_initialized_ss += 1
# resizing the importance_map
resizer = Resize(spatial_size=seg_prob.shape[2:])
# store the result in the proper location of the full output. Apply weights from importance map.
for idx, original_idx in zip(slice_range, unravel_slice):
# zoom roi
original_idx_zoom = list(
original_idx
) # 4D for 2D image, 5D for 3D image
for axis in range(2, len(original_idx_zoom)):
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
original_idx_zoom[axis] = slice(
int(zoomed_start), int(zoomed_end), None
)
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(
compute_dtype
)
# store results and weights
output_image_list[ss][original_idx_zoom] += (
importance_map_zoom * seg_prob[idx - slice_g]
)
count_map_list[ss][original_idx_zoom] += (
importance_map_zoom.unsqueeze(0)
.unsqueeze(0)
.expand(count_map_list[ss][original_idx_zoom].shape)
)
# account for any overlapping sections
for ss in range(len(output_image_list)):
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(
compute_dtype
)
# remove padding if image_size smaller than roi_size
for ss, output_i in enumerate(output_image_list):
zoom_scale = [
seg_prob_map_shape_d / roi_size_d
for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
]
final_slicing: List[slice] = []
for sp in range(num_spatial_dims):
slice_dim = slice(
pad_size[sp * 2],
image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2],
)
slice_dim = slice(
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
)
final_slicing.insert(0, slice_dim)
while len(final_slicing) < len(output_i.shape):
final_slicing.insert(0, slice(None))
output_image_list[ss] = output_i[final_slicing]
if dict_key is not None: # if output of predictor is a dict
final_output = dict(zip(dict_key, output_image_list))
else:
final_output = tuple(output_image_list) # type: ignore
return final_output[0] if is_tensor_output else final_output # type: ignore
def _get_scan_interval(
image_size, roi_size, num_spatial_dims: int, overlap: float
) -> Tuple[int, ...]:
scan_interval = []
for i in range(num_spatial_dims):
if roi_size[i] == image_size[i]:
scan_interval.append(int(roi_size[i]))
else:
interval = int(roi_size[i] * (1 - overlap))
scan_interval.append(interval if interval > 0 else 1)
return tuple(scan_interval)
#####################################################################################
########################### Main Inference Functions #################################
def post_process(pred_mask, device):
dP, cellprob = pred_mask[:2], 1 / (1 + np.exp(-pred_mask[-1]))
H, W = pred_mask.shape[-2], pred_mask.shape[-1]
if np.prod(H * W) < (5000 * 5000):
pred_mask = compute_masks(
dP,
cellprob,
use_gpu=True,
flow_threshold=0.4,
device=device,
cellprob_threshold=0.4,
)[0]
else:
print("\n[Whole Slide] Grid Prediction starting...")
roi_size = 2000
# Get patch grid by roi_size
if H % roi_size != 0:
n_H = H // roi_size + 1
new_H = roi_size * n_H
else:
n_H = H // roi_size
new_H = H
if W % roi_size != 0:
n_W = W // roi_size + 1
new_W = roi_size * n_W
else:
n_W = W // roi_size
new_W = W
# Allocate values on the grid
pred_pad = np.zeros((new_H, new_W), dtype=np.uint32)
dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32)
cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32)
dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob
for i in range(n_H):
for j in range(n_W):
print("Pred on Grid (%d, %d) processing..." % (i, j))
dP_roi = dP_pad[
:,
roi_size * i : roi_size * (i + 1),
roi_size * j : roi_size * (j + 1),
]
cellprob_roi = cellprob_pad[
roi_size * i : roi_size * (i + 1),
roi_size * j : roi_size * (j + 1),
]
pred_mask = compute_masks(
dP_roi,
cellprob_roi,
use_gpu=True,
flow_threshold=0.4,
device=device,
cellprob_threshold=0.4,
)[0]
pred_pad[
roi_size * i : roi_size * (i + 1),
roi_size * j : roi_size * (j + 1),
] = pred_mask
pred_mask = pred_pad[:H, :W]
cell_idx, cell_sizes = np.unique(pred_mask, return_counts=True)
cell_idx, cell_sizes = cell_idx[1:], cell_sizes[1:]
cell_drop = np.where(cell_sizes < np.mean(cell_sizes) - 2.7 * np.std(cell_sizes))
for drop_cell in cell_idx[cell_drop]:
pred_mask[pred_mask == drop_cell] = 0
return pred_mask
def hflip(x):
"""flip batch of images horizontally"""
return x.flip(3)
def vflip(x):
"""flip batch of images vertically"""
return x.flip(2)
class DualTransform:
identity_param = None
def __init__(
self, name: str, params,
):
self.params = params
self.pname = name
def apply_aug_image(self, image, *args, **params):
raise NotImplementedError
def apply_deaug_mask(self, mask, *args, **params):
raise NotImplementedError
class HorizontalFlip(DualTransform):
"""Flip images horizontally (left->right)"""
identity_param = False
def __init__(self):
super().__init__("apply", [False, True])
def apply_aug_image(self, image, apply=False, **kwargs):
if apply:
image = hflip(image)
return image
def apply_deaug_mask(self, mask, apply=False, **kwargs):
if apply:
mask = hflip(mask)
return mask
class VerticalFlip(DualTransform):
"""Flip images vertically (up->down)"""
identity_param = False
def __init__(self):
super().__init__("apply", [False, True])
def apply_aug_image(self, image, apply=False, **kwargs):
if apply:
image = vflip(image)
return image
def apply_deaug_mask(self, mask, apply=False, **kwargs):
if apply:
mask = vflip(mask)
return mask
#################### GradFlow Modules ##################################################
from scipy.ndimage.filters import maximum_filter1d
import scipy.ndimage
import fastremap
from skimage import morphology
from scipy.ndimage import mean
torch_GPU = torch.device("cuda")
torch_CPU = torch.device("cpu")
def _extend_centers_gpu(
neighbors, centers, isneighbor, Ly, Lx, n_iter=200, device=torch.device("cuda")
):
if device is not None:
device = device
nimg = neighbors.shape[0] // 9
pt = torch.from_numpy(neighbors).to(device)
T = torch.zeros((nimg, Ly, Lx), dtype=torch.double, device=device)
meds = torch.from_numpy(centers.astype(int)).to(device).long()
isneigh = torch.from_numpy(isneighbor).to(device)
for i in range(n_iter):
T[:, meds[:, 0], meds[:, 1]] += 1
Tneigh = T[:, pt[:, :, 0], pt[:, :, 1]]
Tneigh *= isneigh
T[:, pt[0, :, 0], pt[0, :, 1]] = Tneigh.mean(axis=1)
del meds, isneigh, Tneigh
T = torch.log(1.0 + T)
# gradient positions
grads = T[:, pt[[2, 1, 4, 3], :, 0], pt[[2, 1, 4, 3], :, 1]]
del pt
dy = grads[:, 0] - grads[:, 1]
dx = grads[:, 2] - grads[:, 3]
del grads
mu_torch = np.stack((dy.cpu().squeeze(), dx.cpu().squeeze()), axis=-2)
return mu_torch
def diameters(masks):
_, counts = np.unique(np.int32(masks), return_counts=True)
counts = counts[1:]
md = np.median(counts ** 0.5)
if np.isnan(md):
md = 0
md /= (np.pi ** 0.5) / 2
return md, counts ** 0.5
def masks_to_flows_gpu(masks, device=None):
if device is None:
device = torch.device("cuda")
Ly0, Lx0 = masks.shape
Ly, Lx = Ly0 + 2, Lx0 + 2
masks_padded = np.zeros((Ly, Lx), np.int64)
masks_padded[1:-1, 1:-1] = masks
# get mask pixel neighbors
y, x = np.nonzero(masks_padded)
neighborsY = np.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), axis=0)
neighborsX = np.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), axis=0)
neighbors = np.stack((neighborsY, neighborsX), axis=-1)
# get mask centers
slices = scipy.ndimage.find_objects(masks)
centers = np.zeros((masks.max(), 2), "int")
for i, si in enumerate(slices):
if si is not None:
sr, sc = si
ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1
yi, xi = np.nonzero(masks[sr, sc] == (i + 1))
yi = yi.astype(np.int32) + 1 # add padding
xi = xi.astype(np.int32) + 1 # add padding
ymed = np.median(yi)
xmed = np.median(xi)
imin = np.argmin((xi - xmed) ** 2 + (yi - ymed) ** 2)
xmed = xi[imin]
ymed = yi[imin]
centers[i, 0] = ymed + sr.start
centers[i, 1] = xmed + sc.start
# get neighbor validator (not all neighbors are in same mask)
neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1]]
isneighbor = neighbor_masks == neighbor_masks[0]
ext = np.array(
[[sr.stop - sr.start + 1, sc.stop - sc.start + 1] for sr, sc in slices]
)
n_iter = 2 * (ext.sum(axis=1)).max()
# run diffusion
mu = _extend_centers_gpu(
neighbors, centers, isneighbor, Ly, Lx, n_iter=n_iter, device=device
)
# normalize
mu /= 1e-20 + (mu ** 2).sum(axis=0) ** 0.5
# put into original image
mu0 = np.zeros((2, Ly0, Lx0))
mu0[:, y - 1, x - 1] = mu
mu_c = np.zeros_like(mu0)
return mu0, mu_c
def masks_to_flows(masks, use_gpu=False, device=None):
if masks.max() == 0 or (masks != 0).sum() == 1:
# dynamics_logger.warning('empty masks!')
return np.zeros((2, *masks.shape), "float32")
if use_gpu:
if use_gpu and device is None:
device = torch_GPU
elif device is None:
device = torch_CPU
masks_to_flows_device = masks_to_flows_gpu
if masks.ndim == 3:
Lz, Ly, Lx = masks.shape
mu = np.zeros((3, Lz, Ly, Lx), np.float32)
for z in range(Lz):
mu0 = masks_to_flows_device(masks[z], device=device)[0]
mu[[1, 2], z] += mu0
for y in range(Ly):
mu0 = masks_to_flows_device(masks[:, y], device=device)[0]
mu[[0, 2], :, y] += mu0
for x in range(Lx):
mu0 = masks_to_flows_device(masks[:, :, x], device=device)[0]
mu[[0, 1], :, :, x] += mu0
return mu
elif masks.ndim == 2:
mu, mu_c = masks_to_flows_device(masks, device=device)
return mu
else:
raise ValueError("masks_to_flows only takes 2D or 3D arrays")
def steps2D_interp(p, dP, niter, use_gpu=False, device=None):
shape = dP.shape[1:]
if use_gpu:
if device is None:
device = torch_GPU
shape = (
np.array(shape)[[1, 0]].astype("float") - 1
) # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1
pt = (
torch.from_numpy(p[[1, 0]].T).float().to(device).unsqueeze(0).unsqueeze(0)
) # p is n_points by 2, so pt is [1 1 2 n_points]
im = (
torch.from_numpy(dP[[1, 0]]).float().to(device).unsqueeze(0)
) # covert flow numpy array to tensor on GPU, add dimension
# normalize pt between 0 and 1, normalize the flow
for k in range(2):
im[:, k, :, :] *= 2.0 / shape[k]
pt[:, :, :, k] /= shape[k]
# normalize to between -1 and 1
pt = pt * 2 - 1
# here is where the stepping happens
for t in range(niter):
# align_corners default is False, just added to suppress warning
dPt = grid_sample(im, pt, align_corners=False)
for k in range(2): # clamp the final pixel locations
pt[:, :, :, k] = torch.clamp(
pt[:, :, :, k] + dPt[:, k, :, :], -1.0, 1.0
)
# undo the normalization from before, reverse order of operations
pt = (pt + 1) * 0.5
for k in range(2):
pt[:, :, :, k] *= shape[k]
p = pt[:, :, :, [1, 0]].cpu().numpy().squeeze().T
return p
else:
assert print("ho")
def follow_flows(dP, mask=None, niter=200, interp=True, use_gpu=True, device=None):
shape = np.array(dP.shape[1:]).astype(np.int32)
niter = np.uint32(niter)
p = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing="ij")
p = np.array(p).astype(np.float32)
inds = np.array(np.nonzero(np.abs(dP[0]) > 1e-3)).astype(np.int32).T
if inds.ndim < 2 or inds.shape[0] < 5:
return p, None
if not interp:
assert print("woo")
else:
p_interp = steps2D_interp(
p[:, inds[:, 0], inds[:, 1]], dP, niter, use_gpu=use_gpu, device=device
)
p[:, inds[:, 0], inds[:, 1]] = p_interp
return p, inds
def flow_error(maski, dP_net, use_gpu=False, device=None):
if dP_net.shape[1:] != maski.shape:
print("ERROR: net flow is not same size as predicted masks")
return
# flows predicted from estimated masks
dP_masks = masks_to_flows(maski, use_gpu=use_gpu, device=device)
# difference between predicted flows vs mask flows
flow_errors = np.zeros(maski.max())
for i in range(dP_masks.shape[0]):
flow_errors += mean(
(dP_masks[i] - dP_net[i] / 5.0) ** 2,
maski,
index=np.arange(1, maski.max() + 1),
)
return flow_errors, dP_masks
def remove_bad_flow_masks(masks, flows, threshold=0.4, use_gpu=False, device=None):
merrors, _ = flow_error(masks, flows, use_gpu, device)
badi = 1 + (merrors > threshold).nonzero()[0]
masks[np.isin(masks, badi)] = 0
return masks
def get_masks(p, iscell=None, rpad=20):
pflows = []
edges = []
shape0 = p.shape[1:]
dims = len(p)
for i in range(dims):
pflows.append(p[i].flatten().astype("int32"))
edges.append(np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1))
h, _ = np.histogramdd(tuple(pflows), bins=edges)
hmax = h.copy()
for i in range(dims):
hmax = maximum_filter1d(hmax, 5, axis=i)
seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10))
Nmax = h[seeds]
isort = np.argsort(Nmax)[::-1]
for s in seeds:
s = s[isort]
pix = list(np.array(seeds).T)
shape = h.shape
if dims == 3:
expand = np.nonzero(np.ones((3, 3, 3)))
else:
expand = np.nonzero(np.ones((3, 3)))
for e in expand:
e = np.expand_dims(e, 1)
for iter in range(5):
for k in range(len(pix)):
if iter == 0:
pix[k] = list(pix[k])
newpix = []
iin = []
for i, e in enumerate(expand):
epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
epix = epix.flatten()
iin.append(np.logical_and(epix >= 0, epix < shape[i]))
newpix.append(epix)
iin = np.all(tuple(iin), axis=0)
for p in newpix:
p = p[iin]
newpix = tuple(newpix)
igood = h[newpix] > 2
for i in range(dims):
pix[k][i] = newpix[i][igood]
if iter == 4:
pix[k] = tuple(pix[k])
M = np.zeros(h.shape, np.uint32)
for k in range(len(pix)):
M[pix[k]] = 1 + k
for i in range(dims):
pflows[i] = pflows[i] + rpad
M0 = M[tuple(pflows)]
# remove big masks
uniq, counts = fastremap.unique(M0, return_counts=True)
big = np.prod(shape0) * 0.9
bigc = uniq[counts > big]
if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
M0 = fastremap.mask(M0, bigc)
fastremap.renumber(M0, in_place=True) # convenient to guarantee non-skipped labels
M0 = np.reshape(M0, shape0)
return M0
def fill_holes_and_remove_small_masks(masks, min_size=15):
""" fill holes in masks (2D/3D) and discard masks smaller than min_size (2D)
fill holes in each mask using scipy.ndimage.morphology.binary_fill_holes
(might have issues at borders between cells, todo: check and fix)
Parameters
----------------
masks: int, 2D or 3D array
labelled masks, 0=NO masks; 1,2,...=mask labels,
size [Ly x Lx] or [Lz x Ly x Lx]
min_size: int (optional, default 15)
minimum number of pixels per mask, can turn off with -1
Returns
---------------
masks: int, 2D or 3D array
masks with holes filled and masks smaller than min_size removed,
0=NO masks; 1,2,...=mask labels,
size [Ly x Lx] or [Lz x Ly x Lx]
"""
slices = find_objects(masks)
j = 0
for i,slc in enumerate(slices):
if slc is not None:
msk = masks[slc] == (i+1)
npix = msk.sum()
if min_size > 0 and npix < min_size:
masks[slc][msk] = 0
elif npix > 0:
if msk.ndim==3:
for k in range(msk.shape[0]):
msk[k] = binary_fill_holes(msk[k])
else:
msk = binary_fill_holes(msk)
masks[slc][msk] = (j+1)
j+=1
return masks
def compute_masks(
dP,
cellprob,
p=None,
niter=200,
cellprob_threshold=0.4,
flow_threshold=0.4,
interp=True,
resize=None,
use_gpu=False,
device=None,
):
"""compute masks using dynamics from dP, cellprob, and boundary"""
cp_mask = cellprob > cellprob_threshold
cp_mask = morphology.remove_small_holes(cp_mask, area_threshold=16)
cp_mask = morphology.remove_small_objects(cp_mask, min_size=16)
if np.any(cp_mask): # mask at this point is a cell cluster binary map, not labels
# follow flows
if p is None:
p, inds = follow_flows(
dP * cp_mask / 5.0,
niter=niter,
interp=interp,
use_gpu=use_gpu,
device=device,
)
if inds is None:
shape = resize if resize is not None else cellprob.shape
mask = np.zeros(shape, np.uint16)
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p
# calculate masks
mask = get_masks(p, iscell=cp_mask)
# flow thresholding factored out of get_masks
shape0 = p.shape[1:]
if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
# make sure labels are unique at output of get_masks
mask = remove_bad_flow_masks(
mask, dP, threshold=flow_threshold, use_gpu=use_gpu, device=device
)
mask = fill_holes_and_remove_small_masks(mask, min_size=15)
else: # nothing to compute, just make it compatible
shape = resize if resize is not None else cellprob.shape
mask = np.zeros(shape, np.uint16)
p = np.zeros((len(shape), *shape), np.uint16)
return mask, p
return mask, p
def main(args):
model = torch.load(args.model_path, map_location=args.device)
model.eval()
hflip_tta = HorizontalFlip()
vflip_tta = VerticalFlip()
img_names = sorted(os.listdir(args.input_path))
os.makedirs(args.output_path, exist_ok=True)
for img_name in img_names:
print(f"Segmenting {img_name}")
img_path = os.path.join(args.input_path, img_name)
img_data = pred_transforms(img_path)
img_data = img_data.to(args.device)
img_size = img_data.shape[-1] * img_data.shape[-2]
if img_size < 1150000 and 900000 < img_size:
overlap = 0.5
else:
overlap = 0.6
with torch.no_grad():
img0 = img_data
outputs0 = sliding_window_inference(
img0,
512,
4,
model,
padding_mode="reflect",
mode="gaussian",
overlap=overlap,
device="cpu",
)
outputs0 = outputs0.cpu().squeeze()
if img_size < 2000 * 2000:
model.load_state_dict(torch.load(args.model_path2, map_location=args.device))
model.eval()
img2 = hflip_tta.apply_aug_image(img_data, apply=True)
outputs2 = sliding_window_inference(
img2,
512,
4,
model,
padding_mode="reflect",
mode="gauusian",
overlap=overlap,
device="cpu",
)
outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True)
outputs2 = outputs2.cpu().squeeze()
outputs = torch.zeros_like(outputs0)
outputs[0] = (outputs0[0] + outputs2[0]) / 2
outputs[1] = (outputs0[1] - outputs2[1]) / 2
outputs[2] = (outputs0[2] + outputs2[2]) / 2
elif img_size < 5000*5000:
# Hflip TTA
img2 = hflip_tta.apply_aug_image(img_data, apply=True)
outputs2 = sliding_window_inference(
img2,
512,
4,
model,
padding_mode="reflect",
mode="gaussian",
overlap=overlap,
device="cpu",
)
outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True)
outputs2 = outputs2.cpu().squeeze()
img2 = img2.cpu()
##################
# #
# ensemble #
# #
##################
model.load_state_dict(torch.load(args.model_path2, map_location=args.device))
model.eval()
img1 = img_data
outputs1 = sliding_window_inference(
img1,
512,
4,
model,
padding_mode="reflect",
mode="gaussian",
overlap=overlap,
device="cpu",
)
outputs1 = outputs1.cpu().squeeze()
# Vflip TTA
img3 = vflip_tta.apply_aug_image(img_data, apply=True)
outputs3 = sliding_window_inference(
img3,
512,
4,
model,
padding_mode="reflect",
mode="gaussian",
overlap=overlap,
device="cpu",
)
outputs3 = vflip_tta.apply_deaug_mask(outputs3, apply=True)
outputs3 = outputs3.cpu().squeeze()
img3 = img3.cpu()
# Merge Results
outputs = torch.zeros_like(outputs0)
outputs[0] = (outputs0[0] + outputs1[0] + outputs2[0] - outputs3[0]) / 4
outputs[1] = (outputs0[1] + outputs1[1] - outputs2[1] + outputs3[1]) / 4
outputs[2] = (outputs0[2] + outputs1[2] + outputs2[2] + outputs3[2]) / 4
else:
outputs = outputs0
pred_mask = post_process(outputs.squeeze(0).cpu().numpy(), args.device)
file_path = os.path.join(
args.output_path, img_name.split(".")[0] + "_label.tiff"
)
tif.imwrite(file_path, pred_mask, compression="zlib")
parser = argparse.ArgumentParser("Submission for Challenge", add_help=False)
parser.add_argument("--model_path", default="./model.pt", type=str)
parser.add_argument("--model_path2", default="./model_sec.pth", type=str)
# Dataset parameters
parser.add_argument(
"-i",
"--input_path",
default="/workspace/inputs/",
type=str,
help="training data path; subfolders: images, labels",
)
parser.add_argument(
"-o", "--output_path", default="/workspace/outputs/", type=str, help="output path",
)
parser.add_argument("--device", default="cuda:0", type=str)
args = parser.parse_args()
if __name__ == "__main__":
print("Starting")
main(args)