JuanLozada97 commited on
Commit
dbeb499
·
1 Parent(s): 9b977a4

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +74 -0
  2. examples/img_demo.png +0 -0
  3. model.py +6 -0
  4. requirements.txt +4 -0
  5. sam_vit_b_01ec64.pth +3 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
+ import base64
8
+ import json
9
+
10
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
11
+ from segment_anything.utils.onnx import SamOnnxModel
12
+
13
+ import torch.nn.functional as F
14
+
15
+ from model import create_sam_model
16
+
17
+ # 1.Setup variables
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ checkpoint = "sam_vit_b_01ec64.pth"
20
+ model_type = "vit_b"
21
+
22
+ # 2.Model preparation and load save weights
23
+ medsam_model = create_sam_model(model_type,checkpoint,device)
24
+ mask_generator = SamAutomaticMaskGenerator(medsam_model)
25
+
26
+ # 3.Predict fn
27
+ @torch.no_grad()
28
+
29
+ def predict(img) -> Tuple[Dict, float]:
30
+ """Transforms and performs a prediction on img and returns prediction and time taken.
31
+ """
32
+ # Start the timer
33
+ start_time = timer()
34
+ # Transform the target image and add a batch dimension
35
+
36
+ img_np = np.array(img)
37
+
38
+ # Convierte de BGR a RGB si es necesario
39
+ image = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
40
+
41
+ masks = mask_generator.generate(image)
42
+ # Calculate the prediction time
43
+ pred_time = round(timer() - start_time, 5)
44
+
45
+ fig,ax = plt.figure(figsize=(20,20))
46
+ plt.imshow(image)
47
+ show_anns(masks)
48
+ plt.axis('off')
49
+
50
+ # Return the prediction dictionary and prediction time
51
+ return fig, pred_time
52
+
53
+ # 4. Gradio app
54
+ # Create title, description and article strings
55
+ title = "MedSam"
56
+ description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder."
57
+ article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ."
58
+
59
+ # Create examples list from "examples/" directory
60
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
61
+
62
+ # Create the Gradio demo
63
+ demo = gr.Interface(fn=predict, # mapping function from input to output
64
+ inputs=gr.Image(type="pil"), # what are the inputs?
65
+ outputs=[gr.Plot(label="Predictions"), # what are the outputs?
66
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
67
+ examples=example_list,
68
+ title=title,
69
+ description=description,
70
+ article=article)
71
+
72
+ # Launch the demo!
73
+ demo.launch(debug=False, # print errors locally?
74
+ share=True) # generate a publically shareable URL?
examples/img_demo.png ADDED
model.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from segment_anything import sam_model_registry
2
+
3
+ def create_sam_model(model_type, checkpoint, device: str = "cpu"):
4
+ medsam_model = sam_model_registry[model_type](checkpoint=checkpoint)
5
+ medsam_model = medsam_model.to(device)
6
+ return medsam_model
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ gradio==3.50.2
4
+ 'git+https://github.com/facebookresearch/segment-anything.git'
sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383