import os |
import time |
from concurrent import futures |
import h5py |
import numpy as np |
import torch |
from skimage import measure |
from torch import nn |
from tqdm import tqdm |
from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset |
from pytorch3dunet.datasets.utils import SliceBuilder |
from pytorch3dunet.unet3d.model import UNet2D |
from pytorch3dunet.unet3d.utils import get_logger |
logger = get_logger('UNetPredictor') |
def _get_output_file(dataset, suffix='_predictions', output_dir=None): |
input_dir, file_name = os.path.split(dataset.file_path) |
if output_dir is None: |
output_dir = input_dir |
output_file = os.path.join(output_dir, os.path.splitext(file_name)[0] + suffix + '.h5') |
return output_file |
def _get_dataset_name(config, prefix='predictions'): |
return config.get('dest_dataset_name', 'predictions') |
def _is_2d_model(model): |
if isinstance(model, nn.DataParallel): |
model = model.module |
return isinstance(model, UNet2D) |
class _AbstractPredictor: |
def __init__(self, model, output_dir, config, **kwargs): |
self.model = model |
self.output_dir = output_dir |
self.config = config |
self.predictor_config = kwargs |
@staticmethod |
def volume_shape(dataset): |
raw = dataset.raw |
if raw.ndim == 3: |
return raw.shape |
else: |
return raw.shape[1:] |
def __call__(self, test_loader): |
raise NotImplementedError |
class StandardPredictor(_AbstractPredictor): |
""" |
Applies the model on the given dataset and saves the result as H5 file. |
Predictions from the network are kept in memory. If the results from the network don't fit in into RAM |
use `LazyPredictor` instead. |
The output dataset names inside the H5 is given by `dest_dataset_name` config argument. If the argument is |
not present in the config 'predictions' is used as a default dataset name. |
Args: |
model (Unet3D): trained 3D UNet model used for prediction |
output_dir (str): path to the output directory (optional) |
config (dict): global config dict |
""" |
def __init__(self, model, output_dir, config, **kwargs): |
super().__init__(model, output_dir, config, **kwargs) |
def __call__(self, test_loader): |
assert isinstance(test_loader.dataset, AbstractHDF5Dataset) |
logger.info(f"Processing '{test_loader.dataset.file_path}'...") |
start = time.time() |
prediction_channel = self.config.get('prediction_channel', None) |
if prediction_channel is not None: |
logger.info(f"Saving only channel '{prediction_channel}' from the network output") |
logger.info(f'Running inference on {len(test_loader)} batches') |
volume_shape = self.volume_shape(test_loader.dataset) |
out_channels = self.config['model'].get('out_channels') |
if prediction_channel is None: |
prediction_maps_shape = (out_channels,) + volume_shape |
else: |
prediction_maps_shape = (1,) + volume_shape |
logger.info(f'The shape of the output prediction maps (CDHW): {prediction_maps_shape}') |
patch_halo = self.predictor_config.get('patch_halo', (4, 4, 4)) |
if _is_2d_model(self.model): |
patch_halo = list(patch_halo) |
patch_halo[0] = 0 |
output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir) |
h5_output_file = h5py.File(output_file, 'w') |
logger.info('Allocating prediction and normalization arrays...') |
prediction_map, normalization_mask = self._allocate_prediction_maps(prediction_maps_shape, h5_output_file) |
self.model.eval() |
with torch.no_grad(): |
for input, indices in tqdm(test_loader): |
if torch.cuda.is_available(): |
input = input.cuda(non_blocking=True) |
input = _pad(input, patch_halo) |
if _is_2d_model(self.model): |
input = torch.squeeze(input, dim=-3) |
prediction = self.model(input) |
prediction = torch.unsqueeze(prediction, dim=-3) |
else: |
prediction = self.model(input) |
prediction = _unpad(prediction, patch_halo) |
prediction = prediction.cpu().numpy() |
for pred, index in zip(prediction, indices): |
if prediction_channel is None: |
channel_slice = slice(0, out_channels) |
else: |
channel_slice = slice(0, 1) |
pred = np.expand_dims(pred[prediction_channel], axis=0) |
index = (channel_slice,) + tuple(index) |
prediction_map[index] += pred |
normalization_mask[index] += 1 |
logger.info(f'Finished inference in {time.time() - start:.2f} seconds') |
logger.info(f'Saving predictions to: {output_file}') |
self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset) |
h5_output_file.close() |
def _allocate_prediction_maps(self, output_shape, output_file): |
prediction_map = np.zeros(output_shape, dtype='float32') |
normalization_mask = np.zeros(output_shape, dtype='uint8') |
return prediction_map, normalization_mask |
def _save_results(self, prediction_map, normalization_mask, output_file, dataset): |
dataset_name = _get_dataset_name(self.config) |
prediction_map = prediction_map / normalization_mask |
output_file.create_dataset(dataset_name, data=prediction_map, compression="gzip") |
def _pad(m, patch_halo): |
if patch_halo is not None: |
z, y, x = patch_halo |
return nn.functional.pad(m, (x, x, y, y, z, z), mode='reflect') |
return m |
def _unpad(m, patch_halo): |
if patch_halo is not None: |
z, y, x = patch_halo |
if z == 0: |
return m[..., y:-y, x:-x] |
else: |
return m[..., z:-z, y:-y, x:-x] |
return m |
class LazyPredictor(StandardPredictor): |
""" |
Applies the model on the given dataset and saves the result in the `output_file` in the H5 format. |
Predicted patches are directly saved into the H5 and they won't be stored in memory. Since this predictor |
is slower than the `StandardPredictor` it should only be used when the predicted volume does not fit into RAM. |
The output dataset names inside the H5 is given by `des_dataset_name` config argument. If the argument is |
not present in the config 'predictions{n}' is used as a default dataset name, where `n` denotes the number |
of the output head from the network. |
Args: |
model (Unet3D): trained 3D UNet model used for prediction |
output_dir (str): path to the output directory (optional) |
config (dict): global config dict |
""" |
def __init__(self, model, output_dir, config, **kwargs): |
super().__init__(model, output_dir, config, **kwargs) |
def _allocate_prediction_maps(self, output_shape, output_file): |
dataset_name = _get_dataset_name(self.config) |
prediction_map = output_file.create_dataset(dataset_name, shape=output_shape, dtype='float32', chunks=True, |
compression='gzip') |
normalization_mask = output_file.create_dataset('normalization', shape=output_shape, dtype='uint8', chunks=True, |
compression='gzip') |
return prediction_map, normalization_mask |
def _save_results(self, prediction_map, normalization_mask, output_file, dataset): |
z, y, x = prediction_map.shape[1:] |
patch_shape = (z // 3, y // 3, x // 3) |
for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape): |
logger.info(f'Normalizing slice: {index}') |
prediction_map[index] /= normalization_mask[index] |
normalization_mask[index] = 1 |
del output_file['normalization'] |
class DSB2018Predictor(_AbstractPredictor): |
def __init__(self, model, output_dir, config, save_segmentation=True, pmaps_thershold=0.5, **kwargs): |
super().__init__(model, output_dir, config, **kwargs) |
self.pmaps_thershold = pmaps_thershold |
self.save_segmentation = save_segmentation |
def _slice_from_pad(self, pad): |
if pad == 0: |
return slice(None, None) |
else: |
return slice(pad, -pad) |
def __call__(self, test_loader): |
self.model.eval() |
executor = futures.ProcessPoolExecutor(max_workers=32) |
with torch.no_grad(): |
for img, path in test_loader: |
if torch.cuda.is_available(): |
img = img.cuda(non_blocking=True) |
pred = self.model(img) |
executor.submit( |
dsb_save_batch, |
self.output_dir, |
path |
) |
print('Waiting for all predictions to be saved to disk...') |
executor.shutdown(wait=True) |
def dsb_save_batch(output_dir, path, pred, save_segmentation=True, pmaps_thershold=0.5): |
def _pmaps_to_seg(pred): |
mask = (pred > pmaps_thershold) |
return measure.label(mask).astype('uint16') |
for single_pred, single_path in zip(pred, path): |
logger.info(f'Processing {single_path}') |
single_pred = single_pred.squeeze() |
out_file = os.path.splitext(single_path)[0] + '_predictions.h5' |
if output_dir is not None: |
out_file = os.path.join(output_dir, os.path.split(out_file)[1]) |
with h5py.File(out_file, 'w') as f: |
f.create_dataset('predictions', data=single_pred, compression='gzip') |
if save_segmentation: |
f.create_dataset('segmentation', data=_pmaps_to_seg(single_pred), compression='gzip') |