ttengwang
support "segment everything in a paragraph"
ccb14a3
import time
import torch
import cv2
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image
import matplotlib.pyplot as plt
import PIL
class BaseSegmenter:
def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None, args=None):
print(f"Initializing BaseSegmenter to {device}")
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = None
if model is None:
if checkpoint is None:
_, checkpoint = prepare_segmenter(model_name)
self.model = sam_model_registry[seg_model_map[model_name]](checkpoint=checkpoint)
self.checkpoint = checkpoint
self.model.to(device=self.device)
else:
self.model = model
self.reuse_feature = reuse_feature
self.predictor = SamPredictor(self.model)
sam_generator_keys = ['pred_iou_thresh', 'min_mask_region_area', 'stability_score_thresh', 'box_nms_thresh']
generator_args = {k:v for k,v in vars(args).items() if k in sam_generator_keys}
self.mask_generator = SamAutomaticMaskGenerator(model=self.model, **generator_args)
self.image_embedding = None
self.image = None
@torch.no_grad()
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
image = load_image(image, return_type='numpy')
self.image = image
if self.reuse_feature:
self.predictor.set_image(image)
self.image_embedding = self.predictor.get_image_embedding()
print(self.image_embedding.shape)
@torch.no_grad()
def inference(self, image: Union[np.ndarray, Image.Image, str], control: dict):
"""
SAM inference of image according to control.
Args:
image: str or PIL.Image or np.ndarray
control: dict to control SAM.
prompt_type:
1. {control['prompt_type'] = ['everything']} to segment everything in the image.
2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
3. {control['prompt_type'] = ['click'] to segment according to click.
4. {control['prompt_type'] = ['box'] to segment according to box.
input_point: list of [x, y] coordinates of click.
input_label: List of labels for points accordingly, 0 for negative, 1 for positive.
input_box: List of [x1, y1, x2, y2] coordinates of box.
multimask_output:
If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
Returns:
masks: np.ndarray of shape [num_masks, height, width]
"""
image = load_image(image, return_type='numpy')
if 'everything' in control['prompt_type']:
masks = self.mask_generator.generate(image)
new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
bbox = np.array([mask["bbox"] for mask in masks])
area = np.array([mask["area"] for mask in masks])
return new_masks, bbox, area
else:
if not self.reuse_feature or self.image_embedding is None:
self.set_image(image)
self.predictor.set_image(self.image)
else:
assert self.image_embedding is not None
self.predictor.features = self.image_embedding
if 'mutimask_output' in control:
masks, scores, logits = self.predictor.predict(
point_coords=np.array(control['input_point']),
point_labels=np.array(control['input_label']),
multimask_output=True,
)
elif 'input_boxes' in control:
transformed_boxes = self.predictor.transform.apply_boxes_torch(
torch.tensor(control["input_boxes"], device=self.predictor.device),
image.shape[1::-1] # Reverse shape because numpy is (W, H) and function need (H, W)
)
masks, _, _ = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks = masks.squeeze(1).cpu().numpy()
else:
input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
if 0 in control['input_label']:
mask_input = logits[np.argmax(scores), :, :]
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
return masks
if __name__ == "__main__":
image_path = 'segmenter/images/truck.jpg'
prompts = [
# {
# "prompt_type":["click"],
# "input_point":[[500, 375]],
# "input_label":[1],
# "multimask_output":"True",
# },
{
"prompt_type": ["click"],
"input_point": [[1000, 600], [1325, 625]],
"input_label": [1, 0],
},
# {
# "prompt_type":["click", "box"],
# "input_box":[425, 600, 700, 875],
# "input_point":[[575, 750]],
# "input_label": [0]
# },
# {
# "prompt_type":["box"],
# "input_boxes": [
# [75, 275, 1725, 850],
# [425, 600, 700, 875],
# [1375, 550, 1650, 800],
# [1240, 675, 1400, 750],
# ]
# },
# {
# "prompt_type":["everything"]
# },
]
init_time = time.time()
segmenter = BaseSegmenter(
device='cuda',
# checkpoint='sam_vit_h_4b8939.pth',
checkpoint='segmenter/sam_vit_h_4b8939.pth',
model_type='vit_h',
reuse_feature=True
)
print(f'init time: {time.time() - init_time}')
image_path = 'test_images/img2.jpg'
infer_time = time.time()
for i, prompt in enumerate(prompts):
print(f'{prompt["prompt_type"]} mode')
image = Image.open(image_path)
segmenter.set_image(np.array(image))
masks = segmenter.inference(np.array(image), prompt)
Image.fromarray(masks[0]).save('seg.png')
print(masks.shape)
print(f'infer time: {time.time() - infer_time}')