mattmdjaga commited on
Commit
fd219d5
1 Parent(s): 21232f6

Added no grad and storing embeddings

Browse files
Files changed (2) hide show
  1. app.py +13 -2
  2. requirements.txt +2 -1
app.py CHANGED
@@ -13,6 +13,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
14
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
15
 
 
 
16
  def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
17
  gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
18
  _, thresh = cv2.threshold(gray, 127, 255, 0)
@@ -27,11 +29,16 @@ def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
27
  points.append([cx, cy])
28
  return [points]
29
 
 
30
  def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
 
31
  image_input = Image.fromarray(image_input)
32
-
33
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
34
- outputs = model(**inputs)
 
 
 
 
35
  masks = processor.image_processor.post_process_masks(
36
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
37
  )
@@ -56,6 +63,9 @@ def main_func(inputs) -> List[Image.Image]:
56
 
57
  return pred_masks
58
 
 
 
 
59
 
60
  with gr.Blocks() as demo:
61
  gr.Markdown("# How to use")
@@ -71,5 +81,6 @@ with gr.Blocks() as demo:
71
  image_button = gr.Button("Segment Image")
72
 
73
  image_button.click(main_func, inputs=image_input, outputs=image_output)
 
74
 
75
  demo.launch()
 
13
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
14
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
15
 
16
+ embedding = None
17
+
18
  def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
19
  gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
20
  _, thresh = cv2.threshold(gray, 127, 255, 0)
 
29
  points.append([cx, cy])
30
  return [points]
31
 
32
+ @torch.no_grad()
33
  def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
34
+ global embedding
35
  image_input = Image.fromarray(image_input)
 
36
  inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
37
+ if not isinstance(embedding, torch.Tensor):
38
+ embedding = model.get_image_embeddings(inputs["pixel_values"])
39
+ del inputs["pixel_values"]
40
+
41
+ outputs = model.forward(image_embeddings=embedding, **inputs)
42
  masks = processor.image_processor.post_process_masks(
43
  outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
44
  )
 
63
 
64
  return pred_masks
65
 
66
+ def reset_embedding():
67
+ global embedding
68
+ embedding = None
69
 
70
  with gr.Blocks() as demo:
71
  gr.Markdown("# How to use")
 
81
  image_button = gr.Button("Segment Image")
82
 
83
  image_button.click(main_func, inputs=image_input, outputs=image_output)
84
+ image_input.upload(reset_embedding)
85
 
86
  demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
  git+https://github.com/huggingface/transformers
3
- opencv-python
 
 
1
  torch
2
  git+https://github.com/huggingface/transformers
3
+ opencv-python
4
+ gradio --upgrade