Spaces:
Running
Running
import gradio as gr | |
import torch | |
from viscy.light.engine import VSUNet | |
from huggingface_hub import hf_hub_download | |
from numpy.typing import ArrayLike | |
import numpy as np | |
from skimage import exposure | |
from skimage.transform import resize | |
from skimage import img_as_float | |
from skimage.util import invert | |
import cmap | |
class VSGradio: | |
def __init__(self, model_config, model_ckpt_path): | |
self.model_config = model_config | |
self.model_ckpt_path = model_ckpt_path | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = None | |
self.load_model() | |
def load_model(self): | |
# Load the model checkpoint and move it to the correct device (GPU or CPU) | |
self.model = VSUNet.load_from_checkpoint( | |
self.model_ckpt_path, | |
architecture="UNeXt2_2D", | |
model_config=self.model_config, | |
) | |
self.model.to(self.device) # Move the model to the correct device (GPU/CPU) | |
self.model.eval() | |
def normalize_fov(self, input: ArrayLike): | |
"Normalizing the fov with zero mean and unit variance" | |
mean = np.mean(input) | |
std = np.std(input) | |
return (input - mean) / std | |
def preprocess_image_standard(self, input: ArrayLike): | |
# Perform standard preprocessing here | |
input = exposure.equalize_adapthist(input) | |
return input | |
def downscale_image(self, inp: ArrayLike, scale_factor: float): | |
"""Downscales the image by the given scaling factor""" | |
height, width = inp.shape | |
new_height = int(height * scale_factor) | |
new_width = int(width * scale_factor) | |
return resize(inp, (new_height, new_width), anti_aliasing=True) | |
def predict(self, inp, cell_diameter: float): | |
# Normalize the input and convert to tensor | |
inp = self.normalize_fov(inp) | |
original_shape = inp.shape | |
# Resize the input image to the expected cell diameter | |
inp = apply_rescale_image(inp, cell_diameter, expected_cell_diameter=30) | |
# Convert the input to a tensor | |
inp = torch.from_numpy(np.array(inp).astype(np.float32)) | |
# Prepare the input dictionary and move input to the correct device (GPU or CPU) | |
test_dict = dict( | |
index=None, | |
source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device), | |
) | |
# Run model inference | |
with torch.inference_mode(): | |
self.model.on_predict_start() # Necessary preprocessing for the model | |
pred = ( | |
self.model.predict_step(test_dict, 0, 0).cpu().numpy() | |
) # Move output back to CPU for post-processing | |
# Post-process the model output and rescale intensity | |
nuc_pred = pred[0, 0, 0] | |
mem_pred = pred[0, 1, 0] | |
# Resize predictions back to the original image size | |
nuc_pred = resize(nuc_pred, original_shape, anti_aliasing=True) | |
mem_pred = resize(mem_pred, original_shape, anti_aliasing=True) | |
# Define colormaps | |
green_colormap = cmap.Colormap("green") # Nucleus: black to green | |
magenta_colormap = cmap.Colormap("magenta") | |
# Apply the colormap to the predictions | |
nuc_rgb = apply_colormap(nuc_pred, green_colormap) | |
mem_rgb = apply_colormap(mem_pred, magenta_colormap) | |
return nuc_rgb, mem_rgb | |
def apply_colormap(prediction, colormap: cmap.Colormap): | |
"""Apply a colormap to a single-channel prediction image.""" | |
# Ensure the prediction is within the valid range [0, 1] | |
prediction = exposure.rescale_intensity(prediction, out_range=(0, 1)) | |
# Apply the colormap to get an RGB image | |
rgb_image = colormap(prediction) | |
# Convert the output from [0, 1] to [0, 255] for display | |
rgb_image_uint8 = (rgb_image * 255).astype(np.uint8) | |
return rgb_image_uint8 | |
def apply_image_adjustments(image, invert_image: bool, gamma_factor: float): | |
"""Applies all the image adjustments (invert, contrast, gamma) in sequence""" | |
# Apply invert | |
if invert_image: | |
image = invert(image, signed_float=False) | |
# Apply gamma adjustment | |
image = exposure.adjust_gamma(image, gamma_factor) | |
return exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8) | |
def apply_rescale_image( | |
image, cell_diameter: float, expected_cell_diameter: float = 30 | |
): | |
# Assume the model was trained with cells ~30 microns in diameter | |
# Resize the input image according to the scaling factor | |
scale_factor = expected_cell_diameter / float(cell_diameter) | |
image = resize( | |
image, | |
(int(image.shape[0] * scale_factor), int(image.shape[1] * scale_factor)), | |
anti_aliasing=True, | |
) | |
return image | |
# Load the custom CSS from the file | |
def load_css(file_path): | |
with open(file_path, "r") as file: | |
return file.read() | |
if __name__ == "__main__": | |
# Download the model checkpoint from Hugging Face | |
model_ckpt_path = hf_hub_download( | |
repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt" | |
) | |
# Model configuration | |
model_config = { | |
"in_channels": 1, | |
"out_channels": 2, | |
"encoder_blocks": [3, 3, 9, 3], | |
"dims": [96, 192, 384, 768], | |
"decoder_conv_blocks": 2, | |
"stem_kernel_size": [1, 2, 2], | |
"in_stack_depth": 1, | |
"pretraining": False, | |
} | |
vsgradio = VSGradio(model_config, model_ckpt_path) | |
# Initialize the Gradio app using Blocks | |
with gr.Blocks(css=load_css("style.css")) as demo: | |
# Title and description | |
gr.HTML( | |
"<div class='title-block'>Image Translation (Virtual Staining) of cellular landmark organelles</div>" | |
) | |
gr.HTML( | |
""" | |
<div class='description-block'> | |
<p><b>Model:</b> VSCyto2D</p> | |
<p><b>Input:</b> label-free image (e.g., QPI or phase contrast).</p> | |
<p><b>Output:</b> Virtual staining of nucleus and membrane.</p> | |
<p><b>Note:</b> The model works well with QPI, and sometimes generalizes to phase contrast and DIC. We continue to diagnose and improve generalization<p> | |
<p>Check out our preprint: <a href='https://www.biorxiv.org/content/10.1101/2024.05.31.596901' target='_blank'><i>Liu et al., Robust virtual staining of landmark organelles</i></a></p> | |
<p> For training, inference and evaluation of the model refer to the <a href='https://github.com/mehta-lab/VisCy/tree/main/examples/virtual_staining/dlmbl_exercise' target='_blank'>GitHub repository</a>.</p> | |
</div> | |
""" | |
) | |
# Layout for input and output images | |
with gr.Row(): | |
input_image = gr.Image(type="numpy", image_mode="L", label="Upload Image") | |
adjusted_image = gr.Image( | |
type="numpy", image_mode="L", label="Adjusted Image (Preview)" | |
) | |
with gr.Column(): | |
output_nucleus = gr.Image( | |
type="numpy", image_mode="RGB", label="VS Nucleus" | |
) | |
output_membrane = gr.Image( | |
type="numpy", image_mode="RGB", label="VS Membrane" | |
) | |
# Checkbox for applying invert | |
preprocess_invert = gr.Checkbox(label="Apply Invert", value=False) | |
# Slider for gamma adjustment | |
gamma_factor = gr.Slider( | |
label="Adjust Gamma", minimum=0.1, maximum=5.0, value=1.0, step=0.1 | |
) | |
# Input field for the cell diameter in microns | |
cell_diameter = gr.Textbox( | |
label="Cell Diameter [um]", | |
value="30.0", | |
placeholder="Enter cell diameter in microns", | |
) | |
# Update the adjusted image based on all the transformations | |
input_image.change( | |
fn=apply_image_adjustments, | |
inputs=[input_image, preprocess_invert, gamma_factor], | |
outputs=adjusted_image, | |
) | |
gamma_factor.change( | |
fn=apply_image_adjustments, | |
inputs=[input_image, preprocess_invert, gamma_factor], | |
outputs=adjusted_image, | |
) | |
preprocess_invert.change( | |
fn=apply_image_adjustments, | |
inputs=[input_image, preprocess_invert, gamma_factor], | |
outputs=adjusted_image, | |
) | |
# Button to trigger prediction | |
submit_button = gr.Button("Submit") | |
# Define what happens when the button is clicked (send adjusted image to predict) | |
submit_button.click( | |
vsgradio.predict, | |
inputs=[adjusted_image, cell_diameter], | |
outputs=[output_nucleus, output_membrane], | |
) | |
# Example images and article | |
gr.Examples( | |
examples=[ | |
"examples/a549.png", | |
"examples/hek.png", | |
"examples/ctc_HeLa.png", | |
"examples/livecell_A172.png", | |
], | |
inputs=input_image, | |
) | |
# Article or footer information | |
gr.HTML( | |
""" | |
<div class='article-block'> | |
<p> Model trained primarily on HEK293T, BJ5, and A549 cells. For best results, use quantitative phase images (QPI)</p> | |
</div> | |
""" | |
) | |
# Launch the Gradio app | |
demo.launch() | |