SoM / sam_utils.py
SkalskiP's picture
Add functionality for interactive mask generation
c263a47
raw
history blame
No virus
877 Bytes
import numpy as np
import supervision as sv
from segment_anything.modeling.sam import Sam
from segment_anything import SamPredictor
def sam_interactive_inference(
image: np.ndarray,
mask: np.ndarray,
model: Sam
) -> sv.Detections:
predictor = SamPredictor(model)
predictor.set_image(image)
masks = []
for polygon in sv.mask_to_polygons(mask.astype(bool)):
random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
input_point = polygon[random_point_indexes]
input_label = np.ones(5)
mask = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)[0][0]
masks.append(mask)
masks = np.array(masks, dtype=bool)
return sv.Detections(
xyxy=sv.mask_to_xyxy(masks),
mask=masks
)