#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import matplotlib.pyplot as plt import numpy as np import scipy from matplotlib.colors import ListedColormap from PIL import Image def extract_axial_mid_slice(ct, mask, crop=True): slice_idx = np.argmax(mask.sum(axis=(0, 1))) ct_slice_z = np.transpose(ct[:, :, slice_idx], axes=(1, 0)) mask_slice_z = np.transpose(mask[:, :, slice_idx], axes=(1, 0)) ct_slice_z = np.flip(ct_slice_z, axis=(0, 1)) mask_slice_z = np.flip(mask_slice_z, axis=(0, 1)) if crop: ct_range_x = np.where(ct_slice_z.max(axis=0) > -200)[0][[0, -1]] ct_slice_z = ct_slice_z[ ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1] ] mask_slice_z = mask_slice_z[ ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1] ] return ct_slice_z, mask_slice_z def extract_coronal_mid_slice(ct, mask, crop=True): # find the slice with max coherent extent of the organ coronary_extent = np.where(mask.sum(axis=(0, 2)))[0] max_extent = 0 max_extent_idx = 0 for idx in coronary_extent: label, num_features = scipy.ndimage.label(mask[:, idx, :]) if num_features > 1: continue else: extent = len(np.where(label.sum(axis=1))[0]) if extent > max_extent: max_extent = extent max_extent_idx = idx ct_slice_y = np.transpose(ct[:, max_extent_idx, :], axes=(1, 0)) mask_slice_y = np.transpose(mask[:, max_extent_idx, :], axes=(1, 0)) ct_slice_y = np.flip(ct_slice_y, axis=1) mask_slice_y = np.flip(mask_slice_y, axis=1) return ct_slice_y, mask_slice_y def save_slice( ct_slice, mask_slice, path, figsize=(12, 12), corner_text=None, unit_dict=None, aspect=1, show=False, xy_placement=None, class_color=1, fontsize=14, ): # colormap for shown segmentations color_array = plt.get_cmap("tab10")(range(10)) color_array = np.concatenate((np.array([[0, 0, 0, 0]]), color_array[:, :]), axis=0) map_object_seg = ListedColormap(name="segmentation_cmap", colors=color_array) fig, axx = plt.subplots(1, figsize=figsize, frameon=False) axx.imshow( ct_slice, cmap="gray", vmin=-400, vmax=400, interpolation="spline36", aspect=aspect, origin="lower", ) axx.imshow( mask_slice * class_color, cmap=map_object_seg, vmin=0, vmax=9, alpha=0.2, interpolation="nearest", aspect=aspect, origin="lower", ) plt.axis("off") axx.axes.get_xaxis().set_visible(False) axx.axes.get_yaxis().set_visible(False) y_size, x_size = ct_slice.shape if corner_text is not None: bbox_props = dict(boxstyle="round", facecolor="gray", alpha=0.5) texts = [] for k, v in corner_text.items(): if isinstance(v, str): texts.append("{:<9}{}".format(k + ":", v)) else: unit = unit_dict[k] if k in unit_dict else "" texts.append("{:<9}{:.0f} {}".format(k + ":", v, unit)) if xy_placement is None: # get the extent of textbox, remove, and the plot again with correct position t = axx.text( 0.5, 0.5, "\n".join(texts), color="white", transform=axx.transAxes, fontsize=fontsize, family="monospace", bbox=bbox_props, va="top", ha="left", ) xmin, xmax = t.get_window_extent().xmin, t.get_window_extent().xmax xmin, xmax = axx.transAxes.inverted().transform((xmin, xmax)) xy_placement = [1 - (xmax - xmin) - (xmax - xmin) * 0.09, 0.975] t.remove() axx.text( xy_placement[0], xy_placement[1], "\n".join(texts), color="white", transform=axx.transAxes, fontsize=fontsize, family="monospace", bbox=bbox_props, va="top", ha="left", ) if show: plt.show() else: fig.savefig(path, bbox_inches="tight", pad_inches=0) plt.close(fig) def slicedDilationOrErosion(input_mask, num_iteration, operation): """ Perform the dilation on the smallest slice that will fit the segmentation """ margin = 2 if num_iteration is None else num_iteration + 1 # find the minimum volume enclosing the organ x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] x_start, x_end = x_idx[0] - margin, x_idx[-1] + margin y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] y_start, y_end = y_idx[0] - margin, y_idx[-1] + margin z_idx = np.where(input_mask.sum(axis=(0, 1)))[0] z_start, z_end = z_idx[0] - margin, z_idx[-1] + margin struct = scipy.ndimage.generate_binary_structure(3, 1) struct = scipy.ndimage.iterate_structure(struct, num_iteration) if operation == "dilate": mask_slice = scipy.ndimage.binary_dilation( input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct ).astype(np.int8) elif operation == "erode": mask_slice = scipy.ndimage.binary_erosion( input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct ).astype(np.int8) output_mask = input_mask.copy() output_mask[x_start:x_end, y_start:y_end, z_start:z_end] = mask_slice return output_mask def extract_organ_metrics( ct, all_masks, class_num=None, vol_per_pixel=None, erode_mask=True ): if erode_mask: eroded_mask = slicedDilationOrErosion( input_mask=(all_masks == class_num), num_iteration=3, operation="erode" ) ct_organ_vals = ct[eroded_mask == 1] else: ct_organ_vals = ct[all_masks == class_num] results = {} # in ml organ_vol = (all_masks == class_num).sum() * vol_per_pixel organ_mean = ct_organ_vals.mean() organ_median = np.median(ct_organ_vals) results = { "Organ": class_map_part_organs[class_num], "Volume": organ_vol, "Mean": organ_mean, "Median": organ_median, } return results def generate_slice_images( ct, all_masks, class_nums, unit_dict, vol_per_pixel, pix_dims, root, fontsize=20, show=False, ): all_results = {} colors = [1, 3, 4] for i, c_num in enumerate(class_nums): organ_name = class_map_part_organs[c_num] axial_path = os.path.join(root, organ_name.lower() + "_axial.png") coronal_path = os.path.join(root, organ_name.lower() + "_coronal.png") ct_slice_z, liver_slice_z = extract_axial_mid_slice(ct, all_masks == c_num) results = extract_organ_metrics( ct, all_masks, class_num=c_num, vol_per_pixel=vol_per_pixel ) save_slice( ct_slice_z, liver_slice_z, axial_path, figsize=(12, 12), corner_text=results, unit_dict=unit_dict, class_color=colors[i], fontsize=fontsize, show=show, ) ct_slice_y, liver_slice_y = extract_coronal_mid_slice(ct, all_masks == c_num) save_slice( ct_slice_y, liver_slice_y, coronal_path, figsize=(12, 12), aspect=pix_dims[2] / pix_dims[1], show=show, class_color=colors[i], ) all_results[results["Organ"]] = results if show: return return all_results def generate_liver_spleen_pancreas_report(root, organ_names): axial_imgs = [ Image.open(os.path.join(root, organ + "_axial.png")) for organ in organ_names ] coronal_imgs = [ Image.open(os.path.join(root, organ + "_coronal.png")) for organ in organ_names ] result_width = max( sum([img.size[0] for img in axial_imgs]), sum([img.size[0] for img in coronal_imgs]), ) result_height = max( [a.size[1] + c.size[1] for a, c in zip(axial_imgs, coronal_imgs)] ) result = Image.new("RGB", (result_width, result_height)) total_width = 0 for a_img, c_img in zip(axial_imgs, coronal_imgs): a_width, a_height = a_img.size c_width, c_height = c_img.size translate = (a_width - c_width) // 2 if a_width > c_width else 0 result.paste(im=a_img, box=(total_width, 0)) result.paste(im=c_img, box=(translate + total_width, a_height)) total_width += a_width result.save(os.path.join(root, "liver_spleen_pancreas_report.png")) # from https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/map_to_binary.py class_map_part_organs = { 1: "Spleen", 2: "Right Kidney", 3: "Left Kidney", 4: "Gallbladder", 5: "Liver", 6: "Stomach", 7: "Aorta", 8: "Inferior vena cava", 9: "portal Vein and Splenic Vein", 10: "Pancreas", 11: "Right Adrenal Gland", 12: "Left Adrenal Gland Left", 13: "lung_upper_lobe_left", 14: "lung_lower_lobe_left", 15: "lung_upper_lobe_right", 16: "lung_middle_lobe_right", 17: "lung_lower_lobe_right", }