# Code copied and modified from https://huggingface.co/spaces/BAAI/SegVol/blob/main/utils.py from pathlib import Path import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import SimpleITK as sitk import torch from mrsegmentator import inference from mrsegmentator.utils import add_postfix from PIL import Image from scipy import ndimage import streamlit as st initial_rectangle = { "version": "4.4.0", "objects": [ { "type": "rect", "version": "4.4.0", "originX": "left", "originY": "top", "left": 50, "top": 50, "width": 100, "height": 100, "fill": "rgba(255, 165, 0, 0.3)", "stroke": "#2909F1", "strokeWidth": 3, "strokeDashArray": None, "strokeLineCap": "butt", "strokeDashOffset": 0, "strokeLineJoin": "miter", "strokeUniform": True, "strokeMiterLimit": 4, "scaleX": 1, "scaleY": 1, "angle": 0, "flipX": False, "flipY": False, "opacity": 1, "shadow": None, "visible": True, "backgroundColor": "", "fillRule": "nonzero", "paintFirst": "fill", "globalCompositeOperation": "source-over", "skewX": 0, "skewY": 0, "rx": 0, "ry": 0, } ], } def run(tmpdirname): if st.session_state.option is not None: image = Path(__file__).parent / str(st.session_state.option) inference.infer([image], tmpdirname, st.session_state.folds, split_level=1) seg_name = add_postfix(image.name, "seg") preds_path = tmpdirname + "/" + seg_name st.session_state.preds_3D = read_image(preds_path) st.session_state.preds_3D_ori = sitk.ReadImage(preds_path) def reflect_box_into_model(box_3d): z1, y1, x1, z2, y2, x2 = box_3d x1_prompt = int(x1 * 256.0 / 325.0) y1_prompt = int(y1 * 256.0 / 325.0) z1_prompt = int(z1 * 32.0 / 325.0) x2_prompt = int(x2 * 256.0 / 325.0) y2_prompt = int(y2 * 256.0 / 325.0) z2_prompt = int(z2 * 32.0 / 325.0) return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])) def reflect_json_data_to_3D_box(json_data, view): if view == "xy": st.session_state.rectangle_3Dbox[1] = json_data["objects"][0]["top"] st.session_state.rectangle_3Dbox[2] = json_data["objects"][0]["left"] st.session_state.rectangle_3Dbox[4] = ( json_data["objects"][0]["top"] + json_data["objects"][0]["height"] * json_data["objects"][0]["scaleY"] ) st.session_state.rectangle_3Dbox[5] = ( json_data["objects"][0]["left"] + json_data["objects"][0]["width"] * json_data["objects"][0]["scaleX"] ) print(st.session_state.rectangle_3Dbox) def make_fig(image, preds, px_range=(10, 400), transparency=0.5): fig, ax = plt.subplots(1, 1, figsize=(4, 4)) image_slice = image.clip(*px_range) ax.imshow( image_slice, cmap="Greys_r", vmin=px_range[0], vmax=px_range[1], ) if preds is not None: image_slice = np.array(preds) alpha = np.zeros(image_slice.shape) alpha[image_slice > 0.1] = transparency ax.imshow( image_slice, cmap="jet", alpha=alpha, vmin=0, vmax=40, ) # plot edges edge_slice = np.zeros(image_slice.shape, dtype=int) for i in np.unique(image_slice): _slice = image_slice.copy() _slice[_slice != i] = 0 edges = ndimage.laplace(_slice) edge_slice[edges != 0] = i cmap = mpl.cm.jet(np.linspace(0, 1, int(preds.max()))) cmap -= 0.4 cmap = cmap.clip(0, 1) cmap = mpl.colors.ListedColormap(cmap) alpha = np.zeros(edge_slice.shape) alpha[edge_slice > 0.01] = 0.9 ax.imshow( edge_slice, alpha=alpha, cmap=cmap, vmin=0, vmax=40, ) plt.axis("off") ax.set_xticks([]) ax.set_yticks([]) fig.canvas.draw() # transform to image return Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) ####################################### def make_isotropic(image, interpolator=sitk.sitkLinear, spacing=None): """ Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same spacing for all axes. Saving non-isotropic data in these formats will result in distorted images. This function makes an image isotropic via resampling, if needed. Args: image (SimpleITK.Image): Input image. interpolator: By default the function uses a linear interpolator. For label images one should use the sitkNearestNeighbor interpolator so as not to introduce non-existant labels. spacing (float): Desired spacing. If none given then use the smallest spacing from the original image. Returns: SimpleITK.Image with isotropic spacing which occupies the same region in space as the input image. """ original_spacing = image.GetSpacing() # Image is already isotropic, just return a copy. if all(spc == original_spacing[0] for spc in original_spacing): return sitk.Image(image) # Make image isotropic via resampling. original_size = image.GetSize() if spacing is None: spacing = min(original_spacing) new_spacing = [spacing] * image.GetDimension() new_size = [int(round(osz * ospc / spacing)) for osz, ospc in zip(original_size, original_spacing)] return sitk.Resample( image, new_size, sitk.Transform(), interpolator, image.GetOrigin(), new_spacing, image.GetDirection(), 0, # default pixel value image.GetPixelID(), ) def read_image(path): img = sitk.ReadImage(path) img = sitk.DICOMOrient(img, "LPS") img = make_isotropic(img) img = sitk.GetArrayFromImage(img) return img