VirtualStaining / app.py
edyoshikun's picture
fixing reference for mouse kidney
b01dd01
import gradio as gr
import torch
from viscy.light.engine import VSUNet
from huggingface_hub import hf_hub_download
from numpy.typing import ArrayLike
import numpy as np
from skimage import exposure
from skimage.transform import resize
from skimage.util import invert
import cmap
class VSGradio:
def __init__(self, model_config, model_ckpt_path):
self.model_config = model_config
self.model_ckpt_path = model_ckpt_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.load_model()
def load_model(self):
# Load the model checkpoint and move it to the correct device (GPU or CPU)
self.model = VSUNet.load_from_checkpoint(
self.model_ckpt_path,
architecture="UNeXt2_2D",
model_config=self.model_config,
)
self.model.to(self.device) # Move the model to the correct device (GPU/CPU)
self.model.eval()
def normalize_fov(self, input: ArrayLike):
"Normalizing the fov with zero mean and unit variance"
mean = np.mean(input)
std = np.std(input)
return (input - mean) / std
def preprocess_image_standard(self, input: ArrayLike):
# Perform standard preprocessing here
input = exposure.equalize_adapthist(input)
return input
def downscale_image(self, inp: ArrayLike, scale_factor: float):
"""Downscales the image by the given scaling factor"""
height, width = inp.shape
new_height = int(height * scale_factor)
new_width = int(width * scale_factor)
return resize(inp, (new_height, new_width), anti_aliasing=True)
def predict(self, inp, scaling_factor: float):
# Normalize the input and convert to tensor
inp = self.normalize_fov(inp)
original_shape = inp.shape
# Resize the input image to the expected cell diameter
inp = apply_rescale_image(inp, scaling_factor)
# Convert the input to a tensor
inp = torch.from_numpy(np.array(inp).astype(np.float32))
# Prepare the input dictionary and move input to the correct device (GPU or CPU)
test_dict = dict(
index=None,
source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
)
# Run model inference
with torch.inference_mode():
self.model.on_predict_start() # Necessary preprocessing for the model
pred = (
self.model.predict_step(test_dict, 0, 0).cpu().numpy()
) # Move output back to CPU for post-processing
# Post-process the model output and rescale intensity
nuc_pred = pred[0, 0, 0]
mem_pred = pred[0, 1, 0]
# Resize predictions back to the original image size
nuc_pred = resize(nuc_pred, original_shape, anti_aliasing=True)
mem_pred = resize(mem_pred, original_shape, anti_aliasing=True)
# Define colormaps
green_colormap = cmap.Colormap("green") # Nucleus: black to green
magenta_colormap = cmap.Colormap("magenta")
# Apply the colormap to the predictions
nuc_rgb = apply_colormap(nuc_pred, green_colormap)
mem_rgb = apply_colormap(mem_pred, magenta_colormap)
return nuc_rgb, mem_rgb # Return both nucleus and membrane images
def apply_colormap(prediction, colormap: cmap.Colormap):
"""Apply a colormap to a single-channel prediction image."""
# Ensure the prediction is within the valid range [0, 1]
prediction = exposure.rescale_intensity(prediction, out_range=(0, 1))
# Apply the colormap to get an RGB image
rgb_image = colormap(prediction)
# Convert the output from [0, 1] to [0, 255] for display
rgb_image_uint8 = (rgb_image * 255).astype(np.uint8)
return rgb_image_uint8
def merge_images(nuc_rgb: ArrayLike, mem_rgb: ArrayLike) -> ArrayLike:
"""Merge nucleus and membrane images into a single RGB image."""
return np.maximum(nuc_rgb, mem_rgb)
def apply_image_adjustments(image, invert_image: bool, gamma_factor: float):
"""Applies all the image adjustments (invert, contrast, gamma) in sequence"""
# Apply invert
if invert_image:
image = invert(image, signed_float=False)
# Apply gamma adjustment
image = exposure.adjust_gamma(image, gamma_factor)
return exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8)
def apply_rescale_image(image, scaling_factor: float):
"""Resize the input image according to the scaling factor"""
scaling_factor = float(scaling_factor)
image = resize(
image,
(int(image.shape[0] * scaling_factor), int(image.shape[1] * scaling_factor)),
anti_aliasing=True,
)
return image
# Function to clear outputs when a new image is uploaded
def clear_outputs(image):
return (
image,
None,
None,
) # Return None for adjusted_image, output_nucleus, and output_membrane
def load_css(file_path):
"""Load custom CSS"""
with open(file_path, "r") as file:
return file.read()
if __name__ == "__main__":
# Download the model checkpoint from Hugging Face
model_ckpt_path = hf_hub_download(
repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
)
# Model configuration
model_config = {
"in_channels": 1,
"out_channels": 2,
"encoder_blocks": [3, 3, 9, 3],
"dims": [96, 192, 384, 768],
"decoder_conv_blocks": 2,
"stem_kernel_size": [1, 2, 2],
"in_stack_depth": 1,
"pretraining": False,
}
vsgradio = VSGradio(model_config, model_ckpt_path)
# Initialize the Gradio app using Blocks
with gr.Blocks(css=load_css("style.css")) as demo:
# Title and description
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://www.czbiohub.org/sf/" target="_blank">
<img src="https://huggingface.co/spaces/compmicro-czb/VirtualStaining/resolve/main/misc/czb_mark.png" style="width: 100px; height: auto; margin-right: 10px;">
</a>
<div class='title-block'>Image Translation (Virtual Staining) of cellular landmark organelles</div>
</div>
"""
)
gr.HTML(
"""
<div class='description-block'>
<p><b>Model:</b> VSCyto2D</p>
<p><b>Input:</b> label-free image (e.g., QPI or phase contrast).</p>
<p><b>Output:</b> Virtual staining of nucleus and membrane.</p>
<p><b>Note:</b> The model works well with QPI, and sometimes generalizes to phase contrast and DIC.<br>
It was trained primarily on HEK293T, BJ5, and A549 cells imaged at 20x. <br>
We continue to diagnose and improve generalization<p>
<p>Check out our preprint: <a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'><i>Liu et al., Robust virtual staining of landmark organelles</i></a></p>
<p> For training your own model and analyzing large amounts of data, use our <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p>
</div>
"""
)
# Layout for input and output images
with gr.Row():
input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image")
adjusted_image = gr.Image(
type="numpy",
image_mode="L",
label="Adjusted Image (Preview)",
interactive=False,
)
with gr.Column():
output_nucleus = gr.Image(
type="numpy", image_mode="RGB", label="VS Nucleus"
)
output_membrane = gr.Image(
type="numpy", image_mode="RGB", label="VS Membrane"
)
merged_image = gr.Image(
type="numpy", image_mode="RGB", label="Merged Image", visible=False
)
# Checkbox for applying invert
preprocess_invert = gr.Checkbox(label="Invert Image", value=False)
# Slider for gamma adjustment
gamma_factor = gr.Slider(
label="Adjust Gamma", minimum=0.01, maximum=5.0, value=1.0, step=0.1
)
# Input field for the cell diameter in microns
scaling_factor = gr.Textbox(
label="Rescaling image factor",
value="1.0",
placeholder="Rescaling factor for the input image",
)
# Checkbox for merging predictions
merge_checkbox = gr.Checkbox(
label="Merge Predictions into one image", value=True
)
input_image.change(
fn=apply_image_adjustments,
inputs=[input_image, preprocess_invert, gamma_factor],
outputs=adjusted_image,
)
gamma_factor.change(
fn=apply_image_adjustments,
inputs=[input_image, preprocess_invert, gamma_factor],
outputs=adjusted_image,
)
cell_name = gr.Textbox(
label="Cell Name", placeholder="Cell Type", visible=False
)
imaging_modality = gr.Textbox(
label="Imaging Modality", placeholder="Imaging Modality", visible=False
)
references = gr.Textbox(
label="References", placeholder="References", visible=False
)
preprocess_invert.change(
fn=apply_image_adjustments,
inputs=[input_image, preprocess_invert, gamma_factor],
outputs=adjusted_image,
)
# Button to trigger prediction and update the output images
submit_button = gr.Button("Submit")
# Function to handle prediction and merging if needed
def submit_and_merge(inp, scaling_factor, merge):
nucleus, membrane = vsgradio.predict(inp, scaling_factor)
if merge:
merged = merge_images(nucleus, membrane)
return (
merged,
gr.update(visible=True),
nucleus,
gr.update(visible=False),
membrane,
gr.update(visible=False),
)
else:
return (
None,
gr.update(visible=False),
nucleus,
gr.update(visible=True),
membrane,
gr.update(visible=True),
)
submit_button.click(
fn=submit_and_merge,
inputs=[adjusted_image, scaling_factor, merge_checkbox],
outputs=[
merged_image,
merged_image,
output_nucleus,
output_nucleus,
output_membrane,
output_membrane,
],
)
# Clear everything when the input image changes
input_image.change(
fn=clear_outputs,
inputs=input_image,
outputs=[adjusted_image, output_nucleus, output_membrane],
)
# Function to handle merging the two predictions after they are shown
def merge_predictions_fn(nucleus_image, membrane_image, merge):
if merge:
merged = merge_images(nucleus_image, membrane_image)
return (
merged,
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
else:
return (
None,
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
)
# Toggle between merged and separate views when the checkbox is checked
merge_checkbox.change(
fn=merge_predictions_fn,
inputs=[output_nucleus, output_membrane, merge_checkbox],
outputs=[merged_image, merged_image, output_nucleus, output_membrane],
)
# Example images and article
examples_component = gr.Examples(
examples=[
["examples/a549.png", "A549", "QPI", 1.0, False, "1.0", "1"],
["examples/hek.png", "HEK293T", "QPI", 1.0, False, "1.0", "1"],
["examples/HEK_PhC.png", "HEK293T", "PhC", 1.2, True, "1.0", "1"],
["examples/livecell_A172.png", "A172", "PhC", 1.0, True, "1.0", "2"],
["examples/ctc_HeLa.png", "HeLa", "DIC", 0.7, False, "0.7", "3"],
[
"examples/ctc_glioblastoma_astrocytoma_U373.png",
"Glioblastoma",
"PhC",
1.0,
True,
"2.0",
"3",
],
["examples/U2OS_BF.png", "U2OS", "Brightfield", 1.0, False, "0.3", "4"],
["examples/U2OS_QPI.png", "U2OS", "QPI", 1.0, False, "0.3", "4"],
[
"examples/neuromast2.png",
"Zebrafish neuromast",
"QPI",
0.6,
False,
"1.2",
"1",
],
[
"examples/mousekidney.png",
"Mouse Kidney",
"QPI",
0.8,
False,
"0.6",
"4",
],
],
inputs=[
input_image,
cell_name,
imaging_modality,
gamma_factor,
preprocess_invert,
scaling_factor,
references,
],
)
# Article or footer information
gr.HTML(
"""
<div class='article-block'>
<li>1. <a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'>Liu et al., Robust virtual staining of landmark organelles</a></li>
<li>2. <a href='https://sartorius-research.github.io/LIVECell/' target='_blank'>Edlund et. al. LIVECEll-A large-scale dataset for label-free live cell segmentation</a></li>
<li>3. <a href='https://celltrackingchallenge.net/' target='_blank'>Maska et. al.,The cell tracking challenge: 10 years of objective benchmarking </a></li>
<li>4. <a href='https://elifesciences.org/articles/55502' target='_blank'>Guo et. al., Revealing architectural order with quantitative label-free imaging and deep learning</a></li>
</div>
"""
)
# Launch the Gradio app
demo.launch()