Emaad commited on
Commit
35aeee1
1 Parent(s): dacc4bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -46
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from prediction import run_image_prediction
3
  import torch
@@ -5,54 +6,84 @@ import torchvision.transforms as T
5
  from celle.utils import process_image
6
  from PIL import Image
7
  from matplotlib import pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
9
 
10
- def gradio_demo(model_name, sequence_input, nucleus_image, protein_image):
11
- model = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
12
- config = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
13
-
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
- if "Finetuned" in model_name:
17
- dataset = "OpenCell"
18
-
19
- else:
20
- dataset = "HPA"
21
-
22
- nucleus_image = process_image(nucleus_image, dataset, "nucleus")
23
- if protein_image:
24
- protein_image = process_image(protein_image, dataset, "protein")
25
- protein_image = protein_image > torch.median(protein_image)
26
- protein_image = protein_image[0, 0]
27
- protein_image = protein_image * 1.0
28
- else:
29
- protein_image = torch.ones((256, 256))
30
-
31
- threshold, heatmap = run_image_prediction(
32
- sequence_input=sequence_input,
33
- nucleus_image=nucleus_image,
34
- model_ckpt_path=model,
35
- model_config_path=config,
36
- device=device,
37
- )
38
-
39
- # Plot the heatmap
40
- plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
41
- plt.axis("off")
42
-
43
- # Save the plot to a temporary file
44
- plt.savefig("temp.png", bbox_inches="tight", dpi=256)
45
 
46
- # Open the temporary file as a PIL image
47
- heatmap = Image.open("temp.png")
48
 
49
- return (
50
- T.ToPILImage()(nucleus_image[0, 0]),
51
- T.ToPILImage()(protein_image),
52
- T.ToPILImage()(threshold),
53
- heatmap,
54
- )
55
 
 
56
 
57
  with gr.Blocks(theme='gradio/soft') as demo:
58
  gr.Markdown("Select the prediction model.")
@@ -64,8 +95,8 @@ with gr.Blocks(theme='gradio/soft') as demo:
64
  )
65
  with gr.Row():
66
  model_name = gr.Dropdown(
67
- ["CELL-E_2-HPA_480", "CELL-E_2-HPA_Finetuned_480"],
68
- value="CELL-E_2-HPA_480",
69
  label="Model Name",
70
  )
71
  with gr.Row():
@@ -120,6 +151,6 @@ with gr.Blocks(theme='gradio/soft') as demo:
120
  predicted_heatmap,
121
  ]
122
 
123
- button.click(gradio_demo, inputs, outputs)
124
 
125
  demo.launch(enable_queue=True)
 
1
+ import os
2
  import gradio as gr
3
  from prediction import run_image_prediction
4
  import torch
 
6
  from celle.utils import process_image
7
  from PIL import Image
8
  from matplotlib import pyplot as plt
9
+ from celle_main import instantiate_from_config
10
+ from omegaconf import OmegaConf
11
+
12
+
13
+ class model:
14
+ def __init__(self):
15
+ self.model = None
16
+ self.model_name = None
17
+
18
+ def gradio_demo(self, model_name, sequence_input, nucleus_image, protein_image):
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ if self.model_name != model_name:
22
+ self.model_name = model_name
23
+ model_ckpt_path = f"CELL-E_2-Image_Prediction/models/{model_name}.ckpt"
24
+ model_config_path = f"CELL-E_2-Image_Prediction/models/{model_name}.yaml"
25
+
26
+
27
+ # Load model config and set ckpt_path if not provided in config
28
+ config = OmegaConf.load(model_config_path)
29
+ if config["model"]["params"]["ckpt_path"] is None:
30
+ config["model"]["params"]["ckpt_path"] = model_ckpt_path
31
+
32
+ # Set condition_model_path and vqgan_model_path to None
33
+ config["model"]["params"]["condition_model_path"] = None
34
+ config["model"]["params"]["vqgan_model_path"] = None
35
+
36
+ base_path = os.getcwd()
37
+
38
+ os.chdir(os.path.dirname(model_ckpt_path))
39
+
40
+ # Instantiate model from config and move to device
41
+ self.model = instantiate_from_config(config.model).to(device)
42
+ self.model = torch.compile(self.model,mode='reduce-overhead')
43
+
44
+ os.chdir(base_path)
45
+
46
+
47
+ if "Finetuned" in model_name:
48
+ dataset = "OpenCell"
49
+
50
+ else:
51
+ dataset = "HPA"
52
+
53
+ nucleus_image = process_image(nucleus_image, dataset, "nucleus")
54
+ if protein_image:
55
+ protein_image = process_image(protein_image, dataset, "protein")
56
+ protein_image = protein_image > torch.median(protein_image)
57
+ protein_image = protein_image[0, 0]
58
+ protein_image = protein_image * 1.0
59
+ else:
60
+ protein_image = torch.ones((256, 256))
61
+
62
+ threshold, heatmap = run_image_prediction(
63
+ sequence_input=sequence_input,
64
+ nucleus_image=nucleus_image,
65
+ model=self.model,
66
+ device=device,
67
+ )
68
 
69
+ # Plot the heatmap
70
+ plt.imshow(heatmap.cpu(), cmap="rainbow", interpolation="bicubic")
71
+ plt.axis("off")
72
 
73
+ # Save the plot to a temporary file
74
+ plt.savefig("temp.png", bbox_inches="tight", dpi=256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # Open the temporary file as a PIL image
77
+ heatmap = Image.open("temp.png")
78
 
79
+ return (
80
+ T.ToPILImage()(nucleus_image[0, 0]),
81
+ T.ToPILImage()(protein_image),
82
+ T.ToPILImage()(threshold),
83
+ heatmap,
84
+ )
85
 
86
+ base_class = model()
87
 
88
  with gr.Blocks(theme='gradio/soft') as demo:
89
  gr.Markdown("Select the prediction model.")
 
95
  )
96
  with gr.Row():
97
  model_name = gr.Dropdown(
98
+ ["CELL-E_2_HPA_480", "CELL-E_2_HPA_Finetuned_480"],
99
+ value="CELL-E_2_HPA_480",
100
  label="Model Name",
101
  )
102
  with gr.Row():
 
151
  predicted_heatmap,
152
  ]
153
 
154
+ button.click(base_class.gradio_demo, inputs, outputs)
155
 
156
  demo.launch(enable_queue=True)