5959 empirical_covariance ,
6060 log_likelihood ,
6161)
62- from .rank import compute_rank
62+ from .rank import _compute_rank
6363from .utils import (
6464 _array_repr ,
6565 _check_fname ,
@@ -1226,6 +1226,21 @@ def _eigvec_subspace(eig, eigvec, mask):
12261226 return eig , eigvec
12271227
12281228
1229+ @verbose
1230+ def _compute_rank_raw_array (
1231+ data , info , rank , scalings , * , log_ch_type = None , verbose = None
1232+ ):
1233+ from .io import RawArray
1234+
1235+ return _compute_rank (
1236+ RawArray (data , info , copy = None , verbose = _verbose_safe_false ()),
1237+ rank ,
1238+ scalings ,
1239+ info ,
1240+ log_ch_type = log_ch_type ,
1241+ )
1242+
1243+
12291244def _compute_covariance_auto (
12301245 data ,
12311246 method ,
@@ -1237,22 +1252,31 @@ def _compute_covariance_auto(
12371252 stop_early ,
12381253 picks_list ,
12391254 rank ,
1255+ * ,
1256+ cov_kind = "" ,
1257+ log_ch_type = None ,
1258+ log_rank = True ,
12401259):
12411260 """Compute covariance auto mode."""
1242- from .io import RawArray
1243-
12441261 # rescale to improve numerical stability
12451262 orig_rank = rank
1246- rank = compute_rank (
1247- RawArray (data .T , info , copy = None , verbose = _verbose_safe_false ()),
1248- rank ,
1249- scalings ,
1263+ rank = _compute_rank_raw_array (
1264+ data .T ,
12501265 info ,
1266+ rank = rank ,
1267+ scalings = scalings ,
1268+ verbose = _verbose_safe_false (),
12511269 )
12521270 with _scaled_array (data .T , picks_list , scalings ):
12531271 C = np .dot (data .T , data )
12541272 _ , eigvec , mask = _smart_eigh (
1255- C , info , rank , proj_subspace = True , do_compute_rank = False
1273+ C ,
1274+ info ,
1275+ rank ,
1276+ proj_subspace = True ,
1277+ do_compute_rank = False ,
1278+ log_ch_type = log_ch_type ,
1279+ verbose = None if log_rank else _verbose_safe_false (),
12561280 )
12571281 eigvec = eigvec [mask ]
12581282 data = np .dot (data , eigvec .T )
@@ -1261,21 +1285,24 @@ def _compute_covariance_auto(
12611285 (key , np .searchsorted (used , picks )) for key , picks in picks_list
12621286 ]
12631287 sub_info = pick_info (info , used ) if len (used ) != len (mask ) else info
1264- logger .info (f"Reducing data rank from { len (mask )} -> { eigvec .shape [0 ]} " )
1288+ if log_rank :
1289+ logger .info (f"Reducing data rank from { len (mask )} -> { eigvec .shape [0 ]} " )
12651290 estimator_cov_info = list ()
1266- msg = "Estimating covariance using {}"
12671291
12681292 ok_sklearn = check_version ("sklearn" )
12691293 if not ok_sklearn and (len (method ) != 1 or method [0 ] != "empirical" ):
12701294 raise ValueError (
1271- " scikit-learn is not installed, `method` must be ` empirical` , got "
1272- f"{ method } "
1295+ ' scikit-learn is not installed, `method` must be " empirical" , got '
1296+ f"{ repr ( method ) } "
12731297 )
12741298
12751299 for method_ in method :
12761300 data_ = data .copy ()
12771301 name = method_ .__name__ if callable (method_ ) else method_
1278- logger .info (msg .format (name .upper ()))
1302+ logger .info (
1303+ f'Estimating { cov_kind + (" " if cov_kind else "" )} '
1304+ f"covariance using { name .upper ()} "
1305+ )
12791306 mp = method_params [method_ ]
12801307 _info = {}
12811308
@@ -1691,9 +1718,8 @@ def _get_ch_whitener(A, pca, ch_type, rank):
16911718 mask [:- rank ] = False
16921719
16931720 logger .info (
1694- " Setting small {} eigenvalues to zero ({})" .format (
1695- ch_type , "using PCA" if pca else "without PCA"
1696- )
1721+ f" Setting small { ch_type } eigenvalues to zero "
1722+ f'({ "using" if pca else "without" } PCA)'
16971723 )
16981724 if pca : # No PCA case.
16991725 # This line will reduce the actual number of variables in data
@@ -1791,6 +1817,8 @@ def _smart_eigh(
17911817 proj_subspace = False ,
17921818 do_compute_rank = True ,
17931819 on_rank_mismatch = "ignore" ,
1820+ * ,
1821+ log_ch_type = None ,
17941822 verbose = None ,
17951823):
17961824 """Compute eigh of C taking into account rank and ch_type scalings."""
@@ -1813,8 +1841,13 @@ def _smart_eigh(
18131841
18141842 noise_cov = Covariance (C , ch_names , [], projs , 0 )
18151843 if do_compute_rank : # if necessary
1816- rank = compute_rank (
1817- noise_cov , rank , scalings , info , on_rank_mismatch = on_rank_mismatch
1844+ rank = _compute_rank (
1845+ noise_cov ,
1846+ rank ,
1847+ scalings ,
1848+ info ,
1849+ on_rank_mismatch = on_rank_mismatch ,
1850+ log_ch_type = log_ch_type ,
18181851 )
18191852 assert C .ndim == 2 and C .shape [0 ] == C .shape [1 ]
18201853
@@ -1838,7 +1871,11 @@ def _smart_eigh(
18381871 else :
18391872 this_rank = rank [ch_type ]
18401873
1841- e , ev , m = _get_ch_whitener (this_C , False , ch_type .upper (), this_rank )
1874+ if log_ch_type is not None :
1875+ ch_type_ = log_ch_type
1876+ else :
1877+ ch_type_ = ch_type .upper ()
1878+ e , ev , m = _get_ch_whitener (this_C , False , ch_type_ , this_rank )
18421879 if proj_subspace :
18431880 # Choose the subspace the same way we do for projections
18441881 e , ev = _eigvec_subspace (e , ev , m )
@@ -1995,7 +2032,7 @@ def regularize(
19952032 else :
19962033 regs .update (mag = mag , grad = grad )
19972034 if rank != "full" :
1998- rank = compute_rank (cov , rank , scalings , info )
2035+ rank = _compute_rank (cov , rank , scalings , info )
19992036
20002037 info_ch_names = info ["ch_names" ]
20012038 ch_names_by_type = dict ()
@@ -2071,7 +2108,17 @@ def regularize(
20712108 return cov
20722109
20732110
2074- def _regularized_covariance (data , reg = None , method_params = None , info = None , rank = None ):
2111+ def _regularized_covariance (
2112+ data ,
2113+ reg = None ,
2114+ method_params = None ,
2115+ info = None ,
2116+ rank = None ,
2117+ * ,
2118+ log_ch_type = None ,
2119+ log_rank = None ,
2120+ cov_kind = "" ,
2121+ ):
20752122 """Compute a regularized covariance from data using sklearn.
20762123
20772124 This is a convenience wrapper for mne.decoding functions, which
@@ -2114,6 +2161,9 @@ def _regularized_covariance(data, reg=None, method_params=None, info=None, rank=
21142161 picks_list = picks_list ,
21152162 scalings = scalings ,
21162163 rank = rank ,
2164+ cov_kind = cov_kind ,
2165+ log_ch_type = log_ch_type ,
2166+ log_rank = log_rank ,
21172167 )[reg ]["data" ]
21182168 return cov
21192169
0 commit comments