File size: 877 Bytes
c263a47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import numpy as np
import supervision as sv
from segment_anything.modeling.sam import Sam
from segment_anything import SamPredictor
def sam_interactive_inference(
image: np.ndarray,
mask: np.ndarray,
model: Sam
) -> sv.Detections:
predictor = SamPredictor(model)
predictor.set_image(image)
masks = []
for polygon in sv.mask_to_polygons(mask.astype(bool)):
random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
input_point = polygon[random_point_indexes]
input_label = np.ones(5)
mask = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)[0][0]
masks.append(mask)
masks = np.array(masks, dtype=bool)
return sv.Detections(
xyxy=sv.mask_to_xyxy(masks),
mask=masks
)
|