Skip to content

Commit 2aad99a

Browse files
committed
fix loading logic for legacy torch
1 parent abdba00 commit 2aad99a

1 file changed

Lines changed: 21 additions & 24 deletions

File tree

cebra/integrations/sklearn/cebra.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,22 @@ def check_version(estimator):
6464
sklearn.__version__) < packaging.version.parse("1.6.dev")
6565

6666

67-
def _safe_torch_load(filename, weights_only, **kwargs):
68-
if weights_only is None:
69-
if packaging.version.parse(
70-
torch.__version__) >= packaging.version.parse("2.6.0"):
71-
weights_only = True
72-
else:
73-
weights_only = False
67+
def _safe_torch_load(filename, weights_only=False, **kwargs):
68+
checkpoint = None
69+
legacy_mode = packaging.version.parse(
70+
torch.__version__) < packaging.version.parse("2.6.0")
7471

75-
if not weights_only:
72+
if legacy_mode:
7673
checkpoint = torch.load(filename, weights_only=False, **kwargs)
7774
else:
78-
# NOTE(stes): This is only supported for torch 2.6+
7975
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
80-
checkpoint = torch.load(filename, weights_only=True, **kwargs)
76+
checkpoint = torch.load(filename,
77+
weights_only=weights_only,
78+
**kwargs)
79+
80+
if not isinstance(checkpoint, dict):
81+
_check_type_checkpoint(checkpoint)
82+
checkpoint = checkpoint._get_state_dict()
8183

8284
return checkpoint
8385

@@ -317,8 +319,9 @@ def _require_arg(key):
317319

318320
def _check_type_checkpoint(checkpoint):
319321
if not isinstance(checkpoint, cebra.CEBRA):
320-
raise RuntimeError("Model loaded from file is not compatible with "
321-
"the current CEBRA version.")
322+
raise RuntimeError(
323+
"Model loaded from file is not compatible with "
324+
f"the current CEBRA version. Got: {type(checkpoint)}")
322325
if not sklearn_utils.check_fitted(checkpoint):
323326
raise ValueError(
324327
"CEBRA model is not fitted. Loading it is not supported.")
@@ -1487,7 +1490,7 @@ def load(cls,
14871490
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
14881491
# introduced in torch 2.6.0.
14891492
try:
1490-
checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs)
1493+
checkpoint = _safe_torch_load(filename, **kwargs)
14911494
except pickle.UnpicklingError as e:
14921495
if weights_only is not False:
14931496
if packaging.version.parse(
@@ -1511,21 +1514,15 @@ def load(cls,
15111514
checkpoint = _safe_torch_load(filename,
15121515
weights_only=False,
15131516
**kwargs)
1514-
checkpoint = _check_type_checkpoint(checkpoint)
1515-
checkpoint = checkpoint._get_state_dict()
1516-
1517-
if isinstance(checkpoint, dict) and backend == "torch":
1518-
raise RuntimeError(
1519-
"Cannot use 'torch' backend with a dictionary-based checkpoint. "
1520-
"Please try a different backend.")
1521-
if not isinstance(checkpoint, dict) and backend == "sklearn":
1522-
raise RuntimeError(
1523-
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1524-
"Please try a different backend.")
15251517

15261518
if backend != "sklearn":
15271519
raise ValueError(f"Unsupported backend: {backend}")
15281520

1521+
if not isinstance(checkpoint, dict):
1522+
raise RuntimeError(
1523+
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1524+
f"Please try a different backend. Got: {type(checkpoint)}")
1525+
15291526
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
15301527

15311528
n_features = cebra_.n_features_

0 commit comments

Comments
 (0)