File size: 2,831 Bytes
078145b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from PIL import Image, ImageDraw
import cv2
import gradio as gr
import torch
from segment_anything import sam_model_registry
from automatic_mask_generator import SamAutomaticMaskGenerator

device = 'cuda'
sam = sam_model_registry['vit_h'](checkpoint='./sam_vit_h_4b8939.pth')
sam.to(device=device)


mask_generator = SamAutomaticMaskGenerator(
                                model=sam,
                                min_mask_region_area=25
                                )

def binarize(x):
    return (x != 0).astype('uint8') * 255

def draw_box(boxes=[], img=None):
    if len(boxes) == 0 and img is None:
        return None

    if img is None:
        img = Image.new('RGB', (512, 512), (255, 255, 255))
    colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"]
    draw = ImageDraw.Draw(img)
    # print(boxes)
    for bid, box in enumerate(boxes):
        draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4)
    return img


def draw_pred_box(boxes=[], img=None):
    if len(boxes) == 0 and img is None:
        return None

    if img is None:
        img = Image.new('RGB', (512, 512), (255, 255, 255))
    colors = "green"
    draw = ImageDraw.Draw(img)
    # print(boxes)
    for bid, box in enumerate(boxes):
        draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors, width=4)
    return img


def debug(input_img):
    mask = input_img["mask"]
    mask = mask[..., 0]
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    boxes = []
    for contour in contours:
        y1, y2 = contour[:, 0, 1].min(), contour[:, 0, 1].max()
        x1, x2 = contour[:, 0, 0].min(), contour[:, 0, 0].max()
        boxes.append([x1, y1, x2, y2])
    draw_image = draw_box(boxes, Image.fromarray(input_img["image"]))

    masks = mask_generator.generate(input_img["image"], boxes)
    pred_cnt = len(masks)
    pred_bboxes = []
    for i in masks:
        x0, y0, w, h = i['bbox']
        pred_bboxes.append([x0, y0, x0+w, y0+h])
    pred_image = draw_pred_box(pred_bboxes, Image.fromarray(input_img["image"]))
    return [draw_image, pred_image, "Count: {}".format(pred_cnt)]

description = """<p style="text-align: center; font-weight: bold;">
    <span style="font-size: 28px">Count Anything</span>
    <br>
    <span style="font-size: 18px" id="paper-info">
        [<a href=" " target="_blank">Project Page</a>]
        [<a href=" " target="_blank">Paper</a>]
        [<a href="https://github.com/Vision-Intelligence-and-Robots-Group/count-anything" target="_blank">GitHub</a>]
    </span>
</p>
"""

run = gr.Interface(
    debug,
    gr.Image(shape=[512, 512], source="upload", tool="sketch").style(height=500, width=500),
    [gr.Image(), gr.Image(), gr.Text()],
    description = description
)

run.launch()