File size: 4,569 Bytes
3fad000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4443af
 
 
 
 
3fad000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4443af
 
 
3fad000
 
 
 
 
 
 
 
 
 
 
a4443af
3fad000
 
 
 
 
 
 
a4443af
3fad000
 
 
 
 
a4443af
3fad000
a4443af
3fad000
 
 
 
 
 
 
a4443af
3fad000
5b6c473
3fad000
 
 
a4443af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os

os.system("pip uninstall -y mmcv-full")
os.system("pip uninstall -y mmsegmentation")
os.system("pip install ./mmcv_full-1.5.0-cp310-cp310-linux_x86_64.whl")
os.system("pip install -r requirements-extras.txt")
# os.system("cp /home/user/data/dinov2_vitg14_ade20k_m2f.pth /home/user/.cache/torch/hub/checkpoints/dinov2_vitg14_ade20k_m2f.pth")

import gradio as gr

import base64
import cv2
import math
import itertools
from functools import partial
from PIL import Image
import numpy as np
import pandas as pd

import dinov2.eval.segmentation.utils.colormaps as colormaps

import torch
import torch.nn.functional as F
from mmseg.apis import init_segmentor, inference_segmentor

import dinov2.eval.segmentation.models
import dinov2.eval.segmentation_m2f.models.segmentors

import urllib

import mmcv
from mmcv.runner import load_checkpoint

model = None
model_loaded = False

DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
CONFIG_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f_config.py"
CHECKPOINT_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth"


def load_config_from_url(url: str) -> str:
    with urllib.request.urlopen(url) as f:
        return f.read().decode()


cfg_str = load_config_from_url(CONFIG_URL)
cfg = mmcv.Config.fromstring(cfg_str, file_format=".py")


DATASET_COLORMAPS = {
    "ade20k": colormaps.ADE20K_COLORMAP,
    "voc2012": colormaps.VOC2012_COLORMAP,
}
colormap = DATASET_COLORMAPS["ade20k"]
flattened = np.array(colormap).flatten()
zeros = np.zeros(768)
zeros[:flattened.shape[0]] = flattened
colorMap = list(zeros.astype('uint8'))

model = init_segmentor(cfg)
load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
model.cuda()
model.eval()

class CenterPadding(torch.nn.Module):
    def __init__(self, multiple):
        super().__init__()
        self.multiple = multiple

    def _get_pad(self, size):
        new_size = math.ceil(size / self.multiple) * self.multiple
        pad_size = new_size - size
        pad_size_left = pad_size // 2
        pad_size_right = pad_size - pad_size_left
        return pad_size_left, pad_size_right

    @torch.inference_mode()
    def forward(self, x):
        pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
        output = F.pad(x, pads)
        return output


def create_segmenter(cfg, backbone_model):
    model = init_segmentor(cfg)
    model.backbone.forward = partial(
        backbone_model.get_intermediate_layers,
        n=cfg.model.backbone.out_indices,
        reshape=True,
    )
    if hasattr(backbone_model, "patch_size"):
        model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone_model.patch_size)(x[0]))
    model.init_weights()
    return model


def render_segmentation(segmentation_logits, dataset):
    colormap_array = np.array(colormap, dtype=np.uint8)
    segmentation_logits += 1
    segmented_image = Image.fromarray(segmentation_logits)
    segmented_image.putpalette(colorMap)
    
    unique_labels = np.unique(segmentation_logits)

    colormap_array = colormap_array[unique_labels]
    df = pd.read_csv("labelmap.txt", sep="\t")

    html_output = '<div style="display: flex; flex-wrap: wrap;">'
    import matplotlib.pyplot as plt

    for idx, color in enumerate(colormap_array):
        color_box = np.zeros((20, 20, 3), dtype=np.uint8)
        color_box[:, :] = color
        color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR)
        _, img_data = cv2.imencode(".jpg", color_box)
        img_base64 = base64.b64encode(img_data).decode("utf-8")
        img_data_uri = f"data:image/jpg;base64,{img_base64}"
        html_output += f'<div style="margin: 10px;"><img src="{img_data_uri}" /><p>{df.iloc[unique_labels[idx]-1]["Name"]}</p></div>'

    html_output += "</div>"

    return segmented_image, html_output


def predict(image_file):
    array = np.array(image_file)[:, :, ::-1]  # BGR
    segmentation_logits = inference_segmentor(model, array)[0]
    segmentation_logits = segmentation_logits.astype(np.uint8)
    segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
    return segmented_image, html_output

description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"

demo = gr.Interface(
    title="Semantic Segmentation - DinoV2",
    fn=predict,
    inputs=gr.inputs.Image(),
    outputs=[gr.outputs.Image(type="pil"), gr.outputs.HTML()],
    examples=["example_1.jpg", "example_2.jpg"],
    cache_examples=False,
    description=description,
)

demo.launch()