Spaces:
Sleeping
Sleeping
# Code copied and modified from https://huggingface.co/spaces/BAAI/SegVol/blob/main/utils.py | |
from pathlib import Path | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import nibabel as nib | |
import numpy as np | |
import torch | |
from monai.transforms import LoadImage | |
from mrsegmentator import inference | |
from mrsegmentator.utils import add_postfix | |
from PIL import Image, ImageColor, ImageDraw, ImageEnhance | |
from scipy import ndimage | |
from monai.transforms import LoadImage, Orientation, Spacing | |
import SimpleITK as sitk | |
import streamlit as st | |
initial_rectangle = { | |
"version": "4.4.0", | |
"objects": [ | |
{ | |
"type": "rect", | |
"version": "4.4.0", | |
"originX": "left", | |
"originY": "top", | |
"left": 50, | |
"top": 50, | |
"width": 100, | |
"height": 100, | |
"fill": "rgba(255, 165, 0, 0.3)", | |
"stroke": "#2909F1", | |
"strokeWidth": 3, | |
"strokeDashArray": None, | |
"strokeLineCap": "butt", | |
"strokeDashOffset": 0, | |
"strokeLineJoin": "miter", | |
"strokeUniform": True, | |
"strokeMiterLimit": 4, | |
"scaleX": 1, | |
"scaleY": 1, | |
"angle": 0, | |
"flipX": False, | |
"flipY": False, | |
"opacity": 1, | |
"shadow": None, | |
"visible": True, | |
"backgroundColor": "", | |
"fillRule": "nonzero", | |
"paintFirst": "fill", | |
"globalCompositeOperation": "source-over", | |
"skewX": 0, | |
"skewY": 0, | |
"rx": 0, | |
"ry": 0, | |
} | |
], | |
} | |
def run(tmpdirname): | |
if st.session_state.option is not None: | |
image = Path(__file__).parent / str(st.session_state.option) | |
inference.infer([image], tmpdirname, [0], split_level=1) | |
seg_name = add_postfix(image.name, "seg") | |
preds_path = tmpdirname + "/" + seg_name | |
st.session_state.preds_3D = read_image(preds_path) | |
st.session_state.preds_3D_ori = sitk.ReadImage(preds_path) | |
def reflect_box_into_model(box_3d): | |
z1, y1, x1, z2, y2, x2 = box_3d | |
x1_prompt = int(x1 * 256.0 / 325.0) | |
y1_prompt = int(y1 * 256.0 / 325.0) | |
z1_prompt = int(z1 * 32.0 / 325.0) | |
x2_prompt = int(x2 * 256.0 / 325.0) | |
y2_prompt = int(y2 * 256.0 / 325.0) | |
z2_prompt = int(z2 * 32.0 / 325.0) | |
return torch.tensor( | |
np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt]) | |
) | |
def reflect_json_data_to_3D_box(json_data, view): | |
if view == "xy": | |
st.session_state.rectangle_3Dbox[1] = json_data["objects"][0]["top"] | |
st.session_state.rectangle_3Dbox[2] = json_data["objects"][0]["left"] | |
st.session_state.rectangle_3Dbox[4] = ( | |
json_data["objects"][0]["top"] | |
+ json_data["objects"][0]["height"] * json_data["objects"][0]["scaleY"] | |
) | |
st.session_state.rectangle_3Dbox[5] = ( | |
json_data["objects"][0]["left"] | |
+ json_data["objects"][0]["width"] * json_data["objects"][0]["scaleX"] | |
) | |
print(st.session_state.rectangle_3Dbox) | |
def make_fig(image, preds, px_range = (10, 400), transparency=0.5): | |
fig, ax = plt.subplots(1, 1, figsize=(4,4)) | |
image_slice = image.clip(*px_range) | |
ax.imshow( | |
image_slice, | |
cmap="Greys_r", | |
vmin=px_range[0], | |
vmax=px_range[1], | |
) | |
if preds is not None: | |
image_slice = np.array(preds) | |
alpha = np.zeros(image_slice.shape) | |
alpha[image_slice > 0.1] = transparency | |
ax.imshow( | |
image_slice, | |
cmap="jet", | |
alpha=alpha, | |
vmin=0, | |
vmax=40, | |
) | |
# plot edges | |
edge_slice = np.zeros(image_slice.shape, dtype=int) | |
for i in np.unique(image_slice): | |
_slice = image_slice.copy() | |
_slice[_slice != i] = 0 | |
edges = ndimage.laplace(_slice) | |
edge_slice[edges != 0] = i | |
cmap = mpl.cm.jet(np.linspace(0, 1, int(preds.max()))) | |
cmap -= 0.4 | |
cmap = cmap.clip(0, 1) | |
cmap = mpl.colors.ListedColormap(cmap) | |
alpha = np.zeros(edge_slice.shape) | |
alpha[edge_slice > 0.01] = 0.9 | |
ax.imshow( | |
edge_slice, | |
alpha=alpha, | |
cmap=cmap, | |
vmin=0, | |
vmax=40, | |
) | |
plt.axis("off") | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
fig.canvas.draw() | |
# transform to image | |
return Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
####################################### | |
def make_isotropic(image, interpolator = sitk.sitkLinear, spacing = None): | |
''' | |
Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same | |
spacing for all axes. Saving non-isotropic data in these formats will result in | |
distorted images. This function makes an image isotropic via resampling, if needed. | |
Args: | |
image (SimpleITK.Image): Input image. | |
interpolator: By default the function uses a linear interpolator. For | |
label images one should use the sitkNearestNeighbor interpolator | |
so as not to introduce non-existant labels. | |
spacing (float): Desired spacing. If none given then use the smallest spacing from | |
the original image. | |
Returns: | |
SimpleITK.Image with isotropic spacing which occupies the same region in space as | |
the input image. | |
''' | |
original_spacing = image.GetSpacing() | |
# Image is already isotropic, just return a copy. | |
if all(spc == original_spacing[0] for spc in original_spacing): | |
return sitk.Image(image) | |
# Make image isotropic via resampling. | |
original_size = image.GetSize() | |
if spacing is None: | |
spacing = min(original_spacing) | |
new_spacing = [spacing]*image.GetDimension() | |
new_size = [int(round(osz*ospc/spacing)) for osz, ospc in zip(original_size, original_spacing)] | |
return sitk.Resample(image, new_size, sitk.Transform(), interpolator, | |
image.GetOrigin(), new_spacing, image.GetDirection(), 0, # default pixel value | |
image.GetPixelID()) | |
def read_image(path): | |
img = sitk.ReadImage(path) | |
img = sitk.DICOMOrient(img, "LPS") | |
img = make_isotropic(img) | |
img = sitk.GetArrayFromImage(img) | |
return img | |