File size: 5,531 Bytes
6f49966
 
b5baf02
 
 
 
 
469f43d
2b6c2bd
469f43d
2b6c2bd
 
c331e65
469f43d
2b6c2bd
b5baf02
 
 
 
 
 
469f43d
 
 
 
 
6abb9e2
c331e65
 
469f43d
b5baf02
469f43d
 
 
 
 
b5baf02
c331e65
b5baf02
 
 
 
 
c331e65
 
b5baf02
 
 
 
 
 
 
 
 
 
c331e65
 
b5baf02
 
 
 
c331e65
b5baf02
 
c331e65
 
 
 
 
 
 
 
469f43d
 
c331e65
 
 
b5baf02
469f43d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f49966
 
b5baf02
469f43d
 
c331e65
 
 
 
469f43d
 
c331e65
 
469f43d
2b6c2bd
c331e65
 
 
b5baf02
c331e65
 
 
 
b5baf02
c331e65
 
 
 
 
 
 
b5baf02
469f43d
 
 
 
 
 
 
 
 
6f49966
469f43d
 
 
 
 
 
 
b5baf02
c331e65
 
 
 
 
 
 
 
b5baf02
6f49966
 
 
469f43d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import torch
import tempfile
import os
from omegaconf import OmegaConf
from sam_segment import predict_masks_with_sam
from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
    show_mask, show_points
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry

def mkstemp(suffix, dir=None):
    fd, path = tempfile.mkstemp(suffix=f"{suffix}", dir=dir)
    os.close(fd)
    return Path(path)


def get_sam_feat(img):
    # predictor.set_image(img)
    model['sam'].set_image(img)
    return

 
def get_masked_img(img, w, h):
    point_coords = [w, h]
    point_labels = [1]
    dilate_kernel_size = 15
    # masks, _, _ = predictor.predict(
    masks, _, _ = model['sam'].predict(
        point_coords=np.array([point_coords]),
        point_labels=np.array(point_labels),
        multimask_output=True,
    )

    masks = masks.astype(np.uint8) * 255

    # dilate mask to avoid unmasked edge effect
    if dilate_kernel_size is not None:
        masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
    else:
        masks = [mask for mask in masks]

    figs = []
    for idx, mask in enumerate(masks):
        # save the pointed and masked image
        tmp_p = mkstemp(".png")
        dpi = plt.rcParams['figure.dpi']
        height, width = img.shape[:2]
        fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
        plt.imshow(img)
        plt.axis('off')
        show_points(plt.gca(), [point_coords], point_labels,
                    size=(width*0.04)**2)
        show_mask(plt.gca(), mask, random_color=False)
        plt.savefig(tmp_p, bbox_inches='tight', pad_inches=0)
        figs.append(fig)
        plt.close()
    return *figs, *masks


def get_inpainted_img(img, mask0, mask1, mask2):
    lama_config = "third_party/lama/configs/prediction/default.yaml"
    lama_ckpt = "pretrained_models/big-lama"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    out = []
    for mask in [mask0, mask1, mask2]:
        if len(mask.shape)==3:
            mask = mask[:,:,0]
        img_inpainted = inpaint_img_with_builded_lama(
            model_lama, img, mask, lama_config, device=device)
        out.append(img_inpainted)
    return out


## build models
model = {}
# build the sam model
model_type="vit_h"
ckpt_p="pretrained_models/sam_vit_h_4b8939.pth"
model_sam = sam_model_registry[model_type](checkpoint=ckpt_p)
device = "cuda" if torch.cuda.is_available() else "cpu"
model_sam.to(device=device)
# predictor = SamPredictor(model_sam)
model['sam'] = SamPredictor(model_sam)

# build the lama model
lama_config = "third_party/lama/configs/prediction/default.yaml"
lama_ckpt = "pretrained_models/big-lama"
device = "cuda" if torch.cuda.is_available() else "cpu"
# model_lama = build_lama_model(lama_config, lama_ckpt, device=device)
model['lama'] = build_lama_model(lama_config, lama_ckpt, device=device)


with gr.Blocks() as demo:
    with gr.Row():
        img = gr.Image(label="Image")
        # img_pointed = gr.Image(label='Pointed Image')
        img_pointed = gr.Plot(label='Pointed Image')
        with gr.Column():
            with gr.Row():
                w = gr.Number(label="Point Coordinate W")
                h = gr.Number(label="Point Coordinate H")
            sam_feat = gr.Button("Generate Features Using SAM")
            sam_mask = gr.Button("Predict Mask Using SAM")
            lama = gr.Button("Inpaint Image Using LaMA")

    # todo: maybe we can delete this row, for it's unnecessary to show the original mask for customers
    with gr.Row():
        mask_0 = gr.outputs.Image(type="numpy", label="Segmentation Mask 0")
        mask_1 = gr.outputs.Image(type="numpy", label="Segmentation Mask 1")
        mask_2 = gr.outputs.Image(type="numpy", label="Segmentation Mask 2")

    with gr.Row():
        img_with_mask_0 = gr.Plot(label="Image with Segmentation Mask 0")
        img_with_mask_1 = gr.Plot(label="Image with Segmentation Mask 1")
        img_with_mask_2 = gr.Plot(label="Image with Segmentation Mask 2")

    with gr.Row():
        img_rm_with_mask_0 = gr.outputs.Image(
            type="numpy", label="Image Removed with Segmentation Mask 0")
        img_rm_with_mask_1 = gr.outputs.Image(
            type="numpy", label="Image Removed with Segmentation Mask 1")
        img_rm_with_mask_2 = gr.outputs.Image(
            type="numpy", label="Image Removed with Segmentation Mask 2")

    def get_select_coords(img, evt: gr.SelectData):
        dpi = plt.rcParams['figure.dpi']
        height, width = img.shape[:2]
        fig = plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
        plt.imshow(img)
        plt.axis('off')
        show_points(plt.gca(), [[evt.index[0], evt.index[1]]], [1],
                    size=(width*0.04)**2)
        return evt.index[0], evt.index[1], fig

    img.select(get_select_coords, [img], [w, h, img_pointed])
    sam_feat.click(
        get_sam_feat,
        [img],
        []
    )
    sam_mask.click(
        get_masked_img,
        [img, w, h],
        [img_with_mask_0, img_with_mask_1, img_with_mask_2, mask_0, mask_1, mask_2]
    )

    lama.click(
        get_inpainted_img,
        [img, mask_0, mask_1, mask_2],
        [img_rm_with_mask_0, img_rm_with_mask_1, img_rm_with_mask_2]
    )


if __name__ == "__main__":
    demo.launch(debug=True)