veb-101 commited on
Commit
c438991
·
1 Parent(s): 30534a5

First commit - All files added.

Browse files
Files changed (47) hide show
  1. .gitignore +2 -0
  2. README.md +21 -7
  3. app.py +89 -0
  4. load_lightning_SD_to_Usual_SD.ipynb +341 -0
  5. requirements.txt +4 -0
  6. samples/case101_day26_slice_0096_266_266_1.50_1.50.png +0 -0
  7. samples/case107_day0_slice_0089_266_266_1.50_1.50.png +0 -0
  8. samples/case107_day21_slice_0069_266_266_1.50_1.50.png +0 -0
  9. samples/case113_day12_slice_0108_360_310_1.50_1.50.png +0 -0
  10. samples/case119_day20_slice_0063_266_266_1.50_1.50.png +0 -0
  11. samples/case119_day25_slice_0075_266_266_1.50_1.50.png +0 -0
  12. samples/case119_day25_slice_0095_266_266_1.50_1.50.png +0 -0
  13. samples/case121_day14_slice_0057_266_266_1.50_1.50.png +0 -0
  14. samples/case122_day25_slice_0087_266_266_1.50_1.50.png +0 -0
  15. samples/case124_day19_slice_0110_266_266_1.50_1.50.png +0 -0
  16. samples/case124_day20_slice_0110_266_266_1.50_1.50.png +0 -0
  17. samples/case130_day0_slice_0106_266_266_1.50_1.50.png +0 -0
  18. samples/case134_day21_slice_0085_360_310_1.50_1.50.png +0 -0
  19. samples/case139_day0_slice_0062_234_234_1.50_1.50.png +0 -0
  20. samples/case139_day18_slice_0094_266_266_1.50_1.50.png +0 -0
  21. samples/case146_day25_slice_0053_276_276_1.63_1.63.png +0 -0
  22. samples/case147_day0_slice_0085_360_310_1.50_1.50.png +0 -0
  23. samples/case148_day0_slice_0113_360_310_1.50_1.50.png +0 -0
  24. samples/case149_day15_slice_0057_266_266_1.50_1.50.png +0 -0
  25. samples/case29_day0_slice_0065_266_266_1.50_1.50.png +0 -0
  26. samples/case2_day1_slice_0054_266_266_1.50_1.50.png +0 -0
  27. samples/case2_day1_slice_0077_266_266_1.50_1.50.png +0 -0
  28. samples/case32_day19_slice_0091_266_266_1.50_1.50.png +0 -0
  29. samples/case32_day19_slice_0100_266_266_1.50_1.50.png +0 -0
  30. samples/case33_day21_slice_0114_266_266_1.50_1.50.png +0 -0
  31. samples/case36_day16_slice_0064_266_266_1.50_1.50.png +0 -0
  32. samples/case40_day0_slice_0094_266_266_1.50_1.50.png +0 -0
  33. samples/case41_day25_slice_0049_266_266_1.50_1.50.png +0 -0
  34. samples/case63_day22_slice_0076_266_266_1.50_1.50.png +0 -0
  35. samples/case63_day26_slice_0093_266_266_1.50_1.50.png +0 -0
  36. samples/case65_day28_slice_0133_266_266_1.50_1.50.png +0 -0
  37. samples/case66_day36_slice_0101_266_266_1.50_1.50.png +0 -0
  38. samples/case67_day0_slice_0049_266_266_1.50_1.50.png +0 -0
  39. samples/case67_day0_slice_0086_266_266_1.50_1.50.png +0 -0
  40. samples/case74_day18_slice_0101_266_266_1.50_1.50.png +0 -0
  41. samples/case74_day19_slice_0084_266_266_1.50_1.50.png +0 -0
  42. samples/case81_day28_slice_0066_266_266_1.50_1.50.png +0 -0
  43. samples/case85_day29_slice_0102_360_310_1.50_1.50.png +0 -0
  44. samples/case89_day19_slice_0082_360_310_1.50_1.50.png +0 -0
  45. samples/case89_day20_slice_0087_266_266_1.50_1.50.png +0 -0
  46. segformer_trained_weights/config.json +82 -0
  47. segformer_trained_weights/pytorch_model.bin +3 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ artifacts
2
+ wandb
README.md CHANGED
@@ -1,13 +1,27 @@
1
  ---
2
- title: UWMGI Medical Image Segmentation
3
- emoji: 📈
4
- colorFrom: pink
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.37.0
8
  app_file: app.py
9
  pinned: false
