Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import celldetection as cd | |
import cv2 | |
import numpy as np | |
__all__ = ['contours2labels', 'CpnInterface'] | |
def contours2labels(contours, size, overlap=False, max_iter=999): | |
labels = cd.data.contours2labels(cd.asnumpy(contours), size, initial_depth=3) | |
if not overlap: | |
kernel = cv2.getStructuringElement(1, (3, 3)) | |
mask_sm = np.sum(labels > 0, axis=-1) | |
mask = mask_sm > 1 # all overlaps | |
if mask.any(): | |
mask_ = mask_sm == 1 # all cores | |
lbl = np.zeros(labels.shape[:2], dtype='float64') | |
lbl[mask_] = labels.max(-1)[mask_] | |
for _ in range(max_iter): | |
lbl_ = np.copy(lbl) | |
m = mask & (lbl <= 0) | |
if not np.any(m): | |
break | |
lbl[m] = cv2.dilate(lbl, kernel=kernel)[m] | |
if np.allclose(lbl_, lbl): | |
break | |
else: | |
lbl = labels.max(-1) | |
labels = lbl.astype('int') | |
return labels | |
class CpnInterface: | |
def __init__(self, model, device=None, **kwargs): | |
self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device | |
model = cd.resolve_model(model, **kwargs) | |
if not isinstance(model, cd.models.LitCpn): | |
model = cd.models.LitCpn(model) | |
self.model = model.to(device) | |
self.model.eval() | |
self.model.requires_grad_(False) | |
self.tile_size = 1664 | |
self.overlap = 384 | |
def __call__( | |
self, | |
img, | |
div=255, | |
reduce_labels=True, | |
return_labels=True, | |
return_viewable_contours=True, | |
): | |
if img.ndim == 2: | |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
img = img / div | |
x = cd.data.to_tensor(img, transpose=True, dtype=torch.float32)[None] | |
with torch.no_grad(): | |
out = cd.asnumpy(self.model(x, crop_size=self.tile_size, | |
stride=max(64, self.tile_size - self.overlap))) | |
# if torch.cuda.device_count(): | |
# print(cd.GpuStats()) | |
contours, = out['contours'] | |
boxes, = out['boxes'] | |
scores, = out['scores'] | |
labels = None | |
if return_labels or return_viewable_contours: | |
labels = contours2labels(contours, img.shape[:2], overlap=not reduce_labels) | |
return dict( | |
contours=contours, | |
labels=labels, | |
boxes=boxes, | |
scores=scores | |
) | |