Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.cuda.amp as amp | |
from typing import Optional, Tuple | |
from .util.transforms import ResizeLongestSide | |
class RegionSpot_Predictor: | |
def __init__( | |
self, | |
regionspot, | |
) -> None: | |
super().__init__() | |
self.regionspot = regionspot | |
self.model = self.regionspot.sam | |
self.transform = ResizeLongestSide(self.model.image_encoder.img_size) | |
self.reset_image() | |
def set_image( | |
self, | |
image: np.ndarray, | |
image_format: str = "RGB", | |
clip_input_size: int = 224, | |
) -> None: | |
""" | |
Calculates the image embeddings for the provided image, allowing | |
masks to be predicted with the 'predict' method. | |
Arguments: | |
image (np.ndarray): The image for calculating masks. Expects an | |
image in HWC uint8 format, with pixel values in [0, 255]. | |
image_format (str): The color format of the image, in ['RGB', 'BGR']. | |
""" | |
assert image_format in [ | |
"RGB", | |
"BGR", | |
], f"image_format must be in ['RGB', 'BGR'], is {image_format}." | |
if image_format != self.model.image_format: | |
image = image[..., ::-1] | |
# import ipdb; ipdb.set_trace() | |
#resize to 224x224 | |
self.resized_image = self.resize_norm(image, target_size=(clip_input_size, clip_input_size)) | |
# Transform the image to the form expected by the model | |
input_image = self.transform.apply_image(image) | |
input_image_torch = torch.as_tensor(input_image, device=self.device) | |
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] | |
# print(image.shape) | |
self.set_torch_image(input_image_torch, image.shape[:2]) | |
def set_torch_image( | |
self, | |
transformed_image: torch.Tensor, | |
original_image_size: Tuple[int, ...], | |
) -> None: | |
""" | |
Calculates the image embeddings for the provided image, allowing | |
masks to be predicted with the 'predict' method. Expects the input | |
image to be already transformed to the format expected by the model. | |
Arguments: | |
transformed_image (torch.Tensor): The input image, with shape | |
1x3xHxW, which has been transformed with ResizeLongestSide. | |
original_image_size (tuple(int, int)): The size of the image | |
before transformation, in (H, W) format. | |
""" | |
assert ( | |
len(transformed_image.shape) == 4 | |
and transformed_image.shape[1] == 3 | |
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size | |
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." | |
self.reset_image() | |
self.original_size = original_image_size | |
self.input_size = tuple(transformed_image.shape[-2:]) | |
input_image = self.model.preprocess(transformed_image) | |
self.features = self.model.image_encoder(input_image) | |
# print(self.resized_image.shape) | |
self.clip_features = self.regionspot.clip_model.encode_image_featuremap(self.resized_image.cuda()).detach() | |
self.is_image_set = True | |
def predict( | |
self, | |
point_coords: Optional[np.ndarray] = None, | |
point_labels: Optional[np.ndarray] = None, | |
box: Optional[np.ndarray] = None, | |
mask_input: Optional[np.ndarray] = None, | |
multimask_output: bool = True, | |
return_logits: bool = False, | |
########################### | |
mask_threshold = 0.0, | |
########################### | |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
""" | |
Predict masks for the given input prompts, using the currently set image. | |
Arguments: | |
point_coords (np.ndarray or None): A Nx2 array of point prompts to the | |
model. Each point is in (X,Y) in pixels. | |
point_labels (np.ndarray or None): A length N array of labels for the | |
point prompts. 1 indicates a foreground point and 0 indicates a | |
background point. | |
box (np.ndarray or None): A length 4 array given a box prompt to the | |
model, in XYXY format. | |
mask_input (np.ndarray): A low resolution mask input to the model, typically | |
coming from a previous prediction iteration. Has form 1xHxW, where | |
for SAM, H=W=256. | |
multimask_output (bool): If true, the model will return three masks. | |
For ambiguous input prompts (such as a single click), this will often | |
produce better masks than a single prediction. If only a single | |
mask is needed, the model's predicted quality score can be used | |
to select the best mask. For non-ambiguous prompts, such as multiple | |
input prompts, multimask_output=False can give better results. | |
return_logits (bool): If true, returns un-thresholded masks logits | |
instead of a binary mask. | |
Returns: | |
(np.ndarray): The output masks in CxHxW format, where C is the | |
number of masks, and (H, W) is the original image size. | |
(np.ndarray): An array of length C containing the model's | |
predictions for the quality of each mask. | |
(np.ndarray): An array of shape CxHxW, where C is the number | |
of masks and H=W=256. These low resolution logits can be passed to | |
a subsequent iteration as mask input. | |
""" | |
if not self.is_image_set: | |
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") | |
# Transform input prompts | |
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None | |
if point_coords is not None: | |
assert ( | |
point_labels is not None | |
), "point_labels must be supplied if point_coords is supplied." | |
point_coords = self.transform.apply_coords(point_coords, self.original_size) | |
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) | |
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) | |
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] | |
if box is not None: | |
box = self.transform.apply_boxes(box, self.original_size) | |
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) | |
# box_torch = box_torch[None, :] | |
if mask_input is not None: | |
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) | |
mask_input_torch = mask_input_torch[None, :, :, :] | |
masks, iou_predictions, low_res_masks, max_values, max_index = self.predict_torch( | |
coords_torch, | |
labels_torch, | |
box_torch, | |
mask_input_torch, | |
multimask_output, | |
return_logits=return_logits, | |
############################## | |
mask_threshold = mask_threshold, | |
################################ | |
) | |
masks_np = masks.detach().cpu().numpy() | |
iou_predictions_np = iou_predictions.detach().cpu().numpy() | |
low_res_masks_np = low_res_masks.detach().cpu().numpy() | |
max_values = max_values.detach().cpu().numpy() | |
max_index = max_index.detach().cpu().numpy() | |
return masks_np, iou_predictions_np, max_values, max_index | |
def predict_torch( | |
self, | |
point_coords: Optional[torch.Tensor], | |
point_labels: Optional[torch.Tensor], | |
boxes: Optional[torch.Tensor] = None, | |
mask_input: Optional[torch.Tensor] = None, | |
multimask_output: bool = True, | |
return_logits: bool = False, | |
########################### | |
mask_threshold = 0.0, | |
########################### | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Predict masks for the given input prompts, using the currently set image. | |
Input prompts are batched torch tensors and are expected to already be | |
transformed to the input frame using ResizeLongestSide. | |
Arguments: | |
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the | |
model. Each point is in (X,Y) in pixels. | |
point_labels (torch.Tensor or None): A BxN array of labels for the | |
point prompts. 1 indicates a foreground point and 0 indicates a | |
background point. | |
boxes (np.ndarray or None): A Bx4 array given a box prompt to the | |
model, in XYXY format. | |
mask_input (np.ndarray): A low resolution mask input to the model, typically | |
coming from a previous prediction iteration. Has form Bx1xHxW, where | |
for SAM, H=W=256. Masks returned by a previous iteration of the | |
predict method do not need further transformation. | |
multimask_output (bool): If true, the model will return three masks. | |
For ambiguous input prompts (such as a single click), this will often | |
produce better masks than a single prediction. If only a single | |
mask is needed, the model's predicted quality score can be used | |
to select the best mask. For non-ambiguous prompts, such as multiple | |
input prompts, multimask_output=False can give better results. | |
return_logits (bool): If true, returns un-thresholded masks logits | |
instead of a binary mask. | |
Returns: | |
(torch.Tensor): The output masks in BxCxHxW format, where C is the | |
number of masks, and (H, W) is the original image size. | |
(torch.Tensor): An array of shape BxC containing the model's | |
predictions for the quality of each mask. | |
(torch.Tensor): An array of shape BxCxHxW, where C is the number | |
of masks and H=W=256. These low res logits can be passed to | |
a subsequent iteration as mask input. | |
""" | |
if not self.is_image_set: | |
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") | |
if point_coords is not None: | |
points = (point_coords, point_labels) | |
else: | |
points = None | |
# Embed prompts | |
sparse_embeddings, dense_embeddings = self.model.prompt_encoder( | |
points=points, | |
boxes=boxes, | |
masks=mask_input, | |
) | |
# Predict masks | |
low_res_masks, iou_predictions, mask_token = self.model.mask_decoder( | |
image_embeddings=self.features, | |
image_pe=self.model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
) | |
########################## | |
mask_token = mask_token.reshape(-1,1,256) | |
########################## | |
#Predict masks class | |
with amp.autocast(enabled=True): | |
with torch.no_grad(): | |
logits_per_image = self.regionspot.forward_inference(self.clip_features, mask_token.cuda(), self.resized_image.cuda()) | |
# print(logits_per_image.shape) | |
#get class and score | |
# probs_per_image = F.softmax(logits_per_image[0], dim=-1)#.cpu().numpy() # n_token c | |
probs_per_image =logits_per_image[0].sigmoid() | |
max_values, max_index = torch.max(probs_per_image, dim=-1) | |
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) | |
# if not return_logits: | |
# masks = masks > self.model.mask_threshold | |
##################### | |
if not return_logits: | |
masks = masks > mask_threshold | |
##################### | |
return masks, iou_predictions, low_res_masks, max_values, max_index | |
def get_image_embedding(self) -> torch.Tensor: | |
""" | |
Returns the image embeddings for the currently set image, with | |
shape 1xCxHxW, where C is the embedding dimension and (H,W) are | |
the embedding spatial dimension of SAM (typically C=256, H=W=64). | |
""" | |
if not self.is_image_set: | |
raise RuntimeError( | |
"An image must be set with .set_image(...) to generate an embedding." | |
) | |
assert self.features is not None, "Features must exist if an image has been set." | |
return self.features | |
def resize_norm(self, image, target_size=(224, 224)): | |
# Convert the numpy image to a torch tensor and ensure it is in CxHxW format | |
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 | |
# Resize | |
resized_image = F.interpolate(image.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False).squeeze(0) | |
# Apply normalization | |
normalize_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).unsqueeze(1).unsqueeze(2) | |
normalize_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).unsqueeze(1).unsqueeze(2) | |
resized_image = (resized_image - normalize_mean) / normalize_std | |
return resized_image.unsqueeze(0) | |
def device(self) -> torch.device: | |
return self.model.device | |
def reset_image(self) -> None: | |
"""Resets the currently set image.""" | |
self.is_image_set = False | |
self.features = None | |
self.orig_h = None | |
self.orig_w = None | |
self.input_h = None | |
self.input_w = None | |