File size: 1,896 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn
import torch.nn.functional as F
CARLINI_COEFF_UPPER = 1e10

class CWExtensionLoss(nn.Module):
    def __init__(self, confidence=0):
        super().__init__()
        self.confidence = confidence

    def precompute(self, *args, **kwargs):
        return {}

    def forward(self, logits_pred, attack_targets,  **kwargs):
        #orign cw attack loss
        if attack_targets.dim() == 1:
            mask_logits = F.one_hot(attack_targets, logits_pred.shape[1]).float()

            real = (mask_logits * logits_pred).sum(dim=1)
            other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
                    ).max(1)[0]
            loss_cw = torch.clamp(other - real + self.confidence, min=0.)
            return loss_cw
            
        #extended cw loss for topk attack tasks
        else:
            mask_logits = torch.zeros([logits_pred.shape[0], logits_pred.shape[1]], device=logits_pred.device)
            min_values = torch.ones(attack_targets.shape[0], dtype=torch.float, device=logits_pred.device) * 1e10
            loss_cw_topk = 0

            for i in range(attack_targets.shape[1]):
                other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
                    ).max(1)[0]


                loss_cw_topk += torch.clamp(other - min_values + self.confidence, min=0.)
                mask_logits[torch.arange(len(attack_targets)), attack_targets[:,i]] = 1
                min_values = torch.min(logits_pred[torch.arange(len(attack_targets)), attack_targets[:,i]], min_values)

            real =  min_values
            other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
                    ).max(1)[0]
            loss_cw_topk += torch.clamp(other - real + self.confidence, min=0.)
            constant = attack_targets.shape[1]

            return (loss_cw_topk / constant)