File size: 5,136 Bytes
2e1faab
cda9ccb
 
 
3b352d4
964b7b7
 
3b352d4
 
964b7b7
3b352d4
964b7b7
3b352d4
 
 
 
964b7b7
3b352d4
 
964b7b7
3b352d4
 
 
 
 
 
 
964b7b7
 
3b352d4
 
 
 
 
964b7b7
 
 
 
3b352d4
 
 
 
964b7b7
 
 
3b352d4
964b7b7
 
 
 
 
 
3b352d4
 
 
964b7b7
3b352d4
 
 
 
 
 
964b7b7
 
 
 
 
 
 
 
 
cda9ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964b7b7
cda9ccb
964b7b7
cda9ccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964b7b7
 
 
 
 
 
cda9ccb
964b7b7
cda9ccb
 
964b7b7
cda9ccb
964b7b7
cda9ccb
 
3b352d4
581ca56
964b7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f312659
964b7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581ca56
964b7b7
581ca56
964b7b7
 
 
50c5448
3b352d4
 
 
 
 
 
 
 
cda9ccb
 
3b352d4
 
 
 
964b7b7
 
 
3b352d4
 
964b7b7
3b352d4
964b7b7
3b352d4
964b7b7
 
50c5448
964b7b7
 
 
3b352d4
 
 
964b7b7
 
3b352d4
 
964b7b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import annotations

import random

import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from CCAgT_utils.categories import CategoriesInfos
from CCAgT_utils.types.mask import Mask
from CCAgT_utils.visualization import plot
from PIL import Image
from torch import nn
from transformers import SegformerFeatureExtractor
from transformers import SegformerForSemanticSegmentation
from transformers.modeling_outputs import SemanticSegmenterOutput


matplotlib.use('Agg')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_hub_name = 'lapix/segformer-b3-finetuned-ccagt-400-300'

model = SegformerForSemanticSegmentation.from_pretrained(
    model_hub_name,
).to(device)
model.eval()

feature_extractor = SegformerFeatureExtractor.from_pretrained(
    model_hub_name,
)


def segment(
    image: Image.Image,
) -> SemanticSegmenterOutput:
    inputs = feature_extractor(
        image,
        return_tensors='pt',
    ).to(device)

    outputs = model(**inputs)

    return outputs


def post_processing(
    outputs: SemanticSegmenterOutput,
    target_size: tuple[int, int],
) -> np.ndarray:
    logits = outputs.logits.cpu()

    upsampled_logits = nn.functional.interpolate(
        logits,
        size=target_size,
        mode='bilinear',
        align_corners=False,
    )

    segmentation_mask = upsampled_logits.argmax(dim=1)[0]

    return np.array(segmentation_mask)


def colorize(
    mask: Mask,
) -> np.ndarray:
    return mask.colorized(CategoriesInfos()) / 255


# Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L35
def get_random_crop_coords(
    height: int,
    width: int,
    crop_height: int,
    crop_width: int,
    h_start: float,
    w_start: float,
):
    y1 = int((height - crop_height + 1) * h_start)
    y2 = y1 + crop_height
    x1 = int((width - crop_width + 1) * w_start)
    x2 = x1 + crop_width
    return x1, y1, x2, y2


# Copied from https://github.com/albumentations-team/albumentations/blob/b1af92ab8e57279f5acd5987770a86a8d6b6b0e5/albumentations/augmentations/crops/functional.py#L46


def random_crop(
    img: np.ndarray,
    crop_height: int,
    crop_width: int,
    h_start: float,
    w_start: float,
) -> np.ndarray:
    height, width = img.shape[:2]

    x1, y1, x2, y2 = get_random_crop_coords(
        height, width, crop_height, crop_width, h_start, w_start,
    )
    img = img[y1:y2, x1:x2]
    return img


def process_big_images(
    image: Image.Image,
) -> Mask:
    '''Process and post-processing for images bigger than 400x300'''
    img = np.asarray(image)

    if img.shape[0] > 300 or img.shape[1] > 400:
        img = random_crop(img, 300, 400, random.random(), random.random())

    target_size = (img.shape[0], img.shape[1])

    outputs = segment(Image.fromarray(img))
    msk = post_processing(outputs, target_size)

    return img, Mask(msk)


def image_with_mask(
    image: Image.Image,
    mask: Mask,
) -> plt.Figure:
    fig = plt.figure(dpi=600)

    plt.imshow(image)
    plt.imshow(
        mask.categorical,
        cmap=mask.cmap(CategoriesInfos()),
        vmax=max(mask.unique_ids),
        vmin=min(mask.unique_ids),
        interpolation='nearest',
        alpha=0.4,
    )
    plt.axis('off')
    plt.tight_layout(pad=0)
    return fig


def categories_map(
    mask: Mask,
) -> plt.Figure:
    fig = plt.figure(dpi=600)

    handles = plot.create_handles(
        CategoriesInfos(), selected_categories=mask.unique_ids,
    )
    plt.legend(handles=handles, fontsize=24, loc='center')
    plt.axis('off')

    return fig


def main(image):
    image = Image.fromarray(image)

    img, mask = process_big_images(image)
    mask_colorized = colorize(mask)
    fig = image_with_mask(img, mask)

    return categories_map(mask), Image.fromarray(img), mask_colorized, fig


title = 'SegFormer (b3) - CCAgT dataset'
description = f"""
This is demo for the SegFormer fine-tuned on sub-dataset from
[CCAgT dataset](https://huggingface.co/datasets/lapix/CCAgT). This model
was trained to segment cervical cells silver-stained (AgNOR technique)
images with resolution of 400x300. The model was available at HF hub at
[{model_hub_name}](https://huggingface.co/{model_hub_name}). If input
an image bigger than 400x300, the demo will random crop it.
"""
examples = [
    [f'https://hf.co/{model_hub_name}/resolve/main/sampleA.png'],
    [f'https://hf.co/{model_hub_name}/resolve/main/sampleB.png'],
] + [
    [f'https://datasets-server.huggingface.co/assets/lapix/CCAgT/--/semantic_segmentation/test/{x}/image/image.jpg']
    for x in {3, 10, 12, 18, 35, 78, 89}
]


demo = gr.Interface(
    main,
    inputs=[gr.Image()],
    outputs=[
        gr.Plot(label='Categories map'),
        gr.Image(label='Image'),
        gr.Image(label='Mask'),
        gr.Plot(label='Image with mask'),
    ],
    title=title,
    description=description,
    examples=examples,
    allow_flagging='never',
    cache_examples=False,
)

if __name__ == '__main__':
    demo.launch()