File size: 19,079 Bytes
32faf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
import cv2
import numpy as np
import torch
from PIL import Image

# Grounding DINO, slightly modified from original repo
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
# segment anything
from SAM.segment_anything import build_sam, SamPredictor
# ImageNet pretrained feature extractor
from .modelinet import ModelINet


class Model(torch.nn.Module):
    def __init__(self,
                 ## DINO
                 dino_config_file,
                 dino_checkpoint,

                 ## SAM
                 sam_checkpoint,

                 ## Parameters
                 box_threshold,
                 text_threshold,

                 ## Others
                 out_size=256,
                 device='cuda',

                 ):
        '''

        Args:
            dino_config_file: the config file for DINO
            dino_checkpoint: the path of checkpoint for DINO
            sam_checkpoint: the path of checkpoint for SAM
            box_threshold: the threshold for box filter
            text_threshold: the threshold for box filter
            out_size: the desired output resolution of anomaly map
            device: the running device, e.g, 'cuda:0'

        NOTE:
            1. In our published paper, the property prompt P^P is applied to R (region).
            Actually, we apply P^P to bounding box-level region R^B in this repo.
            2. We haven't added IoU constraint in this repo.
            3. This module only accepts BS=1.
        '''
        super(Model, self).__init__()

        # Build Model
        self.anomaly_region_generator = self.load_dino(dino_config_file, dino_checkpoint, device=device)
        self.anomaly_region_refiner = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))

        self.transform = T.Compose(
            [
                T.RandomResize([800], max_size=1333),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )

        self.visual_saliency_extractor = ModelINet(device=device)

        self.pixel_mean = [123.675, 116.28, 103.53]
        self.pixel_std = [58.395, 57.12, 57.375]

        # Parameters
        self.box_threshold = box_threshold
        self.text_threshold = text_threshold

        # Others
        self.out_size = out_size
        self.device = device
        self.is_sam_set = False

    def load_dino(self, model_config_path, model_checkpoint_path, device) -> torch.nn.Module:
        '''

        Args:
            model_config_path:
            model_checkpoint_path:
            device:

        Returns:

        '''
        args = SLConfig.fromfile(model_config_path)
        args.device = device
        model = build_model(args)
        checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
        model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
        _ = model.eval()
        model = model.to(device)
        return model

    def get_grounding_output(self, image, caption, device="cpu") -> (torch.Tensor, torch.Tensor, str):
        caption = caption.lower()
        caption = caption.strip()

        if not caption.endswith("."):
            caption = caption + "."
        image = image.to(device)

        with torch.no_grad():
            outputs = self.anomaly_region_generator(image[None], captions=[caption])

        logits = outputs["pred_logits"].sigmoid()[0]  # (nq, 256)
        boxes = outputs["pred_boxes"][0]  # (nq, 4)

        return boxes, logits, caption

    def set_ensemble_text_prompts(self, text_prompt_list: list, verbose=False) -> None:
        self.defect_prompt_list = [f[0] for f in text_prompt_list]
        self.filter_prompt_list = [f[1] for f in text_prompt_list]

        if verbose:
            print('used ensemble text prompts ===========')

            for d, t in zip(self.defect_prompt_list, self.filter_prompt_list):
                print(f'det prompts: {d}')
                print(f'filtered background: {t}')

            print('======================================')

    def set_property_text_prompts(self, property_prompts, verbose=False) -> None:

        self.object_prompt = property_prompts.split(' ')[7]
        self.object_number = int(property_prompts.split(' ')[5])
        self.k_mask = int(property_prompts.split(' ')[12])
        self.defect_area_threshold = float(property_prompts.split(' ')[19])
        self.object_max_area = 1. / self.object_number
        self.object_min_area = 0.
        self.similar = property_prompts.split(' ')[6]

        if verbose:
            print(f'{self.object_prompt}, '
                  f'{self.object_number}, '
                  f'{self.k_mask}, '
                  f'{self.defect_area_threshold}, '
                  f'{self.object_max_area}, '
                  f'{self.object_min_area}')

    def ensemble_text_guided_mask_proposal(self, image, object_phrase_list, filtered_phrase_list,
                                           object_max_area, object_min_area,
                                           bbox_score_thr, text_score_thr):

        size = image.shape[:2]
        H, W = size[0], size[1]

        dino_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        dino_image, _ = self.transform(dino_image, None)  # 3, h, w

        if self.is_sam_set == False:
            self.anomaly_region_refiner.set_image(image)
            self.is_sam_set = True

        ensemble_boxes = []
        ensemble_logits = []
        ensemble_phrases = []

        max_box_area = 0.

        for object_phrase, filtered_phrase in zip(object_phrase_list, filtered_phrase_list):

            ########## language prompts for region proposal
            boxes, logits, object_phrase = self.text_guided_region_proposal(dino_image, object_phrase)

            ########## property prompts for region filter
            boxes_filtered, logits_filtered, pred_phrases = self.bbox_suppression(boxes, logits, object_phrase,
                                                                                  filtered_phrase,
                                                                                  bbox_score_thr, text_score_thr,
                                                                                  object_max_area, object_min_area)
            ## in case there is no box left
            if boxes_filtered is not None:
                ensemble_boxes += [boxes_filtered]
                ensemble_logits += logits_filtered
                ensemble_phrases += pred_phrases

                boxes_area = boxes_filtered[:, 2] * boxes_filtered[:, 3]

                if boxes_area.max() > max_box_area:
                    max_box_area = boxes_area.max()

        if ensemble_boxes != []:
            ensemble_boxes = torch.cat(ensemble_boxes, dim=0)
            ensemble_logits = np.stack(ensemble_logits, axis=0)

            # denormalize the bbox
            for i in range(ensemble_boxes.size(0)):
                ensemble_boxes[i] = ensemble_boxes[i] * torch.Tensor([W, H, W, H]).to(self.device)
                ensemble_boxes[i][:2] -= ensemble_boxes[i][2:] / 2
                ensemble_boxes[i][2:] += ensemble_boxes[i][:2]

            # region 2 mask
            masks, logits = self.region_refine(ensemble_boxes, ensemble_logits, H, W)

        else:  # in case there is no box left
            masks = [np.zeros((H, W), dtype=bool)]
            logits = [0]
            max_box_area = 1

        return masks, logits, max_box_area

    def text_guided_region_proposal(self, dino_image, object_phrase):
        # directly use the output of Grounding DINO
        boxes, logits, caption = self.get_grounding_output(
            dino_image, object_phrase, device=self.device
        )

        return boxes, logits, caption

    def bbox_suppression(self, boxes, logits, object_phrase, filtered_phrase,
                         bbox_score_thr, text_score_thr,
                         object_max_area, object_min_area,
                         with_logits=True):

        # filter output
        logits_filt = logits.clone()
        boxes_filt = boxes.clone()
        boxes_area = boxes_filt[:, 2] * boxes_filt[:, 3]

        # filter the bounding boxes according to the box similarity and the area

        # strategy1: bbox score thr
        box_score_mask = logits_filt.max(dim=1)[0] > bbox_score_thr

        # strategy2: max area
        box_max_area_mask = boxes_area < (object_max_area)

        # strategy3: min area
        box_min_area_mask = boxes_area > (object_min_area)

        filt_mask = torch.bitwise_and(box_score_mask, box_max_area_mask)
        filt_mask = torch.bitwise_and(filt_mask, box_min_area_mask)

        if torch.sum(filt_mask) == 0:  # in case there are no matches
            return None, None, None
        else:
            logits_filt = logits_filt[filt_mask]  # num_filt, 256
            boxes_filt = boxes_filt[filt_mask]  # num_filt, 4

        # get phrase
        tokenlizer = self.anomaly_region_generator.tokenizer
        tokenized = tokenlizer(object_phrase)

        # build pred
        pred_phrases = []
        boxes_filtered = []
        logits_filtered = []
        for logit, box in zip(logits_filt, boxes_filt):
            # strategy4: text score thr
            pred_phrase = get_phrases_from_posmap(logit > text_score_thr, tokenized, tokenlizer)

            # strategy5: filter background
            if pred_phrase.count(filtered_phrase) > 0:  # we don't want to predict the category
                continue

            if with_logits:
                pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
            else:
                pred_phrases.append(pred_phrase)

            boxes_filtered.append(box)
            logits_filtered.append(logit.max().item())

        if boxes_filtered == []:
            return None, None, None

        boxes_filtered = torch.stack(boxes_filtered, dim=0)

        return boxes_filtered, logits_filtered, pred_phrases

    def region_refine(self, boxes_filtered, logits_filtered, H, W):
        if boxes_filtered == []:
            return [np.zeros((H, W), dtype=bool)], [0]

        transformed_boxes = self.anomaly_region_refiner.transform.apply_boxes_torch(boxes_filtered, (H, W)).to(
            self.device)

        masks, _, _ = self.anomaly_region_refiner.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )

        masks = masks.cpu().squeeze(1).numpy()

        return masks, logits_filtered

    def saliency_prompting(self, image, object_masks, defect_masks, defect_logits):

        ###### Self Similarity Calculation
        similarity_map = self.visual_saliency_calculation(image, object_masks)

        ###### Rescore
        defect_masks, defect_rescores = self.rescore(defect_masks, defect_logits, similarity_map)

        return defect_masks, defect_rescores, similarity_map

    def single_object_similarity(self, image, object_masks):
        # use GPU version...
        # only consider the feautures of objects

        # as calculate whole image similarity is memory costly, we use a small resolution here...
        self.visual_saliency_extractor.set_img_size(256)
        resize_image = cv2.resize(image, (256, 256))
        features, ratio_h, ratio_w = self.visual_saliency_extractor(resize_image)

        B, C, H, W = features.shape
        assert B == 1
        features_flattern = features.view(B * C, H * W)

        features_self_similarity = features_flattern.T @ features_flattern
        features_self_similarity = 0.5 * (1 - features_self_similarity)

        features_self_similarity = features_self_similarity.sort(dim=1, descending=True)[0]

        # by default we use N=400 for saliency calculation
        features_self_similarity = torch.mean(features_self_similarity[:, :400], dim=1)
        heatMap2 = features_self_similarity.view(H, W).cpu().numpy()

        mask_anomaly_scores = cv2.resize(heatMap2, (image.shape[1], image.shape[0]))
        # mask_anomaly_scores[~object_masks] = 0.
        return mask_anomaly_scores

    def visual_saliency_calculation(self, image, object_masks):

        if self.object_number == 1:  # use single-instance strategy
            mask_area = np.sum(object_masks, axis=(1, 2))
            object_mask = object_masks[mask_area.argmax(), :, :]
            self_similarity_anomaly_map = self.single_object_similarity(image, object_mask)
            return self_similarity_anomaly_map

        else:  # use multi-instance strategy
            resize_image = cv2.resize(image, (1024, 1024))
            features, ratio_h, ratio_w = self.visual_saliency_extractor(resize_image)

            feature_size = features.shape[2:]
            object_masks_clone = object_masks.copy()
            object_masks_clone = object_masks_clone.astype(np.int32)

            resize_object_masks = []
            for object_mask in object_masks_clone:
                resize_object_masks.append(cv2.resize(object_mask, feature_size, interpolation=cv2.INTER_NEAREST))

            mask_anomaly_scores = []

            for indx in range(len(resize_object_masks)):
                other_object_masks1 = resize_object_masks[:indx]
                other_object_masks2 = resize_object_masks[indx + 1:]
                other_object_masks = other_object_masks1 + other_object_masks2

                one_mask_feature, \
                one_feature_location, \
                other_mask_features = self.region_feature_extraction(
                    features,
                    resize_object_masks[indx],
                    other_object_masks
                )

                similarity = one_mask_feature @ other_mask_features.T  # (H*W, N)
                similarity = similarity.max(dim=1)[0]
                anomaly_score = 0.5 * (1. - similarity)
                anomaly_score = anomaly_score.cpu().numpy()

                mask_anomaly_score = np.zeros(feature_size)
                for location, score in zip(one_feature_location, anomaly_score):
                    mask_anomaly_score[location[0], location[1]] = score

                mask_anomaly_scores.append(mask_anomaly_score)

            mask_anomaly_scores = np.stack(mask_anomaly_scores, axis=0)
            mask_anomaly_scores = np.max(mask_anomaly_scores, axis=0)
            mask_anomaly_scores = cv2.resize(mask_anomaly_scores, (image.shape[1], image.shape[0]))

            return mask_anomaly_scores

    def region_feature_extraction(self, features, one_object_mask, other_object_masks):
        '''
        Use ImageNet pretraine network to extract features for mask
        Args:
            features:
            one_object_mask:
            other_object_masks:

        Returns:

        '''
        features_clone = features.clone()
        one_mask_feature = []
        one_feature_location = []
        for h in range(one_object_mask.shape[0]):
            for w in range(one_object_mask.shape[1]):
                if one_object_mask[h, w] > 0:
                    one_mask_feature += [features_clone[:, :, h, w].clone()]
                    one_feature_location += [np.array((h, w))]
                    features_clone[:, :, h, w] = 0.

        one_feature_location = np.stack(one_feature_location, axis=0)
        one_mask_feature = torch.cat(one_mask_feature, dim=0)

        B, C, H, W = features_clone.shape
        assert B == 1
        features_clone_flattern = features_clone.view(C, -1)

        other_mask_features = []
        for other_object_mask in other_object_masks:
            other_object_mask_flattern = other_object_mask.reshape(-1)
            other_mask_feature = features_clone_flattern[:, other_object_mask_flattern > 0]
            other_mask_features.append(other_mask_feature)

        other_mask_features = torch.cat(other_mask_features, dim=1).T

        return one_mask_feature, one_feature_location, other_mask_features

    def rescore(self, defect_masks, defect_logits, similarity_map):
        defect_rescores = []
        for mask, logit in zip(defect_masks, defect_logits):
            if similarity_map[mask].size == 0:
                similarity_score = 1.
            else:
                similarity_score = np.exp(3 * similarity_map[mask].mean())

            refined_score = logit * similarity_score
            defect_rescores.append(refined_score)

        defect_rescores = np.stack(defect_rescores, axis=0)

        return defect_masks, defect_rescores

    def confidence_prompting(self, defect_masks, defect_scores, similarity_map):
        mask_indx = defect_scores.argsort()[-self.k_mask:]

        filtered_masks = []
        filtered_scores = []

        for indx in mask_indx:
            filtered_masks.append(defect_masks[indx])
            filtered_scores.append(defect_scores[indx])

        anomaly_map = np.zeros(defect_masks[0].shape)
        weight_map = np.ones(defect_masks[0].shape)

        for mask, logits in zip(filtered_masks, filtered_scores):
            anomaly_map += mask * logits
            weight_map += mask * 1.

        anomaly_map[weight_map > 0] /= weight_map[weight_map > 0]
        anomaly_map = cv2.resize(anomaly_map, (self.out_size, self.out_size))
        return anomaly_map

    def forward(self, image: np.ndarray):
        ####### Object TGMP for object detection
        object_masks, object_logits, object_area = self.ensemble_text_guided_mask_proposal(
            image,
            [self.object_prompt],
            ['PlaceHolder'],
            self.object_max_area,
            self.object_min_area,
            self.box_threshold,
            self.text_threshold
        )

        ###### Reasoning: set the anomaly area threshold according to object area
        self.defect_max_area = object_area * self.defect_area_threshold
        self.defect_min_area = 0.

        ####### language prompts and property prompts $\mathcal{P}^L$ $\mathcal{P}^S$
        ####### for region proposal and filter
        defect_masks, defect_logits, _ = self.ensemble_text_guided_mask_proposal(
            image,
            self.defect_prompt_list,
            self.filter_prompt_list,
            self.defect_max_area,
            self.defect_min_area,
            self.box_threshold,
            self.text_threshold
        )

        ###### saliency prompts $\mathcal{P}^S$
        defect_masks, defect_rescores, similarity_map = self.saliency_prompting(
            image,
            object_masks,
            defect_masks,
            defect_logits
        )

        ##### confidence prompts $\mathcal{P}^C$
        anomaly_map = self.confidence_prompting(defect_masks, defect_rescores, similarity_map)

        self.is_sam_set = False

        appendix = {'similarity_map': similarity_map}

        return anomaly_map, appendix