File size: 33,264 Bytes
711211a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
import copy
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
from transformers import OwlViTConfig
# from transformers.models.owlvit.modeling_owlvit import OwlViTVisionTransformer

class OwlViTBoxPredictionHead(nn.Module):
    def __init__(self, config: OwlViTConfig):
        super().__init__()

        width = config.vision_config.hidden_size
        self.dense0 = nn.Linear(width, width)
        self.dense1 = nn.Linear(width, width)
        self.dense2 = nn.Linear(width, width)
        self.dense3 = nn.Linear(width, width)
        self.gelu = nn.GELU()
        self.dense4 = nn.Linear(width, 4)

    def forward(self, image_features: torch.Tensor) -> torch.FloatTensor:
        output = self.dense0(image_features)
        output = self.gelu(output)
        output = self.dense1(output)
        output = self.gelu(output)
        output = self.dense2(output)
        output = self.gelu(output)
        output = self.dense3(output)
        output = self.gelu(output)
        output = self.dense4(output)
        output = self.gelu(output)

        return output



class OwlViTClassPredictionHead(nn.Module):
    def __init__(self, config: OwlViTConfig):
        super().__init__()

        out_dim = config.text_config.hidden_size
        self.query_dim = config.vision_config.hidden_size

        self.dense0 = nn.Linear(self.query_dim, out_dim)
        self.logit_shift = nn.Linear(self.query_dim, 1)
        self.logit_scale = nn.Linear(self.query_dim, 1)
        self.elu = nn.ELU()

    def forward(
        self,
        image_embeds: torch.FloatTensor,
        query_embeds: Optional[torch.FloatTensor],
        query_mask: Optional[torch.Tensor],
    ) -> Tuple[torch.FloatTensor]:
        image_class_embeds = self.dense0(image_embeds)
        if query_embeds is None:
            device = image_class_embeds.device
            batch_size, num_patches = image_class_embeds.shape[:2]
            pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
            return (pred_logits, image_class_embeds)

        # Normalize image and text features
        image_class_embeds = F.normalize(image_class_embeds, dim=-1) + 1e-6
        query_embeds = F.normalize(query_embeds, dim=-1) + 1e-6

        # Get class predictions
        pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)

        # Apply a learnable shift and scale to logits
        logit_shift = self.logit_shift(image_embeds)
        logit_scale = self.logit_scale(image_embeds)
        logit_scale = self.elu(logit_scale) + 1
        pred_logits = (pred_logits + logit_shift) * logit_scale

        if query_mask is not None:
            if query_mask.ndim > 1:
                query_mask = torch.unsqueeze(query_mask, dim=-2)

            pred_logits = pred_logits.to(torch.float64)
            pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
            pred_logits = pred_logits.to(torch.float32)

        return (pred_logits, image_class_embeds)


