File size: 1,498 Bytes
7aef3af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c85b146
 
 
 
 
 
 
7aef3af
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from huggingface_hub import snapshot_download
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import torch
import sys
from tinysam import sam_model_registry, SamPredictor
import cv2

snapshot_download("merve/tinysam", local_dir="tinysam")

model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth")

predictor = SamPredictor(sam)

def infer(img):
  # background (original image) layers[0] ( point prompt) composite (total image)
  image = img["background"].convert("RGB")
  point_prompt = img["layers"][0]
  total_image = img["composite"]
  #torch_img = torch.from_numpy(np.array(image))
  #torch_img = torch_img.permute(2, 0, 1)
  predictor.set_image(np.array(image))

  # get point prompt
  img_arr = np.array(point_prompt)
  nonzero_indices = np.nonzero(img_arr)
  center_x = int(np.mean(nonzero_indices[1]))
  center_y = int(np.mean(nonzero_indices[0]))
  input_point = np.array([[center_x, center_y]])

  input_label = np.array([1])
  masks, scores, logits = predictor.predict(
      point_coords=input_point,
      point_labels=input_label,
  )


  result_label = [(masks[0, :, :], "mask")]
  return image, result_label


with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            im = gr.ImageEditor(
                type="pil"
            )
            submit_btn = gr.Button()
        output = gr.AnnotatedImage()
    submit_btn.click(infer, inputs=im, outputs=gr.AnnotatedImage())

demo.launch(debug=True)