import gradio as gr from prediction import run_image_prediction import torch import torchvision.transforms as T from celle.utils import process_image from PIL import Image from matplotlib import pyplot as plt def gradio_demo(model_name, sequence_input, nucleus_image, protein_image): model = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt" config = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if "Finetuned" in model_name: dataset = "OpenCell" else: dataset = "HPA" nucleus_image = process_image(nucleus_image, dataset, "nucleus") if protein_image: protein_image = process_image(protein_image, dataset, "protein") protein_image = protein_image > torch.median(protein_image) protein_image = protein_image[0, 0] protein_image = protein_image * 1.0 else: protein_image = torch.ones((256, 256)) threshold, heatmap = run_image_prediction( sequence_input=sequence_input, nucleus_image=nucleus_image, model_ckpt_path=model, model_config_path=config, device=device, ) # Plot the heatmap plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic") plt.axis("off") # Save the plot to a temporary file plt.savefig("temp.png", bbox_inches="tight", dpi=256) # Open the temporary file as a PIL image heatmap = Image.open("temp.png") return ( T.ToPILImage()(nucleus_image[0, 0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap, ) with gr.Blocks(theme='gradio/soft') as demo: gr.Markdown("Select the prediction model.") gr.Markdown( "CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF." ) gr.Markdown( "CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells." ) with gr.Row(): model_name = gr.Dropdown( ["CELL-E_2-HPA_480", "CELL-E_2-HPA_Finetuned_480"], value="CELL-E_2-HPA_480", label="Model Name", ) with gr.Row(): gr.Markdown( "Input the desired amino acid sequence. GFP is shown below by default." ) with gr.Row(): sequence_input = gr.Textbox( value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", label="Sequence", ) with gr.Row(): gr.Markdown( "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)" ) gr.Markdown("The protein image is optional and is just used for display.") with gr.Row().style(equal_height=True): nucleus_image = gr.Image( type="pil", label="Nucleus Image", image_mode="L", ) protein_image = gr.Image(type="pil", label="Protein Image (Optional)") with gr.Row(): gr.Markdown("Image predictions are show below.") with gr.Row().style(equal_height=True): nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L") protein_threshold_image = gr.Image( type="pil", label="Protein Threshold Image", image_mode="L" ) predicted_threshold_image = gr.Image( type="pil", label="Predicted Threshold image", image_mode="L" ) predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap") with gr.Row(): button = gr.Button("Run Model") inputs = [model_name, sequence_input, nucleus_image, protein_image] outputs = [ nucleus_image_crop, protein_threshold_image, predicted_threshold_image, predicted_heatmap, ] button.click(gradio_demo, inputs, outputs) demo.launch(enable_queue=True)