Spaces:
Runtime error
Runtime error
File size: 8,334 Bytes
1f39cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gc
import matplotlib.pyplot as plt
import numpy as np
import torch
from models import torch_device
from transformers import SamModel, SamProcessor
import utils
import cv2
from scipy import ndimage
def load_sam():
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(torch_device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
sam_model_dict = dict(
sam_model = sam_model, sam_processor = sam_processor
)
return sam_model_dict
# Not fully backward compatible with the previous implementation
# Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None):
"""target_mask_shape: (h, w)"""
sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
with torch.no_grad():
with torch.autocast(torch_device):
inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
conf_scores = outputs.iou_scores.to(device="cpu", dtype=torch.float32).numpy()[0,0]
del inputs, outputs
gc.collect()
if torch_device == "cuda":
torch.cuda.empty_cache()
masks = masks[0][0].numpy()
if target_mask_shape is not None:
masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, target_mask_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
return masks, conf_scores
def sam_point_input(sam_model_dict, image, input_points, **kwargs):
return sam(sam_model_dict, image, input_points=input_points, **kwargs)
def sam_box_input(sam_model_dict, image, input_boxes, **kwargs):
return sam(sam_model_dict, image, input_boxes=input_boxes, **kwargs)
def get_iou_with_resize(mask, masks, masks_shape):
masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, masks_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
return utils.iou(mask, masks)
def select_mask(masks, conf_scores, coarse_ious=None, rule="largest_over_conf", discourage_mask_below_confidence=0.85, discourage_mask_below_coarse_iou=0.2, verbose=False):
"""masks: numpy bool array"""
mask_sizes = masks.sum(axis=(1, 2))
# Another possible rule: iou with the attention mask
if rule == "largest_over_conf":
# Use the largest segmentation
# Discourage selecting masks with conf too low or coarse iou is too low
max_mask_size = np.max(mask_sizes)
if coarse_ious is not None:
scores = mask_sizes - (conf_scores < discourage_mask_below_confidence) * max_mask_size - (coarse_ious < discourage_mask_below_coarse_iou) * max_mask_size
else:
scores = mask_sizes - (conf_scores < discourage_mask_below_confidence) * max_mask_size
if verbose:
print(f"mask_sizes: {mask_sizes}, scores: {scores}")
else:
raise ValueError(f"Unknown rule: {rule}")
mask_id = np.argmax(scores)
mask = masks[mask_id]
selection_conf = conf_scores[mask_id]
if coarse_ious is not None:
selection_coarse_iou = coarse_ious[mask_id]
else:
selection_coarse_iou = None
if verbose:
# print(f"Confidences: {conf_scores}")
print(f"Selected a mask with confidence: {selection_conf}, coarse_iou: {selection_coarse_iou}")
if verbose:
plt.figure(figsize=(10, 8))
# plt.suptitle("After SAM")
for ind in range(3):
plt.subplot(1, 3, ind+1)
# This is obtained before resize.
plt.title(f"Mask {ind}, score {scores[ind]}, conf {conf_scores[ind]:.2f}, iou {coarse_ious[ind] if coarse_ious is not None else None:.2f}")
plt.imshow(masks[ind])
plt.tight_layout()
plt.show()
return mask, selection_conf
def preprocess_mask(token_attn_np_smooth, mask_th, n_erode_dilate_mask=0):
token_attn_np_smooth_normalized = token_attn_np_smooth - token_attn_np_smooth.min()
token_attn_np_smooth_normalized /= token_attn_np_smooth_normalized.max()
mask_thresholded = token_attn_np_smooth_normalized > mask_th
if n_erode_dilate_mask:
mask_thresholded = ndimage.binary_erosion(mask_thresholded, iterations=n_erode_dilate_mask)
mask_thresholded = ndimage.binary_dilation(mask_thresholded, iterations=n_erode_dilate_mask)
return mask_thresholded
# The overall pipeline to refine the attention mask
def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H, W, use_box_input, gaussian_sigma, mask_th_for_box, n_erode_dilate_mask_for_box, mask_th_for_point, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
# token_attn_np is for visualizations
token_attn_np_smooth = ndimage.gaussian_filter(token_attn_np, sigma=gaussian_sigma)
# (w, h)
mask_size_scale = height // token_attn_np_smooth.shape[1], width // token_attn_np_smooth.shape[0]
if use_box_input:
# box input
mask_binary = preprocess_mask(token_attn_np_smooth, mask_th_for_box, n_erode_dilate_mask=n_erode_dilate_mask_for_box)
input_boxes = utils.binary_mask_to_box(mask_binary, w_scale=mask_size_scale[0], h_scale=mask_size_scale[1])
input_boxes = [input_boxes]
masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
else:
# point input
mask_binary = preprocess_mask(token_attn_np_smooth, mask_th_for_point, n_erode_dilate_mask=0)
# Uses the max coordinate only
max_coord = np.unravel_index(token_attn_np_smooth.argmax(), token_attn_np_smooth.shape)
# print("max_coord:", max_coord)
input_points = [[[max_coord[1] * mask_size_scale[1], max_coord[0] * mask_size_scale[0]]]]
masks, conf_scores = sam_point_input(model_dict, image=sam_input_image, input_points=input_points, target_mask_shape=(H, W))
if verbose:
plt.title("Coarse binary mask (for box for box input and for iou)")
plt.imshow(mask_binary)
plt.show()
coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
rule="largest_over_conf",
discourage_mask_below_confidence=discourage_mask_below_confidence,
discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
verbose=True)
return mask_selected, conf_score_selected
def sam_refine_box(sam_input_image, box, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
# (w, h)
input_boxes = utils.scale_proportion(box, H=height, W=width)
input_boxes = [input_boxes]
masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
mask_binary = utils.proportion_to_mask(box, H, W, return_np=True)
if verbose:
# Also the box is the input for SAM
plt.title("Binary mask from input box (for iou)")
plt.imshow(mask_binary)
plt.show()
coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
rule="largest_over_conf",
discourage_mask_below_confidence=discourage_mask_below_confidence,
discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
verbose=True)
return mask_selected, conf_score_selected |