EmaadKhwaja commited on
Commit
cbeaab6
1 Parent(s): 3c77a96

formatting

Browse files
Files changed (1) hide show
  1. app.py +100 -77
app.py CHANGED
@@ -14,108 +14,131 @@ def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
14
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
15
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
- if 'Finetuned' in model_name:
19
- dataset = 'OpenCell'
20
 
21
  else:
22
- dataset = 'HPA'
23
-
24
- nucleus_image = process_image(nucleus_image,dataset,'nucleus')
25
  if protein_image:
26
- protein_image = process_image(protein_image,dataset,'protein')
27
  protein_image = protein_image > torch.median(protein_image)
28
- protein_image = protein_image[0,0]
29
- protein_image = protein_image*1.0
30
  else:
31
- protein_image = torch.ones((256,256))
32
-
33
-
34
- threshold, heatmap = run_image_prediction(sequence_input = sequence_input,
35
- nucleus_image = nucleus_image,
36
- model_ckpt_path=model,
37
- model_config_path=config,
38
- device=device)
39
-
 
40
  # Plot the heatmap
41
- plt.imshow(heatmap.cpu(), cmap='rainbow', interpolation = 'bicubic')
42
- plt.axis('off')
43
 
44
  # Save the plot to a temporary file
45
- plt.savefig('temp.png', bbox_inches='tight', dpi = 256)
46
 
47
  # Open the temporary file as a PIL image
48
- heatmap = Image.open('temp.png')
49
-
50
- return T.ToPILImage()(nucleus_image[0,0]), T.ToPILImage()(protein_image), T.ToPILImage()(threshold), heatmap
 
 
 
 
 
51
 
52
 
53
  with gr.Blocks() as demo:
54
  gr.Markdown("Select the prediction model.")
55
- gr.Markdown("CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF.")
56
- gr.Markdown("CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells.")
 
 
 
 
 
 
 
 
 
 
57
  with gr.Row():
58
- model_name = gr.Dropdown(['CELL-E_2_HPA_480','CELL-E_2_HPA_Finetuned_480'],
59
- value='CELL-E_2_HPA_480', label = 'Model Name')
 
 
60
  with gr.Row():
61
- gr.Markdown("Input the desired amino acid sequence. GFP is shown below by default.")
62
-
63
- with gr.Row():
64
- sequence_input = gr.Textbox(value='MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
65
- label = 'Sequence')
66
  with gr.Row():
67
- gr.Markdown("Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger.")
68
- gr.Markdown("The protein image is optional and is just used for display.")
69
-
 
 
70
  with gr.Row().style(equal_height=True):
71
- nucleus_image = gr.Image(value = 'https://huggingface.co/spaces/HuangLab/CELL-E_2/resolve/main/images/Armadillo%20repeat-containing%20X-linked%20protein%205%20nucleus.jpg',
72
- type='pil',
73
- label = 'Nucleus Image',
74
- image_mode='L')
75
-
76
- protein_image = gr.Image(type='pil', label = 'Protein Image (Optional)')
77
-
 
 
78
  with gr.Row():
79
- gr.Markdown("Image predictions are show below.")
80
-
81
- with gr.Row().style(equal_height=True):
82
- nucleus_image_crop = gr.Image(type='pil',
83
- label = 'Nucleus Image',
84
- image_mode='L')
85
-
86
- protein_threshold_image = gr.Image(type='pil',
87
- label = 'Protein Threshold Image',
88
- image_mode='L')
89
-
90
- predicted_threshold_image = gr.Image(type='pil',
91
- label = 'Predicted Threshold image',
92
- image_mode='L')
93
-
94
- predicted_heatmap = gr.Image(type='pil',
95
- label = 'Predicted Heatmap')
96
  with gr.Row():
97
  button = gr.Button("Run Model")
98
-
99
- inputs = [model_name,
100
- sequence_input,
101
- nucleus_image,
102
- protein_image]
103
 
104
- outputs = [nucleus_image_crop,
105
- protein_threshold_image,
106
- predicted_threshold_image,
107
- predicted_heatmap]
 
 
 
 
108
 
109
  button.click(gradio_demo, inputs, outputs)
110
 
111
- examples = [['CELL-E_2_HPA_Finetuned_480',
112
- 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
113
- 'images/Proteasome activator complex subunit 3 nucleus.png',
114
- 'images/Proteasome activator complex subunit 3 protein.png'],
115
- ['CELL-E_2_HPA_480',
116
- 'MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK',
117
- 'images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg',
118
- 'images/Armadillo repeat-containing X-linked protein 5 protein.jpg']]
 
 
 
 
 
 
119
 
120
  # demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout)
121
  demo.launch(share=True)
 
14
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="nucleus_vqgan.yaml")
15
  hf_hub_download(repo_id=f"HuangLab/{model_name}", filename="threshold_vqgan.yaml")
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ if "Finetuned" in model_name:
19
+ dataset = "OpenCell"
20
 
