5252# windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072)
5353# on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn)
5454CEBRA_LOAD_SAFE_GLOBALS = [
55- cebra .data .Offset , torch .torch_version .TorchVersion , np .dtype ,
56- np .dtypes .Int32DType , np .dtypes .Float64DType , np .dtypes .Int64DType
55+ cebra .data .Offset ,
56+ torch .torch_version .TorchVersion ,
57+ np .dtype ,
58+ np .dtypes .Int32DType ,
59+ np .dtypes .Int64DType ,
60+ np .dtypes .Float32DType ,
61+ np .dtypes .Float64DType ,
5762]
5863
5964
@@ -64,20 +69,22 @@ def check_version(estimator):
6469 sklearn .__version__ ) < packaging .version .parse ("1.6.dev" )
6570
6671
67- def _safe_torch_load (filename , weights_only , ** kwargs ):
68- if weights_only is None :
69- if packaging .version .parse (
70- torch .__version__ ) >= packaging .version .parse ("2.6.0" ):
71- weights_only = True
72- else :
73- weights_only = False
72+ def _safe_torch_load (filename , weights_only = False , ** kwargs ):
73+ checkpoint = None
74+ legacy_mode = packaging .version .parse (
75+ torch .__version__ ) < packaging .version .parse ("2.6.0" )
7476
75- if not weights_only :
77+ if legacy_mode :
7678 checkpoint = torch .load (filename , weights_only = False , ** kwargs )
7779 else :
78- # NOTE(stes): This is only supported for torch 2.6+
7980 with torch .serialization .safe_globals (CEBRA_LOAD_SAFE_GLOBALS ):
80- checkpoint = torch .load (filename , weights_only = True , ** kwargs )
81+ checkpoint = torch .load (filename ,
82+ weights_only = weights_only ,
83+ ** kwargs )
84+
85+ if not isinstance (checkpoint , dict ):
86+ _check_type_checkpoint (checkpoint )
87+ checkpoint = checkpoint ._get_state_dict ()
8188
8289 return checkpoint
8390
@@ -317,8 +324,9 @@ def _require_arg(key):
317324
318325def _check_type_checkpoint (checkpoint ):
319326 if not isinstance (checkpoint , cebra .CEBRA ):
320- raise RuntimeError ("Model loaded from file is not compatible with "
321- "the current CEBRA version." )
327+ raise RuntimeError (
328+ "Model loaded from file is not compatible with "
329+ f"the current CEBRA version. Got: { type (checkpoint )} " )
322330 if not sklearn_utils .check_fitted (checkpoint ):
323331 raise ValueError (
324332 "CEBRA model is not fitted. Loading it is not supported." )
@@ -1487,7 +1495,7 @@ def load(cls,
14871495 # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
14881496 # introduced in torch 2.6.0.
14891497 try :
1490- checkpoint = _safe_torch_load (filename , weights_only = True , ** kwargs )
1498+ checkpoint = _safe_torch_load (filename , ** kwargs )
14911499 except pickle .UnpicklingError as e :
14921500 if weights_only is not False :
14931501 if packaging .version .parse (
@@ -1511,21 +1519,15 @@ def load(cls,
15111519 checkpoint = _safe_torch_load (filename ,
15121520 weights_only = False ,
15131521 ** kwargs )
1514- checkpoint = _check_type_checkpoint (checkpoint )
1515- checkpoint = checkpoint ._get_state_dict ()
1516-
1517- if isinstance (checkpoint , dict ) and backend == "torch" :
1518- raise RuntimeError (
1519- "Cannot use 'torch' backend with a dictionary-based checkpoint. "
1520- "Please try a different backend." )
1521- if not isinstance (checkpoint , dict ) and backend == "sklearn" :
1522- raise RuntimeError (
1523- "Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1524- "Please try a different backend." )
15251522
15261523 if backend != "sklearn" :
15271524 raise ValueError (f"Unsupported backend: { backend } " )
15281525
1526+ if not isinstance (checkpoint , dict ):
1527+ raise RuntimeError (
1528+ "Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1529+ f"Please try a different backend. Got: { type (checkpoint )} " )
1530+
15291531 cebra_ = _load_cebra_with_sklearn_backend (checkpoint )
15301532
15311533 n_features = cebra_ .n_features_
0 commit comments