@@ -64,20 +64,22 @@ def check_version(estimator):
6464 sklearn .__version__ ) < packaging .version .parse ("1.6.dev" )
6565
6666
67- def _safe_torch_load (filename , weights_only , ** kwargs ):
68- if weights_only is None :
69- if packaging .version .parse (
70- torch .__version__ ) >= packaging .version .parse ("2.6.0" ):
71- weights_only = True
72- else :
73- weights_only = False
67+ def _safe_torch_load (filename , weights_only = False , ** kwargs ):
68+ checkpoint = None
69+ legacy_mode = packaging .version .parse (
70+ torch .__version__ ) < packaging .version .parse ("2.6.0" )
7471
75- if not weights_only :
72+ if legacy_mode :
7673 checkpoint = torch .load (filename , weights_only = False , ** kwargs )
7774 else :
78- # NOTE(stes): This is only supported for torch 2.6+
7975 with torch .serialization .safe_globals (CEBRA_LOAD_SAFE_GLOBALS ):
80- checkpoint = torch .load (filename , weights_only = True , ** kwargs )
76+ checkpoint = torch .load (filename ,
77+ weights_only = weights_only ,
78+ ** kwargs )
79+
80+ if not isinstance (checkpoint , dict ):
81+ _check_type_checkpoint (checkpoint )
82+ checkpoint = checkpoint ._get_state_dict ()
8183
8284 return checkpoint
8385
@@ -317,8 +319,9 @@ def _require_arg(key):
317319
318320def _check_type_checkpoint (checkpoint ):
319321 if not isinstance (checkpoint , cebra .CEBRA ):
320- raise RuntimeError ("Model loaded from file is not compatible with "
321- "the current CEBRA version." )
322+ raise RuntimeError (
323+ "Model loaded from file is not compatible with "
324+ f"the current CEBRA version. Got: { type (checkpoint )} " )
322325 if not sklearn_utils .check_fitted (checkpoint ):
323326 raise ValueError (
324327 "CEBRA model is not fitted. Loading it is not supported." )
@@ -1487,7 +1490,7 @@ def load(cls,
14871490 # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
14881491 # introduced in torch 2.6.0.
14891492 try :
1490- checkpoint = _safe_torch_load (filename , weights_only = True , ** kwargs )
1493+ checkpoint = _safe_torch_load (filename , ** kwargs )
14911494 except pickle .UnpicklingError as e :
14921495 if weights_only is not False :
14931496 if packaging .version .parse (
@@ -1511,21 +1514,15 @@ def load(cls,
15111514 checkpoint = _safe_torch_load (filename ,
15121515 weights_only = False ,
15131516 ** kwargs )
1514- checkpoint = _check_type_checkpoint (checkpoint )
1515- checkpoint = checkpoint ._get_state_dict ()
1516-
1517- if isinstance (checkpoint , dict ) and backend == "torch" :
1518- raise RuntimeError (
1519- "Cannot use 'torch' backend with a dictionary-based checkpoint. "
1520- "Please try a different backend." )
1521- if not isinstance (checkpoint , dict ) and backend == "sklearn" :
1522- raise RuntimeError (
1523- "Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1524- "Please try a different backend." )
15251517
15261518 if backend != "sklearn" :
15271519 raise ValueError (f"Unsupported backend: { backend } " )
15281520
1521+ if not isinstance (checkpoint , dict ):
1522+ raise RuntimeError (
1523+ "Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1524+ f"Please try a different backend. Got: { type (checkpoint )} " )
1525+
15291526 cebra_ = _load_cebra_with_sklearn_backend (checkpoint )
15301527
15311528 n_features = cebra_ .n_features_
0 commit comments