QuadAttack / modelguidedattacks /losses /_qp_solver_patch.py
thomaspaniagua
QuadAttack release
71f183c
raw
history blame
No virus
6.42 kB
import qpth
from qpth.solvers.pdipm import batch as pdipm_b
from qpth.solvers.pdipm.batch import *
def reduce_stats(z):
return z[~z.isnan()].median()
def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3,
maxIter=20, solver=KKTSolvers.LU_PARTIAL):
"""
Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
"""
nineq, nz, neq, nBatch = get_sizes(G, A)
# Find initial values
if solver == KKTSolvers.LU_FULL:
D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
x, s, z, y = factor_solve_kkt(
Q, D, G, A, p,
torch.zeros(nBatch, nineq).type_as(Q),
-h, -b if b is not None else None)
elif solver == KKTSolvers.LU_PARTIAL:
d = torch.ones(nBatch, nineq).type_as(Q)
factor_kkt(S_LU, R, d)
x, s, z, y = solve_kkt(
Q_LU, d, G, A, S_LU,
p, torch.zeros(nBatch, nineq).type_as(Q),
-h, -b if neq > 0 else None)
elif solver == KKTSolvers.IR_UNOPT:
D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
x, s, z, y = solve_kkt_ir(
Q, D, G, A, p,
torch.zeros(nBatch, nineq).type_as(Q),
-h, -b if b is not None else None)
else:
assert False
# Make all of the slack variables >= 1.
M = torch.min(s, 1)[0]
M = M.view(M.size(0), 1).repeat(1, nineq)
I = M < 0
s[I] -= M[I] - 1
# Make all of the inequality dual variables >= 1.
M = torch.min(z, 1)[0]
M = M.view(M.size(0), 1).repeat(1, nineq)
I = M < 0
z[I] -= M[I] - 1
best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
nNotImproved = 0
for i in range(maxIter):
# affine scaling direction
rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
p
rs = z
rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
ry = torch.bmm(x.unsqueeze(1), A.transpose(
1, 2)).squeeze(1) - b if neq > 0 else 0.0
mu = torch.abs((s * z).sum(1).squeeze() / nineq)
z_resid = torch.norm(rz, 2, 1).squeeze()
y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
pri_resid = y_resid + z_resid
dual_resid = torch.norm(rx, 2, 1).squeeze()
resids = pri_resid + dual_resid + nineq * mu
d = z / s
try:
factor_kkt(S_LU, R, d)
except:
return best['x'], best['y'], best['z'], best['s']
if verbose == 1:
print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format(
i, reduce_stats(pri_resid), reduce_stats(dual_resid), reduce_stats(mu)))
if best['resids'] is None:
best['resids'] = resids
best['x'] = x.clone()
best['z'] = z.clone()
best['s'] = s.clone()
best['y'] = y.clone() if y is not None else None
nNotImproved = 0
else:
I = resids < best['resids']
if I.sum() > 0:
nNotImproved = 0
else:
nNotImproved += 1
I_nz = I.repeat(nz, 1).t()
I_nineq = I.repeat(nineq, 1).t()
best['resids'][I] = resids[I]
best['x'][I_nz] = x[I_nz]
best['z'][I_nineq] = z[I_nineq]
best['s'][I_nineq] = s[I_nineq]
if neq > 0:
I_neq = I.repeat(neq, 1).t()
best['y'][I_neq] = y[I_neq]
if nNotImproved == notImprovedLim or reduce_stats(pri_resid) < eps or mu.min() > 1e32:
if best['resids'].max() > 1. and verbose >= 0:
print(INACC_ERR)
return best['x'], best['y'], best['z'], best['s']
if solver == KKTSolvers.LU_FULL:
D = bdiag(d)
dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
Q, D, G, A, rx, rs, rz, ry)
elif solver == KKTSolvers.LU_PARTIAL:
dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
elif solver == KKTSolvers.IR_UNOPT:
D = bdiag(d)
dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
Q, D, G, A, rx, rs, rz, ry)
else:
assert False
# compute centering directions
alpha = torch.min(torch.min(get_step(z, dz_aff),
get_step(s, ds_aff)),
torch.ones(nBatch).type_as(Q))
alpha_nineq = alpha.repeat(nineq, 1).t()
t1 = s + alpha_nineq * ds_aff
t2 = z + alpha_nineq * dz_aff
t3 = torch.sum(t1 * t2, 1).squeeze()
t4 = torch.sum(s * z, 1).squeeze()
sig = (t3 / t4)**3
rx = torch.zeros(nBatch, nz).type_as(Q)
rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
rz = torch.zeros(nBatch, nineq).type_as(Q)
ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor()
if solver == KKTSolvers.LU_FULL:
D = bdiag(d)
dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
Q, D, G, A, rx, rs, rz, ry)
elif solver == KKTSolvers.LU_PARTIAL:
dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
elif solver == KKTSolvers.IR_UNOPT:
D = bdiag(d)
dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
Q, D, G, A, rx, rs, rz, ry)
else:
assert False
dx = dx_aff + dx_cor
ds = ds_aff + ds_cor
dz = dz_aff + dz_cor
dy = dy_aff + dy_cor if neq > 0 else None
alpha = torch.min(0.999 * torch.min(get_step(z, dz),
get_step(s, ds)),
torch.ones(nBatch).type_as(Q))
alpha_nineq = alpha.repeat(nineq, 1).t()
alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
alpha_nz = alpha.repeat(nz, 1).t()
x += alpha_nz * dx
s += alpha_nineq * ds
z += alpha_nineq * dz
y = y + alpha_neq * dy if neq > 0 else None
if best['resids'].max() > 1. and verbose >= 0:
print(INACC_ERR)
return best['x'], best['y'], best['z'], best['s']
pdipm_b.forward = forward