File size: 4,265 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from torch import nn
from torchvision.ops import MLP

from .. import losses

class InstanceGuide(nn.Module):
    def __init__(self, model: nn.Module, optimizer=torch.optim.AdamW, loss_fn=losses.CWExtensionLoss) -> None:
        super().__init__()
        
        self.guided = True
        self.model = model
        

        for p in self.model.parameters():
            p.requires_grad_(False)

        self.loss = loss_fn()
        self.optimizer = optimizer

        self.epochs = 30
        self.mlp_iterations = 5
        self.perturbation_iterations = 5

    def surject_perturbation(self, x):
        return x

    def forward(self, x, attack_targets):
        """
        x: [B, channels, H, W]
        attack_targets: [B, K]
        """

        B = x.shape[0]
        K = attack_targets.shape[-1]
        C = self.model.num_classes()

        with torch.no_grad():
            pred_clean, feats = self.model(x, return_features=True)

        # We are assuming the clean predictions are ground truth since we make that
        # constraint on the dataset side
        attack_ground_truth = pred_clean.argmax(dim=-1) # [B]

        mlp = MLP(self.model.head_features(), 
                    [self.model.head_features()]*3 + [self.model.head_features()],
                    activation_layer=nn.GELU, inplace=None).to(x.device)

        x_perturbation = nn.Parameter(torch.randn(x.shape, 
                                                 device=x.device)*1e-3)
        
        perturbation_optimizer = self.optimizer([x_perturbation], lr=1e-1)

        mlp_optimizer = self.optimizer(mlp.parameters(), lr=1e-3)

        logits_target_best = pred_clean
        feats_target_best = feats

        with torch.enable_grad():
            for i in range(self.epochs):
                for _ in range(self.mlp_iterations):
                    torch.cuda.synchronize()

                    feature_offset = mlp(feats)
                    feats_target_pred = feature_offset + feats
                    logits_target_pred = self.model.head(feats_target_pred)
                    # logits_target_pred = pred_logits
                    pred_classes = logits_target_pred.argsort(dim=-1, descending=True) # [B, C]
                    attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]

                    with torch.no_grad():
                        logits_target_best = torch.where(
                            attack_successful[:, None].expand(-1, C),
                            logits_target_pred,
                            logits_target_best
                        )

                        feats_target_best = torch.where(
                            attack_successful[:, None].expand(-1, self.model.head_features()),
                            feats_target_pred,
                            feats_target_best
                        )

                    mlp_loss = self.loss(logits_pred=logits_target_pred, 
                                         prediction_feats=feats_target_pred,
                                         attack_targets=attack_targets, 
                                         attack_ground_truth=attack_ground_truth, 
                                         model=self.model)
                    mlp_loss = mlp_loss.mean() + feature_offset.view(B, -1).norm(dim=-1, p=2)*1

                    mlp_optimizer.zero_grad()
                    mlp_loss.backward()
                    mlp_optimizer.step()

                feats_target_best = feats_target_best.detach()

                for _ in range(self.perturbation_iterations):
                    x_perturbed = x + self.surject_perturbation(x_perturbation)
                    prediction, perturbed_feats = self.model(x_perturbed, return_features=True)
                    pred_classes = prediction.argsort(dim=-1, descending=True) # [B, C]
                    attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]

                    perturbation_loss = (prediction - logits_target_best).view(B, -1).norm(dim=-1).mean()

                    perturbation_optimizer.zero_grad()
                    perturbation_loss.backward()
                    perturbation_optimizer.step()
            
        return prediction