Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
# Grounding DINO, slightly modified from original repo | |
import GroundingDINO.groundingdino.datasets.transforms as T | |
from GroundingDINO.groundingdino.models import build_model | |
from GroundingDINO.groundingdino.util.slconfig import SLConfig | |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap | |
# segment anything | |
from SAM.segment_anything import build_sam, SamPredictor | |
# ImageNet pretrained feature extractor | |
from .modelinet import ModelINet | |
class Model(torch.nn.Module): | |
def __init__(self, | |
## DINO | |
dino_config_file, | |
dino_checkpoint, | |
## SAM | |
sam_checkpoint, | |
## Parameters | |
box_threshold, | |
text_threshold, | |
## Others | |
out_size=256, | |
device='cuda', | |
): | |
''' | |
Args: | |
dino_config_file: the config file for DINO | |
dino_checkpoint: the path of checkpoint for DINO | |
sam_checkpoint: the path of checkpoint for SAM | |
box_threshold: the threshold for box filter | |
text_threshold: the threshold for box filter | |
out_size: the desired output resolution of anomaly map | |
device: the running device, e.g, 'cuda:0' | |
NOTE: | |
1. In our published paper, the property prompt P^P is applied to R (region). | |
Actually, we apply P^P to bounding box-level region R^B in this repo. | |
2. We haven't added IoU constraint in this repo. | |
3. This module only accepts BS=1. | |
''' | |
super(Model, self).__init__() | |
# Build Model | |
self.anomaly_region_generator = self.load_dino(dino_config_file, dino_checkpoint, device=device) | |
self.anomaly_region_refiner = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device)) | |
self.transform = T.Compose( | |
[ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
self.visual_saliency_extractor = ModelINet(device=device) | |
self.pixel_mean = [123.675, 116.28, 103.53] | |
self.pixel_std = [58.395, 57.12, 57.375] | |
# Parameters | |
self.box_threshold = box_threshold | |
self.text_threshold = text_threshold | |
# Others | |
self.out_size = out_size | |
self.device = device | |
self.is_sam_set = False | |
def load_dino(self, model_config_path, model_checkpoint_path, device) -> torch.nn.Module: | |
''' | |
Args: | |
model_config_path: | |
model_checkpoint_path: | |
device: | |
Returns: | |
''' | |
args = SLConfig.fromfile(model_config_path) | |
args.device = device | |
model = build_model(args) | |
checkpoint = torch.load(model_checkpoint_path, map_location="cpu") | |
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) | |
_ = model.eval() | |
model = model.to(device) | |
return model | |
def get_grounding_output(self, image, caption, device="cpu") -> (torch.Tensor, torch.Tensor, str): | |
caption = caption.lower() | |
caption = caption.strip() | |
if not caption.endswith("."): | |
caption = caption + "." | |
image = image.to(device) | |
with torch.no_grad(): | |
outputs = self.anomaly_region_generator(image[None], captions=[caption]) | |
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) | |
boxes = outputs["pred_boxes"][0] # (nq, 4) | |
return boxes, logits, caption | |
def set_ensemble_text_prompts(self, text_prompt_list: list, verbose=False) -> None: | |
self.defect_prompt_list = [f[0] for f in text_prompt_list] | |
self.filter_prompt_list = [f[1] for f in text_prompt_list] | |
if verbose: | |
print('used ensemble text prompts ===========') | |
for d, t in zip(self.defect_prompt_list, self.filter_prompt_list): | |
print(f'det prompts: {d}') | |
print(f'filtered background: {t}') | |
print('======================================') | |
def set_property_text_prompts(self, property_prompts, verbose=False) -> None: | |
self.object_prompt = property_prompts.split(' ')[7] | |
self.object_number = int(property_prompts.split(' ')[5]) | |
self.k_mask = int(property_prompts.split(' ')[12]) | |
self.defect_area_threshold = float(property_prompts.split(' ')[19]) | |
self.object_max_area = 1. / self.object_number | |
self.object_min_area = 0. | |
self.similar = property_prompts.split(' ')[6] | |
if verbose: | |
print(f'{self.object_prompt}, ' | |
f'{self.object_number}, ' | |
f'{self.k_mask}, ' | |
f'{self.defect_area_threshold}, ' | |
f'{self.object_max_area}, ' | |
f'{self.object_min_area}') | |
def ensemble_text_guided_mask_proposal(self, image, object_phrase_list, filtered_phrase_list, | |
object_max_area, object_min_area, | |
bbox_score_thr, text_score_thr): | |
size = image.shape[:2] | |
H, W = size[0], size[1] | |
dino_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
dino_image, _ = self.transform(dino_image, None) # 3, h, w | |
if self.is_sam_set == False: | |
self.anomaly_region_refiner.set_image(image) | |
self.is_sam_set = True | |
ensemble_boxes = [] | |
ensemble_logits = [] | |
ensemble_phrases = [] | |
max_box_area = 0. | |
for object_phrase, filtered_phrase in zip(object_phrase_list, filtered_phrase_list): | |
########## language prompts for region proposal | |
boxes, logits, object_phrase = self.text_guided_region_proposal(dino_image, object_phrase) | |
########## property prompts for region filter | |
boxes_filtered, logits_filtered, pred_phrases = self.bbox_suppression(boxes, logits, object_phrase, | |
filtered_phrase, | |
bbox_score_thr, text_score_thr, | |
object_max_area, object_min_area) | |
## in case there is no box left | |
if boxes_filtered is not None: | |
ensemble_boxes += [boxes_filtered] | |
ensemble_logits += logits_filtered | |
ensemble_phrases += pred_phrases | |
boxes_area = boxes_filtered[:, 2] * boxes_filtered[:, 3] | |
if boxes_area.max() > max_box_area: | |
max_box_area = boxes_area.max() | |
if ensemble_boxes != []: | |
ensemble_boxes = torch.cat(ensemble_boxes, dim=0) | |
ensemble_logits = np.stack(ensemble_logits, axis=0) | |
# denormalize the bbox | |
for i in range(ensemble_boxes.size(0)): | |
ensemble_boxes[i] = ensemble_boxes[i] * torch.Tensor([W, H, W, H]).to(self.device) | |
ensemble_boxes[i][:2] -= ensemble_boxes[i][2:] / 2 | |
ensemble_boxes[i][2:] += ensemble_boxes[i][:2] | |
# region 2 mask | |
masks, logits = self.region_refine(ensemble_boxes, ensemble_logits, H, W) | |
else: # in case there is no box left | |
masks = [np.zeros((H, W), dtype=bool)] | |
logits = [0] | |
max_box_area = 1 | |
return masks, logits, max_box_area | |
def text_guided_region_proposal(self, dino_image, object_phrase): | |
# directly use the output of Grounding DINO | |
boxes, logits, caption = self.get_grounding_output( | |
dino_image, object_phrase, device=self.device | |
) | |
return boxes, logits, caption | |
def bbox_suppression(self, boxes, logits, object_phrase, filtered_phrase, | |
bbox_score_thr, text_score_thr, | |
object_max_area, object_min_area, | |
with_logits=True): | |
# filter output | |
logits_filt = logits.clone() | |
boxes_filt = boxes.clone() | |
boxes_area = boxes_filt[:, 2] * boxes_filt[:, 3] | |
# filter the bounding boxes according to the box similarity and the area | |
# strategy1: bbox score thr | |
box_score_mask = logits_filt.max(dim=1)[0] > bbox_score_thr | |
# strategy2: max area | |
box_max_area_mask = boxes_area < (object_max_area) | |
# strategy3: min area | |
box_min_area_mask = boxes_area > (object_min_area) | |
filt_mask = torch.bitwise_and(box_score_mask, box_max_area_mask) | |
filt_mask = torch.bitwise_and(filt_mask, box_min_area_mask) | |
if torch.sum(filt_mask) == 0: # in case there are no matches | |
return None, None, None | |
else: | |
logits_filt = logits_filt[filt_mask] # num_filt, 256 | |
boxes_filt = boxes_filt[filt_mask] # num_filt, 4 | |
# get phrase | |
tokenlizer = self.anomaly_region_generator.tokenizer | |
tokenized = tokenlizer(object_phrase) | |
# build pred | |
pred_phrases = [] | |
boxes_filtered = [] | |
logits_filtered = [] | |
for logit, box in zip(logits_filt, boxes_filt): | |
# strategy4: text score thr | |
pred_phrase = get_phrases_from_posmap(logit > text_score_thr, tokenized, tokenlizer) | |
# strategy5: filter background | |
if pred_phrase.count(filtered_phrase) > 0: # we don't want to predict the category | |
continue | |
if with_logits: | |
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") | |
else: | |
pred_phrases.append(pred_phrase) | |
boxes_filtered.append(box) | |
logits_filtered.append(logit.max().item()) | |
if boxes_filtered == []: | |
return None, None, None | |
boxes_filtered = torch.stack(boxes_filtered, dim=0) | |
return boxes_filtered, logits_filtered, pred_phrases | |
def region_refine(self, boxes_filtered, logits_filtered, H, W): | |
if boxes_filtered == []: | |
return [np.zeros((H, W), dtype=bool)], [0] | |
transformed_boxes = self.anomaly_region_refiner.transform.apply_boxes_torch(boxes_filtered, (H, W)).to( | |
self.device) | |
masks, _, _ = self.anomaly_region_refiner.predict_torch( | |
point_coords=None, | |
point_labels=None, | |
boxes=transformed_boxes, | |
multimask_output=False, | |
) | |
masks = masks.cpu().squeeze(1).numpy() | |
return masks, logits_filtered | |
def saliency_prompting(self, image, object_masks, defect_masks, defect_logits): | |
###### Self Similarity Calculation | |
similarity_map = self.visual_saliency_calculation(image, object_masks) | |
###### Rescore | |
defect_masks, defect_rescores = self.rescore(defect_masks, defect_logits, similarity_map) | |
return defect_masks, defect_rescores, similarity_map | |
def single_object_similarity(self, image, object_masks): | |
# use GPU version... | |
# only consider the feautures of objects | |
# as calculate whole image similarity is memory costly, we use a small resolution here... | |
self.visual_saliency_extractor.set_img_size(256) | |
resize_image = cv2.resize(image, (256, 256)) | |
features, ratio_h, ratio_w = self.visual_saliency_extractor(resize_image) | |
B, C, H, W = features.shape | |
assert B == 1 | |
features_flattern = features.view(B * C, H * W) | |
features_self_similarity = features_flattern.T @ features_flattern | |
features_self_similarity = 0.5 * (1 - features_self_similarity) | |
features_self_similarity = features_self_similarity.sort(dim=1, descending=True)[0] | |
# by default we use N=400 for saliency calculation | |
features_self_similarity = torch.mean(features_self_similarity[:, :400], dim=1) | |
heatMap2 = features_self_similarity.view(H, W).cpu().numpy() | |
mask_anomaly_scores = cv2.resize(heatMap2, (image.shape[1], image.shape[0])) | |
# mask_anomaly_scores[~object_masks] = 0. | |
return mask_anomaly_scores | |
def visual_saliency_calculation(self, image, object_masks): | |
if self.object_number == 1: # use single-instance strategy | |
mask_area = np.sum(object_masks, axis=(1, 2)) | |
object_mask = object_masks[mask_area.argmax(), :, :] | |
self_similarity_anomaly_map = self.single_object_similarity(image, object_mask) | |
return self_similarity_anomaly_map | |
else: # use multi-instance strategy | |
resize_image = cv2.resize(image, (1024, 1024)) | |
features, ratio_h, ratio_w = self.visual_saliency_extractor(resize_image) | |
feature_size = features.shape[2:] | |
object_masks_clone = object_masks.copy() | |
object_masks_clone = object_masks_clone.astype(np.int32) | |
resize_object_masks = [] | |
for object_mask in object_masks_clone: | |
resize_object_masks.append(cv2.resize(object_mask, feature_size, interpolation=cv2.INTER_NEAREST)) | |
mask_anomaly_scores = [] | |
for indx in range(len(resize_object_masks)): | |
other_object_masks1 = resize_object_masks[:indx] | |
other_object_masks2 = resize_object_masks[indx + 1:] | |
other_object_masks = other_object_masks1 + other_object_masks2 | |
one_mask_feature, \ | |
one_feature_location, \ | |
other_mask_features = self.region_feature_extraction( | |
features, | |
resize_object_masks[indx], | |
other_object_masks | |
) | |
similarity = one_mask_feature @ other_mask_features.T # (H*W, N) | |
similarity = similarity.max(dim=1)[0] | |
anomaly_score = 0.5 * (1. - similarity) | |
anomaly_score = anomaly_score.cpu().numpy() | |
mask_anomaly_score = np.zeros(feature_size) | |
for location, score in zip(one_feature_location, anomaly_score): | |
mask_anomaly_score[location[0], location[1]] = score | |
mask_anomaly_scores.append(mask_anomaly_score) | |
mask_anomaly_scores = np.stack(mask_anomaly_scores, axis=0) | |
mask_anomaly_scores = np.max(mask_anomaly_scores, axis=0) | |
mask_anomaly_scores = cv2.resize(mask_anomaly_scores, (image.shape[1], image.shape[0])) | |
return mask_anomaly_scores | |
def region_feature_extraction(self, features, one_object_mask, other_object_masks): | |
''' | |
Use ImageNet pretraine network to extract features for mask | |
Args: | |
features: | |
one_object_mask: | |
other_object_masks: | |
Returns: | |
''' | |
features_clone = features.clone() | |
one_mask_feature = [] | |
one_feature_location = [] | |
for h in range(one_object_mask.shape[0]): | |
for w in range(one_object_mask.shape[1]): | |
if one_object_mask[h, w] > 0: | |
one_mask_feature += [features_clone[:, :, h, w].clone()] | |
one_feature_location += [np.array((h, w))] | |
features_clone[:, :, h, w] = 0. | |
one_feature_location = np.stack(one_feature_location, axis=0) | |
one_mask_feature = torch.cat(one_mask_feature, dim=0) | |
B, C, H, W = features_clone.shape | |
assert B == 1 | |
features_clone_flattern = features_clone.view(C, -1) | |
other_mask_features = [] | |
for other_object_mask in other_object_masks: | |
other_object_mask_flattern = other_object_mask.reshape(-1) | |
other_mask_feature = features_clone_flattern[:, other_object_mask_flattern > 0] | |
other_mask_features.append(other_mask_feature) | |
other_mask_features = torch.cat(other_mask_features, dim=1).T | |
return one_mask_feature, one_feature_location, other_mask_features | |
def rescore(self, defect_masks, defect_logits, similarity_map): | |
defect_rescores = [] | |
for mask, logit in zip(defect_masks, defect_logits): | |
if similarity_map[mask].size == 0: | |
similarity_score = 1. | |
else: | |
similarity_score = np.exp(3 * similarity_map[mask].mean()) | |
refined_score = logit * similarity_score | |
defect_rescores.append(refined_score) | |
defect_rescores = np.stack(defect_rescores, axis=0) | |
return defect_masks, defect_rescores | |
def confidence_prompting(self, defect_masks, defect_scores, similarity_map): | |
mask_indx = defect_scores.argsort()[-self.k_mask:] | |
filtered_masks = [] | |
filtered_scores = [] | |
for indx in mask_indx: | |
filtered_masks.append(defect_masks[indx]) | |
filtered_scores.append(defect_scores[indx]) | |
anomaly_map = np.zeros(defect_masks[0].shape) | |
weight_map = np.ones(defect_masks[0].shape) | |
for mask, logits in zip(filtered_masks, filtered_scores): | |
anomaly_map += mask * logits | |
weight_map += mask * 1. | |
anomaly_map[weight_map > 0] /= weight_map[weight_map > 0] | |
anomaly_map = cv2.resize(anomaly_map, (self.out_size, self.out_size)) | |
return anomaly_map | |
def forward(self, image: np.ndarray): | |
####### Object TGMP for object detection | |
object_masks, object_logits, object_area = self.ensemble_text_guided_mask_proposal( | |
image, | |
[self.object_prompt], | |
['PlaceHolder'], | |
self.object_max_area, | |
self.object_min_area, | |
self.box_threshold, | |
self.text_threshold | |
) | |
###### Reasoning: set the anomaly area threshold according to object area | |
self.defect_max_area = object_area * self.defect_area_threshold | |
self.defect_min_area = 0. | |
####### language prompts and property prompts $\mathcal{P}^L$ $\mathcal{P}^S$ | |
####### for region proposal and filter | |
defect_masks, defect_logits, _ = self.ensemble_text_guided_mask_proposal( | |
image, | |
self.defect_prompt_list, | |
self.filter_prompt_list, | |
self.defect_max_area, | |
self.defect_min_area, | |
self.box_threshold, | |
self.text_threshold | |
) | |
###### saliency prompts $\mathcal{P}^S$ | |
defect_masks, defect_rescores, similarity_map = self.saliency_prompting( | |
image, | |
object_masks, | |
defect_masks, | |
defect_logits | |
) | |
##### confidence prompts $\mathcal{P}^C$ | |
anomaly_map = self.confidence_prompting(defect_masks, defect_rescores, similarity_map) | |
self.is_sam_set = False | |
appendix = {'similarity_map': similarity_map} | |
return anomaly_map, appendix | |