File size: 3,581 Bytes
7aef3af
 
 
 
 
 
c456e88
7aef3af
 
 
 
 
 
 
 
48c1344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e53817f
48c1344
7aef3af
e53817f
 
 
 
 
 
 
 
 
 
7aef3af
e53817f
0ebcb8d
e53817f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aef3af
 
e53817f
 
 
 
 
c85b146
 
 
e53817f
 
c85b146
 
48c1344
 
 
 
 
 
e53817f
48c1344
 
 
 
 
0ebcb8d
7aef3af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from huggingface_hub import snapshot_download
import gradio as gr
import numpy as np
import torch
import sys
from tinysam import sam_model_registry, SamPredictor


snapshot_download("merve/tinysam", local_dir="tinysam")

model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth")

predictor = SamPredictor(sam)

examples = [
    ["assets/1.jpg"],
    ["assets/2.jpg"],
    ["assets/3.jpg"],
    ["assets/4.jpeg"],
    ["assets/5.jpg"],
    ["assets/6.jpeg"]
]

default_example = examples[0]
# Description
title = "<center><strong><font size='8'>TinySAM<font></strong> <a href='https://github.com/xinghaochen/TinySAM'><font size='6'>[GitHub]</font></a> </center>"
description_p = """# Interactive Instance Segmentation
                - Point-prompt instruction
                <ol>
                <li> Click on the left image (point input), visualizing the point on the right image </li>
                <li> Click the button of Segment with Point Prompt </li>
                </ol>
                - Box-prompt instruction
                <ol>
                <li> Click on the left image (one point input), visualizing the point on the right image </li>
                <li> Click on the left image (another point input), visualizing the point and the box on the right image</li>
                <li> Click the button of Segment with Box Prompt </li>
                </ol>
                - Github [link](https://github.com/xinghaochen/TinySAM)
              """
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"

def infer(img):
    if img is None:
        gr.Error("Please upload an image and select a point.")
    if img["background"] is None:
        gr.Error("Please upload an image and select a point.")
    # background (original image) layers[0] ( point prompt) composite (total image)
    image = img["background"].convert("RGB")
    point_prompt = img["layers"][0]
    total_image = img["composite"]
    predictor.set_image(np.array(image))
    print("point_prompt : ", point_prompt)

    # get point prompt
    img_arr = np.array(point_prompt)
    if not np.any(img_arr):
        gr.Error("Please select a point on top of the image.")
    else:
        nonzero_indices = np.nonzero(img_arr)
        img_arr = np.array(point_prompt)
        nonzero_indices = np.nonzero(img_arr)
        center_x = int(np.mean(nonzero_indices[1]))
        center_y = int(np.mean(nonzero_indices[0]))
        input_point = np.array([[center_x, center_y]])
        input_label = np.array([1])
        masks, scores, logits = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
        )
        result_label = [(masks[scores.argmax(), :, :], "mask")]
        return image, result_label


with gr.Blocks(css=css, title="TinySAM") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            # Title
            gr.Markdown(title)
    with gr.Row():
        with gr.Column():
            im = gr.ImageEditor(
                type="pil",
                value=default_example[0]
            )
        output = gr.AnnotatedImage()
    with gr.Row():
        with gr.Column():
            gr.Markdown("Try some of the examples below ⬇️")
            gr.Examples(
                examples=examples,
                inputs=[im],
                examples_per_page=6,
            )

        with gr.Column():
            # Description
            gr.Markdown(description_p)
    im.change(infer, inputs=im, outputs=output)

demo.launch(debug=True)