cyk
SAA+
32faf2b
import cv2
import numpy as np
import torch
from PIL import Image
# Grounding DINO, slightly modified from original repo
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from SAM.segment_anything import build_sam, SamPredictor
# ImageNet pretrained feature extractor
from .modelinet import ModelINet
class Model(torch.nn.Module):
def __init__(self,
## DINO
dino_config_file,
dino_checkpoint,
## SAM
sam_checkpoint,
## Parameters
box_threshold,
text_threshold,
## Others
out_size=256,
device='cuda',
):
'''
Args:
dino_config_file: the config file for DINO
dino_checkpoint: the path of checkpoint for DINO
sam_checkpoint: the path of checkpoint for SAM
box_threshold: the threshold for box filter
text_threshold: the threshold for box filter
out_size: the desired output resolution of anomaly map
device: the running device, e.g, 'cuda:0'
NOTE:
1. In our published paper, the property prompt P^P is applied to R (region).
Actually, we apply P^P to bounding box-level region R^B in this repo.
2. We haven't added IoU constraint in this repo.
3. This module only accepts BS=1.
'''
super(Model, self).__init__()
# Build Model
self.anomaly_region_generator = self.load_dino(dino_config_file, dino_checkpoint, device=device)
self.anomaly_region_refiner = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
self.transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
self.visual_saliency_extractor = ModelINet(device=device)
self.pixel_mean = [123.675, 116.28, 103.53]
self.pixel_std = [58.395, 57.12, 57.375]
# Parameters
self.box_threshold = box_threshold
self.text_threshold = text_threshold
# Others
self.out_size = out_size
self.device = device
self.is_sam_set = False
def load_dino(self, model_config_path, model_checkpoint_path, device) -> torch.nn.Module:
'''
Args:
model_config_path:
model_checkpoint_path:
device:
Returns:
'''
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
_ = model.eval()
model = model.to(device)
return model
def get_grounding_output(self, image, caption, device="cpu") -> (torch.Tensor, torch.Tensor, str):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
image = image.to(device)
with torch.no_grad():
outputs = self.anomaly_region_generator(image[None], captions=[caption])
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)
return boxes, logits, caption
def set_ensemble_text_prompts(self, text_prompt_list: list, verbose=False) -> None:
self.defect_prompt_list = [f[0] for f in text_prompt_list]
self.filter_prompt_list = [f[1] for f in text_prompt_list]
if verbose:
print('used ensemble text prompts ===========')
for d, t in zip(self.defect_prompt_list, self.filter_prompt_list):
print(f'det prompts: {d}')
print(f'filtered background: {t}')
print('======================================')
def set_property_text_prompts(self, property_prompts, verbose=False) -> None:
self.object_prompt = property_prompts.split(' ')[7]
self.object_number = int(property_prompts.split(' ')[5])
self.k_mask = int(property_prompts.split(' ')[12])
self.defect_area_threshold = float(property_prompts.split(' ')[19])
self.object_max_area = 1. / self.object_number
self.object_min_area = 0.
self.similar = property_prompts.split(' ')[6]
if verbose:
print(f'{self.object_prompt}, '
f'{self.object_number}, '
f'{self.k_mask}, '
f'{self.defect_area_threshold}, '
f'{self.object_max_area}, '
f'{self.object_min_area}')
def ensemble_text_guided_mask_proposal(self, image, object_phrase_list, filtered_phrase_list,
object_max_area, object_min_area,
bbox_score_thr, text_score_thr):
size = image.shape[:2]
H, W = size[0], size[1]
dino_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
dino_image, _ = self.transform(dino_image, None) # 3, h, w
if self.is_sam_set == False:
self.anomaly_region_refiner.set_image(image)
self.is_sam_set = True
ensemble_boxes = []
ensemble_logits = []
ensemble_phrases = []
max_box_area = 0.
for object_phrase, filtered_phrase in zip(object_phrase_list, filtered_phrase_list):
########## language prompts for region proposal
boxes, logits, object_phrase = self.text_guided_region_proposal(dino_image, object_phrase)
########## property prompts for region filter
boxes_filtered, logits_filtered, pred_phrases = self.bbox_suppression(boxes, logits, object_phrase,
filtered_phrase,
bbox_score_thr, text_score_thr,
object_max_area, object_min_area)
## in case there is no box left
if boxes_filtered is not None:
ensemble_boxes += [boxes_filtered]
ensemble_logits += logits_filtered
ensemble_phrases += pred_phrases
boxes_area = boxes_filtered[:, 2] * boxes_filtered[:, 3]
if boxes_area.max() > max_box_area:
max_box_area = boxes_area.max()
if ensemble_boxes != []:
ensemble_boxes = torch.cat(ensemble_boxes, dim=0)
ensemble_logits = np.stack(ensemble_logits, axis=0)
# denormalize the bbox
for i in range(ensemble_boxes.size(0)):
ensemble_boxes[i] = ensemble_boxes[i] * torch.Tensor([W, H, W, H]).to(self.device)
ensemble_boxes[i][:2] -= ensemble_boxes[i][2:] / 2
ensemble_boxes[i][2:] += ensemble_boxes[i][:2]
# region 2 mask
masks, logits = self.region_refine(ensemble_boxes, ensemble_logits, H, W)
else: # in case there is no box left
masks = [np.zeros((H, W), dtype=bool)]
logits = [0]
max_box_area = 1
return masks, logits, max_box_area
def text_guided_region_proposal(self, dino_image, object_phrase):
# directly use the output of Grounding DINO
boxes, logits, caption = self.get_grounding_output(
dino_image, object_phrase, device=self.device
)
return boxes, logits, caption
def bbox_suppression(self, boxes, logits, object_phrase, filtered_phrase,
bbox_score_thr, text_score_thr,
object_max_area, object_min_area,
with_logits=True):
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
boxes_area = boxes_filt[:, 2] * boxes_filt[:, 3]
# filter the bounding boxes according to the box similarity and the area
# strategy1: bbox score thr
box_score_mask = logits_filt.max(dim=1)[0] > bbox_score_thr
# strategy2: max area
box_max_area_mask = boxes_area < (object_max_area)
# strategy3: min area
box_min_area_mask = boxes_area > (object_min_area)
filt_mask = torch.bitwise_and(box_score_mask, box_max_area_mask)
filt_mask = torch.bitwise_and(filt_mask, box_min_area_mask)
if torch.sum(filt_mask) == 0: # in case there are no matches
return None, None, None
else:
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
# get phrase
tokenlizer = self.anomaly_region_generator.tokenizer
tokenized = tokenlizer(object_phrase)
# build pred
pred_phrases = []
boxes_filtered = []
logits_filtered = []
for logit, box in zip(logits_filt, boxes_filt):
# strategy4: text score thr
pred_phrase = get_phrases_from_posmap(logit > text_score_thr, tokenized, tokenlizer)
# strategy5: filter background
if pred_phrase.count(filtered_phrase) > 0: # we don't want to predict the category
continue
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
boxes_filtered.append(box)
logits_filtered.append(logit.max().item())
if boxes_filtered == []:
return None, None, None
boxes_filtered = torch.stack(boxes_filtered, dim=0)
return boxes_filtered, logits_filtered, pred_phrases
def region_refine(self, boxes_filtered, logits_filtered, H, W):
if boxes_filtered == []:
return [np.zeros((H, W), dtype=bool)], [0]
transformed_boxes = self.anomaly_region_refiner.transform.apply_boxes_torch(boxes_filtered, (H, W)).to(
self.device)
masks, _, _ = self.anomaly_region_refiner.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks = masks.cpu().squeeze(1).numpy()
return masks, logits_filtered
def saliency_prompting(self, image, object_masks, defect_masks, defect_logits):
###### Self Similarity Calculation
similarity_map = self.visual_saliency_calculation(image, object_masks)
###### Rescore
defect_masks, defect_rescores = self.rescore(defect_masks, defect_logits, similarity_map)
return defect_masks, defect_rescores, similarity_map
def single_object_similarity(self, image, object_masks):
# use GPU version...
# only consider the feautures of objects
# as calculate whole image similarity is memory costly, we use a small resolution here...
self.visual_saliency_extractor.set_img_size(256)
resize_image = cv2.resize(image, (256, 256))
features, ratio_h, ratio_w = self.visual_saliency_extractor(resize_image)
B, C, H, W = features.shape
assert B == 1
features_flattern = features.view(B * C, H * W)
features_self_similarity = features_flattern.T @ features_flattern
features_self_similarity = 0.5 * (1 - features_self_similarity)
features_self_similarity = features_self_similarity.sort(dim=1, descending=True)[0]
# by default we use N=400 for saliency calculation
features_self_similarity = torch.mean(features_self_similarity[:, :400], dim=1)
heatMap2 = features_self_similarity.view(H, W).cpu().numpy()
mask_anomaly_scores = cv2.resize(heatMap2, (image.shape[1], image.shape[0]))
# mask_anomaly_scores[~object_masks] = 0.
return mask_anomaly_scores
def visual_saliency_calculation(self, image, object_masks):
if self.object_number == 1: # use single-instance strategy
mask_area = np.sum(object_masks, axis=(1, 2))
object_mask = object_masks[mask_area.argmax(), :, :]
self_similarity_anomaly_map = self.single_object_similarity(image, object_mask)
return self_similarity_anomaly_map
else: # use multi-instance strategy
resize_image = cv2.resize(image, (1024, 1024))
features, ratio_h, ratio_w = self.visual_saliency_extractor(resize_image)
feature_size = features.shape[2:]
object_masks_clone = object_masks.copy()
object_masks_clone = object_masks_clone.astype(np.int32)
resize_object_masks = []
for object_mask in object_masks_clone:
resize_object_masks.append(cv2.resize(object_mask, feature_size, interpolation=cv2.INTER_NEAREST))
mask_anomaly_scores = []
for indx in range(len(resize_object_masks)):
other_object_masks1 = resize_object_masks[:indx]
other_object_masks2 = resize_object_masks[indx + 1:]
other_object_masks = other_object_masks1 + other_object_masks2
one_mask_feature, \
one_feature_location, \
other_mask_features = self.region_feature_extraction(
features,
resize_object_masks[indx],
other_object_masks
)
similarity = one_mask_feature @ other_mask_features.T # (H*W, N)
similarity = similarity.max(dim=1)[0]
anomaly_score = 0.5 * (1. - similarity)
anomaly_score = anomaly_score.cpu().numpy()
mask_anomaly_score = np.zeros(feature_size)
for location, score in zip(one_feature_location, anomaly_score):
mask_anomaly_score[location[0], location[1]] = score
mask_anomaly_scores.append(mask_anomaly_score)
mask_anomaly_scores = np.stack(mask_anomaly_scores, axis=0)
mask_anomaly_scores = np.max(mask_anomaly_scores, axis=0)
mask_anomaly_scores = cv2.resize(mask_anomaly_scores, (image.shape[1], image.shape[0]))
return mask_anomaly_scores
def region_feature_extraction(self, features, one_object_mask, other_object_masks):
'''
Use ImageNet pretraine network to extract features for mask
Args:
features:
one_object_mask:
other_object_masks:
Returns:
'''
features_clone = features.clone()
one_mask_feature = []
one_feature_location = []
for h in range(one_object_mask.shape[0]):
for w in range(one_object_mask.shape[1]):
if one_object_mask[h, w] > 0:
one_mask_feature += [features_clone[:, :, h, w].clone()]
one_feature_location += [np.array((h, w))]
features_clone[:, :, h, w] = 0.
one_feature_location = np.stack(one_feature_location, axis=0)
one_mask_feature = torch.cat(one_mask_feature, dim=0)
B, C, H, W = features_clone.shape
assert B == 1
features_clone_flattern = features_clone.view(C, -1)
other_mask_features = []
for other_object_mask in other_object_masks:
other_object_mask_flattern = other_object_mask.reshape(-1)
other_mask_feature = features_clone_flattern[:, other_object_mask_flattern > 0]
other_mask_features.append(other_mask_feature)
other_mask_features = torch.cat(other_mask_features, dim=1).T
return one_mask_feature, one_feature_location, other_mask_features
def rescore(self, defect_masks, defect_logits, similarity_map):
defect_rescores = []
for mask, logit in zip(defect_masks, defect_logits):
if similarity_map[mask].size == 0:
similarity_score = 1.
else:
similarity_score = np.exp(3 * similarity_map[mask].mean())
refined_score = logit * similarity_score
defect_rescores.append(refined_score)
defect_rescores = np.stack(defect_rescores, axis=0)
return defect_masks, defect_rescores
def confidence_prompting(self, defect_masks, defect_scores, similarity_map):
mask_indx = defect_scores.argsort()[-self.k_mask:]
filtered_masks = []
filtered_scores = []
for indx in mask_indx:
filtered_masks.append(defect_masks[indx])
filtered_scores.append(defect_scores[indx])
anomaly_map = np.zeros(defect_masks[0].shape)
weight_map = np.ones(defect_masks[0].shape)
for mask, logits in zip(filtered_masks, filtered_scores):
anomaly_map += mask * logits
weight_map += mask * 1.
anomaly_map[weight_map > 0] /= weight_map[weight_map > 0]
anomaly_map = cv2.resize(anomaly_map, (self.out_size, self.out_size))
return anomaly_map
def forward(self, image: np.ndarray):
####### Object TGMP for object detection
object_masks, object_logits, object_area = self.ensemble_text_guided_mask_proposal(
image,
[self.object_prompt],
['PlaceHolder'],
self.object_max_area,
self.object_min_area,
self.box_threshold,
self.text_threshold
)
###### Reasoning: set the anomaly area threshold according to object area
self.defect_max_area = object_area * self.defect_area_threshold
self.defect_min_area = 0.
####### language prompts and property prompts $\mathcal{P}^L$ $\mathcal{P}^S$
####### for region proposal and filter
defect_masks, defect_logits, _ = self.ensemble_text_guided_mask_proposal(
image,
self.defect_prompt_list,
self.filter_prompt_list,
self.defect_max_area,
self.defect_min_area,
self.box_threshold,
self.text_threshold
)
###### saliency prompts $\mathcal{P}^S$
defect_masks, defect_rescores, similarity_map = self.saliency_prompting(
image,
object_masks,
defect_masks,
defect_logits
)
##### confidence prompts $\mathcal{P}^C$
anomaly_map = self.confidence_prompting(defect_masks, defect_rescores, similarity_map)
self.is_sam_set = False
appendix = {'similarity_map': similarity_map}
return anomaly_map, appendix