|
from pathlib import Path |
|
from typing import Callable, List, Optional, Tuple |
|
|
|
from monai.transforms import Compose |
|
|
|
from transforms.base import get_image_loading_transform, get_apply_crop_transform, get_stacking_transform |
|
from transforms.mask import get_mask_transform |
|
from transforms.coordinates import get_normalized_coordinates_transform |
|
from transforms.augmentation import * |
|
from transforms.backbone import * |
|
|
|
|
|
def _build_transforms_composition(hparams, transform_getters: List[Callable], *initial_args) -> Tuple[Compose, List[str]]: |
|
""" |
|
Builds a transforms composition from the given functions, which take the hparams and loaded keys as arguments, and |
|
produce a Compose containing the desired transforms. The initialization function receives the provided initial arguments. |
|
""" |
|
transforms = [] |
|
keys = [] |
|
|
|
for i in range(0, len(transform_getters)): |
|
if len(keys) == 0: |
|
assert i == 0, f"Function {transform_getters[i]} did not yield any loaded keys." |
|
|
|
transform, keys = transform_getters[0](hparams, *initial_args) |
|
else: |
|
transform, keys = transform_getters[i](hparams, keys) |
|
transforms.append(transform) |
|
|
|
return Compose(transforms), keys |
|
|
|
def _get_config_transform_by_name(transform_name: str) -> Callable: |
|
if transform_name == "intensity": |
|
return intensity_transform |
|
elif transform_name.startswith("spatial3d"): |
|
if "simple" in transform_name: |
|
return lambda hparams, loaded_keys: spatial_transform(hparams, loaded_keys, mode='simple') |
|
else: |
|
return lambda hparams, loaded_keys: spatial_transform(hparams, loaded_keys, mode='default') |
|
elif transform_name == "modelsgenesis": |
|
return models_genesis_transform |
|
elif transform_name == "pretrained_resnet": |
|
return pretrained_resnet_transform |
|
elif transform_name == "robustness": |
|
return robustness_transform |
|
else: |
|
raise ValueError(f"Unknown transform: {transform_name}") |
|
|
|
def get_training_transforms(hparams, image_dir: Path, mask_dir: Optional[Path] = None) -> Compose: |
|
transforms_base = [get_image_loading_transform, get_mask_transform] |
|
|
|
|
|
|
|
if "robustness" in hparams.transforms: transforms_base.append(_get_config_transform_by_name("robustness")) |
|
|
|
transforms_base.extend([get_apply_crop_transform, get_normalized_coordinates_transform]) |
|
|
|
|
|
preprocessing_transforms = ["modelsgenesis", "pretrained_resnet"] |
|
config_transforms = [_get_config_transform_by_name(transform_name) for transform_name in hparams.transforms if transform_name in preprocessing_transforms] |
|
|
|
|
|
exclusion_criterion = lambda transform_name: transform_name in preprocessing_transforms or transform_name == "robustness" |
|
config_transforms.extend([_get_config_transform_by_name(transform_name) for transform_name in hparams.transforms if not exclusion_criterion]) |
|
|
|
|
|
return _build_transforms_composition(hparams, transforms_base + config_transforms + [get_stacking_transform], image_dir, mask_dir)[0] |
|
|
|
def get_base_transforms(hparams, image_dir: Path, mask_dir: Optional[Path] = None) -> Compose: |
|
transforms_base = [get_image_loading_transform, get_mask_transform, get_apply_crop_transform, get_normalized_coordinates_transform] |
|
|
|
|
|
preprocessing_transforms = ["modelsgenesis", "pretrained_resnet"] |
|
config_transforms = [_get_config_transform_by_name(transform_name) for transform_name in hparams.transforms if transform_name in preprocessing_transforms] |
|
|
|
return _build_transforms_composition(hparams, transforms_base + config_transforms + [get_stacking_transform], image_dir, mask_dir)[0] |