diff --git a/README.md b/README.md index 1035f617..45e2def5 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ solver = scs.SCS(data, cone, verbose=False) sol = solver.solve() print(sol["info"]["status"]) # 'solved' +print(sol["info"]["aa_stats"]) # Anderson acceleration diagnostics print(sol["x"]) # primal solution ``` diff --git a/scs/scsobject.h b/scs/scsobject.h index 2f589ac0..d3610df9 100644 --- a/scs/scsobject.h +++ b/scs/scsobject.h @@ -955,7 +955,7 @@ static PyObject *SCS_solve(SCS *self, PyObject *args) { /* else: SCS will overwite sol if _warm_start is false */ /* so we don't need to set to zeros here */ - PyObject *x, *y, *s, *return_dict, *info_dict; + PyObject *x, *y, *s, *return_dict, *info_dict, *aa_stats_dict; scs_float *_x, *_y, *_s; /* release the GIL */ Py_BEGIN_ALLOW_THREADS; @@ -1027,17 +1027,21 @@ static PyObject *SCS_solve(SCS *self, PyObject *args) { #ifdef SFLOAT char *outarg_string = "{s:L,s:L,s:L,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f," "s:f,s:f,s:f,s:f,s:f,s:L,s:L,s:s}"; + char *aa_stats_string = "{s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:f,s:f}"; #else char *outarg_string = "{s:L,s:L,s:L,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d," "s:d,s:d,s:d,s:d,s:d,s:L,s:L,s:s}"; + char *aa_stats_string = "{s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:d,s:d}"; #endif #else #ifdef SFLOAT char *outarg_string = "{s:i,s:i,s:i,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f," "s:f,s:f,s:f,s:f,s:f,s:i,s:i,s:s}"; + char *aa_stats_string = "{s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:f,s:f}"; #else char *outarg_string = "{s:i,s:i,s:i,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d," "s:d,s:d,s:d,s:d,s:d,s:i,s:i,s:s}"; + char *aa_stats_string = "{s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:d,s:d}"; #endif #endif @@ -1066,18 +1070,39 @@ static PyObject *SCS_solve(SCS *self, PyObject *args) { "rejected_accel_steps", (scs_int)info.rejected_accel_steps, "accepted_accel_steps", (scs_int)info.accepted_accel_steps, "status", info.status); + aa_stats_dict = Py_BuildValue( + aa_stats_string, + "iter", (scs_int)info.aa_stats.iter, + "n_accept", (scs_int)info.aa_stats.n_accept, + "n_reject_lapack", (scs_int)info.aa_stats.n_reject_lapack, + "n_reject_rank0", (scs_int)info.aa_stats.n_reject_rank0, + "n_reject_nonfinite", (scs_int)info.aa_stats.n_reject_nonfinite, + "n_reject_weight_cap", (scs_int)info.aa_stats.n_reject_weight_cap, + "n_safeguard_reject", (scs_int)info.aa_stats.n_safeguard_reject, + "last_rank", (scs_int)info.aa_stats.last_rank, + "last_aa_norm", (scs_float)info.aa_stats.last_aa_norm, + "last_regularization", (scs_float)info.aa_stats.last_regularization); /* clang-format on */ + if (!info_dict || !aa_stats_dict || + PyDict_SetItemString(info_dict, "aa_stats", aa_stats_dict) < 0) { + Py_DECREF(x); + Py_DECREF(y); + Py_DECREF(s); + Py_XDECREF(info_dict); + Py_XDECREF(aa_stats_dict); + return NULL; + } + return_dict = Py_BuildValue("{s:O,s:O,s:O,s:O}", "x", x, "y", y, "s", s, "info", info_dict); - /* Give up ownership to the return dictionary. x/y/s are non-NULL - * (NULL-checked above). info_dict can be NULL if Py_BuildValue OOM'd, - * in which case return_dict is also NULL (we propagate it) — use - * Py_XDECREF to avoid dereferencing NULL in that path. */ + /* Give up ownership to the return dictionary. x/y/s/info_dict are non-NULL + * here, and Py_BuildValue borrowed each with "O". */ Py_DECREF(x); Py_DECREF(y); Py_DECREF(s); - Py_XDECREF(info_dict); + Py_DECREF(info_dict); + Py_DECREF(aa_stats_dict); return return_dict; } diff --git a/scs_source b/scs_source index 0929b5e0..86f4b8a5 160000 --- a/scs_source +++ b/scs_source @@ -1 +1 @@ -Subproject commit 0929b5e09ab66f6d7ec571c3d97683da36d1cc81 +Subproject commit 86f4b8a5ce6cc67fdf37d51777ec889665c796bf diff --git a/test/test_scs_coverage.py b/test/test_scs_coverage.py index 764f098c..4e397ce3 100644 --- a/test/test_scs_coverage.py +++ b/test/test_scs_coverage.py @@ -332,6 +332,20 @@ def test_status_constants(): "setup_time", "solve_time", "scale", "accepted_accel_steps", "rejected_accel_steps", + "aa_stats", +} + +_EXPECTED_AA_STATS_KEYS = { + "iter", + "n_accept", + "n_reject_lapack", + "n_reject_rank0", + "n_reject_nonfinite", + "n_reject_weight_cap", + "n_safeguard_reject", + "last_rank", + "last_aa_norm", + "last_regularization", } @@ -343,6 +357,14 @@ def test_info_dict_has_expected_keys(): assert key in info, f"Missing key '{key}' in sol['info']" +def test_aa_stats_dict_has_expected_keys(): + solver = scs.SCS(_make_data(), _CONE, verbose=False) + sol = solver.solve() + aa_stats = sol["info"]["aa_stats"] + for key in _EXPECTED_AA_STATS_KEYS: + assert key in aa_stats, f"Missing key '{key}' in sol['info']['aa_stats']" + + def test_info_status_val_matches_constant(): solver = scs.SCS(_make_data(), _CONE, verbose=False) sol = solver.solve() @@ -1244,6 +1266,7 @@ def test_first_solve_warm_start_true(): "scale", "comp_slack", "accepted_accel_steps", "rejected_accel_steps", + "aa_stats", } @@ -1294,6 +1317,19 @@ def test_accel_steps_nonnegative(): assert sol["info"]["rejected_accel_steps"] >= 0 +def test_aa_stats_no_acceleration(): + """AA diagnostics should be present even when acceleration is disabled.""" + solver = scs.SCS( + _make_data(), _CONE, acceleration_lookback=0, verbose=False + ) + sol = solver.solve() + aa_stats = sol["info"]["aa_stats"] + for key in _EXPECTED_AA_STATS_KEYS - {"last_aa_norm", "last_regularization"}: + assert aa_stats[key] == 0 + assert np.isnan(aa_stats["last_aa_norm"]) + assert aa_stats["last_regularization"] == 0.0 + + def test_pobj_matches_c_dot_x(): """For an LP (no P), primal objective should equal c'x.""" solver = scs.SCS(_make_data(), _CONE, eps_abs=1e-8, eps_rel=1e-8, verbose=False)