import numpy as np | |
from monai.transforms import MapTransform | |
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform): | |
""" | |
Convert labels to multi channels based on brats classes: | |
label 1 is the necrotic and non-enhancing tumor core | |
label 2 is the peritumoral edema | |
label 4 is the GD-enhancing tumor | |
The possible classes are TC (Tumor core), WT (Whole tumor) | |
and ET (Enhancing tumor). | |
""" | |
def __call__(self, data): | |
d = dict(data) | |
for key in self.keys: | |
result = [] | |
# merge label 1 and label 4 to construct TC | |
result.append(np.logical_or(d[key] == 1, d[key] == 4)) | |
# merge labels 1, 2 and 4 to construct WT | |
result.append( | |
np.logical_or( | |
np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2 | |
) | |
) | |
# label 4 is ET | |
result.append(d[key] == 4) | |
d[key] = np.stack(result, axis=0).astype(np.float32) | |
return d |