class OwlViTPredictionHead(nn.Module):
    def __init__(self, config: OwlViTConfig, num_classes: int, finetuned: bool):
        super().__init__()

        out_dim = config.text_config.hidden_size
        self.query_dim = config.vision_config.hidden_size
        self.finetuned = finetuned
        self.num_classes = num_classes

        self.mlp_image = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=self.query_dim, out_features=self.query_dim),
            nn.GELU(),
            nn.Linear(in_features=self.query_dim, out_features=self.query_dim),
            nn.GELU(),
            nn.Linear(in_features=self.query_dim, out_features=out_dim),
            nn.GELU(),
        )
        
        # if self.finetuned:
        #     self.cls_head = nn.Sequential(
        #         nn.GELU(),
        #         nn.Linear(in_features=out_dim, out_features=out_dim),
        #         nn.GELU()
        #     )

    def forward(self,
                image_embeds: torch.FloatTensor,
                query_embeds: torch.FloatTensor,
                topk_idxs: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor]:

        # Get class predictions: topk_idxs (batch_size, n_parts, 1), one_hot (batch_size, n_parts, n_patches*n_patches)
        topk_idxs = torch.swapaxes(topk_idxs, 1, 2)
        one_hot = torch.zeros(topk_idxs.shape[0], topk_idxs.shape[1], image_embeds.shape[1]).to(image_embeds.device).scatter_(2, topk_idxs, 1)
        batch_size, n_parts = one_hot.shape[0], one_hot.shape[1]

        # (batch_size, n_parts, 3600, 1) * (batch_size, 1, 3600, 1024) = (batch_size, n_parts, 3600, 1024).sum(dim=-2)
        image_embeds = (one_hot.unsqueeze(-1) * image_embeds.unsqueeze(1)).sum(dim=-2)

        # image_embeds = self.dense0(image_embeds)            # (batch_size, n_patches, 1024) --> (.., .., 768)
        image_embeds = self.mlp_image(image_embeds.view(-1, image_embeds.shape[-1])).view(batch_size, n_parts, -1)
        query_embeds = query_embeds.view(batch_size, -1, query_embeds.shape[-1])

        # if self.finetuned:
        #     image_embeds = self.cls_head(image_embeds)
        #     query_embeds = query_embeds.view(batch_size, -1, query_embeds.shape[-1])
        
        # Normalize image and text features
        image_embeds = F.normalize(image_embeds, dim=-1) + 1e-6  # (batch_size, n_parts, 768)
        query_embeds = F.normalize(query_embeds, dim=-1) + 1e-6  # (batch_size, num_classes * n_parts, 768)

        # Shape: torch.Size([bs, num_boxes, num_classes * num_parts])
        image_text_logits = torch.einsum('bnd, bid -> bni', image_embeds, query_embeds)
        image_text_logits_reshaped = image_text_logits.view(-1, image_text_logits.shape[-1])

        # Shape: (bs, num_classes * num_parts, num_boxes) --> (bs, num_classes, num_parts, num_boxes)
        pred_logits = image_text_logits.swapaxes(axis0=1, axis1=2).view(batch_size, self.num_classes, n_parts, -1)
        pred_logits = torch.diagonal(pred_logits, dim1=-2, dim2=-1)     # --> torch.Size([bs, num_classes, 12])
        #DEBUG: try add sigmoid here to see if it helps. PEIJIE: It does not help. 
        # pred_logits = pred_logits.sigmoid()
        # pred_logits = abs(pred_logits) # for debugging
        
        final_pred_logits = torch.sum(pred_logits, dim=-1)

        return (image_text_logits_reshaped, final_pred_logits, pred_logits)


