mattmdjaga commited on
Commit
21232f6
1 Parent(s): 206aa70

Added type hinting and some clean up

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +18 -19
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  test.ipynb
2
  data
 
 
1
  test.ipynb
2
  data
3
+ __pycache__
app.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image, ImageDraw
5
  import requests
6
  from transformers import SamModel, SamProcessor
7
  import cv2
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
@@ -12,7 +13,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
- def mask_2_dots(mask):
16
  gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
17
  _, thresh = cv2.threshold(gray, 127, 255, 0)
18
  kernel = np.ones((5,5),np.uint8)
@@ -26,34 +27,32 @@ def mask_2_dots(mask):
26
  points.append([cx, cy])
27
  return [points]
28
 
29
- def main_func(inputs):
30
- dots = inputs['mask']
31
- points = mask_2_dots(dots)
32
-
33
- image_input = inputs['image']
34
  image_input = Image.fromarray(image_input)
35
 
36
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
37
- # Forward pass
38
  outputs = model(**inputs)
39
-
40
- # Postprocess outputs
41
- draw = ImageDraw.Draw(image_input)
42
- for point in points[0]:
43
- draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
44
-
45
-
46
  masks = processor.image_processor.post_process_masks(
47
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
48
  )
49
- #scores = outputs.iou_scores
50
 
51
- mask = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  pred_masks = [image_input]
54
- for i in range(mask.shape[2]):
55
- #mask[:,:,i] = mask[:,:,i] * scores[0][i].item()
56
- pred_masks.append(Image.fromarray((mask[:,:,i] * 255).astype(np.uint8)))
57
 
58
  return pred_masks
59
 
 
5
  import requests
6
  from transformers import SamModel, SamProcessor
7
  import cv2
8
+ from typing import List
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
13
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
14
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
15
 
16
+ def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
17
  gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
18
  _, thresh = cv2.threshold(gray, 127, 255, 0)
19
  kernel = np.ones((5,5),np.uint8)
 
27
  points.append([cx, cy])
28
  return [points]
29
 
30
+ def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
 
 
 
 
31
  image_input = Image.fromarray(image_input)
32
 
33
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
 
34
  outputs = model(**inputs)
 
 
 
 
 
 
 
35
  masks = processor.image_processor.post_process_masks(
36
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
37
  )
38
+ masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
39
 
40
+ return masks
41
+
42
+ def main_func(inputs) -> List[Image.Image]:
43
+ dots = inputs['mask']
44
+ points = mask_2_dots(dots)
45
+ image_input = inputs['image']
46
+ masks = foward_pass(image_input, points)
47
+
48
+ image_input = Image.fromarray(image_input)
49
+ draw = ImageDraw.Draw(image_input)
50
+ for point in points[0]:
51
+ draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
52
 
53
  pred_masks = [image_input]
54
+ for i in range(masks.shape[2]):
55
+ pred_masks.append(Image.fromarray((masks[:,:,i] * 255).astype(np.uint8)))
 
56
 
57
  return pred_masks
58