import gradio as gr import os import torch import numpy as np import cv2 import matplotlib.pyplot as plt import base64 import json from typing import Tuple, Dict from timeit import default_timer as timer from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor from segment_anything.utils.onnx import SamOnnxModel import torch.nn.functional as F from model import create_sam_model # 1.Setup variables device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" # 2.Model preparation and load save weights medsam_model = create_sam_model(model_type,checkpoint,device) mask_generator = SamAutomaticMaskGenerator( model=medsam_model, points_per_side=32, pred_iou_thresh=0.86, stability_score_thresh=0.92, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, # Requires open-cv to run post-processing ) # 3.Predict fn def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) img[:,:,3] = 0 for ann in sorted_anns: m = ann['segmentation'] color_mask = np.concatenate([np.random.random(3), [0.35]]) img[m] = color_mask ax.imshow(img) @torch.no_grad() def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken. """ # Start the timer start_time = timer() # Transform the target image and add a batch dimension img_np = np.array(img) # Convierte de BGR a RGB si es necesario image = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) masks = mask_generator.generate(image) # Calculate the prediction time pred_time = round(timer() - start_time, 5) fig = plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') # Return the prediction dictionary and prediction time return fig, pred_time # 4. Gradio app # Create title, description and article strings title = "MedSam" description = "a specialized SAM model finely tuned for the segmentation of medical images. With this app, effortlessly extract image embeddings using the model's advanced mask decoder." article = "Created at gradio-sam-predictor-image-embedding-generator.ipynb ." # Create examples list from "examples/" directory example_list = [["examples/" + example] for example in os.listdir("examples")] # Create the Gradio demo demo = gr.Interface(fn=predict, # mapping function from input to output inputs=gr.Image(type="pil"), # what are the inputs? outputs=[gr.Plot(label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs examples=example_list, title=title, description=description, article=article) # Launch the demo! demo.launch(debug=False, # print errors locally? share=True) # generate a publically shareable URL?