Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import cvxpy as cp | |
import qpth | |
from . import _qp_solver_patch | |
def solve_qp(Q, P, G, H): | |
B = Q.shape[0] | |
if B == 1: | |
# Batch size of 1 has weird instabilities | |
# I imagine there is a .squeeze() or something inside the QP solver's code | |
# that messes up broadcasting dimensions when batch dimension is 1 so let's | |
# artificially make 2 solutions when we need 1 | |
Q = Q.expand(2, -1, -1) | |
P = P.expand(2, -1) | |
G = G.expand(2, -1, -1) | |
H = H.expand(2, -1) | |
e = torch.empty(0, device=Q.device) | |
z_sol = qpth.qp.QPFunction(verbose=-1, eps=1e-2, check_Q_spd=False)(Q, P, G, H, e, e) | |
if B == 1: | |
z_sol = z_sol[:1] | |
return z_sol | |
class CVXProjLoss(nn.Module): | |
def __init__(self, confidence=0): | |
super().__init__() | |
self.confidence = confidence | |
def precompute(self, attack_targets, gt_labels, config): | |
return { | |
"margin": config.cvx_proj_margin | |
} | |
def forward(self, logits_pred, feats_pred, feats_pred_0, attack_targets, model, margin, **kwargs): | |
device = logits_pred.device | |
head_W, head_bias = model.head_matrices() | |
num_feats = head_W.shape[1] | |
num_classes = head_W.shape[0] | |
K = attack_targets.shape[-1] | |
B = logits_pred.shape[0] | |
# Start with all classes should be less than smallest attack target | |
D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1) # [B, C, C] | |
attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1) | |
D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device)) | |
# Clear out the constraint row for each item in the attack targets | |
attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1]) | |
D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device)) | |
batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1) | |
attack_targets_pos = attack_targets[:, :-1] # [B, K-1] | |
attack_targets_neg = attack_targets[:, 1:] # [B, K-1] | |
attack_targets_neg_inds = torch.stack(( | |
batch_inds, | |
attack_targets_neg, | |
attack_targets_neg | |
), dim=0) # [3, B, K - 1] | |
attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1) | |
D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1 | |
attack_targets_pos_inds = torch.stack(( | |
batch_inds, | |
attack_targets_neg, | |
attack_targets_pos | |
), dim=0) # [3, B, K - 1] | |
D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1 | |
A = head_W | |
b = head_bias | |
Q = 2*torch.eye(feats_pred.shape[1], device=device)[None].expand(B, -1, -1) | |
# We want the solution features to be as close as possible | |
# to the current features but also head on the direction of | |
# the smallest possible perturbation from the initial predicted | |
# features | |
anchor_feats = feats_pred | |
P = -2*anchor_feats.expand(B, -1) | |
G = -D@A | |
H = -(margin - D @ b) | |
# Constraints are indexed by smaller logit | |
# First attack target isn't smaller than any logit, so its | |
# constraint index is redundant, but we keep it for easier parallelization | |
# Make this constraint all 0s | |
zero_inds = attack_targets[:, 0:1] # [B, 1] | |
H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device)) | |
z_sol = solve_qp(Q, P, G, H) | |
loss = (feats_pred - z_sol).square().sum(dim=-1) | |
# loss_check = self.forward_check(logits_pred, feats_pred, attack_targets, model, **kwargs) | |
return loss |