10
- license: gpl-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Medical Image Segmentation Gradio App
3
+ emoji: 🌖
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+
13
+ # Medical Image Segmentation Gradio App
14
+
15
+ For the Gradio app we've removed the dependency on pytorch-lightning otherwise used in the project.
16
+ The `load_lightning_SD_to_Usual_SD.ipynb` notebook contains the steps used to convert pytorch-lightning checkpoint to a regular model checkpoint. This was mainly done to reduce the file size (977 MB --> 244 MB).
17
+
18
+ You can download the original saved checkpoint from over here: [wandb artifact](https://wandb.ai/veb-101/UM_medical_segmentation/artifacts/model/model-jsr2fn8v/v0/files)
19
+
20
+ Or via Python:
21
+
22
+ ```python
23
+ import wandb
24
+ run = wandb.init()
25
+ artifact = run.use_artifact('veb-101/UM_medical_segmentation/model-jsr2fn8v:v0', type='model')
26
+ artifact_dir = artifact.download()
27
+ ```
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ from glob import glob
5
+ from functools import partial
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as TF
11
+ from transformers import SegformerForSemanticSegmentation
12
+
13
+
14
+ @dataclass
15
+ class Configs:
16
+ NUM_CLASSES: int = 4 # including background.
17
+ CLASSES: tuple = ("Large bowel", "Small bowel", "Stomach")
18
+ IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
19
+ MEAN: tuple = (0.485, 0.456, 0.406)
20
+ STD: tuple = (0.229, 0.224, 0.225)
21
+ MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
22
+
23
+
24
+ def get_model(*, model_path, num_classes):
25
+ model = SegformerForSemanticSegmentation.from_pretrained(model_path, num_labels=num_classes, ignore_mismatched_sizes=True)
26
+ return model
27
+
28
+
29
+ @torch.inference_mode()
30
+ def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
31
+ shape_H_W = input_image.size
32
+ input_tensor = preprocess_fn(input_image)
33
+ input_tensor = input_tensor.unsqueeze(0).to(device)
34
+
35
+ # Generate predictions
36
+ outputs = model(pixel_values=input_tensor.to(device), return_dict=True)
37
+ predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False)
38
+
39
+ preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy()
40
+
41
+ seg_info = [(preds_argmax == idx, class_name) for idx, class_name in enumerate(Configs.CLASSES, 1)]
42
+
43
+ return (input_image, seg_info)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ # Create a mapping of class ID to RGB value.
48
+ id2color = {
49
+ 0: (0, 0, 0), # background pixel
50
+ 1: (0, 0, 255), # Stomach
51
+ 2: (0, 255, 0), # Small bowel
52
+ 3: (255, 0, 0), # large bowel
53
+ }
54
+
55
+ class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
56
+
57
+ DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
58
+ CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
59
+
60
+ model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
61
+ model.to(DEVICE)
62
+ model.eval()
63
+ _ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
64
+
65
+ preprocess = TF.Compose(
66
+ [
67
+ TF.Resize(size=Configs.IMAGE_SIZE[::-1]),
68
+ TF.ToTensor(),
69
+ TF.Normalize(Configs.MEAN, Configs.STD, inplace=True),
70
+ ]
71
+ )
72
+
73
+ with gr.Blocks(title="Medical Image Segmentation") as demo:
74
+ gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
75
+
76
+ with gr.Row():
77
+ img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
78
+ img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
79
+
80
+ section_btn = gr.Button("Generate Predictions")
81
+
82
+ section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
83
+
84
+ images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
85
+ examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
86
+
87
+ gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
88
+
89
+ demo.launch()
load_lightning_SD_to_Usual_SD.ipynb ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Base Configurations"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "from transformers import SegformerForSemanticSegmentation\n",
19
+ "from dataclasses import dataclass\n",
20
+ "\n",
21
+ "\n",
22
+ "@dataclass\n",
23
+ "class Configs:\n",
24
+ " NUM_CLASSES = 4\n",
25
+ " MODEL_PATH: str = \"nvidia/segformer-b4-finetuned-ade-512-512\""
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "## Load Model To Inspect Parameter Names"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "\n",
42
+ "\n",
43
+ "def get_model(*, model_path, num_classes):\n",
44
+ " model = SegformerForSemanticSegmentation.from_pretrained(\n",
45
+ " model_path,\n",
46
+ " num_labels=num_classes,\n",
47
+ " ignore_mismatched_sizes=True,\n",
48
+ " )\n",
49
+ " return model"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 3,
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stderr",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:\n",
62
+ "- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated\n",
63
+ "- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated\n",
64
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
65
+ ]
66
+ },
67
+ {
68
+ "name": "stdout",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "\n",
72
+ "segformer.encoder.patch_embeddings.0.proj.weight\n",
73
+ "segformer.encoder.patch_embeddings.0.proj.bias\n",
74
+ "segformer.encoder.patch_embeddings.0.layer_norm.weight\n",
75
+ "segformer.encoder.patch_embeddings.0.layer_norm.bias\n",
76
+ "segformer.encoder.patch_embeddings.1.proj.weight\n",
77
+ "segformer.encoder.patch_embeddings.1.proj.bias\n"
78
+ ]
79
+ }
80
+ ],
81
+ "source": [
82
+ "model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)\n",
83
+ "model_state_dict = model.state_dict()\n",
84
+ "\n",
85
+ "print()\n",
86
+ "for i, (key, val) in enumerate(model_state_dict.items()):\n",
87
+ " print(key)\n",
88
+ " if i == 5:\n",
89
+ " break"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "metadata": {},
95
+ "source": [
96
+ "## Download & load PyTorch-Lightning Checkpoint and Inspect Parameter Names"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "metadata": {},
103
+ "outputs": [
104
+ {
105
+ "name": "stderr",
106
+ "output_type": "stream",
107
+ "text": [
108
+ "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
109
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mveb-101\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
110
+ ]
111
+ },
112
+ {
113
+ "data": {
114
+ "application/vnd.jupyter.widget-view+json": {
115
+ "model_id": "2e6699f8bae4469fb42d361bf569b161",
116
+ "version_major": 2,
117
+ "version_minor": 0
118
+ },
119
+ "text/plain": [
120
+ "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016933333332902596, max=1.0…"
121
+ ]
122
+ },
123
+ "metadata": {},
124
+ "output_type": "display_data"
125
+ },
126
+ {
127
+ "data": {
128
+ "text/html": [
129
+ "Tracking run with wandb version 0.15.5"
130
+ ],
131
+ "text/plain": [
132
+ "<IPython.core.display.HTML object>"
133
+ ]
134
+ },
135
+ "metadata": {},
136
+ "output_type": "display_data"
137
+ },
138
+ {
139
+ "data": {
140
+ "text/html": [
141
+ "Run data is saved locally in <code>c:\\Users\\vaibh\\OneDrive\\Desktop\\Work\\BigVision\\BLOG_POSTS\\Medical_segmentation\\GRADIO_APP\\UWMGI_Medical_Image_Segmentation\\wandb\\run-20230719_044820-hnv9dwr2</code>"
142
+ ],
143
+ "text/plain": [
144
+ "<IPython.core.display.HTML object>"
145
+ ]
146
+ },
147
+ "metadata": {},
148
+ "output_type": "display_data"
149
+ },
150
+ {
151
+ "data": {
152
+ "text/html": [
153
+ "Syncing run <strong><a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/hnv9dwr2' target=\"_blank\">fanciful-jazz-1</a></strong> to <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
154
+ ],
155
+ "text/plain": [
156
+ "<IPython.core.display.HTML object>"
157
+ ]
158
+ },
159
+ "metadata": {},
160
+ "output_type": "display_data"
161
+ },
162
+ {
163
+ "data": {
164
+ "text/html": [
165
+ " View project at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation</a>"
166
+ ],
167
+ "text/plain": [
168
+ "<IPython.core.display.HTML object>"
169
+ ]
170
+ },
171
+ "metadata": {},
172
+ "output_type": "display_data"
173
+ },
174
+ {
175
+ "data": {
176
+ "text/html": [
177
+ " View run at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/hnv9dwr2' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/hnv9dwr2</a>"
178
+ ],
179
+ "text/plain": [
180
+ "<IPython.core.display.HTML object>"
181
+ ]
182
+ },
183
+ "metadata": {},
184
+ "output_type": "display_data"
185
+ },
186
+ {
187
+ "name": "stderr",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-jsr2fn8v:v0, 977.89MB. 1 files... \n",
191
+ "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
192
+ "Done. 0:0:1.4\n"
193
+ ]
194
+ }
195
+ ],
196
+ "source": [
197
+ "import wandb\n",
198
+ "run = wandb.init()\n",
199
+ "artifact = run.use_artifact(r'veb-101/UM_medical_segmentation/model-jsr2fn8v:v0', type='model')\n",
200
+ "artifact_dir = artifact.download()"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 5,
206
+ "metadata": {},
207
+ "outputs": [
208
+ {
209
+ "name": "stdout",
210
+ "output_type": "stream",
211
+ "text": [
212
+ "dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecisionPlugin', 'hparams_name', 'hyper_parameters'])\n"
213
+ ]
214
+ }
215
+ ],
216
+ "source": [
217
+ "CKPT = torch.load(os.path.join(artifact_dir, \"model.ckpt\"), map_location=\"cpu\")\n",
218
+ "print(CKPT.keys())"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 6,
224
+ "metadata": {},
225
+ "outputs": [
226
+ {
227
+ "name": "stdout",
228
+ "output_type": "stream",
229
+ "text": [
230
+ "model.segformer.encoder.patch_embeddings.0.proj.weight\n",
231
+ "model.segformer.encoder.patch_embeddings.0.proj.bias\n",
232
+ "model.segformer.encoder.patch_embeddings.0.layer_norm.weight\n",
233
+ "model.segformer.encoder.patch_embeddings.0.layer_norm.bias\n",
234
+ "model.segformer.encoder.patch_embeddings.1.proj.weight\n",
235
+ "model.segformer.encoder.patch_embeddings.1.proj.bias\n"
236
+ ]
237
+ }
238
+ ],
239
+ "source": [
240
+ "TRAINED_CKPT_state_dict = CKPT[\"state_dict\"]\n",
241
+ "\n",
242
+ "for i, (key, val) in enumerate(TRAINED_CKPT_state_dict.items()):\n",
243
+ " print(key)\n",
244
+ " if i == 5:\n",
245
+ " break"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "markdown",
250
+ "metadata": {},
251
+ "source": [
252
+ "**The pytorch-lightning `state_dict()` has an extra `model.` string at the front that refers to the object/variable name that was holding the model in the `LightningModule` class.**\n",
253
+ "\n",
254
+ "We can simply iterate over the parameters and change the parameter key name. We'll create a new `OrderedDict` for it."
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 7,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "from collections import OrderedDict\n",
264
+ "\n",
265
+ "new_state_dict = OrderedDict()\n",
266
+ "\n",
267
+ "for key_name, value in CKPT[\"state_dict\"].items():\n",
268
+ " new_state_dict[key_name.replace(\"model.\", \"\")] = value"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 8,
274
+ "metadata": {},
275
+ "outputs": [
276
+ {
277
+ "data": {
278
+ "text/plain": [
279
+ "<All keys matched successfully>"
280
+ ]
281
+ },
282
+ "execution_count": 8,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ }
286
+ ],
287
+ "source": [
288
+ "model.load_state_dict(new_state_dict)"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 9,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "# torch.save(model.state_dict(), \"Segformer_best_state_dict.ckpt\")"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": 10,
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": [
306
+ "model.save_pretrained(\"segformer_trained_weights\")"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": 11,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "model = get_model(model_path=os.path.join(os.getcwd(), \"segformer_trained_weights\"), num_classes=Configs.NUM_CLASSES)"
316
+ ]
317
+ }
318
+ ],
319
+ "metadata": {
320
+ "kernelspec": {
321
+ "display_name": "pytorchx",
322
+ "language": "python",
323
+ "name": "python3"
324
+ },
325
+ "language_info": {
326
+ "codemirror_mode": {
327
+ "name": "ipython",
328
+ "version": 3
329
+ },
330
+ "file_extension": ".py",
331
+ "mimetype": "text/x-python",
332
+ "name": "python",
333
+ "nbconvert_exporter": "python",
334
+ "pygments_lexer": "ipython3",
335
+ "version": "3.10.12"
336
+ },
337
+ "orig_nbformat": 4
338
+ },
339
+ "nbformat": 4,
340
+ "nbformat_minor": 2
341
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch==2.0.0+cpu
3
+ torchvision==0.15.0
4
+ transformers==4.30.2
samples/case101_day26_slice_0096_266_266_1.50_1.50.png ADDED
samples/case107_day0_slice_0089_266_266_1.50_1.50.png ADDED
samples/case107_day21_slice_0069_266_266_1.50_1.50.png ADDED
samples/case113_day12_slice_0108_360_310_1.50_1.50.png ADDED
samples/case119_day20_slice_0063_266_266_1.50_1.50.png ADDED
samples/case119_day25_slice_0075_266_266_1.50_1.50.png ADDED
samples/case119_day25_slice_0095_266_266_1.50_1.50.png ADDED
samples/case121_day14_slice_0057_266_266_1.50_1.50.png ADDED
samples/case122_day25_slice_0087_266_266_1.50_1.50.png ADDED
samples/case124_day19_slice_0110_266_266_1.50_1.50.png ADDED
samples/case124_day20_slice_0110_266_266_1.50_1.50.png ADDED
samples/case130_day0_slice_0106_266_266_1.50_1.50.png ADDED
samples/case134_day21_slice_0085_360_310_1.50_1.50.png ADDED
samples/case139_day0_slice_0062_234_234_1.50_1.50.png ADDED
samples/case139_day18_slice_0094_266_266_1.50_1.50.png ADDED
samples/case146_day25_slice_0053_276_276_1.63_1.63.png ADDED
samples/case147_day0_slice_0085_360_310_1.50_1.50.png ADDED
samples/case148_day0_slice_0113_360_310_1.50_1.50.png ADDED
samples/case149_day15_slice_0057_266_266_1.50_1.50.png ADDED
samples/case29_day0_slice_0065_266_266_1.50_1.50.png ADDED
samples/case2_day1_slice_0054_266_266_1.50_1.50.png ADDED
samples/case2_day1_slice_0077_266_266_1.50_1.50.png ADDED
samples/case32_day19_slice_0091_266_266_1.50_1.50.png ADDED
samples/case32_day19_slice_0100_266_266_1.50_1.50.png ADDED
samples/case33_day21_slice_0114_266_266_1.50_1.50.png ADDED
samples/case36_day16_slice_0064_266_266_1.50_1.50.png ADDED
samples/case40_day0_slice_0094_266_266_1.50_1.50.png ADDED
samples/case41_day25_slice_0049_266_266_1.50_1.50.png ADDED
samples/case63_day22_slice_0076_266_266_1.50_1.50.png ADDED
samples/case63_day26_slice_0093_266_266_1.50_1.50.png ADDED
samples/case65_day28_slice_0133_266_266_1.50_1.50.png ADDED
samples/case66_day36_slice_0101_266_266_1.50_1.50.png ADDED
samples/case67_day0_slice_0049_266_266_1.50_1.50.png ADDED
samples/case67_day0_slice_0086_266_266_1.50_1.50.png ADDED
samples/case74_day18_slice_0101_266_266_1.50_1.50.png ADDED
samples/case74_day19_slice_0084_266_266_1.50_1.50.png ADDED
samples/case81_day28_slice_0066_266_266_1.50_1.50.png ADDED
samples/case85_day29_slice_0102_360_310_1.50_1.50.png ADDED
samples/case89_day19_slice_0082_360_310_1.50_1.50.png ADDED
samples/case89_day20_slice_0087_266_266_1.50_1.50.png ADDED
segformer_trained_weights/config.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nvidia/segformer-b4-finetuned-ade-512-512",
3
+ "architectures": [
4
+ "SegformerForSemanticSegmentation"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "classifier_dropout_prob": 0.1,
8
+ "decoder_hidden_size": 768,
9
+ "depths": [
10
+ 3,
11
+ 8,
12
+ 27,
13
+ 3
14
+ ],
15
+ "downsampling_rates": [
16
+ 1,
17
+ 4,
18
+ 8,
19
+ 16
20
+ ],
21
+ "drop_path_rate": 0.1,
22
+ "hidden_act": "gelu",
23
+ "hidden_dropout_prob": 0.0,
24
+ "hidden_sizes": [
25
+ 64,
26
+ 128,
27
+ 320,
28
+ 512
29
+ ],
30
+ "id2label": {
31
+ "0": "LABEL_0",
32
+ "1": "LABEL_1",
33
+ "2": "LABEL_2",
34
+ "3": "LABEL_3"
35
+ },
36
+ "image_size": 224,
37
+ "initializer_range": 0.02,
38
+ "label2id": {
39
+ "LABEL_0": 0,
40
+ "LABEL_1": 1,
41
+ "LABEL_2": 2,
42
+ "LABEL_3": 3
43
+ },
44
+ "layer_norm_eps": 1e-06,
45
+ "mlp_ratios": [
46
+ 4,
47
+ 4,
48
+ 4,
49
+ 4
50
+ ],
51
+ "model_type": "segformer",
52
+ "num_attention_heads": [
53
+ 1,
54
+ 2,
55
+ 5,
56
+ 8
57
+ ],
58
+ "num_channels": 3,
59
+ "num_encoder_blocks": 4,
60
+ "patch_sizes": [
61
+ 7,
62
+ 3,
63
+ 3,
64
+ 3
65
+ ],
66
+ "reshape_last_stage": true,
67
+ "semantic_loss_ignore_index": 255,
68
+ "sr_ratios": [
69
+ 8,
70
+ 4,
71
+ 2,
72
+ 1
73
+ ],
74
+ "strides": [
75
+ 4,
76
+ 2,
77
+ 2,
78
+ 2
79
+ ],
80
+ "torch_dtype": "float32",
81
+ "transformers_version": "4.30.2"
82
+ }
segformer_trained_weights/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d4e3ff92a26341874655b112c2dc458cbb5ecf6d03f078dea8dd92ab0639e4a
3
+ size 256300245