SkalskiP commited on
Commit
c263a47
1 Parent(s): 42c187d

Add functionality for interactive mask generation

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -1
  2. app.py +28 -13
  3. sam_utils.py +30 -0
Dockerfile CHANGED
@@ -31,7 +31,7 @@ WORKDIR $HOME/app
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
- RUN pip install --no-cache-dir gradio==4.5.0 opencv-python supervision==0.17.0rc3 \
35
  pillow requests
36
 
37
  # Install SAM and Detectron2
@@ -45,6 +45,7 @@ RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles
45
  COPY app.py .
46
  COPY utils.py .
47
  COPY gpt4v.py .
 
48
 
49
  RUN find $HOME/app
50
 
 
31
  RUN pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
32
 
33
  # Install dependencies
34
+ RUN pip install --no-cache-dir gradio==3.50.2 opencv-python supervision==0.17.0rc3 \
35
  pillow requests
36
 
37
  # Install SAM and Detectron2
 
45
  COPY app.py .
46
  COPY utils.py .
47
  COPY gpt4v.py .
48
+ COPY sam_utils.py .
49
 
50
  RUN find $HOME/app
51
 
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import os
2
- import cv2
3
- import torch
4
 
 
5
  import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
-
9
- from typing import List
10
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
11
- from utils import postprocess_masks, Visualizer
12
  from gpt4v import prompt_image
 
 
13
 
14
  HOME = os.getenv("HOME")
15
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
@@ -32,24 +33,33 @@ MARKDOWN = """
32
 
33
  - [ ] Support for alphabetic labels
34
  - [ ] Support for Semantic-SAM (multi-level)
35
- - [ ] Support for interactive mode
36
  - [ ] Support for result highlighting
 
37
  """
38
 
39
  SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
40
 
41
 
42
  def inference(
43
- image: np.ndarray,
44
  annotation_mode: List[str],
45
  mask_alpha: float
46
  ) -> np.ndarray:
 
 
 
47
  visualizer = Visualizer(mask_opacity=mask_alpha)
48
- mask_generator = SamAutomaticMaskGenerator(SAM)
49
- result = mask_generator.generate(image=image)
50
- detections = sv.Detections.from_sam(result)
51
- detections = postprocess_masks(
52
- detections=detections)
 
 
 
 
 
 
53
  bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
54
  annotated_image = visualizer.visualize(
55
  image=bgr_image,
@@ -76,7 +86,12 @@ def prompt(message, history, image: np.ndarray, api_key: str) -> str:
76
  image_input = gr.Image(
77
  label="Input",
78
  type="numpy",
79
- height=512)
 
 
 
 
 
80
  checkbox_annotation_mode = gr.CheckboxGroup(
81
  choices=["Mark", "Polygon", "Mask", "Box"],
82
  value=['Mark'],
 
1
  import os
2
+ from typing import List, Dict
 
3
 
4
+ import cv2
5
  import gradio as gr
6
  import numpy as np
7
  import supervision as sv
8
+ import torch
 
9
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
10
+
11
  from gpt4v import prompt_image
12
+ from utils import postprocess_masks, Visualizer
13
+ from sam_utils import sam_interactive_inference
14
 
15
  HOME = os.getenv("HOME")
16
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
33
 
34
  - [ ] Support for alphabetic labels
35
  - [ ] Support for Semantic-SAM (multi-level)
 
36
  - [ ] Support for result highlighting
37
+ - [ ] Support for mask filtering based on granularity
38
  """
39
 
40
  SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
41
 
42
 
43
  def inference(
44
+ image_and_mask: Dict[str, np.ndarray],
45
  annotation_mode: List[str],
46
  mask_alpha: float
47
  ) -> np.ndarray:
48
+ image = image_and_mask['image']
49
+ mask = cv2.cvtColor(image_and_mask['mask'], cv2.COLOR_RGB2GRAY)
50
+ is_interactive = not np.all(mask == 0)
51
  visualizer = Visualizer(mask_opacity=mask_alpha)
52
+ if is_interactive:
53
+ detections = sam_interactive_inference(
54
+ image=image,
55
+ mask=mask,
56
+ model=SAM)
57
+ else:
58
+ mask_generator = SamAutomaticMaskGenerator(SAM)
59
+ result = mask_generator.generate(image=image)
60
+ detections = sv.Detections.from_sam(result)
61
+ detections = postprocess_masks(
62
+ detections=detections)
63
  bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
64
  annotated_image = visualizer.visualize(
65
  image=bgr_image,
 
86
  image_input = gr.Image(
87
  label="Input",
88
  type="numpy",
89
+ height=512,
90
+ tool="sketch",
91
+ interactive=True,
92
+ brush_radius=20.0,
93
+ brush_color="#FFFFFF"
94
+ )
95
  checkbox_annotation_mode = gr.CheckboxGroup(
96
  choices=["Mark", "Polygon", "Mask", "Box"],
97
  value=['Mark'],
sam_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import supervision as sv
3
+
4
+ from segment_anything.modeling.sam import Sam
5
+ from segment_anything import SamPredictor
6
+
7
+
8
+ def sam_interactive_inference(
9
+ image: np.ndarray,
10
+ mask: np.ndarray,
11
+ model: Sam
12
+ ) -> sv.Detections:
13
+ predictor = SamPredictor(model)
14
+ predictor.set_image(image)
15
+ masks = []
16
+ for polygon in sv.mask_to_polygons(mask.astype(bool)):
17
+ random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
18
+ input_point = polygon[random_point_indexes]
19
+ input_label = np.ones(5)
20
+ mask = predictor.predict(
21
+ point_coords=input_point,
22
+ point_labels=input_label,
23
+ multimask_output=False,
24
+ )[0][0]
25
+ masks.append(mask)
26
+ masks = np.array(masks, dtype=bool)
27
+ return sv.Detections(
28
+ xyxy=sv.mask_to_xyxy(masks),
29
+ mask=masks
30
+ )