sagar007 commited on
Commit
e9cd6fd
Β·
verified Β·
1 Parent(s): 1432260

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -28
app.py CHANGED
@@ -1,39 +1,43 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
- from torchvision import transforms
 
5
 
6
- # Load pre-trained U-Net model
7
- model = torch.hub.load('nvidia/DeepLearningExamples:torchhub', 'unet', pretrained=True)
8
 
9
- # Define a function to segment an image
10
- def segment_image(image):
11
- # Preprocess image
12
- image = Image.fromarray(image)
13
- image = transforms.Compose([
14
- transforms.Resize((256, 256)),
15
- transforms.ToTensor(),
16
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
- ])(image)
18
 
19
- # Run segmentation model
20
- output = model(image.unsqueeze(0))
21
- output = torch.argmax(output, dim=1)
22
 
23
- # Postprocess output
24
- output = output.squeeze(0).cpu().numpy()
25
- output = Image.fromarray(output.astype('uint8'))
26
 
27
- return output
 
 
 
 
 
 
 
 
28
 
29
- # Create Gradio app
30
- demo = gr.Interface(
31
  fn=segment_image,
32
- inputs=gr.Image(type="pil"),
33
- outputs=gr.Image(type="pil"),
34
- title="Segment Anything",
35
- description="Segment any image using a pre-trained U-Net model"
 
 
 
36
  )
37
 
38
- # Launch Gradio app
39
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import cv2
4
+ import numpy as np
5
+ from fastsam import FastSAM, FastSAMPrompt
6
 
7
+ # Load the FastSAM model
8
+ model = FastSAM('FastSAM-x.pt')
9
 
10
+ def segment_image(input_image, points):
11
+ # Prepare the image
12
+ input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
13
 
14
+ # Run the model
15
+ everything_results = model(input_image, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
 
16
 
17
+ # Prepare prompts
18
+ prompt_process = FastSAMPrompt(input_image, everything_results, device='cpu')
 
19
 
20
+ # Generate mask based on points
21
+ ann = prompt_process.point_prompt(points=points, pointlabel=[1] * len(points))
22
+
23
+ # Overlay the mask on the original image
24
+ result_image = input_image.copy()
25
+ mask = ann[0].astype(bool)
26
+ result_image[mask] = result_image[mask] * 0.5 + np.array([255, 0, 0]) * 0.5
27
+
28
+ return result_image
29
 
30
+ # Create Gradio interface
31
+ iface = gr.Interface(
32
  fn=segment_image,
33
+ inputs=[
34
+ gr.Image(type="numpy"),
35
+ gr.Image(type="numpy", tool="sketch", brush_radius=5, label="Click on objects to segment")
36
+ ],
37
+ outputs=gr.Image(type="numpy"),
38
+ title="FastSAM Image Segmentation",
39
+ description="Click on objects in the image to segment them using FastSAM."
40
  )
41
 
42
+ # Launch the interface
43
+ iface.launch()