MRSegmentator / app.py
DiGuaQiu's picture
Update app.py
52811bb verified
# Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol
import tempfile
from pathlib import Path
import SimpleITK as sitk
from mrsegmentator.utils import add_postfix
import streamlit as st
import utils
print("script run")
st.title("MRSegmentator")
st.write("(On-site segmentation is currently disabled, because we lack access to GPUs)")
#############################################
# init session_state
if "option" not in st.session_state:
st.session_state.option = None
if "reset_demo_case" not in st.session_state:
st.session_state.reset_demo_case = False
if "preds_3D" not in st.session_state:
st.session_state.preds_3D = None
st.session_state.preds_path = None
if "data_item" not in st.session_state:
st.session_state.data_item = None
if "rectangle_3Dbox" not in st.session_state:
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0]
if "running" not in st.session_state:
st.session_state.running = False
if "transparency" not in st.session_state:
st.session_state.transparency = 0.25
case_list = [
"amos_0517_MRI.nii.gz",
"amos_0541_MRI.nii.gz",
"amos_0571_MRI.nii.gz",
]
#############################################
#############################################
# reset functions
def clear_prompts():
st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0]
def reset_demo_case():
st.session_state.data_item = None
st.session_state.reset_demo_case = True
st.session_state.preds_3D = None
st.session_state.preds_3D_ori = None
clear_prompts()
def clear_file():
st.session_state.option = None
reset_demo_case()
clear_prompts()
#############################################
github_col, arxive_col = st.columns(2)
with github_col:
st.write("Git: https://github.com/hhaentze/mrsegmentator")
with arxive_col:
st.write("Paper: https://arxiv.org/abs/2405.06463")
# modify demo case here
demo_type = st.radio("Demo case source", ["Select (presegmented)", "Upload"], on_change=clear_file)
with tempfile.TemporaryDirectory() as tmpdirname:
# modify demo case here
if demo_type == "Select (presegmented)":
selection = st.selectbox(
"Select a demo case",
case_list,
index=None,
placeholder="Select a demo case...",
on_change=reset_demo_case,
)
if selection:
uploaded_file = "images/" + selection
seg_path = Path(__file__).parent / ("segmentations/" + add_postfix(selection, "seg"))
st.session_state.preds_3D = utils.read_image(seg_path)
st.session_state.preds_3D_ori = sitk.ReadImage(seg_path)
else:
uploaded_file = None
else:
uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case)
if uploaded_file is not None:
with open(tmpdirname + "/" + uploaded_file.name, "wb") as f:
f.write(uploaded_file.getvalue())
uploaded_file = tmpdirname + "/" + uploaded_file.name
st.session_state.option = uploaded_file
if (
st.session_state.option is not None
and st.session_state.reset_demo_case
or (st.session_state.data_item is None and st.session_state.option is not None)
):
st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file))
# st.session_state.preds_3D = None
# st.session_state.preds_3D_ori = None
st.session_state.reset_demo_case = False
if st.session_state.option is None:
st.write("please select demo case first")
else:
image_3D = st.session_state.data_item
px_range = st.slider(
"Select intensity range",
int(image_3D.min()),
int(image_3D.max()),
(int(image_3D.min()), int(image_3D.max())),
)
col_control1, col_control2 = st.columns(2)
with col_control1:
selected_index_z = st.slider(
"Axial view",
0,
image_3D.shape[0] - 1,
image_3D.shape[0] // 2,
key="xy",
disabled=st.session_state.running,
)
with col_control2:
selected_index_y = st.slider(
"Coronal view",
0,
image_3D.shape[1] - 1,
image_3D.shape[1] // 2,
key="xz",
disabled=st.session_state.running,
)
col_image1, col_image2 = st.columns(2)
if st.session_state.preds_3D is not None:
st.session_state.transparency = st.slider("Mask opacity", 0.0, 1.0, 0.35, disabled=st.session_state.running)
with col_image1:
image_z_array = image_3D[selected_index_z]
preds_z_array = None
if st.session_state.preds_3D is not None:
preds_z_array = st.session_state.preds_3D[selected_index_z]
image_z = utils.make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency)
st.image(image_z, use_column_width=False)
with col_image2:
image_y_array = image_3D[:, selected_index_y, :]
preds_y_array = None
if st.session_state.preds_3D is not None:
preds_y_array = st.session_state.preds_3D[:, selected_index_y, :]
image_y = utils.make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency)
st.image(image_y, use_column_width=False)
######################################################
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("#")
st.markdown("####")
st.markdown("####")
if st.button(
"Clear",
use_container_width=True,
disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)),
):
clear_prompts()
st.session_state.preds_3D = None
st.session_state.preds_path = None
st.rerun()
with col2:
st.markdown("#")
st.markdown("####")
st.markdown("####")
if st.session_state.preds_3D is not None and st.session_state.data_item is not None:
with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile:
preds = st.session_state.preds_3D_ori
sitk.WriteImage(preds, tmpfile.name)
with open(tmpfile.name, "rb") as f:
bytes_data = f.read()
st.download_button(
label="Download result (.nii.gz)",
data=bytes_data,
file_name="segmentation.nii.gz",
mime="application/octet-stream",
disabled=False,
)
with col3:
folds = st.radio("", ["Model of Fold 1 (fast)", "Ensemble Segmentation"])
if folds == "Model of Fold 1":
st.session_state.folds = (0,)
else:
st.session_state.folds = (
0,
1,
2,
3,
4,
)
run_button_name = "Run" if not st.session_state.running else "Running"
if st.button(
run_button_name,
type="primary",
use_container_width=True,
disabled=True,
# disabled=(st.session_state.data_item is None or st.session_state.running),
):
st.session_state.running = True
st.rerun()
if st.session_state.running:
st.session_state.running = False
with st.status("Running...", expanded=False) as status:
utils.run(tmpdirname)
st.rerun()