MRSegmentator / app.py
DiGuaQiu's picture
Create app.py
c8ffaae verified
raw
history blame
7.23 kB
# Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol
import tempfile
from pathlib import Path
import nibabel as nib
import numpy as np
from PIL import ImageDraw
from streamlit_drawable_canvas import st_canvas
from streamlit_image_coordinates import streamlit_image_coordinates
import nibabel as nib
import SimpleITK as sitk
import streamlit as st
import utils
from utils import (
initial_rectangle,
make_fig,
reflect_box_into_model,
reflect_json_data_to_3D_box,
run,
)
# from viewer import BasicViewer
print("script run")
st.title("MRSegmentator")
#############################################
# init session_state
if "option" not in st.session_state:
st.session_state.option = None
if "reset_demo_case" not in st.session_state:
st.session_state.reset_demo_case = False
if "preds_3D" not in st.session_state:
st.session_state.preds_3D = None
st.session_state.preds_path = None
if "data_item" not in st.session_state:
st.session_state.data_item = None
if "rectangle_3Dbox" not in st.session_state:
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0]
if "running" not in st.session_state:
st.session_state.running = False
if "transparency" not in st.session_state:
st.session_state.transparency = 0.25
case_list = [
"images/amos_0541_MRI.nii.gz",
"images/amos_0571_MRI.nii.gz",
"images/amos_0001_CT.nii.gz",
]
#############################################
#############################################
# reset functions
def clear_prompts():
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0]
def reset_demo_case():
st.session_state.data_item = None
st.session_state.reset_demo_case = True
clear_prompts()
def clear_file():
st.session_state.option = None
reset_demo_case()
clear_prompts()
#############################################
github_col, arxive_col = st.columns(2)
with github_col:
st.write("Git: https://github.com/hhaentze/mrsegmentator")
with arxive_col:
st.write("Paper: https://arxiv.org/abs/2405.06463")
# modify demo case here
demo_type = st.radio("Demo case source", ["Select", "Upload"], on_change=clear_file)
with tempfile.TemporaryDirectory() as tmpdirname:
# modify demo case here
if demo_type == "Select":
uploaded_file = st.selectbox(
"Select a demo case",
case_list,
index=None,
placeholder="Select a demo case...",
on_change=reset_demo_case,
)
else:
uploaded_file = st.file_uploader(
"Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case
)
if( uploaded_file is not None ):
with open(tmpdirname + "/" + uploaded_file.name, 'wb') as f:
f.write(uploaded_file.getvalue())
uploaded_file = tmpdirname + "/" + uploaded_file.name
st.session_state.option = uploaded_file
if (
st.session_state.option is not None
and st.session_state.reset_demo_case
or (st.session_state.data_item is None and st.session_state.option is not None)
):
st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
st.session_state.data_item_ori = sitk.ReadImage(Path(__file__).parent / str(uploaded_file))
st.session_state.reset_demo_case = False
st.session_state.preds_3D = None
st.session_state.preds_path = None
if st.session_state.option is None:
st.write("please select demo case first")
else:
image_3D = st.session_state.data_item
px_range = st.slider( "Select intensity range",
int(image_3D.min()),
int(image_3D.max()),
(int(image_3D.min()), int(image_3D.max()))
)
col_control1, col_control2 = st.columns(2)
with col_control1:
selected_index_z = st.slider(
"Axial view", 0, image_3D.shape[0] - 1, image_3D.shape[0] // 2, key="xy", disabled=st.session_state.running
)
with col_control2:
selected_index_y = st.slider(
"Coronal view", 0, image_3D.shape[1] - 1, image_3D.shape[1] // 2, key="xz", disabled=st.session_state.running
)
col_image1, col_image2 = st.columns(2)
if st.session_state.preds_3D is not None:
st.session_state.transparency = st.slider(
"Mask opacity", 0.0, 1.0, 0.5, disabled=st.session_state.running
)
with col_image1:
image_z_array = image_3D[selected_index_z]
preds_z_array = None
if st.session_state.preds_3D is not None:
preds_z_array = st.session_state.preds_3D[selected_index_z]
image_z = make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
st.image(image_z, use_column_width=False)
with col_image2:
image_y_array = image_3D[:, selected_index_y, :]
preds_y_array = None
if st.session_state.preds_3D is not None:
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
image_y = make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
st.image(image_y, use_column_width=False)
######################################################
col1, col2, col3 = st.columns(3)
with col1:
if st.button(
"Clear",
use_container_width=True,
disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)),
):
clear_prompts()
st.session_state.preds_3D = None
st.session_state.preds_path = None
st.rerun()
with col2:
if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
preds = st.session_state.preds_3D_ori
#result_image.CopyInformation(inputImage)
sitk.WriteImage(preds, tmpfile.name)
#nib.save(st.session_state.preds_3D, tmpfile.name)
with open(tmpfile.name, "rb") as f:
bytes_data = f.read()
st.download_button(
label="Download result(.nii.gz)",
data=bytes_data,
file_name="segmentation.nii.gz",
mime="application/octet-stream",
disabled=False,
)
with col3:
run_button_name = "Run" if not st.session_state.running else "Running"
if st.button(
run_button_name,
type="primary",
use_container_width=True,
disabled=(st.session_state.data_item is None or st.session_state.running),
):
st.session_state.running = True
st.rerun()
if st.session_state.running:
st.session_state.running = False
with st.status("Running...", expanded=False) as status:
run(tmpdirname)
st.rerun()