import torch import numpy as np import cv2 import matplotlib.pyplot as plt from segment_anything import sam_model_registry, SamPredictor from segment_anything.utils.onnx import SamOnnxModel import torch.nn.functional as F def create_sam_model(model_type, checkpoint, device: str = "cpu"): medsam_model = sam_model_registry[model_type](checkpoint=checkpoint) medsam_model = medsam_model.to(device) return medsam_model