File size: 5,688 Bytes
86d2765 64212e0 86d2765 64212e0 3c77a96 64212e0 3c77a96 64212e0 3c77a96 64212e0 3c77a96 64212e0 3c77a96 64212e0 f8307a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
from huggingface_hub import hf_hub_download
from prediction import run_image_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt
def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if 'Finetuned' in model_name:
dataset = 'OpenCell'
else:
dataset = 'HPA'
nucleus_image = process_image(nucleus_image,dataset,'nucleus')
if protein_image:
protein_image = process_image(protein_image,dataset,'protein')
protein_image = protein_image > torch.median(protein_image)
protein_image = protein_image[0,0]
protein_image = protein_image*1.0
else:
protein_image = torch.ones((256,256))
threshold, heatmap = run_image_prediction(sequence_input = sequence_input,
nucleus_image = nucleus_image,
model_ckpt_path=model,
model_config_path=config,
device=device)
# Plot the heatmap
plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic')
plt.axis('off')
# Save the plot to a temporary file
plt.savefig('temp.png', bbox_inches='tight', dpi = 256)
# Open the temporary file as a PIL image
heatmap = Image.open('temp.png')
return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap
with gr.Blocks() as demo:
gr.Markdown("Select the prediction model.")
gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.")
gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.")
with gr.Row():
model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'],
value='CELL-E_2_HPA_480', label = 'Model Name')
with gr.Row():
gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.")
with gr.Row():
sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
label = 'Sequence')
with gr.Row():
gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.")
gr.Markdown("The protein image is optional and is just used for display.")
with gr.Row().style(equal_height=True):
nucleus_image = gr.Image(value = 'https://huggingface.co/spaces/HuangLab/CELL-E_2/resolve/main/images/Armadillo%20repeat-containing%20X-linked%20protein%205%20nucleus.jpg',
type='pil',
label = 'Nucleus Image',
image_mode='L')
protein_image = gr.Image(type='pil', label = 'Protein Image (Optional)')
with gr.Row():
gr.Markdown("Image predictions are show below.")
with gr.Row().style(equal_height=True):
nucleus_image_crop = gr.Image(type='pil',
label = 'Nucleus Image',
image_mode='L')
protein_threshold_image = gr.Image(type='pil',
label = 'Protein Threshold Image',
image_mode='L')
predicted_threshold_image = gr.Image(type='pil',
label = 'Predicted Threshold image',
image_mode='L')
predicted_heatmap = gr.Image(type='pil',
label = 'Predicted Heatmap')
with gr.Row():
button = gr.Button("Run Model")
inputs = [model_name,
sequence_input,
nucleus_image,
protein_image]
outputs = [nucleus_image_crop,
protein_threshold_image,
predicted_threshold_image,
predicted_heatmap]
button.click(gradio_demo, inputs, outputs)
examples = [['CELL-E_2_HPA_Finetuned_480',
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
'images/Proteasome activator complex subunit 3 nucleus.png',
'images/Proteasome activator complex subunit 3 protein.png'],
['CELL-E_2_HPA_480',
'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']]
# demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout)
demo.launch(share=True)
|