curt-park commited on
Commit
acb3eab
·
1 Parent(s): 66201bd

Add all mask drawing

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ flagged
Makefile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ env:
2
+ conda create -n segment-anything python=3.9
3
+
4
+ setup:
5
+ pip install -r requirements.txt
6
+
7
+ run:
8
+ gradio app.py
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL
3
+ from functools import lru_cache
4
+
5
+ from random import randint
6
+ import gradio as gr
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
11
+ from typing import List
12
+
13
+ CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
14
+ MODEL_TYPE = "default"
15
+ MAX_WIDTH = MAX_HEIGHT = 800
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+
19
+ @lru_cache
20
+ def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator:
21
+ sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
22
+ mask_generator = SamAutomaticMaskGenerator(sam)
23
+ return mask_generator
24
+
25
+
26
+ def adjust_image_size(image: np.ndarray) -> np.ndarray:
27
+ height, width = image.shape[:2]
28
+ if height > width:
29
+ if height > MAX_HEIGHT:
30
+ height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width)
31
+ else:
32
+ if width > MAX_WIDTH:
33
+ height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
34
+ image = cv2.resize(image, (width, height))
35
+ print(image.shape)
36
+ return image
37
+
38
+
39
+ def draw_masks(
40
+ image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
41
+ ) -> np.ndarray:
42
+ for mask in masks:
43
+ color = [randint(127, 255) for _ in range(3)]
44
+ segmentation = mask["segmentation"]
45
+
46
+ # draw mask overlay
47
+ colored_seg = np.expand_dims(segmentation, 0).repeat(3, axis=0)
48
+ colored_seg = np.moveaxis(colored_seg, 0, -1)
49
+ masked = np.ma.MaskedArray(image, mask=colored_seg, fill_value=color)
50
+ image_overlay = masked.filled()
51
+ image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
52
+
53
+ # draw contour
54
+ contours, _ = cv2.findContours(
55
+ np.uint8(segmentation), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
56
+ )
57
+ cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
58
+ return image
59
+
60
+
61
+ def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile:
62
+ mask_generator = load_mask_generator()
63
+ # reduce the size to save gpu memory
64
+ image = adjust_image_size(cv2.imread(image_path))
65
+ masks = mask_generator.generate(image)
66
+ image = draw_masks(image, masks)
67
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
+ image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
69
+ return image
70
+
71
+
72
+ demo = gr.Interface(
73
+ fn=segment,
74
+ inputs=[gr.Image(type="filepath"), "text"],
75
+ outputs="image",
76
+ allow_flagging="never",
77
+ title="Segment Anything with CLIP",
78
+ examples=[
79
+ [os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), ""],
80
+ [os.path.join(os.path.dirname(__file__), "examples/city.jpg"), ""],
81
+ [os.path.join(os.path.dirname(__file__), "examples/food.jpg"), ""],
82
+ [os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), ""],
83
+ ],
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch()
examples/city.jpg ADDED
examples/dog.jpg ADDED
examples/food.jpg ADDED
examples/horse.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.24.1
2
+ opencv-python==4.7.0.72
3
+ pycocotools==2.0.6
4
+ matplotlib==3.7.1
5
+ git+https://github.com/facebookresearch/segment-anything.git
6
+ git+https://github.com/openai/CLIP.git
sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879