import qpth from qpth.solvers.pdipm import batch as pdipm_b from qpth.solvers.pdipm.batch import * def reduce_stats(z): return z[~z.isnan()].median() def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3, maxIter=20, solver=KKTSolvers.LU_PARTIAL): """ Q_LU, S_LU, R = pre_factor_kkt(Q, G, A) """ nineq, nz, neq, nBatch = get_sizes(G, A) # Find initial values if solver == KKTSolvers.LU_FULL: D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q) x, s, z, y = factor_solve_kkt( Q, D, G, A, p, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if b is not None else None) elif solver == KKTSolvers.LU_PARTIAL: d = torch.ones(nBatch, nineq).type_as(Q) factor_kkt(S_LU, R, d) x, s, z, y = solve_kkt( Q_LU, d, G, A, S_LU, p, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if neq > 0 else None) elif solver == KKTSolvers.IR_UNOPT: D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q) x, s, z, y = solve_kkt_ir( Q, D, G, A, p, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if b is not None else None) else: assert False # Make all of the slack variables >= 1. M = torch.min(s, 1)[0] M = M.view(M.size(0), 1).repeat(1, nineq) I = M < 0 s[I] -= M[I] - 1 # Make all of the inequality dual variables >= 1. M = torch.min(z, 1)[0] M = M.view(M.size(0), 1).repeat(1, nineq) I = M < 0 z[I] -= M[I] - 1 best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None} nNotImproved = 0 for i in range(maxIter): # affine scaling direction rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \ torch.bmm(z.unsqueeze(1), G).squeeze(1) + \ torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \ p rs = z rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h ry = torch.bmm(x.unsqueeze(1), A.transpose( 1, 2)).squeeze(1) - b if neq > 0 else 0.0 mu = torch.abs((s * z).sum(1).squeeze() / nineq) z_resid = torch.norm(rz, 2, 1).squeeze() y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0 pri_resid = y_resid + z_resid dual_resid = torch.norm(rx, 2, 1).squeeze() resids = pri_resid + dual_resid + nineq * mu d = z / s try: factor_kkt(S_LU, R, d) except: return best['x'], best['y'], best['z'], best['s'] if verbose == 1: print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format( i, reduce_stats(pri_resid), reduce_stats(dual_resid), reduce_stats(mu))) if best['resids'] is None: best['resids'] = resids best['x'] = x.clone() best['z'] = z.clone() best['s'] = s.clone() best['y'] = y.clone() if y is not None else None nNotImproved = 0 else: I = resids < best['resids'] if I.sum() > 0: nNotImproved = 0 else: nNotImproved += 1 I_nz = I.repeat(nz, 1).t() I_nineq = I.repeat(nineq, 1).t() best['resids'][I] = resids[I] best['x'][I_nz] = x[I_nz] best['z'][I_nineq] = z[I_nineq] best['s'][I_nineq] = s[I_nineq] if neq > 0: I_neq = I.repeat(neq, 1).t() best['y'][I_neq] = y[I_neq] if nNotImproved == notImprovedLim or reduce_stats(pri_resid) < eps or mu.min() > 1e32: if best['resids'].max() > 1. and verbose >= 0: print(INACC_ERR) return best['x'], best['y'], best['z'], best['s'] if solver == KKTSolvers.LU_FULL: D = bdiag(d) dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt( Q, D, G, A, rx, rs, rz, ry) elif solver == KKTSolvers.LU_PARTIAL: dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt( Q_LU, d, G, A, S_LU, rx, rs, rz, ry) elif solver == KKTSolvers.IR_UNOPT: D = bdiag(d) dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir( Q, D, G, A, rx, rs, rz, ry) else: assert False # compute centering directions alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)), torch.ones(nBatch).type_as(Q)) alpha_nineq = alpha.repeat(nineq, 1).t() t1 = s + alpha_nineq * ds_aff t2 = z + alpha_nineq * dz_aff t3 = torch.sum(t1 * t2, 1).squeeze() t4 = torch.sum(s * z, 1).squeeze() sig = (t3 / t4)**3 rx = torch.zeros(nBatch, nz).type_as(Q) rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s rz = torch.zeros(nBatch, nineq).type_as(Q) ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor() if solver == KKTSolvers.LU_FULL: D = bdiag(d) dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt( Q, D, G, A, rx, rs, rz, ry) elif solver == KKTSolvers.LU_PARTIAL: dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt( Q_LU, d, G, A, S_LU, rx, rs, rz, ry) elif solver == KKTSolvers.IR_UNOPT: D = bdiag(d) dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir( Q, D, G, A, rx, rs, rz, ry) else: assert False dx = dx_aff + dx_cor ds = ds_aff + ds_cor dz = dz_aff + dz_cor dy = dy_aff + dy_cor if neq > 0 else None alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)), torch.ones(nBatch).type_as(Q)) alpha_nineq = alpha.repeat(nineq, 1).t() alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None alpha_nz = alpha.repeat(nz, 1).t() x += alpha_nz * dx s += alpha_nineq * ds z += alpha_nineq * dz y = y + alpha_neq * dy if neq > 0 else None if best['resids'].max() > 1. and verbose >= 0: print(INACC_ERR) return best['x'], best['y'], best['z'], best['s'] pdipm_b.forward = forward