thomaspaniagua
QuadAttack release
71f183c
raw
history blame
No virus
1.9 kB
import torch
from torch import nn
import torch.nn.functional as F
CARLINI_COEFF_UPPER = 1e10
class CWExtensionLoss(nn.Module):
def __init__(self, confidence=0):
super().__init__()
self.confidence = confidence
def precompute(self, *args, **kwargs):
return {}
def forward(self, logits_pred, attack_targets, **kwargs):
#orign cw attack loss
if attack_targets.dim() == 1:
mask_logits = F.one_hot(attack_targets, logits_pred.shape[1]).float()
real = (mask_logits * logits_pred).sum(dim=1)
other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
).max(1)[0]
loss_cw = torch.clamp(other - real + self.confidence, min=0.)
return loss_cw
#extended cw loss for topk attack tasks
else:
mask_logits = torch.zeros([logits_pred.shape[0], logits_pred.shape[1]], device=logits_pred.device)
min_values = torch.ones(attack_targets.shape[0], dtype=torch.float, device=logits_pred.device) * 1e10
loss_cw_topk = 0
for i in range(attack_targets.shape[1]):
other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
).max(1)[0]
loss_cw_topk += torch.clamp(other - min_values + self.confidence, min=0.)
mask_logits[torch.arange(len(attack_targets)), attack_targets[:,i]] = 1
min_values = torch.min(logits_pred[torch.arange(len(attack_targets)), attack_targets[:,i]], min_values)
real = min_values
other = ((1.0 - mask_logits) * logits_pred - (mask_logits * 10000.0)
).max(1)[0]
loss_cw_topk += torch.clamp(other - real + self.confidence, min=0.)
constant = attack_targets.shape[1]
return (loss_cw_topk / constant)