|
from huggingface_hub import snapshot_download |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import sys |
|
from tinysam import sam_model_registry, SamPredictor |
|
|
|
|
|
snapshot_download("merve/tinysam", local_dir="tinysam") |
|
|
|
model_type = "vit_t" |
|
sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth") |
|
|
|
predictor = SamPredictor(sam) |
|
|
|
examples = [ |
|
["assets/1.jpg"], |
|
["assets/2.jpg"], |
|
["assets/3.jpg"], |
|
["assets/4.jpeg"], |
|
["assets/5.jpg"], |
|
["assets/6.jpeg"] |
|
] |
|
|
|
default_example = examples[0] |
|
|
|
title = "<center><strong><font size='8'>TinySAM<font></strong> <a href='https://github.com/xinghaochen/TinySAM'><font size='6'>[GitHub]</font></a> </center>" |
|
description_p = """# Interactive Instance Segmentation |
|
- Point-prompt instruction |
|
<ol> |
|
<li> Click on the left image (point input), visualizing the point on the right image </li> |
|
<li> Click the button of Segment with Point Prompt </li> |
|
</ol> |
|
- Box-prompt instruction |
|
<ol> |
|
<li> Click on the left image (one point input), visualizing the point on the right image </li> |
|
<li> Click on the left image (another point input), visualizing the point and the box on the right image</li> |
|
<li> Click the button of Segment with Box Prompt </li> |
|
</ol> |
|
- Github [link](https://github.com/xinghaochen/TinySAM) |
|
""" |
|
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" |
|
|
|
def infer(img): |
|
if img is None: |
|
gr.Error("Please upload an image and select a point.") |
|
if img["background"] is None: |
|
gr.Error("Please upload an image and select a point.") |
|
|
|
image = img["background"].convert("RGB") |
|
point_prompt = img["layers"][0] |
|
total_image = img["composite"] |
|
predictor.set_image(np.array(image)) |
|
print("point_prompt : ", point_prompt) |
|
|
|
|
|
img_arr = np.array(point_prompt) |
|
if not np.any(img_arr): |
|
gr.Error("Please select a point on top of the image.") |
|
else: |
|
nonzero_indices = np.nonzero(img_arr) |
|
img_arr = np.array(point_prompt) |
|
nonzero_indices = np.nonzero(img_arr) |
|
center_x = int(np.mean(nonzero_indices[1])) |
|
center_y = int(np.mean(nonzero_indices[0])) |
|
input_point = np.array([[center_x, center_y]]) |
|
input_label = np.array([1]) |
|
masks, scores, logits = predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
) |
|
result_label = [(masks[scores.argmax(), :, :], "mask")] |
|
return image, result_label |
|
|
|
|
|
with gr.Blocks(css=css, title="TinySAM") as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown(title) |
|
with gr.Row(): |
|
with gr.Column(): |
|
im = gr.ImageEditor( |
|
type="pil", |
|
value=default_example[0] |
|
) |
|
output = gr.AnnotatedImage() |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Try some of the examples below ⬇️") |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[im], |
|
examples_per_page=6, |
|
) |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown(description_p) |
|
im.change(infer, inputs=im, outputs=output) |
|
|
|
demo.launch(debug=True) |