File size: 5,688 Bytes
86d2765
64212e0
 
 
 
 
 
 
86d2765
 
64212e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c77a96
64212e0
3c77a96
 
64212e0
 
 
 
 
 
 
 
3c77a96
 
64212e0
 
3c77a96
 
64212e0
 
3c77a96
 
64212e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8307a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from huggingface_hub import hf_hub_download
from prediction import run_image_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt


def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
    model = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
    config = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
    hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
    hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    if 'Finetuned' in model_name:
        dataset = 'OpenCell'

    else:
        dataset = 'HPA'
    
    nucleus_image = process_image(nucleus_image,dataset,'nucleus')
    if protein_image:
        protein_image = process_image(protein_image,dataset,'protein') 
        protein_image = protein_image > torch.median(protein_image)
        protein_image = protein_image[0,0]
        protein_image = protein_image*1.0
    else:
        protein_image = torch.ones((256,256))

        
    threshold, heatmap = run_image_prediction(sequence_input = sequence_input,
                         nucleus_image = nucleus_image,
                         model_ckpt_path=model,
                         model_config_path=config,
                         device=device)
    
    # Plot the heatmap
    plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic')
    plt.axis('off')

    # Save the plot to a temporary file
    plt.savefig('temp.png', bbox_inches='tight', dpi = 256)

    # Open the temporary file as a PIL image
    heatmap = Image.open('temp.png')
    
    return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap


with gr.Blocks() as demo:
    gr.Markdown("Select the prediction model.")
    gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.")
    gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.")
    with gr.Row():
        model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'], 
                      value='CELL-E_2_HPA_480', label = 'Model Name')
    with gr.Row():
        gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.")  
    
    with gr.Row():    
        sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
                        label = 'Sequence')
    with gr.Row():
        gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.")
        gr.Markdown("The protein image is optional and is just used for display.")            
    
    with gr.Row().style(equal_height=True):
        nucleus_image = gr.Image(value = 'https://huggingface.co/spaces/HuangLab/CELL-E_2/resolve/main/images/Armadillo%20repeat-containing%20X-linked%20protein%205%20nucleus.jpg', 
                   type='pil',
                   label = 'Nucleus Image',
                    image_mode='L')
        
        protein_image =  gr.Image(type='pil', label = 'Protein Image (Optional)')
        
    with gr.Row():
        gr.Markdown("Image predictions are show below.")   
    
    with gr.Row().style(equal_height=True):   
        nucleus_image_crop = gr.Image(type='pil',
                   label = 'Nucleus Image',
                    image_mode='L')

        protein_threshold_image = gr.Image(type='pil', 
                    label = 'Protein Threshold Image',
                    image_mode='L')
        
        predicted_threshold_image = gr.Image(type='pil',
                    label = 'Predicted Threshold image',
                    image_mode='L')
                
        predicted_heatmap = gr.Image(type='pil',
                            label = 'Predicted Heatmap')
    with gr.Row():
        button = gr.Button("Run Model")
            
        inputs = [model_name, 
                sequence_input,
                nucleus_image,
                protein_image]

        outputs = [nucleus_image_crop,
                    protein_threshold_image,
                   predicted_threshold_image,
                   predicted_heatmap]

        button.click(gradio_demo, inputs, outputs)

examples = [['CELL-E_2_HPA_Finetuned_480',
            'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
            'images/Proteasome activator complex subunit 3 nucleus.png',
            'images/Proteasome activator complex subunit 3 protein.png'],
            ['CELL-E_2_HPA_480',
            'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
            'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
            'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']]

# demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout)
demo.launch(share=True)