import gradio as gr import os import torch import numpy as np import cv2 import matplotlib.pyplot as plt import base64 import json from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor from segment_anything.utils.onnx import SamOnnxModel import torch.nn.functional as F from model import create_sam_model # 1.Setup variables device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" # 2.Model preparation and load save weights medsam_model = create_sam_model(model_type,checkpoint,device) mask_generator = SamAutomaticMaskGenerator(medsam_model) # 3.Predict fn @torch.no_grad() def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken. """ # Start the timer start_time = timer() # Transform the target image and add a batch dimension img_np = np.array(img) # Convierte de BGR a RGB si es necesario image = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) masks = mask_generator.generate(image) # Calculate the prediction time pred_time = round(timer() - start_time, 5) fig,ax = plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') # Return the prediction dictionary and prediction time return fig, pred_time # 4. Gradio app # Create title, description and article strings title = "MedSam" description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder." article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ." # Create examples list from "examples/" directory example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type="pil"), # what are the inputs? outputs=[gr.Plot(label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.launch(debug=False, # print errors locally? share=True) # generate a publically shareable URL?