SoM / sam_utils.py
SkalskiP's picture
Updated the 'sam_utils.py' and 'app.py' modules to implement automated mask generation, result highlighting and mark generation functionalities.
7e50af9
raw history blame
No virus
1.13 kB
import numpy as np
import supervision as sv
from segment_anything.modeling.sam import Sam
from segment_anything import SamPredictor, SamAutomaticMaskGenerator
def sam_inference(
image: np.ndarray,
model: Sam
) -> sv.Detections:
mask_generator = SamAutomaticMaskGenerator(model)
result = mask_generator.generate(image=image)
return sv.Detections.from_sam(result)
def sam_interactive_inference(
image: np.ndarray,
mask: np.ndarray,
model: Sam
) -> sv.Detections:
predictor = SamPredictor(model)
predictor.set_image(image)
masks = []
for polygon in sv.mask_to_polygons(mask.astype(bool)):
random_point_indexes = np.random.choice(polygon.shape[0], size=5, replace=True)
input_point = polygon[random_point_indexes]
input_label = np.ones(5)
mask = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)[0][0]
masks.append(mask)
masks = np.array(masks, dtype=bool)
return sv.Detections(
xyxy=sv.mask_to_xyxy(masks),
mask=masks
)