SAM-arena / utils /SAM.py
Jose Benitez
Streamlit app (#4)
2e64de9 unverified
raw
history blame
4.61 kB
import torch
import numpy as np
from PIL import Image
import streamlit as st
import supervision as sv
from ultralytics import YOLO
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt
from transformers import SamModel, SamProcessor
from utils.efficient_sam import load, inference_with_point
import sys
sys.path.insert(1, './utils')
from edge_sam import sam_model_registry, SamPredictor
from edge_sam.onnx import SamPredictorONNX
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use ONNX to speed up the inference.
ENABLE_ONNX = False
ENCODER_ONNX_PATH = 'weights/edge_sam_3x_encoder.onnx'
DECODER_ONNX_PATH = 'weights/edge_sam_3x_decoder.onnx'
EDGESAM_CHECKPOINT = 'weights/edge_sam_3x.pth'
SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
FASTSAM_MODEL = FastSAM('FastSAM-x.pt')
EFFICIENT_SAM_MODEL = load(device=DEVICE)
if ENABLE_ONNX:
predictor = SamPredictorONNX(ENCODER_ONNX_PATH, DECODER_ONNX_PATH)
else:
sam = sam_model_registry["edge_sam"](EDGESAM_CHECKPOINT, upsample_mode="bicubic")
sam = sam.to(device=DEVICE)
sam.eval()
predictor = SamPredictor(sam)
@st.cache_data
def SAM_points_inference(image: np.ndarray, input_points) -> np.ndarray:
print('Processing SAM... 📊')
#input_points = [[[float(num) for num in sublist]] for sublist in global_points]
#print(input_points)
#input_points = [[[773.0, 167.0]]]
x = int(input_points[0][0][0])
y = int(input_points[0][0][1])
inputs = SAM_PROCESSOR(
Image.fromarray(image),
input_points=[input_points],
return_tensors="pt"
).to(DEVICE)
with torch.no_grad():
outputs = SAM_MODEL(**inputs)
mask = SAM_PROCESSOR.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0][0].numpy()
mask = mask[np.newaxis, ...]
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
return detections
@st.cache_data
def FastSAM_points_inference(
input,
input_points,
input_labels,
input_size=1024,
iou_threshold=0.7,
conf_threshold=0.25
):
# scaled input points
#input_points = [[[float(num) for num in sublist]] for sublist in input_points]
print('Processing FastSAM... 📊')
results = FASTSAM_MODEL(input,
device=DEVICE,
retina_masks=True,
iou=iou_threshold,
conf=conf_threshold,
imgsz=input_size)
prompt_process = FastSAMPrompt(input, results, device=DEVICE)
# Point prompt
detections = prompt_process.point_prompt(points=input_points, pointlabel=[1])
return detections
@st.cache_data
def EfficientSAM_points_inference(image: np.ndarray, input_points):
x, y = int(input_points[0][0]), int(input_points[0][1])
point = np.array([[int(x), int(y)]])
mask = inference_with_point(image, point, EFFICIENT_SAM_MODEL, DEVICE)
mask = mask[np.newaxis, ...]
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
return detections
@st.cache_data
def EdgeSAM_points_inference(
image_input,
input_points,
input_labels,
input_size=1024,
better_quality=False,
withContours=True,
use_retina=True,
mask_random_color=False,
):
# convert the numpy image from BGR to RGB
features = predictor.set_image(image_input)
print(type(predictor))
print(type(image_input))
print(image_input.shape)
print(image_input.dtype)
if ENABLE_ONNX:
input_points_np = np.array(input_points)[None]
input_labels_np = np.array(input_labels)[None]
masks, scores, _ = predictor.predict(
features=features,
point_coords=input_points_np,
point_labels=input_labels_np,
)
masks = masks.squeeze(0)
scores = scores.squeeze(0)
else:
input_points_np = np.array(input_points)
input_labels_np = np.array(input_labels)
masks, scores, logits = predictor.predict(
features=features,
point_coords=input_points_np,
point_labels=input_labels_np,
num_multimask_outputs=4,
use_stability_score=True
)
print(f'scores: {scores}')
area = masks.sum(axis=(1, 2))
print(f'area: {area}')
annotations = np.expand_dims(masks[scores.argmax()], axis=0)
return annotations