SkalskiP commited on
Commit
242f627
1 Parent(s): c2b0349

Add refine_mask utility and update Dockerfile and app.py

Browse files

Introduced a new utility function 'refine_mask' in utils.py for refining masks by removing small islands or filling small holes based on a given area threshold.

Files changed (3) hide show
  1. Dockerfile +1 -0
  2. app.py +2 -1
  3. utils.py +39 -0
Dockerfile CHANGED
@@ -42,6 +42,7 @@ RUN mkdir -p $HOME/app/weights
42
  RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
43
 
44
  COPY app.py .
 
45
 
46
  RUN find $HOME/app
47
 
 
42
  RUN wget -c -O $HOME/app/weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
43
 
44
  COPY app.py .
45
+ COPY utils.py .
46
 
47
  RUN find $HOME/app
48
 
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
 
4
  import gradio as gr
@@ -7,7 +8,7 @@ import supervision as sv
7
 
8
  from typing import List
9
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
10
-
11
 
12
  HOME = os.getenv("HOME")
13
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
 
1
  import os
2
+ import cv2
3
  import torch
4
 
5
  import gradio as gr
 
8
 
9
  from typing import List
10
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
11
+ from utils import refine_mask
12
 
13
  HOME = os.getenv("HOME")
14
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ import numpy as np
4
+
5
+
6
+ def refine_mask(
7
+ mask: np.ndarray,
8
+ area_threshold: float,
9
+ mode: str = 'islands'
10
+ ) -> np.ndarray:
11
+ """
12
+ Refines a mask by removing small islands or filling small holes based on area
13
+ threshold.
14
+
15
+ Parameters:
16
+ mask (np.ndarray): Input binary mask.
17
+ area_threshold (float): Threshold for relative area to remove or fill features.
18
+ mode (str): Operation mode ('islands' for removing islands, 'holes' for filling
19
+ holes).
20
+
21
+ Returns:
22
+ np.ndarray: Refined binary mask.
23
+ """
24
+ mask = np.uint8(mask * 255)
25
+ operation = cv2.RETR_EXTERNAL if mode == 'islands' else cv2.RETR_CCOMP
26
+ contours, _ = cv2.findContours(
27
+ mask, operation, cv2.CHAIN_APPROX_SIMPLE
28
+ )
29
+ total_area = cv2.countNonZero(mask) if mode == 'islands' else mask.size
30
+
31
+ for contour in contours:
32
+ area = cv2.contourArea(contour)
33
+ relative_area = area / total_area
34
+ if relative_area < area_threshold:
35
+ cv2.drawContours(
36
+ mask, [contour], -1, (0 if mode == 'islands' else 255), -1
37
+ )
38
+
39
+ return np.where(mask > 0, 1, 0)