Spaces:
Runtime error
Runtime error
update app
Browse files- README.md +4 -4
- app.py +126 -14
- requirements.txt +2 -0
README.md
CHANGED
@@ -2,19 +2,19 @@
|
|
2 |
app_file: app.py
|
3 |
colorFrom: red
|
4 |
colorTo: green
|
5 |
-
datasets:
|
6 |
- lapix/CCAgT
|
7 |
emoji: 💻
|
8 |
license: mit
|
9 |
-
models:
|
10 |
- lapix/segformer-b3-finetuned-ccagt-400-300
|
11 |
pinned: true
|
12 |
sdk: gradio
|
13 |
sdk_version: "3.3.1"
|
14 |
-
tags:
|
15 |
- vision
|
16 |
- image-segmentation
|
17 |
-
task_ids:
|
18 |
- semantic-segmentation
|
19 |
title: "SegFormer B3 CCAgT"
|
20 |
---
|
|
|
2 |
app_file: app.py
|
3 |
colorFrom: red
|
4 |
colorTo: green
|
5 |
+
datasets:
|
6 |
- lapix/CCAgT
|
7 |
emoji: 💻
|
8 |
license: mit
|
9 |
+
models:
|
10 |
- lapix/segformer-b3-finetuned-ccagt-400-300
|
11 |
pinned: true
|
12 |
sdk: gradio
|
13 |
sdk_version: "3.3.1"
|
14 |
+
tags:
|
15 |
- vision
|
16 |
- image-segmentation
|
17 |
+
task_ids:
|
18 |
- semantic-segmentation
|
19 |
title: "SegFormer B3 CCAgT"
|
20 |
---
|
app.py
CHANGED
@@ -1,13 +1,21 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
|
|
4 |
from CCAgT_utils.types.mask import Mask
|
|
|
5 |
from PIL import Image
|
6 |
from torch import nn
|
7 |
from transformers import SegformerFeatureExtractor
|
8 |
from transformers import SegformerForSemanticSegmentation
|
|
|
9 |
|
10 |
|
|
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
13 |
model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'
|
@@ -15,37 +23,130 @@ model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'
|
|
15 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
16 |
model_hub_name,
|
17 |
).to(device)
|
|
|
|
|
18 |
feature_extractor = SegformerFeatureExtractor.from_pretrained(
|
19 |
model_hub_name,
|
20 |
)
|
21 |
|
22 |
|
23 |
-
def
|
24 |
-
image
|
25 |
-
|
26 |
-
|
27 |
-
pixel_values = feature_extractor(
|
28 |
image,
|
29 |
return_tensors='pt',
|
30 |
).to(device)
|
31 |
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
upsampled_logits = nn.functional.interpolate(
|
38 |
logits,
|
39 |
-
size=
|
40 |
mode='bilinear',
|
41 |
align_corners=False,
|
42 |
)
|
43 |
|
44 |
segmentation_mask = upsampled_logits.argmax(dim=1)[0]
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
title = 'SegFormer (b3) - CCAgT dataset'
|
@@ -59,15 +160,26 @@ images with resolution of 400x300. The model was available at HF hub at
|
|
59 |
examples = [
|
60 |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'],
|
61 |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'],
|
|
|
|
|
|
|
62 |
]
|
63 |
|
|
|
64 |
demo = gr.Interface(
|
65 |
-
|
66 |
inputs=[gr.Image()],
|
67 |
-
outputs=
|
|
|
|
|
|
|
|
|
68 |
title=title,
|
69 |
description=description,
|
70 |
examples=examples,
|
|
|
|
|
71 |
)
|
72 |
|
73 |
-
|
|
|
|
1 |
+
import cv2
|
2 |
import gradio as gr
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
from CCAgT_utils.categories import CategoriesInfos
|
8 |
+
from CCAgT_utils.slice import __create_xy_slice
|
9 |
from CCAgT_utils.types.mask import Mask
|
10 |
+
from CCAgT_utils.visualization import plot
|
11 |
from PIL import Image
|
12 |
from torch import nn
|
13 |
from transformers import SegformerFeatureExtractor
|
14 |
from transformers import SegformerForSemanticSegmentation
|
15 |
+
from transformers.modeling_outputs import SemanticSegmenterOutput
|
16 |
|
17 |
|
18 |
+
matplotlib.use('Agg')
|
19 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
20 |
|
21 |
model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'
|
|
|
23 |
model = SegformerForSemanticSegmentation.from_pretrained(
|
24 |
model_hub_name,
|
25 |
).to(device)
|
26 |
+
model.eval()
|
27 |
+
|
28 |
feature_extractor = SegformerFeatureExtractor.from_pretrained(
|
29 |
model_hub_name,
|
30 |
)
|
31 |
|
32 |
|
33 |
+
def segment(
|
34 |
+
image: Image.Image,
|
35 |
+
) -> SemanticSegmenterOutput:
|
36 |
+
inputs = feature_extractor(
|
|
|
37 |
image,
|
38 |
return_tensors='pt',
|
39 |
).to(device)
|
40 |
|
41 |
+
outputs = model(**inputs)
|
42 |
+
|
43 |
+
return outputs
|
44 |
|
45 |
+
|
46 |
+
def post_processing(
|
47 |
+
outputs: SemanticSegmenterOutput,
|
48 |
+
target_size: tuple[int, int],
|
49 |
+
) -> np.ndarray:
|
50 |
+
logits = outputs.logits.cpu()
|
51 |
|
52 |
upsampled_logits = nn.functional.interpolate(
|
53 |
logits,
|
54 |
+
size=target_size,
|
55 |
mode='bilinear',
|
56 |
align_corners=False,
|
57 |
)
|
58 |
|
59 |
segmentation_mask = upsampled_logits.argmax(dim=1)[0]
|
60 |
|
61 |
+
return np.array(segmentation_mask)
|
62 |
+
|
63 |
+
|
64 |
+
def colorize(
|
65 |
+
mask: Mask,
|
66 |
+
) -> np.ndarray:
|
67 |
+
return mask.colorized(CategoriesInfos()) / 255
|
68 |
+
|
69 |
+
|
70 |
+
def check_and_resize(
|
71 |
+
image: np.ndarray,
|
72 |
+
) -> np.ndarray:
|
73 |
+
|
74 |
+
if image.shape[0] > 1200 or image.shape[1] > 1600:
|
75 |
+
r = 1600.0 / image.shape[1]
|
76 |
+
dim = (1600, int(image.shape[0] * r))
|
77 |
+
return cv2.resize(image, dim, interpolation=cv2.INTER_AREA)
|
78 |
+
|
79 |
+
return image
|
80 |
+
|
81 |
+
|
82 |
+
def process_big_images(
|
83 |
+
image: Image.Image,
|
84 |
+
) -> Mask:
|
85 |
+
'''Process and post-processing for images bigger than 400x300'''
|
86 |
+
img = check_and_resize(np.asarray(image))
|
87 |
+
mask = np.zeros(shape=(img.shape[0], img.shape[1]), dtype=np.uint8)
|
88 |
+
|
89 |
+
for bbox in __create_xy_slice(image.size[1], image.size[0], 300, 400):
|
90 |
+
part = cv2.copyMakeBorder(
|
91 |
+
img,
|
92 |
+
bbox.y_init,
|
93 |
+
bbox.y_end,
|
94 |
+
bbox.x_init,
|
95 |
+
bbox.x_end,
|
96 |
+
cv2.BORDER_REFLECT,
|
97 |
+
)
|
98 |
+
target_size = (part.shape[0], part.shape[1])
|
99 |
+
|
100 |
+
outputs = segment(Image.fromarray(part))
|
101 |
+
msk = post_processing(outputs, target_size)
|
102 |
+
|
103 |
+
mask[bbox.slice_y, bbox.slice_x] = msk[bbox.slice_y, bbox.slice_x]
|
104 |
|
105 |
+
return Mask(mask)
|
106 |
+
|
107 |
+
|
108 |
+
def image_with_mask(
|
109 |
+
image: Image.Image,
|
110 |
+
mask: Mask,
|
111 |
+
) -> plt.Figure:
|
112 |
+
fig = plt.figure(dpi=600)
|
113 |
+
|
114 |
+
plt.imshow(image)
|
115 |
+
plt.imshow(
|
116 |
+
mask.categorical,
|
117 |
+
cmap=mask.cmap(CategoriesInfos()),
|
118 |
+
vmax=max(mask.unique_ids),
|
119 |
+
vmin=min(mask.unique_ids),
|
120 |
+
interpolation='nearest',
|
121 |
+
alpha=0.4,
|
122 |
+
)
|
123 |
+
plt.axis('off')
|
124 |
+
|
125 |
+
return fig
|
126 |
+
|
127 |
+
|
128 |
+
def categories_map(
|
129 |
+
mask: Mask,
|
130 |
+
) -> plt.Figure:
|
131 |
+
fig = plt.figure(dpi=600)
|
132 |
+
|
133 |
+
handles = plot.create_handles(
|
134 |
+
CategoriesInfos(), selected_categories=mask.unique_ids,
|
135 |
+
)
|
136 |
+
plt.legend(handles=handles, fontsize=24, loc='center')
|
137 |
+
plt.axis('off')
|
138 |
+
|
139 |
+
return fig
|
140 |
+
|
141 |
+
|
142 |
+
def main(image):
|
143 |
+
img = Image.fromarray(image)
|
144 |
+
|
145 |
+
mask = process_big_images(img)
|
146 |
+
mask_colorized = colorize(mask)
|
147 |
+
fig = image_with_mask(img, mask)
|
148 |
+
|
149 |
+
return categories_map(mask), mask_colorized, fig
|
150 |
|
151 |
|
152 |
title = 'SegFormer (b3) - CCAgT dataset'
|
|
|
160 |
examples = [
|
161 |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'],
|
162 |
[f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'],
|
163 |
+
] + [
|
164 |
+
[f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg']
|
165 |
+
for x in {3, 10, 12, 18, 35, 78, 89}
|
166 |
]
|
167 |
|
168 |
+
|
169 |
demo = gr.Interface(
|
170 |
+
main,
|
171 |
inputs=[gr.Image()],
|
172 |
+
outputs=[
|
173 |
+
gr.Plot(label='Categories map'),
|
174 |
+
gr.Image(label='Mask'),
|
175 |
+
gr.Plot(label='Image with mask'),
|
176 |
+
],
|
177 |
title=title,
|
178 |
description=description,
|
179 |
examples=examples,
|
180 |
+
allow_flagging='never',
|
181 |
+
cache_examples=False,
|
182 |
)
|
183 |
|
184 |
+
if __name__ == '__main__':
|
185 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
CCAgT-utils
|
|
|
2 |
numpy
|
|
|
3 |
torch
|
4 |
transformers
|
|
|
1 |
CCAgT-utils
|
2 |
+
matplotlib
|
3 |
numpy
|
4 |
+
opencv-python
|
5 |
torch
|
6 |
transformers
|