test-3
Browse files- Segformer_best_state_dict.ckpt +3 -0
- app.py +27 -25
- load_lightning_SD_to_Usual_SD.ipynb +7 -21
Segformer_best_state_dict.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:800bb5ba3fff6c5539542cc6d9548da73dbc1a35c0dc686f0bade3b3c6c5746c
|
3 |
+
size 256373829
|
app.py
CHANGED
@@ -18,7 +18,7 @@ class Configs:
|
|
18 |
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
19 |
MEAN: tuple = (0.485, 0.456, 0.406)
|
20 |
STD: tuple = (0.229, 0.224, 0.225)
|
21 |
-
MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
|
22 |
|
23 |
|
24 |
def get_model(*, model_path, num_classes):
|
@@ -58,6 +58,8 @@ if __name__ == "__main__":
|
|
58 |
CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
|
59 |
|
60 |
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
|
|
|
|
61 |
model.to(DEVICE)
|
62 |
model.eval()
|
63 |
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
@@ -70,29 +72,29 @@ if __name__ == "__main__":
|
|
70 |
]
|
71 |
)
|
72 |
|
73 |
-
images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
|
74 |
-
examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
|
75 |
-
demo = gr.Interface(
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
)
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
98 |
demo.launch()
|
|
|
18 |
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
19 |
MEAN: tuple = (0.485, 0.456, 0.406)
|
20 |
STD: tuple = (0.229, 0.224, 0.225)
|
21 |
+
MODEL_PATH: str = "nvidia/segformer-b4-finetuned-ade-512-512" # os.path.join(os.getcwd(), "segformer_trained_weights")
|
22 |
|
23 |
|
24 |
def get_model(*, model_path, num_classes):
|
|
|
58 |
CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
|
59 |
|
60 |
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
61 |
+
model.load_state_dict(torch.load(CKPT_PATH, map_location=DEVICE))
|
62 |
+
|
63 |
model.to(DEVICE)
|
64 |
model.eval()
|
65 |
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
|
|
72 |
]
|
73 |
)
|
74 |
|
75 |
+
# images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
|
76 |
+
# examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
|
77 |
+
# demo = gr.Interface(
|
78 |
+
# fn=partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE),
|
79 |
+
# inputs=gr.Image(type="pil", height=300, width=300, label="Input image"),
|
80 |
+
# outputs=gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor),
|
81 |
+
# examples=examples,
|
82 |
+
# cache_examples=False,
|
83 |
+
# allow_flagging="never",
|
84 |
+
# title="Medical Image Segmentation with UW-Madison GI Tract Dataset",
|
85 |
+
# )
|
86 |
+
|
87 |
+
with gr.Blocks(title="Medical Image Segmentation") as demo:
|
88 |
+
gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
|
89 |
+
with gr.Row():
|
90 |
+
img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
|
91 |
+
img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
|
92 |
+
|
93 |
+
section_btn = gr.Button("Generate Predictions")
|
94 |
+
section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
|
95 |
+
|
96 |
+
images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
|
97 |
+
examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
|
98 |
+
gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
|
99 |
|
100 |
demo.launch()
|
load_lightning_SD_to_Usual_SD.ipynb
CHANGED
@@ -109,20 +109,6 @@
|
|
109 |
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mveb-101\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
110 |
]
|
111 |
},
|
112 |
-
{
|
113 |
-
"data": {
|
114 |
-
"application/vnd.jupyter.widget-view+json": {
|
115 |
-
"model_id": "2e6699f8bae4469fb42d361bf569b161",
|
116 |
-
"version_major": 2,
|
117 |
-
"version_minor": 0
|
118 |
-
},
|
119 |
-
"text/plain": [
|
120 |
-
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016933333332902596, max=1.0…"
|
121 |
-
]
|
122 |
-
},
|
123 |
-
"metadata": {},
|
124 |
-
"output_type": "display_data"
|
125 |
-
},
|
126 |
{
|
127 |
"data": {
|
128 |
"text/html": [
|
@@ -138,7 +124,7 @@
|
|
138 |
{
|
139 |
"data": {
|
140 |
"text/html": [
|
141 |
-
"Run data is saved locally in <code>c:\\Users\\vaibh\\OneDrive\\Desktop\\Work\\BigVision\\BLOG_POSTS\\Medical_segmentation\\GRADIO_APP\\UWMGI_Medical_Image_Segmentation\\wandb\\run-
|
142 |
],
|
143 |
"text/plain": [
|
144 |
"<IPython.core.display.HTML object>"
|
@@ -150,7 +136,7 @@
|
|
150 |
{
|
151 |
"data": {
|
152 |
"text/html": [
|
153 |
-
"Syncing run <strong><a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/
|
154 |
],
|
155 |
"text/plain": [
|
156 |
"<IPython.core.display.HTML object>"
|
@@ -174,7 +160,7 @@
|
|
174 |
{
|
175 |
"data": {
|
176 |
"text/html": [
|
177 |
-
" View run at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/
|
178 |
],
|
179 |
"text/plain": [
|
180 |
"<IPython.core.display.HTML object>"
|
@@ -189,7 +175,7 @@
|
|
189 |
"text": [
|
190 |
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-jsr2fn8v:v0, 977.89MB. 1 files... \n",
|
191 |
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
|
192 |
-
"Done. 0:0:
|
193 |
]
|
194 |
}
|
195 |
],
|
@@ -294,12 +280,12 @@
|
|
294 |
"metadata": {},
|
295 |
"outputs": [],
|
296 |
"source": [
|
297 |
-
"
|
298 |
]
|
299 |
},
|
300 |
{
|
301 |
"cell_type": "code",
|
302 |
-
"execution_count":
|
303 |
"metadata": {},
|
304 |
"outputs": [],
|
305 |
"source": [
|
@@ -308,7 +294,7 @@
|
|
308 |
},
|
309 |
{
|
310 |
"cell_type": "code",
|
311 |
-
"execution_count":
|
312 |
"metadata": {},
|
313 |
"outputs": [],
|
314 |
"source": [
|
|
|
109 |
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mveb-101\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
110 |
]
|
111 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
{
|
113 |
"data": {
|
114 |
"text/html": [
|
|
|
124 |
{
|
125 |
"data": {
|
126 |
"text/html": [
|
127 |
+
"Run data is saved locally in <code>c:\\Users\\vaibh\\OneDrive\\Desktop\\Work\\BigVision\\BLOG_POSTS\\Medical_segmentation\\GRADIO_APP\\UWMGI_Medical_Image_Segmentation\\wandb\\run-20230719_111329-pfqbikbe</code>"
|
128 |
],
|
129 |
"text/plain": [
|
130 |
"<IPython.core.display.HTML object>"
|
|
|
136 |
{
|
137 |
"data": {
|
138 |
"text/html": [
|
139 |
+
"Syncing run <strong><a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/pfqbikbe' target=\"_blank\">generous-music-2</a></strong> to <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
140 |
],
|
141 |
"text/plain": [
|
142 |
"<IPython.core.display.HTML object>"
|
|
|
160 |
{
|
161 |
"data": {
|
162 |
"text/html": [
|
163 |
+
" View run at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/pfqbikbe' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/pfqbikbe</a>"
|
164 |
],
|
165 |
"text/plain": [
|
166 |
"<IPython.core.display.HTML object>"
|
|
|
175 |
"text": [
|
176 |
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-jsr2fn8v:v0, 977.89MB. 1 files... \n",
|
177 |
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
|
178 |
+
"Done. 0:0:3.9\n"
|
179 |
]
|
180 |
}
|
181 |
],
|
|
|
280 |
"metadata": {},
|
281 |
"outputs": [],
|
282 |
"source": [
|
283 |
+
"torch.save(model.state_dict(), \"Segformer_best_state_dict.ckpt\")"
|
284 |
]
|
285 |
},
|
286 |
{
|
287 |
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
"metadata": {},
|
290 |
"outputs": [],
|
291 |
"source": [
|
|
|
294 |
},
|
295 |
{
|
296 |
"cell_type": "code",
|
297 |
+
"execution_count": null,
|
298 |
"metadata": {},
|
299 |
"outputs": [],
|
300 |
"source": [
|