sagar007 commited on
Commit
0c00155
1 Parent(s): 4f39124

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -25
app.py CHANGED
@@ -1,18 +1,16 @@
1
  import gradio as gr
2
  import torch
3
- import cv2
4
  import numpy as np
5
  from PIL import Image
6
  import matplotlib.pyplot as plt
7
  import io
8
- from ultralytics import FastSAM
9
- from ultralytics.models.fastsam import FastSAMPrompt
10
 
11
  # Set up device
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
- # Load FastSAM model
15
- model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
16
 
17
  def fig2img(fig):
18
  buf = io.BytesIO()
@@ -21,16 +19,13 @@ def fig2img(fig):
21
  img = Image.open(buf)
22
  return img
23
 
24
- def plot_masks(annotations, output_shape):
25
  fig, ax = plt.subplots(figsize=(10, 10))
26
- ax.imshow(annotations[0].orig_img)
27
 
28
- for ann in annotations:
29
- for mask in ann.masks.data:
30
- mask = cv2.resize(mask.cpu().numpy().astype('uint8'), output_shape[::-1])
31
- masked = np.ma.masked_where(mask == 0, mask)
32
- ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
33
-
34
  ax.axis('off')
35
  plt.close()
36
  return fig2img(fig)
@@ -42,19 +37,15 @@ def segment_everything(input_image):
42
 
43
  input_image = Image.fromarray(input_image).convert("RGB")
44
 
45
- # Run FastSAM model in "everything" mode
46
- everything_results = model(input_image, device=device, retina_masks=True, imgsz=1024, conf=0.25, iou=0.9, agnostic_nms=True)
47
-
48
- # Prepare a Prompt Process object
49
- prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
50
-
51
- # Get everything segmentation
52
- ann = prompt_process.everything_prompt()
53
 
54
  # Plot the results
55
- result_image = plot_masks(ann, input_image.size)
56
 
57
- return result_image, f"Segmented everything in the image. Found {len(ann[0].masks)} objects."
58
 
59
  except Exception as e:
60
  return None, f"An error occurred: {str(e)}"
@@ -69,8 +60,8 @@ iface = gr.Interface(
69
  gr.Image(type="pil", label="Segmented Image"),
70
  gr.Textbox(label="Status")
71
  ],
72
- title="FastSAM Everything Segmentation",
73
- description="Upload an image to segment all objects using FastSAM."
74
  )
75
 
76
  # Launch the interface
 
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  from PIL import Image
5
  import matplotlib.pyplot as plt
6
  import io
7
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
 
8
 
9
  # Set up device
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Load SAM 2 model
13
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
14
 
15
  def fig2img(fig):
16
  buf = io.BytesIO()
 
19
  img = Image.open(buf)
20
  return img
21
 
22
+ def plot_masks(image, masks):
23
  fig, ax = plt.subplots(figsize=(10, 10))
24
+ ax.imshow(image)
25
 
26
+ for mask in masks:
27
+ masked = np.ma.masked_where(mask == 0, mask)
28
+ ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
 
 
 
29
  ax.axis('off')
30
  plt.close()
31
  return fig2img(fig)
 
37
 
38
  input_image = Image.fromarray(input_image).convert("RGB")
39
 
40
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
41
+ predictor.set_image(input_image)
42
+ # Use 'everything' prompt
43
+ masks, _, _ = predictor.predict([])
 
 
 
 
44
 
45
  # Plot the results
46
+ result_image = plot_masks(input_image, masks)
47
 
48
+ return result_image, f"Segmented everything in the image. Found {len(masks)} objects."
49
 
50
  except Exception as e:
51
  return None, f"An error occurred: {str(e)}"
 
60
  gr.Image(type="pil", label="Segmented Image"),
61
  gr.Textbox(label="Status")
62
  ],
63
+ title="SAM 2 Everything Segmentation",
64
+ description="Upload an image to segment all objects using SAM 2."
65
  )
66
 
67
  # Launch the interface