First commit - All files added.
Browse files- .gitignore +2 -0
- README.md +21 -7
- app.py +89 -0
- load_lightning_SD_to_Usual_SD.ipynb +341 -0
- requirements.txt +4 -0
- samples/case101_day26_slice_0096_266_266_1.50_1.50.png +0 -0
- samples/case107_day0_slice_0089_266_266_1.50_1.50.png +0 -0
- samples/case107_day21_slice_0069_266_266_1.50_1.50.png +0 -0
- samples/case113_day12_slice_0108_360_310_1.50_1.50.png +0 -0
- samples/case119_day20_slice_0063_266_266_1.50_1.50.png +0 -0
- samples/case119_day25_slice_0075_266_266_1.50_1.50.png +0 -0
- samples/case119_day25_slice_0095_266_266_1.50_1.50.png +0 -0
- samples/case121_day14_slice_0057_266_266_1.50_1.50.png +0 -0
- samples/case122_day25_slice_0087_266_266_1.50_1.50.png +0 -0
- samples/case124_day19_slice_0110_266_266_1.50_1.50.png +0 -0
- samples/case124_day20_slice_0110_266_266_1.50_1.50.png +0 -0
- samples/case130_day0_slice_0106_266_266_1.50_1.50.png +0 -0
- samples/case134_day21_slice_0085_360_310_1.50_1.50.png +0 -0
- samples/case139_day0_slice_0062_234_234_1.50_1.50.png +0 -0
- samples/case139_day18_slice_0094_266_266_1.50_1.50.png +0 -0
- samples/case146_day25_slice_0053_276_276_1.63_1.63.png +0 -0
- samples/case147_day0_slice_0085_360_310_1.50_1.50.png +0 -0
- samples/case148_day0_slice_0113_360_310_1.50_1.50.png +0 -0
- samples/case149_day15_slice_0057_266_266_1.50_1.50.png +0 -0
- samples/case29_day0_slice_0065_266_266_1.50_1.50.png +0 -0
- samples/case2_day1_slice_0054_266_266_1.50_1.50.png +0 -0
- samples/case2_day1_slice_0077_266_266_1.50_1.50.png +0 -0
- samples/case32_day19_slice_0091_266_266_1.50_1.50.png +0 -0
- samples/case32_day19_slice_0100_266_266_1.50_1.50.png +0 -0
- samples/case33_day21_slice_0114_266_266_1.50_1.50.png +0 -0
- samples/case36_day16_slice_0064_266_266_1.50_1.50.png +0 -0
- samples/case40_day0_slice_0094_266_266_1.50_1.50.png +0 -0
- samples/case41_day25_slice_0049_266_266_1.50_1.50.png +0 -0
- samples/case63_day22_slice_0076_266_266_1.50_1.50.png +0 -0
- samples/case63_day26_slice_0093_266_266_1.50_1.50.png +0 -0
- samples/case65_day28_slice_0133_266_266_1.50_1.50.png +0 -0
- samples/case66_day36_slice_0101_266_266_1.50_1.50.png +0 -0
- samples/case67_day0_slice_0049_266_266_1.50_1.50.png +0 -0
- samples/case67_day0_slice_0086_266_266_1.50_1.50.png +0 -0
- samples/case74_day18_slice_0101_266_266_1.50_1.50.png +0 -0
- samples/case74_day19_slice_0084_266_266_1.50_1.50.png +0 -0
- samples/case81_day28_slice_0066_266_266_1.50_1.50.png +0 -0
- samples/case85_day29_slice_0102_360_310_1.50_1.50.png +0 -0
- samples/case89_day19_slice_0082_360_310_1.50_1.50.png +0 -0
- samples/case89_day20_slice_0087_266_266_1.50_1.50.png +0 -0
- segformer_trained_weights/config.json +82 -0
- segformer_trained_weights/pytorch_model.bin +3 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
artifacts
|
2 |
+
wandb
|
README.md
CHANGED
@@ -1,13 +1,27 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: gpl-2.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Medical Image Segmentation Gradio App
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
+
|
13 |
+
# Medical Image Segmentation Gradio App
|
14 |
+
|
15 |
+
For the Gradio app we've removed the dependency on pytorch-lightning otherwise used in the project.
|
16 |
+
The `load_lightning_SD_to_Usual_SD.ipynb` notebook contains the steps used to convert pytorch-lightning checkpoint to a regular model checkpoint. This was mainly done to reduce the file size (977 MB --> 244 MB).
|
17 |
+
|
18 |
+
You can download the original saved checkpoint from over here: [wandb artifact](https://wandb.ai/veb-101/UM_medical_segmentation/artifacts/model/model-jsr2fn8v/v0/files)
|
19 |
+
|
20 |
+
Or via Python:
|
21 |
+
|
22 |
+
```python
|
23 |
+
import wandb
|
24 |
+
run = wandb.init()
|
25 |
+
artifact = run.use_artifact('veb-101/UM_medical_segmentation/model-jsr2fn8v:v0', type='model')
|
26 |
+
artifact_dir = artifact.download()
|
27 |
+
```
|
app.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
from glob import glob
|
5 |
+
from functools import partial
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision.transforms as TF
|
11 |
+
from transformers import SegformerForSemanticSegmentation
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Configs:
|
16 |
+
NUM_CLASSES: int = 4 # including background.
|
17 |
+
CLASSES: tuple = ("Large bowel", "Small bowel", "Stomach")
|
18 |
+
IMAGE_SIZE: tuple[int, int] = (288, 288) # W, H
|
19 |
+
MEAN: tuple = (0.485, 0.456, 0.406)
|
20 |
+
STD: tuple = (0.229, 0.224, 0.225)
|
21 |
+
MODEL_PATH: str = os.path.join(os.getcwd(), "segformer_trained_weights")
|
22 |
+
|
23 |
+
|
24 |
+
def get_model(*, model_path, num_classes):
|
25 |
+
model = SegformerForSemanticSegmentation.from_pretrained(model_path, num_labels=num_classes, ignore_mismatched_sizes=True)
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
@torch.inference_mode()
|
30 |
+
def predict(input_image, model=None, preprocess_fn=None, device="cpu"):
|
31 |
+
shape_H_W = input_image.size
|
32 |
+
input_tensor = preprocess_fn(input_image)
|
33 |
+
input_tensor = input_tensor.unsqueeze(0).to(device)
|
34 |
+
|
35 |
+
# Generate predictions
|
36 |
+
outputs = model(pixel_values=input_tensor.to(device), return_dict=True)
|
37 |
+
predictions = F.interpolate(outputs["logits"], size=shape_H_W, mode="bilinear", align_corners=False)
|
38 |
+
|
39 |
+
preds_argmax = predictions.argmax(dim=1).cpu().squeeze().numpy()
|
40 |
+
|
41 |
+
seg_info = [(preds_argmax == idx, class_name) for idx, class_name in enumerate(Configs.CLASSES, 1)]
|
42 |
+
|
43 |
+
return (input_image, seg_info)
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
# Create a mapping of class ID to RGB value.
|
48 |
+
id2color = {
|
49 |
+
0: (0, 0, 0), # background pixel
|
50 |
+
1: (0, 0, 255), # Stomach
|
51 |
+
2: (0, 255, 0), # Small bowel
|
52 |
+
3: (255, 0, 0), # large bowel
|
53 |
+
}
|
54 |
+
|
55 |
+
class2hexcolor = {"Stomach": "#007fff", "Small bowel": "#009A17", "Large bowel": "#FF0000"}
|
56 |
+
|
57 |
+
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
|
58 |
+
CKPT_PATH = os.path.join(os.getcwd(), "Segformer_best_state_dict.ckpt")
|
59 |
+
|
60 |
+
model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)
|
61 |
+
model.to(DEVICE)
|
62 |
+
model.eval()
|
63 |
+
_ = model(torch.randn(1, 3, *Configs.IMAGE_SIZE[::-1], device=DEVICE))
|
64 |
+
|
65 |
+
preprocess = TF.Compose(
|
66 |
+
[
|
67 |
+
TF.Resize(size=Configs.IMAGE_SIZE[::-1]),
|
68 |
+
TF.ToTensor(),
|
69 |
+
TF.Normalize(Configs.MEAN, Configs.STD, inplace=True),
|
70 |
+
]
|
71 |
+
)
|
72 |
+
|
73 |
+
with gr.Blocks(title="Medical Image Segmentation") as demo:
|
74 |
+
gr.Markdown("""<h1><center>Medical Image Segmentation with UW-Madison GI Tract Dataset</center></h1>""")
|
75 |
+
|
76 |
+
with gr.Row():
|
77 |
+
img_input = gr.Image(type="pil", height=300, width=300, label="Input image")
|
78 |
+
img_output = gr.AnnotatedImage(label="Predictions", height=300, width=300, color_map=class2hexcolor)
|
79 |
+
|
80 |
+
section_btn = gr.Button("Generate Predictions")
|
81 |
+
|
82 |
+
section_btn.click(partial(predict, model=model, preprocess_fn=preprocess, device=DEVICE), img_input, img_output)
|
83 |
+
|
84 |
+
images_dir = glob(os.path.join(os.getcwd(), "samples") + os.sep + "*.png")
|
85 |
+
examples = [i for i in np.random.choice(images_dir, size=8, replace=False)]
|
86 |
+
|
87 |
+
gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
|
88 |
+
|
89 |
+
demo.launch()
|
load_lightning_SD_to_Usual_SD.ipynb
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"## Base Configurations"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 1,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"import os\n",
|
17 |
+
"import torch\n",
|
18 |
+
"from transformers import SegformerForSemanticSegmentation\n",
|
19 |
+
"from dataclasses import dataclass\n",
|
20 |
+
"\n",
|
21 |
+
"\n",
|
22 |
+
"@dataclass\n",
|
23 |
+
"class Configs:\n",
|
24 |
+
" NUM_CLASSES = 4\n",
|
25 |
+
" MODEL_PATH: str = \"nvidia/segformer-b4-finetuned-ade-512-512\""
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "markdown",
|
30 |
+
"metadata": {},
|
31 |
+
"source": [
|
32 |
+
"## Load Model To Inspect Parameter Names"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": 2,
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"\n",
|
42 |
+
"\n",
|
43 |
+
"def get_model(*, model_path, num_classes):\n",
|
44 |
+
" model = SegformerForSemanticSegmentation.from_pretrained(\n",
|
45 |
+
" model_path,\n",
|
46 |
+
" num_labels=num_classes,\n",
|
47 |
+
" ignore_mismatched_sizes=True,\n",
|
48 |
+
" )\n",
|
49 |
+
" return model"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 3,
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [
|
57 |
+
{
|
58 |
+
"name": "stderr",
|
59 |
+
"output_type": "stream",
|
60 |
+
"text": [
|
61 |
+
"Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b4-finetuned-ade-512-512 and are newly initialized because the shapes did not match:\n",
|
62 |
+
"- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([4, 768, 1, 1]) in the model instantiated\n",
|
63 |
+
"- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([4]) in the model instantiated\n",
|
64 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "stdout",
|
69 |
+
"output_type": "stream",
|
70 |
+
"text": [
|
71 |
+
"\n",
|
72 |
+
"segformer.encoder.patch_embeddings.0.proj.weight\n",
|
73 |
+
"segformer.encoder.patch_embeddings.0.proj.bias\n",
|
74 |
+
"segformer.encoder.patch_embeddings.0.layer_norm.weight\n",
|
75 |
+
"segformer.encoder.patch_embeddings.0.layer_norm.bias\n",
|
76 |
+
"segformer.encoder.patch_embeddings.1.proj.weight\n",
|
77 |
+
"segformer.encoder.patch_embeddings.1.proj.bias\n"
|
78 |
+
]
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"model = get_model(model_path=Configs.MODEL_PATH, num_classes=Configs.NUM_CLASSES)\n",
|
83 |
+
"model_state_dict = model.state_dict()\n",
|
84 |
+
"\n",
|
85 |
+
"print()\n",
|
86 |
+
"for i, (key, val) in enumerate(model_state_dict.items()):\n",
|
87 |
+
" print(key)\n",
|
88 |
+
" if i == 5:\n",
|
89 |
+
" break"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "markdown",
|
94 |
+
"metadata": {},
|
95 |
+
"source": [
|
96 |
+
"## Download & load PyTorch-Lightning Checkpoint and Inspect Parameter Names"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 4,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [
|
104 |
+
{
|
105 |
+
"name": "stderr",
|
106 |
+
"output_type": "stream",
|
107 |
+
"text": [
|
108 |
+
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
|
109 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mveb-101\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"data": {
|
114 |
+
"application/vnd.jupyter.widget-view+json": {
|
115 |
+
"model_id": "2e6699f8bae4469fb42d361bf569b161",
|
116 |
+
"version_major": 2,
|
117 |
+
"version_minor": 0
|
118 |
+
},
|
119 |
+
"text/plain": [
|
120 |
+
"VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016933333332902596, max=1.0…"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
"metadata": {},
|
124 |
+
"output_type": "display_data"
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"data": {
|
128 |
+
"text/html": [
|
129 |
+
"Tracking run with wandb version 0.15.5"
|
130 |
+
],
|
131 |
+
"text/plain": [
|
132 |
+
"<IPython.core.display.HTML object>"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
"metadata": {},
|
136 |
+
"output_type": "display_data"
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"data": {
|
140 |
+
"text/html": [
|
141 |
+
"Run data is saved locally in <code>c:\\Users\\vaibh\\OneDrive\\Desktop\\Work\\BigVision\\BLOG_POSTS\\Medical_segmentation\\GRADIO_APP\\UWMGI_Medical_Image_Segmentation\\wandb\\run-20230719_044820-hnv9dwr2</code>"
|
142 |
+
],
|
143 |
+
"text/plain": [
|
144 |
+
"<IPython.core.display.HTML object>"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
"metadata": {},
|
148 |
+
"output_type": "display_data"
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"data": {
|
152 |
+
"text/html": [
|
153 |
+
"Syncing run <strong><a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/hnv9dwr2' target=\"_blank\">fanciful-jazz-1</a></strong> to <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
154 |
+
],
|
155 |
+
"text/plain": [
|
156 |
+
"<IPython.core.display.HTML object>"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
"metadata": {},
|
160 |
+
"output_type": "display_data"
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"data": {
|
164 |
+
"text/html": [
|
165 |
+
" View project at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation</a>"
|
166 |
+
],
|
167 |
+
"text/plain": [
|
168 |
+
"<IPython.core.display.HTML object>"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
"metadata": {},
|
172 |
+
"output_type": "display_data"
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"data": {
|
176 |
+
"text/html": [
|
177 |
+
" View run at <a href='https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/hnv9dwr2' target=\"_blank\">https://wandb.ai/veb-101/UWMGI_Medical_Image_Segmentation/runs/hnv9dwr2</a>"
|
178 |
+
],
|
179 |
+
"text/plain": [
|
180 |
+
"<IPython.core.display.HTML object>"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
"metadata": {},
|
184 |
+
"output_type": "display_data"
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"name": "stderr",
|
188 |
+
"output_type": "stream",
|
189 |
+
"text": [
|
190 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact model-jsr2fn8v:v0, 977.89MB. 1 files... \n",
|
191 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
|
192 |
+
"Done. 0:0:1.4\n"
|
193 |
+
]
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"import wandb\n",
|
198 |
+
"run = wandb.init()\n",
|
199 |
+
"artifact = run.use_artifact(r'veb-101/UM_medical_segmentation/model-jsr2fn8v:v0', type='model')\n",
|
200 |
+
"artifact_dir = artifact.download()"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 5,
|
206 |
+
"metadata": {},
|
207 |
+
"outputs": [
|
208 |
+
{
|
209 |
+
"name": "stdout",
|
210 |
+
"output_type": "stream",
|
211 |
+
"text": [
|
212 |
+
"dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecisionPlugin', 'hparams_name', 'hyper_parameters'])\n"
|
213 |
+
]
|
214 |
+
}
|
215 |
+
],
|
216 |
+
"source": [
|
217 |
+
"CKPT = torch.load(os.path.join(artifact_dir, \"model.ckpt\"), map_location=\"cpu\")\n",
|
218 |
+
"print(CKPT.keys())"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 6,
|
224 |
+
"metadata": {},
|
225 |
+
"outputs": [
|
226 |
+
{
|
227 |
+
"name": "stdout",
|
228 |
+
"output_type": "stream",
|
229 |
+
"text": [
|
230 |
+
"model.segformer.encoder.patch_embeddings.0.proj.weight\n",
|
231 |
+
"model.segformer.encoder.patch_embeddings.0.proj.bias\n",
|
232 |
+
"model.segformer.encoder.patch_embeddings.0.layer_norm.weight\n",
|
233 |
+
"model.segformer.encoder.patch_embeddings.0.layer_norm.bias\n",
|
234 |
+
"model.segformer.encoder.patch_embeddings.1.proj.weight\n",
|
235 |
+
"model.segformer.encoder.patch_embeddings.1.proj.bias\n"
|
236 |
+
]
|
237 |
+
}
|
238 |
+
],
|
239 |
+
"source": [
|
240 |
+
"TRAINED_CKPT_state_dict = CKPT[\"state_dict\"]\n",
|
241 |
+
"\n",
|
242 |
+
"for i, (key, val) in enumerate(TRAINED_CKPT_state_dict.items()):\n",
|
243 |
+
" print(key)\n",
|
244 |
+
" if i == 5:\n",
|
245 |
+
" break"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "markdown",
|
250 |
+
"metadata": {},
|
251 |
+
"source": [
|
252 |
+
"**The pytorch-lightning `state_dict()` has an extra `model.` string at the front that refers to the object/variable name that was holding the model in the `LightningModule` class.**\n",
|
253 |
+
"\n",
|
254 |
+
"We can simply iterate over the parameters and change the parameter key name. We'll create a new `OrderedDict` for it."
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": 7,
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [],
|
262 |
+
"source": [
|
263 |
+
"from collections import OrderedDict\n",
|
264 |
+
"\n",
|
265 |
+
"new_state_dict = OrderedDict()\n",
|
266 |
+
"\n",
|
267 |
+
"for key_name, value in CKPT[\"state_dict\"].items():\n",
|
268 |
+
" new_state_dict[key_name.replace(\"model.\", \"\")] = value"
|
269 |
+
]
|
270 |
+
},
|
271 |
+
{
|
272 |
+
"cell_type": "code",
|
273 |
+
"execution_count": 8,
|
274 |
+
"metadata": {},
|
275 |
+
"outputs": [
|
276 |
+
{
|
277 |
+
"data": {
|
278 |
+
"text/plain": [
|
279 |
+
"<All keys matched successfully>"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
"execution_count": 8,
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "execute_result"
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"model.load_state_dict(new_state_dict)"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 9,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"# torch.save(model.state_dict(), \"Segformer_best_state_dict.ckpt\")"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "code",
|
302 |
+
"execution_count": 10,
|
303 |
+
"metadata": {},
|
304 |
+
"outputs": [],
|
305 |
+
"source": [
|
306 |
+
"model.save_pretrained(\"segformer_trained_weights\")"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": 11,
|
312 |
+
"metadata": {},
|
313 |
+
"outputs": [],
|
314 |
+
"source": [
|
315 |
+
"model = get_model(model_path=os.path.join(os.getcwd(), \"segformer_trained_weights\"), num_classes=Configs.NUM_CLASSES)"
|
316 |
+
]
|
317 |
+
}
|
318 |
+
],
|
319 |
+
"metadata": {
|
320 |
+
"kernelspec": {
|
321 |
+
"display_name": "pytorchx",
|
322 |
+
"language": "python",
|
323 |
+
"name": "python3"
|
324 |
+
},
|
325 |
+
"language_info": {
|
326 |
+
"codemirror_mode": {
|
327 |
+
"name": "ipython",
|
328 |
+
"version": 3
|
329 |
+
},
|
330 |
+
"file_extension": ".py",
|
331 |
+
"mimetype": "text/x-python",
|
332 |
+
"name": "python",
|
333 |
+
"nbconvert_exporter": "python",
|
334 |
+
"pygments_lexer": "ipython3",
|
335 |
+
"version": "3.10.12"
|
336 |
+
},
|
337 |
+
"orig_nbformat": 4
|
338 |
+
},
|
339 |
+
"nbformat": 4,
|
340 |
+
"nbformat_minor": 2
|
341 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--find-links https://download.pytorch.org/whl/torch_stable.html
|
2 |
+
torch==2.0.0+cpu
|
3 |
+
torchvision==0.15.0
|
4 |
+
transformers==4.30.2
|
samples/case101_day26_slice_0096_266_266_1.50_1.50.png
ADDED
![]() |
samples/case107_day0_slice_0089_266_266_1.50_1.50.png
ADDED
![]() |
samples/case107_day21_slice_0069_266_266_1.50_1.50.png
ADDED
![]() |
samples/case113_day12_slice_0108_360_310_1.50_1.50.png
ADDED
![]() |
samples/case119_day20_slice_0063_266_266_1.50_1.50.png
ADDED
![]() |
samples/case119_day25_slice_0075_266_266_1.50_1.50.png
ADDED
![]() |
samples/case119_day25_slice_0095_266_266_1.50_1.50.png
ADDED
![]() |
samples/case121_day14_slice_0057_266_266_1.50_1.50.png
ADDED
![]() |
samples/case122_day25_slice_0087_266_266_1.50_1.50.png
ADDED
![]() |
samples/case124_day19_slice_0110_266_266_1.50_1.50.png
ADDED
![]() |
samples/case124_day20_slice_0110_266_266_1.50_1.50.png
ADDED
![]() |
samples/case130_day0_slice_0106_266_266_1.50_1.50.png
ADDED
![]() |
samples/case134_day21_slice_0085_360_310_1.50_1.50.png
ADDED
![]() |
samples/case139_day0_slice_0062_234_234_1.50_1.50.png
ADDED
![]() |
samples/case139_day18_slice_0094_266_266_1.50_1.50.png
ADDED
![]() |
samples/case146_day25_slice_0053_276_276_1.63_1.63.png
ADDED
![]() |
samples/case147_day0_slice_0085_360_310_1.50_1.50.png
ADDED
![]() |
samples/case148_day0_slice_0113_360_310_1.50_1.50.png
ADDED
![]() |
samples/case149_day15_slice_0057_266_266_1.50_1.50.png
ADDED
![]() |
samples/case29_day0_slice_0065_266_266_1.50_1.50.png
ADDED
![]() |
samples/case2_day1_slice_0054_266_266_1.50_1.50.png
ADDED
![]() |
samples/case2_day1_slice_0077_266_266_1.50_1.50.png
ADDED
![]() |
samples/case32_day19_slice_0091_266_266_1.50_1.50.png
ADDED
![]() |
samples/case32_day19_slice_0100_266_266_1.50_1.50.png
ADDED
![]() |
samples/case33_day21_slice_0114_266_266_1.50_1.50.png
ADDED
![]() |
samples/case36_day16_slice_0064_266_266_1.50_1.50.png
ADDED
![]() |
samples/case40_day0_slice_0094_266_266_1.50_1.50.png
ADDED
![]() |
samples/case41_day25_slice_0049_266_266_1.50_1.50.png
ADDED
![]() |
samples/case63_day22_slice_0076_266_266_1.50_1.50.png
ADDED
![]() |
samples/case63_day26_slice_0093_266_266_1.50_1.50.png
ADDED
![]() |
samples/case65_day28_slice_0133_266_266_1.50_1.50.png
ADDED
![]() |
samples/case66_day36_slice_0101_266_266_1.50_1.50.png
ADDED
![]() |
samples/case67_day0_slice_0049_266_266_1.50_1.50.png
ADDED
![]() |
samples/case67_day0_slice_0086_266_266_1.50_1.50.png
ADDED
![]() |
samples/case74_day18_slice_0101_266_266_1.50_1.50.png
ADDED
![]() |
samples/case74_day19_slice_0084_266_266_1.50_1.50.png
ADDED
![]() |
samples/case81_day28_slice_0066_266_266_1.50_1.50.png
ADDED
![]() |
samples/case85_day29_slice_0102_360_310_1.50_1.50.png
ADDED
![]() |
samples/case89_day19_slice_0082_360_310_1.50_1.50.png
ADDED
![]() |
samples/case89_day20_slice_0087_266_266_1.50_1.50.png
ADDED
![]() |
segformer_trained_weights/config.json
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nvidia/segformer-b4-finetuned-ade-512-512",
|
3 |
+
"architectures": [
|
4 |
+
"SegformerForSemanticSegmentation"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.0,
|
7 |
+
"classifier_dropout_prob": 0.1,
|
8 |
+
"decoder_hidden_size": 768,
|
9 |
+
"depths": [
|
10 |
+
3,
|
11 |
+
8,
|
12 |
+
27,
|
13 |
+
3
|
14 |
+
],
|
15 |
+
"downsampling_rates": [
|
16 |
+
1,
|
17 |
+
4,
|
18 |
+
8,
|
19 |
+
16
|
20 |
+
],
|
21 |
+
"drop_path_rate": 0.1,
|
22 |
+
"hidden_act": "gelu",
|
23 |
+
"hidden_dropout_prob": 0.0,
|
24 |
+
"hidden_sizes": [
|
25 |
+
64,
|
26 |
+
128,
|
27 |
+
320,
|
28 |
+
512
|
29 |
+
],
|
30 |
+
"id2label": {
|
31 |
+
"0": "LABEL_0",
|
32 |
+
"1": "LABEL_1",
|
33 |
+
"2": "LABEL_2",
|
34 |
+
"3": "LABEL_3"
|
35 |
+
},
|
36 |
+
"image_size": 224,
|
37 |
+
"initializer_range": 0.02,
|
38 |
+
"label2id": {
|
39 |
+
"LABEL_0": 0,
|
40 |
+
"LABEL_1": 1,
|
41 |
+
"LABEL_2": 2,
|
42 |
+
"LABEL_3": 3
|
43 |
+
},
|
44 |
+
"layer_norm_eps": 1e-06,
|
45 |
+
"mlp_ratios": [
|
46 |
+
4,
|
47 |
+
4,
|
48 |
+
4,
|
49 |
+
4
|
50 |
+
],
|
51 |
+
"model_type": "segformer",
|
52 |
+
"num_attention_heads": [
|
53 |
+
1,
|
54 |
+
2,
|
55 |
+
5,
|
56 |
+
8
|
57 |
+
],
|
58 |
+
"num_channels": 3,
|
59 |
+
"num_encoder_blocks": 4,
|
60 |
+
"patch_sizes": [
|
61 |
+
7,
|
62 |
+
3,
|
63 |
+
3,
|
64 |
+
3
|
65 |
+
],
|
66 |
+
"reshape_last_stage": true,
|
67 |
+
"semantic_loss_ignore_index": 255,
|
68 |
+
"sr_ratios": [
|
69 |
+
8,
|
70 |
+
4,
|
71 |
+
2,
|
72 |
+
1
|
73 |
+
],
|
74 |
+
"strides": [
|
75 |
+
4,
|
76 |
+
2,
|
77 |
+
2,
|
78 |
+
2
|
79 |
+
],
|
80 |
+
"torch_dtype": "float32",
|
81 |
+
"transformers_version": "4.30.2"
|
82 |
+
}
|
segformer_trained_weights/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9d4e3ff92a26341874655b112c2dc458cbb5ecf6d03f078dea8dd92ab0639e4a
|
3 |
+
size 256300245
|