import torch import numpy as np from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from .automatic_mask_generator_prob import SamAutomaticMaskAndProbabilityGenerator class Segmentor: def __init__(self, sam_args): """ sam_args: sam_checkpoint: path of SAM checkpoint generator_args: args for everything_generator gpu_id: device """ self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.sam = sam_model_registry[sam_args["model_type"]](checkpoint=sam_args["sam_checkpoint"]) self.sam.to(device=self.device) # self.everything_generator = SamAutomaticMaskGenerator(model=self.sam, **sam_args['generator_args']) self.automatic_generator = SamAutomaticMaskAndProbabilityGenerator(model=self.sam, **sam_args['generator_args']) self.interactive_predictor = self.automatic_generator.predictor self.have_embedded = False @torch.no_grad() def set_image(self, image): # calculate the embedding only once per frame. if not self.have_embedded: self.interactive_predictor.set_image(image) self.have_embedded = True @torch.no_grad() def interactive_predict(self, prompts, mode, multimask=True): assert self.have_embedded, 'image embedding for sam need be set before predict.' if mode == 'point': masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], point_labels=prompts['point_modes'], multimask_output=multimask) elif mode == 'mask': masks, scores, logits = self.interactive_predictor.predict(mask_input=prompts['mask_prompt'], multimask_output=multimask) elif mode == 'point_mask': masks, scores, logits = self.interactive_predictor.predict(point_coords=prompts['point_coords'], point_labels=prompts['point_modes'], mask_input=prompts['mask_prompt'], multimask_output=multimask) return masks, scores, logits @torch.no_grad() def automatic_segment(self, image): masks = self.automatic_generator.generate(image) return masks @torch.no_grad() def segment_with_click(self, origin_frame, coords, modes, multimask=True): ''' return: mask: one-hot ''' self.set_image(origin_frame) prompts = { 'point_coords': coords, 'point_modes': modes, } masks, scores, logits = self.interactive_predict(prompts, 'point', multimask) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] prompts = { 'point_coords': coords, 'point_modes': modes, 'mask_prompt': logit[None, :, :] } masks, scores, logits = self.interactive_predict(prompts, 'point_mask', multimask) mask = masks[np.argmax(scores)] return mask.astype(np.uint8) def segment_with_box(self, origin_frame, bbox, reset_image=False): if reset_image: self.interactive_predictor.set_image(origin_frame) else: self.set_image(origin_frame) masks, scores, logits = self.interactive_predictor.predict( point_coords=None, point_labels=None, box=np.array([bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]), multimask_output=True ) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] masks, scores, logits = self.interactive_predictor.predict( point_coords=None, point_labels=None, box=np.array([[bbox[0][0], bbox[0][1], bbox[1][0], bbox[1][1]]]), mask_input=logit[None, :, :], multimask_output=True ) mask = masks[np.argmax(scores)] return [mask]