Skip to content

Commit 721eabb

Browse files
committed
Merge branch 'stes/fix-model-load' into stes/0.6.1
2 parents f3dfac1 + bd27653 commit 721eabb

1 file changed

Lines changed: 28 additions & 26 deletions

File tree

cebra/integrations/sklearn/cebra.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,13 @@
5252
# windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072)
5353
# on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn)
5454
CEBRA_LOAD_SAFE_GLOBALS = [
55-
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
56-
np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType
55+
cebra.data.Offset,
56+
torch.torch_version.TorchVersion,
57+
np.dtype,
58+
np.dtypes.Int32DType,
59+
np.dtypes.Int64DType,
60+
np.dtypes.Float32DType,
61+
np.dtypes.Float64DType,
5762
]
5863

5964

@@ -64,20 +69,22 @@ def check_version(estimator):
6469
sklearn.__version__) < packaging.version.parse("1.6.dev")
6570

6671

67-
def _safe_torch_load(filename, weights_only, **kwargs):
68-
if weights_only is None:
69-
if packaging.version.parse(
70-
torch.__version__) >= packaging.version.parse("2.6.0"):
71-
weights_only = True
72-
else:
73-
weights_only = False
72+
def _safe_torch_load(filename, weights_only=False, **kwargs):
73+
checkpoint = None
74+
legacy_mode = packaging.version.parse(
75+
torch.__version__) < packaging.version.parse("2.6.0")
7476

75-
if not weights_only:
77+
if legacy_mode:
7678
checkpoint = torch.load(filename, weights_only=False, **kwargs)
7779
else:
78-
# NOTE(stes): This is only supported for torch 2.6+
7980
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
80-
checkpoint = torch.load(filename, weights_only=True, **kwargs)
81+
checkpoint = torch.load(filename,
82+
weights_only=weights_only,
83+
**kwargs)
84+
85+
if not isinstance(checkpoint, dict):
86+
_check_type_checkpoint(checkpoint)
87+
checkpoint = checkpoint._get_state_dict()
8188

8289
return checkpoint
8390

@@ -317,8 +324,9 @@ def _require_arg(key):
317324

318325
def _check_type_checkpoint(checkpoint):
319326
if not isinstance(checkpoint, cebra.CEBRA):
320-
raise RuntimeError("Model loaded from file is not compatible with "
321-
"the current CEBRA version.")
327+
raise RuntimeError(
328+
"Model loaded from file is not compatible with "
329+
f"the current CEBRA version. Got: {type(checkpoint)}")
322330
if not sklearn_utils.check_fitted(checkpoint):
323331
raise ValueError(
324332
"CEBRA model is not fitted. Loading it is not supported.")
@@ -1487,7 +1495,7 @@ def load(cls,
14871495
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
14881496
# introduced in torch 2.6.0.
14891497
try:
1490-
checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs)
1498+
checkpoint = _safe_torch_load(filename, **kwargs)
14911499
except pickle.UnpicklingError as e:
14921500
if weights_only is not False:
14931501
if packaging.version.parse(
@@ -1511,21 +1519,15 @@ def load(cls,
15111519
checkpoint = _safe_torch_load(filename,
15121520
weights_only=False,
15131521
**kwargs)
1514-
checkpoint = _check_type_checkpoint(checkpoint)
1515-
checkpoint = checkpoint._get_state_dict()
1516-
1517-
if isinstance(checkpoint, dict) and backend == "torch":
1518-
raise RuntimeError(
1519-
"Cannot use 'torch' backend with a dictionary-based checkpoint. "
1520-
"Please try a different backend.")
1521-
if not isinstance(checkpoint, dict) and backend == "sklearn":
1522-
raise RuntimeError(
1523-
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1524-
"Please try a different backend.")
15251522

15261523
if backend != "sklearn":
15271524
raise ValueError(f"Unsupported backend: {backend}")
15281525

1526+
if not isinstance(checkpoint, dict):
1527+
raise RuntimeError(
1528+
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1529+
f"Please try a different backend. Got: {type(checkpoint)}")
1530+
15291531
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
15301532

15311533
n_features = cebra_.n_features_

0 commit comments

Comments
 (0)