|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from prediction import run_image_prediction |
|
import torch |
|
import torchvision.transforms as T |
|
from celle.utils import process_image |
|
from PIL import Image |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
def gradio_demo(model_name, sequence_input, nucleus_image, protein_image): |
|
model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt") |
|
config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml") |
|
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if 'Finetuned' in model_name: |
|
dataset = 'OpenCell' |
|
|
|
else: |
|
dataset = 'HPA' |
|
|
|
nucleus_image = process_image(nucleus_image,dataset,'nucleus') |
|
if protein_image: |
|
protein_image = process_image(protein_image,dataset,'protein') |
|
protein_image = protein_image > torch.median(protein_image) |
|
protein_image = protein_image[0,0] |
|
protein_image = protein_image*1.0 |
|
else: |
|
protein_image = torch.ones((256,256)) |
|
|
|
|
|
threshold, heatmap = run_image_prediction(sequence_input = sequence_input, |
|
nucleus_image = nucleus_image, |
|
model_ckpt_path=model, |
|
model_config_path=config, |
|
device=device) |
|
|
|
|
|
plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic') |
|
plt.axis('off') |
|
|
|
|
|
plt.savefig('temp.png', bbox_inches='tight', dpi = 256) |
|
|
|
|
|
heatmap = Image.open('temp.png') |
|
|
|
return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Select the prediction model.") |
|
gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.") |
|
gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.") |
|
with gr.Row(): |
|
model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'], |
|
value='CELL-E_2_HPA_480', label = 'Model Name') |
|
with gr.Row(): |
|
gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.") |
|
|
|
with gr.Row(): |
|
sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', |
|
label = 'Sequence') |
|
with gr.Row(): |
|
gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.") |
|
gr.Markdown("The protein image is optional and is just used for display.") |
|
|
|
with gr.Row().style(equal_height=True): |
|
nucleus_image = gr.Image(value = 'https://huggingface.co/spaces/HuangLab/CELL-E_2/resolve/main/images/Armadillo%20repeat-containing%20X-linked%20protein%205%20nucleus.jpg', |
|
type='pil', |
|
label = 'Nucleus Image', |
|
image_mode='L') |
|
|
|
protein_image = gr.Image(type='pil', label = 'Protein Image (Optional)') |
|
|
|
with gr.Row(): |
|
gr.Markdown("Image predictions are show below.") |
|
|
|
with gr.Row().style(equal_height=True): |
|
nucleus_image_crop = gr.Image(type='pil', |
|
label = 'Nucleus Image', |
|
image_mode='L') |
|
|
|
protein_threshold_image = gr.Image(type='pil', |
|
label = 'Protein Threshold Image', |
|
image_mode='L') |
|
|
|
predicted_threshold_image = gr.Image(type='pil', |
|
label = 'Predicted Threshold image', |
|
image_mode='L') |
|
|
|
predicted_heatmap = gr.Image(type='pil', |
|
label = 'Predicted Heatmap') |
|
with gr.Row(): |
|
button = gr.Button("Run Model") |
|
|
|
inputs = [model_name, |
|
sequence_input, |
|
nucleus_image, |
|
protein_image] |
|
|
|
outputs = [nucleus_image_crop, |
|
protein_threshold_image, |
|
predicted_threshold_image, |
|
predicted_heatmap] |
|
|
|
button.click(gradio_demo, inputs, outputs) |
|
|
|
examples = [['CELL-E_2_HPA_Finetuned_480', |
|
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', |
|
'images/Proteasome activator complex subunit 3 nucleus.png', |
|
'images/Proteasome activator complex subunit 3 protein.png'], |
|
['CELL-E_2_HPA_480', |
|
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK', |
|
'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg', |
|
'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']] |
|
|
|
|
|
demo.launch(share=True) |
|
|