Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ solver = scs.SCS(data, cone, verbose=False)
sol = solver.solve()

print(sol["info"]["status"]) # 'solved'
print(sol["info"]["aa_stats"]) # Anderson acceleration diagnostics
print(sol["x"]) # primal solution
```

Expand Down
37 changes: 31 additions & 6 deletions scs/scsobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ static PyObject *SCS_solve(SCS *self, PyObject *args) {
/* else: SCS will overwite sol if _warm_start is false */
/* so we don't need to set to zeros here */

PyObject *x, *y, *s, *return_dict, *info_dict;
PyObject *x, *y, *s, *return_dict, *info_dict, *aa_stats_dict;
scs_float *_x, *_y, *_s;
/* release the GIL */
Py_BEGIN_ALLOW_THREADS;
Expand Down Expand Up @@ -1027,17 +1027,21 @@ static PyObject *SCS_solve(SCS *self, PyObject *args) {
#ifdef SFLOAT
char *outarg_string = "{s:L,s:L,s:L,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,"
"s:f,s:f,s:f,s:f,s:f,s:L,s:L,s:s}";
char *aa_stats_string = "{s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:f,s:f}";
#else
char *outarg_string = "{s:L,s:L,s:L,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,"
"s:d,s:d,s:d,s:d,s:d,s:L,s:L,s:s}";
char *aa_stats_string = "{s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:L,s:d,s:d}";
#endif
#else
#ifdef SFLOAT
char *outarg_string = "{s:i,s:i,s:i,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,s:f,"
"s:f,s:f,s:f,s:f,s:f,s:i,s:i,s:s}";
char *aa_stats_string = "{s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:f,s:f}";
#else
char *outarg_string = "{s:i,s:i,s:i,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,s:d,"
"s:d,s:d,s:d,s:d,s:d,s:i,s:i,s:s}";
char *aa_stats_string = "{s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:i,s:d,s:d}";
#endif
#endif

Expand Down Expand Up @@ -1066,18 +1070,39 @@ static PyObject *SCS_solve(SCS *self, PyObject *args) {
"rejected_accel_steps", (scs_int)info.rejected_accel_steps,
"accepted_accel_steps", (scs_int)info.accepted_accel_steps,
"status", info.status);
aa_stats_dict = Py_BuildValue(
aa_stats_string,
"iter", (scs_int)info.aa_stats.iter,
"n_accept", (scs_int)info.aa_stats.n_accept,
"n_reject_lapack", (scs_int)info.aa_stats.n_reject_lapack,
"n_reject_rank0", (scs_int)info.aa_stats.n_reject_rank0,
"n_reject_nonfinite", (scs_int)info.aa_stats.n_reject_nonfinite,
"n_reject_weight_cap", (scs_int)info.aa_stats.n_reject_weight_cap,
"n_safeguard_reject", (scs_int)info.aa_stats.n_safeguard_reject,
"last_rank", (scs_int)info.aa_stats.last_rank,
"last_aa_norm", (scs_float)info.aa_stats.last_aa_norm,
"last_regularization", (scs_float)info.aa_stats.last_regularization);
/* clang-format on */

if (!info_dict || !aa_stats_dict ||
PyDict_SetItemString(info_dict, "aa_stats", aa_stats_dict) < 0) {
Py_DECREF(x);
Py_DECREF(y);
Py_DECREF(s);
Py_XDECREF(info_dict);
Py_XDECREF(aa_stats_dict);
return NULL;
}

return_dict = Py_BuildValue("{s:O,s:O,s:O,s:O}", "x", x, "y", y, "s", s,
"info", info_dict);
/* Give up ownership to the return dictionary. x/y/s are non-NULL
* (NULL-checked above). info_dict can be NULL if Py_BuildValue OOM'd,
* in which case return_dict is also NULL (we propagate it) — use
* Py_XDECREF to avoid dereferencing NULL in that path. */
/* Give up ownership to the return dictionary. x/y/s/info_dict are non-NULL
* here, and Py_BuildValue borrowed each with "O". */
Py_DECREF(x);
Py_DECREF(y);
Py_DECREF(s);
Py_XDECREF(info_dict);
Py_DECREF(info_dict);
Py_DECREF(aa_stats_dict);

return return_dict;
}
Expand Down
36 changes: 36 additions & 0 deletions test/test_scs_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,20 @@ def test_status_constants():
"setup_time", "solve_time",
"scale",
"accepted_accel_steps", "rejected_accel_steps",
"aa_stats",
}

_EXPECTED_AA_STATS_KEYS = {
"iter",
"n_accept",
"n_reject_lapack",
"n_reject_rank0",
"n_reject_nonfinite",
"n_reject_weight_cap",
"n_safeguard_reject",
"last_rank",
"last_aa_norm",
"last_regularization",
}


Expand All @@ -343,6 +357,14 @@ def test_info_dict_has_expected_keys():
assert key in info, f"Missing key '{key}' in sol['info']"


def test_aa_stats_dict_has_expected_keys():
solver = scs.SCS(_make_data(), _CONE, verbose=False)
sol = solver.solve()
aa_stats = sol["info"]["aa_stats"]
for key in _EXPECTED_AA_STATS_KEYS:
assert key in aa_stats, f"Missing key '{key}' in sol['info']['aa_stats']"


def test_info_status_val_matches_constant():
solver = scs.SCS(_make_data(), _CONE, verbose=False)
sol = solver.solve()
Expand Down Expand Up @@ -1244,6 +1266,7 @@ def test_first_solve_warm_start_true():
"scale",
"comp_slack",
"accepted_accel_steps", "rejected_accel_steps",
"aa_stats",
}


Expand Down Expand Up @@ -1294,6 +1317,19 @@ def test_accel_steps_nonnegative():
assert sol["info"]["rejected_accel_steps"] >= 0


def test_aa_stats_no_acceleration():
"""AA diagnostics should be present even when acceleration is disabled."""
solver = scs.SCS(
_make_data(), _CONE, acceleration_lookback=0, verbose=False
)
sol = solver.solve()
aa_stats = sol["info"]["aa_stats"]
for key in _EXPECTED_AA_STATS_KEYS - {"last_aa_norm", "last_regularization"}:
assert aa_stats[key] == 0
assert np.isnan(aa_stats["last_aa_norm"])
assert aa_stats["last_regularization"] == 0.0


def test_pobj_matches_c_dot_x():
"""For an LP (no P), primal objective should equal c'x."""
solver = scs.SCS(_make_data(), _CONE, eps_abs=1e-8, eps_rel=1e-8, verbose=False)
Expand Down
Loading