Spaces:
Sleeping
Sleeping
File size: 4,265 Bytes
71f183c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
from torch import nn
from torchvision.ops import MLP
from .. import losses
class InstanceGuide(nn.Module):
def __init__(self, model: nn.Module, optimizer=torch.optim.AdamW, loss_fn=losses.CWExtensionLoss) -> None:
super().__init__()
self.guided = True
self.model = model
for p in self.model.parameters():
p.requires_grad_(False)
self.loss = loss_fn()
self.optimizer = optimizer
self.epochs = 30
self.mlp_iterations = 5
self.perturbation_iterations = 5
def surject_perturbation(self, x):
return x
def forward(self, x, attack_targets):
"""
x: [B, channels, H, W]
attack_targets: [B, K]
"""
B = x.shape[0]
K = attack_targets.shape[-1]
C = self.model.num_classes()
with torch.no_grad():
pred_clean, feats = self.model(x, return_features=True)
# We are assuming the clean predictions are ground truth since we make that
# constraint on the dataset side
attack_ground_truth = pred_clean.argmax(dim=-1) # [B]
mlp = MLP(self.model.head_features(),
[self.model.head_features()]*3 + [self.model.head_features()],
activation_layer=nn.GELU, inplace=None).to(x.device)
x_perturbation = nn.Parameter(torch.randn(x.shape,
device=x.device)*1e-3)
perturbation_optimizer = self.optimizer([x_perturbation], lr=1e-1)
mlp_optimizer = self.optimizer(mlp.parameters(), lr=1e-3)
logits_target_best = pred_clean
feats_target_best = feats
with torch.enable_grad():
for i in range(self.epochs):
for _ in range(self.mlp_iterations):
torch.cuda.synchronize()
feature_offset = mlp(feats)
feats_target_pred = feature_offset + feats
logits_target_pred = self.model.head(feats_target_pred)
# logits_target_pred = pred_logits
pred_classes = logits_target_pred.argsort(dim=-1, descending=True) # [B, C]
attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]
with torch.no_grad():
logits_target_best = torch.where(
attack_successful[:, None].expand(-1, C),
logits_target_pred,
logits_target_best
)
feats_target_best = torch.where(
attack_successful[:, None].expand(-1, self.model.head_features()),
feats_target_pred,
feats_target_best
)
mlp_loss = self.loss(logits_pred=logits_target_pred,
prediction_feats=feats_target_pred,
attack_targets=attack_targets,
attack_ground_truth=attack_ground_truth,
model=self.model)
mlp_loss = mlp_loss.mean() + feature_offset.view(B, -1).norm(dim=-1, p=2)*1
mlp_optimizer.zero_grad()
mlp_loss.backward()
mlp_optimizer.step()
feats_target_best = feats_target_best.detach()
for _ in range(self.perturbation_iterations):
x_perturbed = x + self.surject_perturbation(x_perturbation)
prediction, perturbed_feats = self.model(x_perturbed, return_features=True)
pred_classes = prediction.argsort(dim=-1, descending=True) # [B, C]
attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B]
perturbation_loss = (prediction - logits_target_best).view(B, -1).norm(dim=-1).mean()
perturbation_optimizer.zero_grad()
perturbation_loss.backward()
perturbation_optimizer.step()
return prediction |