Skip to content

Commit e23e9e1

Browse files
BUG: Fix bug with CSP rank (#12476)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 73661d1 commit e23e9e1

6 files changed

Lines changed: 297 additions & 82 deletions

File tree

doc/changes/devel/12476.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed bugs with handling of rank in :class:`mne.decoding.CSP`, by `Eric Larson`_.

examples/decoding/decoding_csp_eeg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
montage = make_standard_montage("standard_1005")
5050
raw.set_montage(montage)
5151
raw.annotations.rename(dict(T1="hands", T2="feet"))
52+
raw.set_eeg_reference(projection=True)
5253

5354
# Apply band-pass filter
5455
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

mne/cov.py

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
empirical_covariance,
6060
log_likelihood,
6161
)
62-
from .rank import compute_rank
62+
from .rank import _compute_rank
6363
from .utils import (
6464
_array_repr,
6565
_check_fname,
@@ -1226,6 +1226,21 @@ def _eigvec_subspace(eig, eigvec, mask):
12261226
return eig, eigvec
12271227

12281228

1229+
@verbose
1230+
def _compute_rank_raw_array(
1231+
data, info, rank, scalings, *, log_ch_type=None, verbose=None
1232+
):
1233+
from .io import RawArray
1234+
1235+
return _compute_rank(
1236+
RawArray(data, info, copy=None, verbose=_verbose_safe_false()),
1237+
rank,
1238+
scalings,
1239+
info,
1240+
log_ch_type=log_ch_type,
1241+
)
1242+
1243+
12291244
def _compute_covariance_auto(
12301245
data,
12311246
method,
@@ -1237,22 +1252,31 @@ def _compute_covariance_auto(
12371252
stop_early,
12381253
picks_list,
12391254
rank,
1255+
*,
1256+
cov_kind="",
1257+
log_ch_type=None,
1258+
log_rank=True,
12401259
):
12411260
"""Compute covariance auto mode."""
1242-
from .io import RawArray
1243-
12441261
# rescale to improve numerical stability
12451262
orig_rank = rank
1246-
rank = compute_rank(
1247-
RawArray(data.T, info, copy=None, verbose=_verbose_safe_false()),
1248-
rank,
1249-
scalings,
1263+
rank = _compute_rank_raw_array(
1264+
data.T,
12501265
info,
1266+
rank=rank,
1267+
scalings=scalings,
1268+
verbose=_verbose_safe_false(),
12511269
)
12521270
with _scaled_array(data.T, picks_list, scalings):
12531271
C = np.dot(data.T, data)
12541272
_, eigvec, mask = _smart_eigh(
1255-
C, info, rank, proj_subspace=True, do_compute_rank=False
1273+
C,
1274+
info,
1275+
rank,
1276+
proj_subspace=True,
1277+
do_compute_rank=False,
1278+
log_ch_type=log_ch_type,
1279+
verbose=None if log_rank else _verbose_safe_false(),
12561280
)
12571281
eigvec = eigvec[mask]
12581282
data = np.dot(data, eigvec.T)
@@ -1261,21 +1285,24 @@ def _compute_covariance_auto(
12611285
(key, np.searchsorted(used, picks)) for key, picks in picks_list
12621286
]
12631287
sub_info = pick_info(info, used) if len(used) != len(mask) else info
1264-
logger.info(f"Reducing data rank from {len(mask)} -> {eigvec.shape[0]}")
1288+
if log_rank:
1289+
logger.info(f"Reducing data rank from {len(mask)} -> {eigvec.shape[0]}")
12651290
estimator_cov_info = list()
1266-
msg = "Estimating covariance using {}"
12671291

12681292
ok_sklearn = check_version("sklearn")
12691293
if not ok_sklearn and (len(method) != 1 or method[0] != "empirical"):
12701294
raise ValueError(
1271-
"scikit-learn is not installed, `method` must be `empirical`, got "
1272-
f"{method}"
1295+
'scikit-learn is not installed, `method` must be "empirical", got '
1296+
f"{repr(method)}"
12731297
)
12741298

12751299
for method_ in method:
12761300
data_ = data.copy()
12771301
name = method_.__name__ if callable(method_) else method_
1278-
logger.info(msg.format(name.upper()))
1302+
logger.info(
1303+
f'Estimating {cov_kind + (" " if cov_kind else "")}'
1304+
f"covariance using {name.upper()}"
1305+
)
12791306
mp = method_params[method_]
12801307
_info = {}
12811308

@@ -1691,9 +1718,8 @@ def _get_ch_whitener(A, pca, ch_type, rank):
16911718
mask[:-rank] = False
16921719

16931720
logger.info(
1694-
" Setting small {} eigenvalues to zero ({})".format(
1695-
ch_type, "using PCA" if pca else "without PCA"
1696-
)
1721+
f" Setting small {ch_type} eigenvalues to zero "
1722+
f'({"using" if pca else "without"} PCA)'
16971723
)
16981724
if pca: # No PCA case.
16991725
# This line will reduce the actual number of variables in data
@@ -1791,6 +1817,8 @@ def _smart_eigh(
17911817
proj_subspace=False,
17921818
do_compute_rank=True,
17931819
on_rank_mismatch="ignore",
1820+
*,
1821+
log_ch_type=None,
17941822
verbose=None,
17951823
):
17961824
"""Compute eigh of C taking into account rank and ch_type scalings."""
@@ -1813,8 +1841,13 @@ def _smart_eigh(
18131841

18141842
noise_cov = Covariance(C, ch_names, [], projs, 0)
18151843
if do_compute_rank: # if necessary
1816-
rank = compute_rank(
1817-
noise_cov, rank, scalings, info, on_rank_mismatch=on_rank_mismatch
1844+
rank = _compute_rank(
1845+
noise_cov,
1846+
rank,
1847+
scalings,
1848+
info,
1849+
on_rank_mismatch=on_rank_mismatch,
1850+
log_ch_type=log_ch_type,
18181851
)
18191852
assert C.ndim == 2 and C.shape[0] == C.shape[1]
18201853

@@ -1838,7 +1871,11 @@ def _smart_eigh(
18381871
else:
18391872
this_rank = rank[ch_type]
18401873

1841-
e, ev, m = _get_ch_whitener(this_C, False, ch_type.upper(), this_rank)
1874+
if log_ch_type is not None:
1875+
ch_type_ = log_ch_type
1876+
else:
1877+
ch_type_ = ch_type.upper()
1878+
e, ev, m = _get_ch_whitener(this_C, False, ch_type_, this_rank)
18421879
if proj_subspace:
18431880
# Choose the subspace the same way we do for projections
18441881
e, ev = _eigvec_subspace(e, ev, m)
@@ -1995,7 +2032,7 @@ def regularize(
19952032
else:
19962033
regs.update(mag=mag, grad=grad)
19972034
if rank != "full":
1998-
rank = compute_rank(cov, rank, scalings, info)
2035+
rank = _compute_rank(cov, rank, scalings, info)
19992036

20002037
info_ch_names = info["ch_names"]
20012038
ch_names_by_type = dict()
@@ -2071,7 +2108,17 @@ def regularize(
20712108
return cov
20722109

20732110

2074-
def _regularized_covariance(data, reg=None, method_params=None, info=None, rank=None):
2111+
def _regularized_covariance(
2112+
data,
2113+
reg=None,
2114+
method_params=None,
2115+
info=None,
2116+
rank=None,
2117+
*,
2118+
log_ch_type=None,
2119+
log_rank=None,
2120+
cov_kind="",
2121+
):
20752122
"""Compute a regularized covariance from data using sklearn.
20762123
20772124
This is a convenience wrapper for mne.decoding functions, which
@@ -2114,6 +2161,9 @@ def _regularized_covariance(data, reg=None, method_params=None, info=None, rank=
21142161
picks_list=picks_list,
21152162
scalings=scalings,
21162163
rank=rank,
2164+
cov_kind=cov_kind,
2165+
log_ch_type=log_ch_type,
2166+
log_rank=log_rank,
21172167
)[reg]["data"]
21182168
return cov
21192169

mne/decoding/csp.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@
1212
import numpy as np
1313
from scipy.linalg import eigh
1414

15-
from ..cov import _regularized_covariance
15+
from .._fiff.meas_info import create_info
16+
from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh
1617
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
1718
from ..evoked import EvokedArray
1819
from ..fixes import pinv
19-
from ..utils import _check_option, _validate_type, copy_doc, fill_doc
20+
from ..utils import (
21+
_check_option,
22+
_validate_type,
23+
_verbose_safe_false,
24+
copy_doc,
25+
fill_doc,
26+
)
2027
from .base import BaseEstimator
2128
from .mixin import TransformerMixin
2229

@@ -185,6 +192,9 @@ def fit(self, X, y):
185192
f"{n_classes} classes; use component_order='mutual_info' instead."
186193
)
187194

195+
# Convert rank to one that will run
196+
_validate_type(self.rank, (dict, None), "rank")
197+
188198
covs, sample_weights = self._compute_covariance_matrices(X, y)
189199
eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights)
190200
ix = self._order_components(
@@ -519,10 +529,28 @@ def _compute_covariance_matrices(self, X, y):
519529
elif self.cov_est == "epoch":
520530
cov_estimator = self._epoch_cov
521531

532+
# Someday we could allow the user to pass this, then we wouldn't need to convert
533+
# but in the meantime they can use a pipeline with a scaler
534+
self._info = create_info(n_channels, 1000.0, "mag")
535+
if self.rank is None:
536+
self._rank = _compute_rank_raw_array(
537+
X.transpose(1, 0, 2).reshape(X.shape[1], -1),
538+
self._info,
539+
rank=None,
540+
scalings=None,
541+
log_ch_type="data",
542+
)
543+
else:
544+
self._rank = {"mag": sum(self.rank.values())}
545+
522546
covs = []
523547
sample_weights = []
524-
for this_class in self._classes:
525-
cov, weight = cov_estimator(X[y == this_class])
548+
for ci, this_class in enumerate(self._classes):
549+
cov, weight = cov_estimator(
550+
X[y == this_class],
551+
cov_kind=f"class={this_class}",
552+
log_rank=ci == 0,
553+
)
526554

527555
if self.norm_trace:
528556
cov /= np.trace(cov)
@@ -532,29 +560,39 @@ def _compute_covariance_matrices(self, X, y):
532560

533561
return np.stack(covs), np.array(sample_weights)
534562

535-
def _concat_cov(self, x_class):
563+
def _concat_cov(self, x_class, *, cov_kind, log_rank):
536564
"""Concatenate epochs before computing the covariance."""
537565
_, n_channels, _ = x_class.shape
538566

539-
x_class = np.transpose(x_class, [1, 0, 2])
540-
x_class = x_class.reshape(n_channels, -1)
567+
x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1)
541568
cov = _regularized_covariance(
542-
x_class, reg=self.reg, method_params=self.cov_method_params, rank=self.rank
569+
x_class,
570+
reg=self.reg,
571+
method_params=self.cov_method_params,
572+
rank=self._rank,
573+
info=self._info,
574+
cov_kind=cov_kind,
575+
log_rank=log_rank,
576+
log_ch_type="data",
543577
)
544578
weight = x_class.shape[0]
545579

546580
return cov, weight
547581

548-
def _epoch_cov(self, x_class):
582+
def _epoch_cov(self, x_class, *, cov_kind, log_rank):
549583
"""Mean of per-epoch covariances."""
550584
cov = sum(
551585
_regularized_covariance(
552586
this_X,
553587
reg=self.reg,
554588
method_params=self.cov_method_params,
555-
rank=self.rank,
589+
rank=self._rank,
590+
info=self._info,
591+
cov_kind=cov_kind,
592+
log_rank=log_rank and ii == 0,
593+
log_ch_type="data",
556594
)
557-
for this_X in x_class
595+
for ii, this_X in enumerate(x_class)
558596
)
559597
cov /= len(x_class)
560598
weight = len(x_class)
@@ -563,6 +601,20 @@ def _epoch_cov(self, x_class):
563601

564602
def _decompose_covs(self, covs, sample_weights):
565603
n_classes = len(covs)
604+
n_channels = covs[0].shape[0]
605+
assert self._rank is not None # should happen in _compute_covariance_matrices
606+
_, sub_vec, mask = _smart_eigh(
607+
covs.mean(0),
608+
self._info,
609+
self._rank,
610+
proj_subspace=True,
611+
do_compute_rank=False,
612+
log_ch_type="data",
613+
verbose=_verbose_safe_false(),
614+
)
615+
sub_vec = sub_vec[mask]
616+
covs = np.array([sub_vec @ cov @ sub_vec.T for cov in covs], float)
617+
assert covs[0].shape == (mask.sum(),) * 2
566618
if n_classes == 2:
567619
eigen_values, eigen_vectors = eigh(covs[0], covs.sum(0))
568620
else:
@@ -573,6 +625,9 @@ def _decompose_covs(self, covs, sample_weights):
573625
eigen_vectors.T, covs, sample_weights
574626
)
575627
eigen_values = None
628+
# project back
629+
eigen_vectors = sub_vec.T @ eigen_vectors
630+
assert eigen_vectors.shape == (n_channels, mask.sum())
576631
return eigen_vectors, eigen_values
577632

578633
def _compute_mutual_info(self, covs, sample_weights, eigen_vectors):
@@ -824,6 +879,8 @@ def fit(self, X, y):
824879
reg=self.reg,
825880
method_params=self.cov_method_params,
826881
rank=self.rank,
882+
log_ch_type="data",
883+
log_rank=ii == 0,
827884
)
828885

829886
C = covs.mean(0)

0 commit comments

Comments
 (0)