MRSegmentator / utils.py
DiGuaQiu's picture
Update utils.py
58ca1a1 verified
raw
history blame
6.29 kB
# Code copied and modified from https://huggingface.co/spaces/BAAI/SegVol/blob/main/utils.py
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
import torch
from mrsegmentator import inference
from mrsegmentator.utils import add_postfix
from PIL import Image
from scipy import ndimage
import streamlit as st
initial_rectangle = {
"version": "4.4.0",
"objects": [
{
"type": "rect",
"version": "4.4.0",
"originX": "left",
"originY": "top",
"left": 50,
"top": 50,
"width": 100,
"height": 100,
"fill": "rgba(255, 165, 0, 0.3)",
"stroke": "#2909F1",
"strokeWidth": 3,
"strokeDashArray": None,
"strokeLineCap": "butt",
"strokeDashOffset": 0,
"strokeLineJoin": "miter",
"strokeUniform": True,
"strokeMiterLimit": 4,
"scaleX": 1,
"scaleY": 1,
"angle": 0,
"flipX": False,
"flipY": False,
"opacity": 1,
"shadow": None,
"visible": True,
"backgroundColor": "",
"fillRule": "nonzero",
"paintFirst": "fill",
"globalCompositeOperation": "source-over",
"skewX": 0,
"skewY": 0,
"rx": 0,
"ry": 0,
}
],
}
def run(tmpdirname):
if st.session_state.option is not None:
image = Path(__file__).parent / str(st.session_state.option)
inference.infer([image], tmpdirname, st.session_state.folds, split_level=1)
seg_name = add_postfix(image.name, "seg")
preds_path = tmpdirname + "/" + seg_name
st.session_state.preds_3D = read_image(preds_path)
st.session_state.preds_3D_ori = sitk.ReadImage(preds_path)
def reflect_box_into_model(box_3d):
z1, y1, x1, z2, y2, x2 = box_3d
x1_prompt = int(x1 * 256.0 / 325.0)
y1_prompt = int(y1 * 256.0 / 325.0)
z1_prompt = int(z1 * 32.0 / 325.0)
x2_prompt = int(x2 * 256.0 / 325.0)
y2_prompt = int(y2 * 256.0 / 325.0)
z2_prompt = int(z2 * 32.0 / 325.0)
return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]))
def reflect_json_data_to_3D_box(json_data, view):
if view == "xy":
st.session_state.rectangle_3Dbox[1] = json_data["objects"][0]["top"]
st.session_state.rectangle_3Dbox[2] = json_data["objects"][0]["left"]
st.session_state.rectangle_3Dbox[4] = (
json_data["objects"][0]["top"] + json_data["objects"][0]["height"] * json_data["objects"][0]["scaleY"]
)
st.session_state.rectangle_3Dbox[5] = (
json_data["objects"][0]["left"] + json_data["objects"][0]["width"] * json_data["objects"][0]["scaleX"]
)
print(st.session_state.rectangle_3Dbox)
def make_fig(image, preds, px_range=(10, 400), transparency=0.5):
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
image_slice = image.clip(*px_range)
ax.imshow(
image_slice,
cmap="Greys_r",
vmin=px_range[0],
vmax=px_range[1],
)
if preds is not None:
image_slice = np.array(preds)
alpha = np.zeros(image_slice.shape)
alpha[image_slice > 0.1] = transparency
ax.imshow(
image_slice,
cmap="jet",
alpha=alpha,
vmin=0,
vmax=40,
)
# plot edges
edge_slice = np.zeros(image_slice.shape, dtype=int)
for i in np.unique(image_slice):
_slice = image_slice.copy()
_slice[_slice != i] = 0
edges = ndimage.laplace(_slice)
edge_slice[edges != 0] = i
cmap = mpl.cm.jet(np.linspace(0, 1, int(preds.max())))
cmap -= 0.4
cmap = cmap.clip(0, 1)
cmap = mpl.colors.ListedColormap(cmap)
alpha = np.zeros(edge_slice.shape)
alpha[edge_slice > 0.01] = 0.9
ax.imshow(
edge_slice,
alpha=alpha,
cmap=cmap,
vmin=0,
vmax=40,
)
plt.axis("off")
ax.set_xticks([])
ax.set_yticks([])
fig.canvas.draw()
# transform to image
return Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
#######################################
def make_isotropic(image, interpolator=sitk.sitkLinear, spacing=None):
"""
Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same
spacing for all axes. Saving non-isotropic data in these formats will result in
distorted images. This function makes an image isotropic via resampling, if needed.
Args:
image (SimpleITK.Image): Input image.
interpolator: By default the function uses a linear interpolator. For
label images one should use the sitkNearestNeighbor interpolator
so as not to introduce non-existant labels.
spacing (float): Desired spacing. If none given then use the smallest spacing from
the original image.
Returns:
SimpleITK.Image with isotropic spacing which occupies the same region in space as
the input image.
"""
original_spacing = image.GetSpacing()
# Image is already isotropic, just return a copy.
if all(spc == original_spacing[0] for spc in original_spacing):
return sitk.Image(image)
# Make image isotropic via resampling.
original_size = image.GetSize()
if spacing is None:
spacing = min(original_spacing)
new_spacing = [spacing] * image.GetDimension()
new_size = [int(round(osz * ospc / spacing)) for osz, ospc in zip(original_size, original_spacing)]
return sitk.Resample(
image,
new_size,
sitk.Transform(),
interpolator,
image.GetOrigin(),
new_spacing,
image.GetDirection(),
0, # default pixel value
image.GetPixelID(),
)
def read_image(path):
img = sitk.ReadImage(path)
img = sitk.DICOMOrient(img, "LPS")
img = make_isotropic(img)
img = sitk.GetArrayFromImage(img)
return img