import cvxpy as cp from cvxpylayers.torch import CvxpyLayer from torch.nn import functional as F import torch from modelguidedattacks import cls_models import time torch.manual_seed(0) device = "cuda" # model = cls_models.get_model("imagenet", "resnet18", device) rand_feats = torch.randn(1, 512, device=device) attack_targets = [4, 7, 5, 9, 2] # # pred_logits = model.head(rand_feats) # # head_W, head_bias = model.head_matrices() (head_W, head_bias, pred_logits) = torch.load("") rand_feats, rand_logits, attack_targets = torch.load("attack_case.p", map_location=device) reconstructed_logits = rand_feats@head_W.T + head_bias num_feats = head_W.shape[1] num_classes = head_W.shape[0] x = cp.Variable(num_feats) anchor_feats = cp.Parameter(x.shape) A = cp.Parameter(head_W.shape) b = cp.Parameter(head_bias.shape) logits = A@x + b MARGIN = 0.1 # constraints = [] # for i in range(len(attack_targets) - 1): # constraints.append( logits[attack_targets[i]] - logits[attack_targets[i+1]] >= MARGIN) # for i in range(num_classes): # if i in attack_targets: # continue # constraints.append(logits[attack_targets[-1]] - logits[i] >= MARGIN ) # objective = cp.Minimize(0.5 * cp.pnorm(x - anchor_feats, p=2)) # problem = cp.Problem(objective, constraints) # anchor_feats.value = rand_feats[0].cpu().numpy() # A.value = head_W.detach().cpu().numpy() # b.value = head_bias.detach().cpu().numpy() # start_time = time.time() # problem.solve() # print ("Non vectorized sol", time.time() - start_time) # logits_sol_torch = torch.from_numpy(logits.value) # logits_check = logits_sol_torch.argsort(descending=True) # feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats) # sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1) # sol_logits = head_W@feats_sol + head_bias[:, None] # sol_sort = sol_logits.argsort(dim=0, descending=True) # Constraint matrix num_constraints = num_classes - 1 D = torch.zeros((num_classes), num_constraints) non_attack_targets = list(set(range(num_classes)) - set(attack_targets)) for constraint_cursor in range(num_constraints): if constraint_cursor < len(attack_targets) - 1: D[attack_targets[constraint_cursor], constraint_cursor] = 1 D[attack_targets[constraint_cursor + 1], constraint_cursor] = -1 else: non_attack_i = constraint_cursor - len(attack_targets) + 1 D[attack_targets[-1], constraint_cursor] = 1 D[non_attack_targets[non_attack_i], constraint_cursor] = -1 D = D.T # vectorized_differences = D @ logits # vectorized_constraint = vectorized_differences >= torch.full(vectorized_differences.shape, fill_value=MARGIN).numpy() # Q = 2*torch.eye(x.shape[0]).numpy() # P = -2*anchor_feats # G = D@A # H = MARGIN - D @ b # G = -G # H = -H # vectorized_constraint = G@x <= H # objective = cp.Minimize((1/2)*cp.quad_form(x, Q) + P.T@x) # problem = cp.Problem(objective, [vectorized_constraint]) # anchor_feats.value = rand_feats[0].cpu().numpy() # A.value = head_W.detach().cpu().numpy() # b.value = head_bias.detach().cpu().numpy() # start_time = time.time() # problem.solve() # print ("vectorized sol", time.time() - start_time) # logits_sol_torch = torch.from_numpy(logits.value) # logits_check = logits_sol_torch.argsort(descending=True) # feats_sol = torch.from_numpy(x.value[:, None]).float().to(rand_feats) # sol_feat_norm = (feats_sol[:, 0].cpu() - rand_feats[0].cpu()).norm(dim=-1) # sol_logits = head_W@feats_sol + head_bias[:, None] # sol_sort = sol_logits.argsort(dim=0, descending=True) import qpth B = 2 nz = num_feats nineq = num_constraints device = "cuda" attack_targets = attack_targets.expand(B, -1) K = attack_targets.shape[-1] # Start with all classes should be less than smallest attack target D = -torch.eye(num_classes, device=device)[None].repeat(B, 1, 1) attack_targets_write = attack_targets[:, -1][:, None, None].expand(-1, D.shape[1], -1) D.scatter_(dim=2, index=attack_targets_write, src=torch.ones(attack_targets_write.shape, device=device)) # Clear out the constraint row for each item in the attack targets attack_targets_clear = attack_targets[:, :, None].expand(-1, -1, D.shape[-1]) D.scatter_(dim=1, index=attack_targets_clear, src=torch.zeros(attack_targets_clear.shape, device=device)) batch_inds = torch.arange(B, device=device)[:, None].expand(-1, K - 1) attack_targets_pos = attack_targets[:, :-1] # [B, K-1] attack_targets_neg = attack_targets[:, 1:] # [B, K-1] attack_targets_neg_inds = torch.stack(( batch_inds, attack_targets_neg, attack_targets_neg ), dim=0) # [3, B, K - 1] attack_targets_neg_inds = attack_targets_neg_inds.view(3, -1) D[attack_targets_neg_inds[0], attack_targets_neg_inds[1], attack_targets_neg_inds[2]] = -1 attack_targets_pos_inds = torch.stack(( batch_inds, attack_targets_neg, attack_targets_pos ), dim=0) # [3, B, K - 1] D[attack_targets_pos_inds[0], attack_targets_pos_inds[1], attack_targets_pos_inds[2]] = 1 A = head_W.detach().to(device) b = head_bias.detach().to(device) D = #rand_feats: [B, num_features] Q = 2*torch.eye(nz, device=device)[None].expand(B, -1, -1) P = -2*, -1) # G = torch.randn(B, nineq, nz, device=device) G = -D@A # h = torch.randn(B, nineq) H = -(MARGIN - D @ b) # Constraints are indexed by smaller logit # First attack target isn't smaller than any logit, so its # constraint index is redundant, but we keep it for easier parallelization # Make this constraint all 0s zero_inds = attack_targets[:, 0:1] # [B, 1] H.scatter_(dim=1, index=zero_inds, src=torch.zeros(zero_inds.shape, device=device)) e = torch.empty(0, device=device) Q_t, P_t, G_t, H_t = torch.load("qpinputs.p", map_location=device) z_sol = qpth.qp.QPFunction(verbose=True, check_Q_spd=False)(Q, P, G, H, e, e).T logits = A@z_sol + b[:, None] x = 5