class OwlViTForClassification(nn.Module):
    config_class = OwlViTConfig

    def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
        super(OwlViTForClassification, self).__init__()

        self.config = owlvit_det_model.config
        self.num_classes = num_classes
        self.num_parts = 12
        self.device = device

        self.sigmoid = nn.Sigmoid()
        self.ce_loss = torch.nn.CrossEntropyLoss()

        # Use CE loss for classification OR only train with contrastive loss
        self.network_type = network_type
        self.logits_from_teacher = logits_from_teacher

        # Initialize OwlViT model from the teacher model
        self.owlvit = copy.deepcopy(owlvit_det_model.owlvit)
        self.layer_norm = copy.deepcopy(owlvit_det_model.layer_norm)

        # For image-level classification
        self.cls_head = OwlViTPredictionHead(self.config, self.num_classes, finetuned=finetuned)

        # For box prediction
        if custom_box_head:
            self.box_head = OwlViTBoxPredictionHead(self.config)
        else:
            self.box_head = copy.deepcopy(owlvit_det_model.box_head)

        # For box-level classification
        # Why don't just:
        # self.class_head = copy.deepcopy(owlvit_det_model.class_head)
        
        self.class_head = OwlViTClassPredictionHead(self.config)
        self.class_head.dense0.load_state_dict(owlvit_det_model.class_head.dense0.state_dict())
        self.class_head.logit_shift.load_state_dict(owlvit_det_model.class_head.logit_shift.state_dict())
        self.class_head.logit_scale.load_state_dict(owlvit_det_model.class_head.logit_scale.state_dict())

        # OwlViT: set equal weights for the bounding box, gIoU and classification losses
        # self.matcher = DetrHungarianMatcher(class_cost=1, bbox_cost=1, giou_cost=1)

        # Losses for the criterion in DETR/OwlViT
        self.weight_dict = weight_dict
        losses = ["cardinality"]
        losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
        losses += ["labels"] if weight_dict["loss_ce"] > 0 else []

        self.criterion = DetrLoss(
            matcher=None,
            num_parts=self.num_parts,
            eos_coef=0.1,   # Following facebook/detr-resnet-50
            losses=losses,
        )

        self.freeze_parameters(freeze_box_heads, train_box_heads_only)
        del owlvit_det_model

    def freeze_parameters(self, freeze_box_heads, train_box_heads_only):
        # OwlViT's text encoder is frozen by default
        for param in self.owlvit.text_model.parameters():
            param.requires_grad = False
        for param in self.owlvit.text_projection.parameters():
            param.requires_grad = False

        # SKIP finetuning box heads
        if freeze_box_heads:
            for param in self.box_head.parameters():
                param.requires_grad = False
            for param in self.class_head.parameters():
                param.requires_grad = False

        # SKIP finetuning vision encoder and MLP head for classification --> Adjust weights of box heads only
        if train_box_heads_only:
            for param in self.owlvit.parameters():
                param.requires_grad = False
            for param in self.layer_norm.parameters():
                param.requires_grad = False
            for param in self.cls_head.parameters():
                param.requires_grad = False

    def update_num_classes(self, num_classes):
        self.num_classes = num_classes
        self.cls_head.num_classes = num_classes

    def image_text_embedder(self,
        input_ids: torch.Tensor,
        pixel_values: torch.FloatTensor,
        attention_mask: torch.Tensor,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> Tuple[torch.FloatTensor]:

        # Encode text and image
        outputs = self.owlvit(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        # Get image embeddings
        last_hidden_state = outputs.vision_model_output[0]      # 0: last_hidden_state; 1: pooled_output
        image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)

        # Resize class token
        new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
        class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)

        # Merge image embedding with class tokens
        image_embeds = image_embeds[:, 1:, :] * class_token_out
        image_embeds = self.layer_norm(image_embeds)

        # Resize to [batch_size, num_patches, num_patches, hidden_size]
        new_size = (
            image_embeds.shape[0],
            int(np.sqrt(image_embeds.shape[1])),
            int(np.sqrt(image_embeds.shape[1])),
            image_embeds.shape[-1],
        )
        image_embeds = image_embeds.reshape(new_size)
        text_embeds = outputs[-4]

        return (text_embeds, image_embeds, outputs)

    def image_embedder(
        self,
        pixel_values: torch.FloatTensor
    ) -> Tuple[torch.FloatTensor]:

        # Get OwlViTModel vision embeddings (same as CLIP)
        vision_outputs = self.owlvit.vision_model(pixel_values=pixel_values, return_dict=True)

        # Apply post_layernorm to last_hidden_state, return non-projected output
        last_hidden_state = vision_outputs[0]
        image_embeds = self.owlvit.vision_model.post_layernorm(last_hidden_state)

        # Resize class token
        new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
        class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)

        # Merge image embedding with class tokens
        image_embeds = image_embeds[:, 1:, :] * class_token_out
        image_embeds = self.layer_norm(image_embeds)

        # Resize to [batch_size, num_patches, num_patches, hidden_size]
        new_size = (
            image_embeds.shape[0],
            int(np.sqrt(image_embeds.shape[1])),
            int(np.sqrt(image_embeds.shape[1])),
            image_embeds.shape[-1],
        )
        image_embeds = image_embeds.reshape(new_size)

        return (image_embeds, vision_outputs)

    def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
        # Computes normalized xy corner coordinates from feature_map.
        if not feature_map.ndim == 4:
            raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")

        device = feature_map.device
        num_patches = feature_map.shape[1]

        box_coordinates = np.stack(np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1).astype(np.float32)
        box_coordinates /= np.array([num_patches, num_patches], np.float32)

        # Flatten (h, w, 2) -> (h*w, 2)
        box_coordinates = box_coordinates.reshape(box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2])
        box_coordinates = torch.from_numpy(box_coordinates).to(device)

        return box_coordinates

    def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
        # The box center is biased to its position on the feature grid
        box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
        box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)

        # Unnormalize xy
        box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)

        # The box size is biased to the patch size
        box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
        box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)

        # Compute box bias
        box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
        return box_bias

    def box_predictor(
        self,
        image_feats: torch.FloatTensor,
        feature_map: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """
        Args:
            image_feats:
                Features extracted from the image, returned by the `image_text_embedder` method.
            feature_map:
                A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
        Returns:
            pred_boxes:
                List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
        """
        # Bounding box detection head [batch_size, num_boxes, 4].
        pred_boxes = self.box_head(image_feats)

        # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
        pred_boxes += self.compute_box_bias(feature_map)
        pred_boxes = self.sigmoid(pred_boxes)
        return pred_boxes

    def class_predictor(
        self,
        image_feats: torch.FloatTensor,
        query_embeds: Optional[torch.FloatTensor] = None,
        query_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            image_feats:
                Features extracted from the `image_text_embedder`.
            query_embeds:
                Text query embeddings.
            query_mask:
                Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
        """
        (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)

        return (pred_logits, image_class_embeds)

    def _get_text_query_mask(self, text_inputs, text_embeds, batch_size: int):
        # Embed images and text queries
        input_ids = text_inputs["input_ids"]

        # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
        max_text_queries = input_ids.shape[0] // batch_size
        text_embeds = text_embeds.reshape(batch_size, max_text_queries, text_embeds.shape[-1])

        # If first token is 0, then this is a padded query [batch_size, num_queries].
        input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
        query_mask = input_ids[..., 0] > 0
        return query_mask, text_embeds
        

    def forward(self, image_inputs, text_inputs_parts, text_embeds, targets: dict = None):
        # Store outputs for computing losses
        loss_dict = {}

        if not isinstance(image_inputs, torch.Tensor):
            feature_map, _ = self.image_embedder(pixel_values = image_inputs['pixel_values'])
        else:
            feature_map = image_inputs
        batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
        image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))

        if self.logits_from_teacher:
            teacher_boxes_logits = torch.stack([target["logits"] for target in targets], dim=0).to(self.device)
            topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)

        else:
            text_embeds_parts = self.owlvit.get_text_features(**text_inputs_parts)
            
            # # Embed images and text queries
            query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
            
            # Predict object classes [batch_size, num_patches, num_queries+1]
            pred_logits_parts, class_embeds = self.class_predictor(image_feats, text_embeds_parts, query_mask)

            # Predict object boxes
            pred_boxes = self.box_predictor(image_feats, feature_map)
            
            # Get the top-1 predictions
            scores = self.sigmoid(pred_logits_parts)
            topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1)
            mapping_indices = [(selected_indices, torch.tensor(list(range(self.num_parts))).to(self.device)) for selected_indices in topk_idxs.squeeze(1)]

            # get the selected_indexs for mapping_indices
            selected_idxs = torch.stack([item[0].cpu() for item in mapping_indices])
            loss_dict["pred_boxes"] = torch.gather(pred_boxes.cpu(), 1, selected_idxs.unsqueeze(-1).expand(*selected_idxs.shape, 4))
            
            if targets is not None:
                # ----------------------------------------------------------------------------------------
                #   Computing box + class + symmetric losses for box selection
                # ----------------------------------------------------------------------------------------
                outputs_loss = {}
                outputs_loss["logits"] = pred_logits_parts
                outputs_loss["pred_boxes"] = pred_boxes

                # Compute box + class losses
                loss_dict = self.criterion(outputs_loss, targets, mapping_indices)

                # Compute symmetric loss to get rid of the teacher model
                logits_per_image = torch.softmax(pred_logits_parts, dim=1)
                logits_per_text = torch.softmax(pred_logits_parts, dim=-1)

                # For getting rid of the teacher model
                if self.weight_dict["loss_sym_box_label"] > 0:
                    sym_loss_box_label = self.loss_symmetric(logits_per_image, logits_per_text, teacher_boxes_logits)
                    loss_dict["loss_sym_box_label"] = sym_loss_box_label
                # ----------------------------------------------------------------------------------------

        # Predict image-level classes (batch_size, num_patches, num_queries)
        image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)

        if self.weight_dict["loss_xclip"] > 0:
            targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device)
            if self.network_type == "classification":
                one_hot = torch.zeros_like(pred_logits).scatter(1, targets_cls, 1).to(self.device)
                cls_loss = self.ce_loss(pred_logits, one_hot)
                loss_dict["loss_xclip"] = cls_loss
            else:
                # TODO: Need a linear classifier for this approach
                # Compute symmetric loss for part-descriptor contrastive learning
                logits_per_image = torch.softmax(image_text_logits, dim=0)
                logits_per_text = torch.softmax(image_text_logits, dim=-1)
                sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
                loss_dict["loss_xclip"] = sym_loss

        return pred_logits, part_logits, loss_dict

    def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
        # text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
        # targets (batch_size, 1): The ground truth label of box-text pair for classification OR
        # targets (batch_size, all_boxes, num_parts): The ground truth label of box-text pair for box selection
        # box_labels (batch_size, num_boxes), 0 for no box, 1 for box

        assert text_logits.shape == image_logits.shape

        # For image classification
        if image_logits.shape != targets.shape:
            batch_size = targets.shape[0]

            # get the matching labels (bs * 12, num_classes * num_parts)
            default_box_labels = torch.kron(torch.ones(batch_size, self.num_classes), torch.eye(self.num_parts)).to(self.device)
            if box_labels is None:
                box_labels = default_box_labels.clone()
            else:
                # (batch_size, num_boxes) -> (bs * num_boxes, num_classes * num_parts)
                box_labels = box_labels.view(-1, 1) * default_box_labels

            # Create one-hot encoding of targets; matching_labels shape: (bs * 12, num_classes * num_parts)
            target_one_hot = torch.zeros(batch_size, self.num_classes).to(self.device).scatter(1, targets.view(-1, 1), 1)
            target_one_hot = torch.kron(target_one_hot, torch.ones(self.num_parts, self.num_parts).to(self.device))

            matching_labels = target_one_hot * box_labels
        else:
            # For box selection: matching_labels shape: (bs, 576, num_parts)
            values, indices = torch.max(targets, dim=1)
            matching_labels = torch.zeros_like(targets).scatter(1, indices.unsqueeze(1), 1)

        loss_i = F.binary_cross_entropy_with_logits(image_logits, matching_labels, reduction='mean')
        loss_t = F.binary_cross_entropy_with_logits(text_logits, matching_labels, reduction='mean')
        sym_loss = (loss_i + loss_t).mean()

        return sym_loss
    
class DetrLoss(nn.Module):
    """
    This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
    we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
    of matched ground-truth / prediction (supervise class and box).

    A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
    parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
    the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
    be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
    (`max_obj_id` + 1). For more details on this, check the following discussion
    https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"


    Args:
        matcher (`DetrHungarianMatcher`):
            Module able to compute a matching between targets and proposals.
        num_parts (`int`):
            Number of object categories, omitting the special no-object category.
        eos_coef (`float`):
            Relative classification weight applied to the no-object category.
        losses (`List[str]`):
            List of all the losses to be applied. See `get_loss` for a list of all available losses.
    """

    def __init__(self, matcher, num_parts, eos_coef, losses):
        super().__init__()
        self.matcher = matcher
        self.num_parts = num_parts
        self.eos_coef = eos_coef
        self.losses = losses

        # empty_weight = torch.ones(self.num_parts + 1)
        empty_weight = torch.ones(self.num_parts)
        empty_weight[-1] = self.eos_coef
        self.register_buffer("empty_weight", empty_weight)

    # removed logging parameter, which was part of the original implementation
    def loss_labels(self, outputs, targets, indices, num_boxes):
        """
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        """
        if "logits" not in outputs:
            raise KeyError("No logits were found in the outputs")
        source_logits = outputs["logits"]

        idx = self._get_source_permutation_idx(indices)
        # target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
        # target_classes = torch.full(source_logits.shape[:2], self.num_parts, dtype=torch.int64, device=source_logits.device)
        # target_classes[idx] = target_classes_o

        source_logits = source_logits[idx].view(len(indices), -1, self.num_parts)
        target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)

        loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {"loss_ce": loss_ce}

        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """
        Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.

        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
        """
        logits = outputs["logits"]
        device = logits.device
        target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
        card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
        losses = {"cardinality_error": card_err}
        return losses

    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.

        Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
        are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        if "pred_boxes" not in outputs:
            raise KeyError("No predicted boxes found in outputs")

        idx = self._get_source_permutation_idx(indices)
        source_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        losses = {}

        loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)))
        losses["loss_giou"] = loss_giou.sum() / num_boxes

        return losses

    def loss_masks(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.

        Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
        """
        if "pred_masks" not in outputs:
            raise KeyError("No predicted masks found in outputs")

        source_idx = self._get_source_permutation_idx(indices)
        target_idx = self._get_target_permutation_idx(indices)
        source_masks = outputs["pred_masks"]
        source_masks = source_masks[source_idx]
        masks = [t["masks"] for t in targets]

        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(source_masks)
        target_masks = target_masks[target_idx]

        # upsample predictions to the target size
        source_masks = nn.functional.interpolate(
            source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
        )
        source_masks = source_masks[:, 0].flatten(1)

        target_masks = target_masks.flatten(1)
        target_masks = target_masks.view(source_masks.shape)
        losses = {
            "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
        }
        return losses

    def _get_source_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
        source_idx = torch.cat([source for (source, _) in indices])
        return batch_idx, source_idx

    def _get_target_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
        target_idx = torch.cat([target for (_, target) in indices])
        return batch_idx, target_idx

    def get_loss(self, loss, outputs, targets, indices, num_boxes):
        loss_map = {
            "labels": self.loss_labels,
            "cardinality": self.loss_cardinality,
            "boxes": self.loss_boxes,
            "masks": self.loss_masks,
        }
        if loss not in loss_map:
            raise ValueError(f"Loss {loss} not supported")
        return loss_map[loss](outputs, targets, indices, num_boxes)

    def forward(self, outputs, targets, indices):
        """
        This performs the loss computation.

        Args:
             outputs (`dict`, *optional*):
                Dictionary of tensors, see the output specification of the model for the format.
             targets (`List[dict]`, *optional*):
                List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
                losses applied, see each loss' doc.
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}

        # ThangPM: Do NOT use bipartite matching --> Use the boxes selected by argmax for computing symmetric loss
        # Retrieve the matching between the outputs of the last layer and the targets
        # indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target boxes across all nodes, for normalization purposes
        num_boxes = sum(len(t["class_labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        # (Niels): comment out function below, distributed training to be added
        # if is_dist_avail_and_initialized():
        #     torch.distributed.all_reduce(num_boxes)
        # (Niels) in original implementation, num_boxes is divided by get_world_size()
        num_boxes = torch.clamp(num_boxes, min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if "auxiliary_outputs" in outputs:
            for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
                # indices = self.matcher(auxiliary_outputs, targets)
                for loss in self.losses:
                    if loss == "masks":
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses