|
import os |
|
import gradio as gr |
|
from prediction import run_image_prediction |
|
import torch |
|
import torchvision.transforms as T |
|
from celle.utils import process_image |
|
from PIL import Image |
|
from matplotlib import pyplot as plt |
|
from celle_main import instantiate_from_config |
|
from huggingface_hub import hf_hub_download |
|
from omegaconf import OmegaConf |
|
|
|
class model: |
|
def __init__(self): |
|
self.model = None |
|
self.model_name = None |
|
|
|
def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if self.model_name != model_name: |
|
self.model_name = model_name |
|
model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") |
|
model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") |
|
|
|
|
|
|
|
config = OmegaConf.load(model_config_path) |
|
if config["model"]["params"]["ckpt_path"] is None: |
|
config["model"]["params"]["ckpt_path"] = model_ckpt_path |
|
|
|
|
|
config["model"]["params"]["condition_model_path"] = None |
|
config["model"]["params"]["vqgan_model_path"] = None |
|
|
|
base_path = os.getcwd() |
|
|
|
os.chdir(os.path.dirname(model_ckpt_path)) |
|
|
|
|
|
self.model = instantiate_from_config(config.model).to(device) |
|
self.model = torch.compile(self.model,mode='max-autotune') |
|
|
|
os.chdir(base_path) |
|
|
|
|
|
if "Finetuned" in model_name: |
|
dataset = "OpenCell" |
|
|
|
else: |
|
dataset = "HPA" |
|
|
|
nucleus_image = process_image(nucleus_image, dataset, "nucleus") |
|
if protein_image: |
|
protein_image = process_image(protein_image, dataset, "protein") |
|
protein_image = protein_image > torch.median(protein_image) |
|
protein_image = protein_image[0, 0] |
|
protein_image = protein_image * 1.0 |
|
else: |
|
protein_image = torch.ones((256, 256)) |
|
|
|
threshold, heatmap = run_image_prediction( |
|
sequence_input=sequence_input, |
|
nucleus_image=nucleus_image, |
|
model=self.model, |
|
device=device, |
|
) |
|
|
|
|
|
plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic") |
|
plt.axis("off") |
|
|
|
|
|
plt.savefig("temp.png", bbox_inches="tight", dpi=256) |
|
|
|
|
|
heatmap = Image.open("temp.png") |
|
|
|
return ( |
|
T.ToPILImage()(nucleus_image[0, 0]), |
|
T.ToPILImage()(protein_image), |
|
T.ToPILImage()(threshold), |
|
heatmap, |
|
) |
|
|
|
base_class = model() |
|
|
|
with gr.Blocks(theme='gradio/soft') as demo: |
|
gr.Markdown("## Inputs") |
|
gr.Markdown("Select the prediction model. **Note the first run may take ~1-2 minutes, but will take 2-3 seconds afterwards.**") |
|
gr.Markdown( |
|
"```CELL-E_2_HPA_480``` is a good general purpose model for various cell types using ICC-IF." |
|
) |
|
gr.Markdown( |
|
"```CELL-E_2_HPA_Finetuned_480``` is finetuned on OpenCell and is good more live-cell predictions on HEK cells." |
|
) |
|
with gr.Row(): |
|
model_name = gr.Dropdown( |
|
["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"], |
|
value="CELL-E_2_HPA_480", |
|
label="Model Name", |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"Input the desired amino acid sequence. GFP is shown below by default." |
|
) |
|
|
|
with gr.Row(): |
|
sequence_input = gr.Textbox( |
|
value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK", |
|
label="Sequence", |
|
) |
|
with gr.Row(): |
|
gr.Markdown( |
|
"Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)" |
|
) |
|
gr.Markdown("The protein image is optional and is just used for display.") |
|
|
|
with gr.Row().style(equal_height=True): |
|
nucleus_image = gr.Image( |
|
type="pil", |
|
label="Nucleus Image", |
|
image_mode="L", |
|
) |
|
|
|
protein_image = gr.Image(type="pil", label="Protein Image (Optional)") |
|
|
|
with gr.Row(): |
|
gr.Markdown("## Outputs") |
|
|
|
with gr.Row(): |
|
gr.Markdown("Image predictions are show below.") |
|
|
|
with gr.Row().style(equal_height=True): |
|
nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L") |
|
|
|
protein_threshold_image = gr.Image( |
|
type="pil", label="Protein Threshold Image", image_mode="L" |
|
) |
|
|
|
predicted_threshold_image = gr.Image( |
|
type="pil", label="Predicted Threshold image", image_mode="L" |
|
) |
|
|
|
predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap") |
|
with gr.Row(): |
|
button = gr.Button("Run Model") |
|
|
|
inputs = [model_name, sequence_input, nucleus_image, protein_image] |
|
|
|
outputs = [ |
|
nucleus_image_crop, |
|
protein_threshold_image, |
|
predicted_threshold_image, |
|
predicted_heatmap, |
|
] |
|
|
|
button.click(base_class.gradio_demo, inputs, outputs) |
|
|
|
demo.launch(enable_queue=True) |