File size: 4,958 Bytes
81e7bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import sys
import modules.config
import numpy as np
import torch
from extras.GroundingDINO.util.inference import default_groundingdino
from extras.sam.predictor import SamPredictor
from rembg import remove, new_session
from segment_anything import sam_model_registry
from segment_anything.utils.amg import remove_small_regions
class SAMOptions:
def __init__(self,
# GroundingDINO
dino_prompt: str = '',
dino_box_threshold=0.3,
dino_text_threshold=0.25,
dino_erode_or_dilate=0,
dino_debug=False,
# SAM
max_detections=2,
model_type='vit_b'
):
self.dino_prompt = dino_prompt
self.dino_box_threshold = dino_box_threshold
self.dino_text_threshold = dino_text_threshold
self.dino_erode_or_dilate = dino_erode_or_dilate
self.dino_debug = dino_debug
self.max_detections = max_detections
self.model_type = model_type
def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
"""
removes small disconnected regions and holes
"""
fine_masks = []
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
return torch.from_numpy(masks)
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
dino_detection_count = 0
sam_detection_count = 0
sam_detection_on_mask_count = 0
if image is None:
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
if extras is None:
extras = {}
if 'image' in image:
image = image['image']
if mask_model != 'sam' or sam_options is None:
result = remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
**extras
)
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
detections, boxes, logits, phrases = default_groundingdino(
image=image,
caption=sam_options.dino_prompt,
box_threshold=sam_options.dino_box_threshold,
text_threshold=sam_options.dino_text_threshold
)
H, W = image.shape[0], image.shape[1]
boxes = boxes * torch.Tensor([W, H, W, H])
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
sam_checkpoint = modules.config.download_sam_model(sam_options.model_type)
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint)
sam_predictor = SamPredictor(sam)
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
dino_detection_count = boxes.size(0)
if dino_detection_count > 0:
sam_predictor.set_image(image)
if sam_options.dino_erode_or_dilate != 0:
for index in range(boxes.size(0)):
assert boxes.size(1) == 4
boxes[index][0] -= sam_options.dino_erode_or_dilate
boxes[index][1] -= sam_options.dino_erode_or_dilate
boxes[index][2] += sam_options.dino_erode_or_dilate
boxes[index][3] += sam_options.dino_erode_or_dilate
if sam_options.dino_debug:
from PIL import ImageDraw, Image
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
draw = ImageDraw.Draw(debug_dino_image)
for box in boxes.numpy():
draw.rectangle(box.tolist(), fill="white")
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks = optimize_masks(masks)
sam_detection_count = len(masks)
if sam_options.max_detections == 0:
sam_options.max_detections = sys.maxsize
sam_objects = min(len(logits), sam_options.max_detections)
for obj_ind in range(sam_objects):
mask_tensor = masks[obj_ind][0]
final_mask_tensor += mask_tensor
sam_detection_on_mask_count += 1
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
mask_image = np.array(mask_image, dtype=np.uint8)
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
|