Spaces:
Build error
Build error
Satyajithchary
commited on
Commit
•
7365db2
1
Parent(s):
e2c4088
Update app.py
Browse files
app.py
CHANGED
@@ -3,27 +3,80 @@ import cv2
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
def
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
unsam_output = image_cv # Replace with UNSAM processing
|
16 |
-
return unsam_plus_output, unsam_output
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
st.image(
|
|
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
+
from detectron2.engine import DefaultPredictor
|
7 |
+
from detectron2.config import get_cfg
|
8 |
+
from detectron2.projects.deeplab import add_deeplab_config
|
9 |
+
from detectron2.utils.colormap import random_color
|
10 |
+
from mask2former import add_maskformer2_config
|
11 |
+
from tqdm import tqdm
|
12 |
|
13 |
+
def setup_predictor(config_file, weights_path, device='cpu'):
|
14 |
+
cfg = get_cfg()
|
15 |
+
cfg.set_new_allowed(True)
|
16 |
+
add_deeplab_config(cfg)
|
17 |
+
add_maskformer2_config(cfg)
|
18 |
+
cfg.merge_from_file(config_file)
|
19 |
+
cfg.MODEL.WEIGHTS = weights_path
|
20 |
+
cfg.MODEL.DEVICE = device
|
21 |
+
predictor = DefaultPredictor(cfg)
|
22 |
+
return predictor
|
23 |
|
24 |
+
def area(mask):
|
25 |
+
if mask.size == 0:
|
26 |
+
return 0
|
27 |
+
return np.count_nonzero(mask) / mask.size
|
|
|
|
|
28 |
|
29 |
+
def vis_mask(input, mask, mask_color):
|
30 |
+
fg = mask > 0.5
|
31 |
+
rgb = np.copy(input)
|
32 |
+
rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
|
33 |
+
return Image.fromarray(rgb)
|
34 |
|
35 |
+
def show_image(I, pool):
|
36 |
+
already_painted = np.zeros(np.array(I).shape[:2])
|
37 |
+
input = I.copy()
|
38 |
+
for mask in tqdm(pool):
|
39 |
+
already_painted += mask.astype(np.uint8)
|
40 |
+
overlap = (already_painted == 2)
|
41 |
+
if np.sum(overlap) != 0:
|
42 |
+
input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
|
43 |
+
already_painted -= overlap
|
44 |
+
input = vis_mask(input, mask, random_color(rgb=True))
|
45 |
+
return input
|
46 |
|
47 |
+
# Load UnSAM and UnSAM+ predictors
|
48 |
+
unsam_predictor = setup_predictor(
|
49 |
+
"/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
|
50 |
+
"/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth"
|
51 |
+
)
|
52 |
+
unsam_plus_predictor = setup_predictor(
|
53 |
+
"/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
|
54 |
+
"/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth"
|
55 |
+
)
|
56 |
+
|
57 |
+
st.title("Image Segmentation with UnSAM and UnSAM+")
|
58 |
+
|
59 |
+
# Upload image
|
60 |
+
uploaded_file = st.file_uploader("Choose an image...", type="png")
|
61 |
+
|
62 |
+
if uploaded_file is not None:
|
63 |
+
# Read the image
|
64 |
+
image = np.array(Image.open(uploaded_file))
|
65 |
+
|
66 |
+
# Display the original image
|
67 |
+
st.image(image, caption='Original Image', use_column_width=True)
|
68 |
+
|
69 |
+
# Run predictions for UnSAM+
|
70 |
+
unsam_plus_outputs = unsam_plus_predictor(image)['instances']
|
71 |
+
unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks]
|
72 |
+
sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True)
|
73 |
+
unsam_plus_image = show_image(image, sorted_unsam_plus_masks)
|
74 |
|
75 |
+
# Run predictions for UnSAM
|
76 |
+
unsam_outputs = unsam_predictor(image)['instances']
|
77 |
+
unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks]
|
78 |
+
sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True)
|
79 |
+
unsam_image = show_image(image, sorted_unsam_masks)
|
80 |
|
81 |
+
# Display the images side by side
|
82 |
+
st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True)
|