import torch from torch import nn from torchvision.ops import MLP from .. import losses class InstanceGuide(nn.Module): def __init__(self, model: nn.Module, optimizer=torch.optim.AdamW, loss_fn=losses.CWExtensionLoss) -> None: super().__init__() self.guided = True self.model = model for p in self.model.parameters(): p.requires_grad_(False) self.loss = loss_fn() self.optimizer = optimizer self.epochs = 30 self.mlp_iterations = 5 self.perturbation_iterations = 5 def surject_perturbation(self, x): return x def forward(self, x, attack_targets): """ x: [B, channels, H, W] attack_targets: [B, K] """ B = x.shape[0] K = attack_targets.shape[-1] C = self.model.num_classes() with torch.no_grad(): pred_clean, feats = self.model(x, return_features=True) # We are assuming the clean predictions are ground truth since we make that # constraint on the dataset side attack_ground_truth = pred_clean.argmax(dim=-1) # [B] mlp = MLP(self.model.head_features(), [self.model.head_features()]*3 + [self.model.head_features()], activation_layer=nn.GELU, inplace=None).to(x.device) x_perturbation = nn.Parameter(torch.randn(x.shape, device=x.device)*1e-3) perturbation_optimizer = self.optimizer([x_perturbation], lr=1e-1) mlp_optimizer = self.optimizer(mlp.parameters(), lr=1e-3) logits_target_best = pred_clean feats_target_best = feats with torch.enable_grad(): for i in range(self.epochs): for _ in range(self.mlp_iterations): torch.cuda.synchronize() feature_offset = mlp(feats) feats_target_pred = feature_offset + feats logits_target_pred = self.model.head(feats_target_pred) # logits_target_pred = pred_logits pred_classes = logits_target_pred.argsort(dim=-1, descending=True) # [B, C] attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B] with torch.no_grad(): logits_target_best = torch.where( attack_successful[:, None].expand(-1, C), logits_target_pred, logits_target_best ) feats_target_best = torch.where( attack_successful[:, None].expand(-1, self.model.head_features()), feats_target_pred, feats_target_best ) mlp_loss = self.loss(logits_pred=logits_target_pred, prediction_feats=feats_target_pred, attack_targets=attack_targets, attack_ground_truth=attack_ground_truth, model=self.model) mlp_loss = mlp_loss.mean() + feature_offset.view(B, -1).norm(dim=-1, p=2)*1 mlp_optimizer.zero_grad() mlp_loss.backward() mlp_optimizer.step() feats_target_best = feats_target_best.detach() for _ in range(self.perturbation_iterations): x_perturbed = x + self.surject_perturbation(x_perturbation) prediction, perturbed_feats = self.model(x_perturbed, return_features=True) pred_classes = prediction.argsort(dim=-1, descending=True) # [B, C] attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B] perturbation_loss = (prediction - logits_target_best).view(B, -1).norm(dim=-1).mean() perturbation_optimizer.zero_grad() perturbation_loss.backward() perturbation_optimizer.step() return prediction