Satyajithchary commited on
Commit
7365db2
1 Parent(s): e2c4088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -17
app.py CHANGED
@@ -3,27 +3,80 @@ import cv2
3
  import numpy as np
4
  from PIL import Image
5
  import torch
 
 
 
 
 
 
6
 
7
- def load_image(image_file):
8
- img = Image.open(image_file)
9
- return img
 
 
 
 
 
 
 
10
 
11
- def process_image(image):
12
- image_cv = np.array(image.convert('RGB'))
13
- image_cv = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
14
- unsam_plus_output = image_cv # Replace with UNSAM+ processing
15
- unsam_output = image_cv # Replace with UNSAM processing
16
- return unsam_plus_output, unsam_output
17
 
18
- st.title("UNSAM Image Processing")
 
 
 
 
19
 
20
- image_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"])
 
 
 
 
 
 
 
 
 
 
21
 
22
- if image_file is not None:
23
- original_image = load_image(image_file)
24
- st.image(original_image, caption="Original Image", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- unsam_plus_output, unsam_output = process_image(original_image)
 
 
 
 
27
 
28
- st.image(unsam_plus_output, caption="UNSAM+ Output", use_column_width=True)
29
- st.image(unsam_output, caption="UNSAM Output", use_column_width=True)
 
3
  import numpy as np
4
  from PIL import Image
5
  import torch
6
+ from detectron2.engine import DefaultPredictor
7
+ from detectron2.config import get_cfg
8
+ from detectron2.projects.deeplab import add_deeplab_config
9
+ from detectron2.utils.colormap import random_color
10
+ from mask2former import add_maskformer2_config
11
+ from tqdm import tqdm
12
 
13
+ def setup_predictor(config_file, weights_path, device='cpu'):
14
+ cfg = get_cfg()
15
+ cfg.set_new_allowed(True)
16
+ add_deeplab_config(cfg)
17
+ add_maskformer2_config(cfg)
18
+ cfg.merge_from_file(config_file)
19
+ cfg.MODEL.WEIGHTS = weights_path
20
+ cfg.MODEL.DEVICE = device
21
+ predictor = DefaultPredictor(cfg)
22
+ return predictor
23
 
24
+ def area(mask):
25
+ if mask.size == 0:
26
+ return 0
27
+ return np.count_nonzero(mask) / mask.size
 
 
28
 
29
+ def vis_mask(input, mask, mask_color):
30
+ fg = mask > 0.5
31
+ rgb = np.copy(input)
32
+ rgb[fg] = (rgb[fg] * 0.5 + np.array(mask_color) * 0.5).astype(np.uint8)
33
+ return Image.fromarray(rgb)
34
 
35
+ def show_image(I, pool):
36
+ already_painted = np.zeros(np.array(I).shape[:2])
37
+ input = I.copy()
38
+ for mask in tqdm(pool):
39
+ already_painted += mask.astype(np.uint8)
40
+ overlap = (already_painted == 2)
41
+ if np.sum(overlap) != 0:
42
+ input = Image.fromarray(overlap[:, :, np.newaxis] * np.copy(I) + np.logical_not(overlap)[:, :, np.newaxis] * np.copy(input))
43
+ already_painted -= overlap
44
+ input = vis_mask(input, mask, random_color(rgb=True))
45
+ return input
46
 
47
+ # Load UnSAM and UnSAM+ predictors
48
+ unsam_predictor = setup_predictor(
49
+ "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
50
+ "/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth"
51
+ )
52
+ unsam_plus_predictor = setup_predictor(
53
+ "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
54
+ "/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth"
55
+ )
56
+
57
+ st.title("Image Segmentation with UnSAM and UnSAM+")
58
+
59
+ # Upload image
60
+ uploaded_file = st.file_uploader("Choose an image...", type="png")
61
+
62
+ if uploaded_file is not None:
63
+ # Read the image
64
+ image = np.array(Image.open(uploaded_file))
65
+
66
+ # Display the original image
67
+ st.image(image, caption='Original Image', use_column_width=True)
68
+
69
+ # Run predictions for UnSAM+
70
+ unsam_plus_outputs = unsam_plus_predictor(image)['instances']
71
+ unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks]
72
+ sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True)
73
+ unsam_plus_image = show_image(image, sorted_unsam_plus_masks)
74
 
75
+ # Run predictions for UnSAM
76
+ unsam_outputs = unsam_predictor(image)['instances']
77
+ unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks]
78
+ sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True)
79
+ unsam_image = show_image(image, sorted_unsam_masks)
80
 
81
+ # Display the images side by side
82
+ st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True)