21
  else:
22
+ dataset = "HPA"
23
+
24
+ nucleus_image = process_image(nucleus_image, dataset, "nucleus")
25
  if protein_image:
26
+ protein_image = process_image(protein_image, dataset, "protein")
27
  protein_image = protein_image > torch.median(protein_image)
28
+ protein_image = protein_image[0, 0]
29
+ protein_image = protein_image * 1.0
30
  else:
31
+ protein_image = torch.ones((256, 256))
32
+
33
+ threshold, heatmap = run_image_prediction(
34
+ sequence_input=sequence_input,
35
+ nucleus_image=nucleus_image,
36
+ model_ckpt_path=model,
37
+ model_config_path=config,
38
+ device=device,
39
+ )
40
+
41
  # Plot the heatmap
42
+ plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
43
+ plt.axis("off")
44
 
45
  # Save the plot to a temporary file
46
+ plt.savefig("temp.png", bbox_inches="tight", dpi=256)
47
 
48
  # Open the temporary file as a PIL image
49
+ heatmap = Image.open("temp.png")
50
+
51
+ return (
52
+ T.ToPILImage()(nucleus_image[0, 0]),
53
+ T.ToPILImage()(protein_image),
54
+ T.ToPILImage()(threshold),
55
+ heatmap,
56
+ )
57
 
58
 
59
  with gr.Blocks() as demo:
60
  gr.Markdown("Select the prediction model.")
61
+ gr.Markdown(
62
+ "CELL-E_2_HPA_480 is a good general purpose model for various cell types using ICC-IF."
63
+ )
64
+ gr.Markdown(
65
+ "CELL-E_2_HPA_Finetuned_480 is finetuned on OpenCell and is good more live-cell predictions on HEK cells."
66
+ )
67
+ with gr.Row():
68
+ model_name = gr.Dropdown(
69
+ ["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"],
70
+ value="CELL-E_2_HPA_480",
71
+ label="Model Name",
72
+ )
73
  with gr.Row():
74
+ gr.Markdown(
75
+ "Input the desired amino acid sequence. GFP is shown below by default."
76
+ )
77
+
78
  with gr.Row():
79
+ sequence_input = gr.Textbox(
80
+ value="MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
81
+ label="Sequence",
82
+ )
 
83
  with gr.Row():
84
+ gr.Markdown(
85
+ "Uploading a nucleus image is necessary. A random crop of 256 x 256 will be applied if larger."
86
+ )
87
+ gr.Markdown("The protein image is optional and is just used for display.")
88
+
89
  with gr.Row().style(equal_height=True):
90
+ nucleus_image = gr.Image(
91
+ value="https://huggingface.co/spaces/HuangLab/CELL-E_2/resolve/main/images/Armadillo%20repeat-containing%20X-linked%20protein%205%20nucleus.jpg",
92
+ type="pil",
93
+ label="Nucleus Image",
94
+ image_mode="L",
95
+ )
96
+
97
+ protein_image = gr.Image(type="pil", label="Protein Image (Optional)")
98
+
99
  with gr.Row():
100
+ gr.Markdown("Image predictions are show below.")
101
+
102
+ with gr.Row().style(equal_height=True):
103
+ nucleus_image_crop = gr.Image(type="pil", label="Nucleus Image", image_mode="L")
104
+
105
+ protein_threshold_image = gr.Image(
106
+ type="pil", label="Protein Threshold Image", image_mode="L"
107
+ )
108
+
109
+ predicted_threshold_image = gr.Image(
110
+ type="pil", label="Predicted Threshold image", image_mode="L"
111
+ )
112
+
113
+ predicted_heatmap = gr.Image(type="pil", label="Predicted Heatmap")
 
 
 
114
  with gr.Row():
115
  button = gr.Button("Run Model")
 
 
 
 
 
116
 
117
+ inputs = [model_name, sequence_input, nucleus_image, protein_image]
118
+
119
+ outputs = [
120
+ nucleus_image_crop,
121
+ protein_threshold_image,
122
+ predicted_threshold_image,
123
+ predicted_heatmap,
124
+ ]
125
 
126
  button.click(gradio_demo, inputs, outputs)
127
 
128
+ examples = [
129
+ [
130
+ "CELL-E_2_HPA_Finetuned_480",
131
+ "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
132
+ "images/Proteasome activator complex subunit 3 nucleus.png",
133
+ "images/Proteasome activator complex subunit 3 protein.png",
134
+ ],
135
+ [
136
+ "CELL-E_2_HPA_480",
137
+ "MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK",
138
+ "images/Armadillo repeat-containing X-linked protein 5 nucleus.jpg",
139
+ "images/Armadillo repeat-containing X-linked protein 5 protein.jpg",
140
+ ],
141
+ ]
142
 
143
  # demo = gr.Interface(gradio_demo, inputs, outputs, examples, cache_examples=True, layout = layout)
144
  demo.launch(share=True)