JuanLozada97's picture
first commit
c6ccb48
raw
history blame
433 Bytes
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel
import torch.nn.functional as F
def create_sam_model(model_type, checkpoint, device: str = "cpu"):
medsam_model = sam_model_registry[model_type](checkpoint=checkpoint)
medsam_model = medsam_model.to(device)
return medsam_model