File size: 6,171 Bytes
35aeee1
86d2765
64212e0
 
 
 
 
 
35aeee1
64e25c0
35aeee1
 
 
 
 
 
 
 
 
 
 
 
fece6fa
 
 
 
35aeee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815b096
35aeee1
 
 
 
 
 
 
 
 
b75e844
 
 
 
35aeee1
 
b75e844
 
 
 
 
 
 
35aeee1
 
 
 
 
 
 
 
 
86d2765
35aeee1
 
 
86d2765
35aeee1
 
64212e0
35aeee1
 
cbeaab6
35aeee1
 
 
 
 
 
64212e0
35aeee1
64212e0
0a9dccb
d8ffd20
ce23710
cbeaab6
2d4b37f
cbeaab6
 
2d4b37f
cbeaab6
 
 
35aeee1
 
cbeaab6
 
64212e0
cbeaab6
 
 
 
64212e0
cbeaab6
 
 
 
64212e0
cbeaab6
e38e937
cbeaab6
 
 
64212e0
cbeaab6
 
 
 
 
 
 
 
d8ffd20
 
 
64212e0
cbeaab6
 
 
 
 
 
 
 
 
 
 
 
 
 
64212e0
 
 
cbeaab6
 
 
 
 
 
 
 
64212e0
35aeee1
64212e0
e9c0dec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import gradio as gr
from prediction import run_image_prediction
import torch
import torchvision.transforms as T
from celle.utils import process_image
from PIL import Image
from matplotlib import pyplot as plt
from celle_main import instantiate_from_config
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf

class model:
    def __init__(self):
        self.model = None
        self.model_name = None

    def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        if self.model_name != model_name:
            self.model_name = model_name
            model_ckpt_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="model.ckpt")
            model_config_path = hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="config.yaml")
            hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
            hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
            
            
            # Load model config and set ckpt_path if not provided in config
            config = OmegaConf.load(model_config_path)
            if config["model"]["params"]["ckpt_path"] is None:
                config["model"]["params"]["ckpt_path"] = model_ckpt_path

            # Set condition_model_path and vqgan_model_path to None
            config["model"]["params"]["condition_model_path"] = None
            config["model"]["params"]["vqgan_model_path"] = None
            
            base_path = os.getcwd()

            os.chdir(os.path.dirname(model_ckpt_path))

            # Instantiate model from config and move to device
            self.model = instantiate_from_config(config.model).to(device)
            self.model = torch.compile(self.model,mode='max-autotune')
            
            os.chdir(base_path)


        if "Finetuned" in model_name:
            dataset = "OpenCell"

        else:
            dataset = "HPA"
        
        to_tensor = T.ToTensor()
        
        nucleus_tensor = to_tensor(nucleus_image)

        if protein_image:
            protein_tensor = to_tensor(protein_image)
            stacked_images = torch.stack([nucleus_tensor, protein_tensor], dim=0)
            processed_images = process_image(stacked_images, dataset)
            nucleus_image = processed_images[0].unsqueeze(0)
            protein_image = processed_images[1].unsqueeze(0)
            protein_image = protein_image > 0
            protein_image = 1.0 * protein_image
        else:
            protein_image = torch.ones((256, 256))

        threshold, heatmap = run_image_prediction(
            sequence_input=sequence_input,
            nucleus_image=nucleus_image,
            model=self.model,
            device=device,
        )

        # Plot the heatmap
        plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
        plt.axis("off")

        # Save the plot to a temporary file
        plt.savefig("temp.png", bbox_inches="tight", dpi=256)

        # Open the temporary file as a PIL image
        heatmap = Image.open("temp.png")

        return (
            T.ToPILImage()(nucleus_image[0, 0]),
            T.ToPILImage()(protein_image),
            T.ToPILImage()(threshold),
            heatmap,
        )

base_class = model()

with gr.Blocks(theme='gradio/soft') as demo:
    gr.Markdown("## Inputs")
    gr.Markdown("Select the prediction model. **Note the first run may take ~1-2 minutes, but will take 2-3 seconds afterwards.**")
    gr.Markdown(
        "```CELL-E_2_HPA_480``` is a good general purpose model for various cell types using ICC-IF."
    )
    gr.Markdown(
        "```CELL-E_2_HPA_Finetuned_480``` is finetuned on OpenCell and is good more live-cell predictions on HEK cells."
    )
    with gr.Row():
        model_name = gr.Dropdown(
            ["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"],
            value="CELL-E_2_HPA_480",
            label="Model Name",
        )
    with gr.Row():
        gr.Markdown(
            "Input the desired amino acid sequence. GFP is shown below by default."
        )

    with gr.Row():
        sequence_input = gr.Textbox(
            value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
            label="Sequence",
        )
    with gr.Row():
        gr.Markdown(
            "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger. We provide default images in [images](https://huggingface.co/spaces/HuangLab/CELL-E_2/tree/main/images)"
        )
        gr.Markdown("The protein image is optional and is just used for display.")

    with gr.Row().style(equal_height=True):
        nucleus_image = gr.Image(
            type="pil",
            label="Nucleus Image",
            image_mode="L",
        )

        protein_image = gr.Image(type="pil", label="Protein Image (Optional)")

    with gr.Row():
        gr.Markdown("## Outputs")

    with gr.Row():
        gr.Markdown("Image predictions are show below.")

    with gr.Row().style(equal_height=True):
        nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L")

        protein_threshold_image = gr.Image(
            type="pil", label="Protein Threshold Image", image_mode="L"
        )

        predicted_threshold_image = gr.Image(
            type="pil", label="Predicted Threshold image", image_mode="L"
        )

        predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap")
    with gr.Row():
        button = gr.Button("Run Model")

        inputs = [model_name, sequence_input, nucleus_image, protein_image]

        outputs = [
            nucleus_image_crop,
            protein_threshold_image,
            predicted_threshold_image,
            predicted_heatmap,
        ]

        button.click(base_class.gradio_demo, inputs, outputs)

demo.launch(enable_queue=True)