Skip to content

Commit 3afe82e

Browse files
committed
Merge branch 'stes/fix-model-load' into stes/0.6.1
2 parents 595ffdf + abdba00 commit 3afe82e

4 files changed

Lines changed: 194 additions & 56 deletions

File tree

.github/workflows/build.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,28 @@ jobs:
2929
# https://pytorch.org/get-started/previous-versions/
3030
torch-version: ["2.6.0", "2.10.0"]
3131
sklearn-version: ["latest"]
32+
numpy-version: ["latest"]
33+
3234
include:
3335
# windows test with standard config
3436
- os: windows-latest
3537
torch-version: 2.6.0
3638
python-version: "3.12"
3739
sklearn-version: "latest"
40+
numpy-version: "latest"
3841

3942
# legacy sklearn (several API differences)
4043
- os: ubuntu-latest
4144
torch-version: 2.6.0
4245
python-version: "3.12"
4346
sklearn-version: "legacy"
47+
numpy-version: "latest"
48+
49+
- os: ubuntu-latest
50+
torch-version: 2.6.0
51+
python-version: "3.12"
52+
sklearn-version: "latest"
53+
numpy-version: "legacy"
4454

4555
# TODO(stes): latest torch and python
4656
# requires a PyTables release compatible with
@@ -55,6 +65,7 @@ jobs:
5565
torch-version: 2.4.0
5666
python-version: "3.10"
5767
sklearn-version: "legacy"
68+
numpy-version: "latest"
5869

5970
runs-on: ${{ matrix.os }}
6071

@@ -88,6 +99,11 @@ jobs:
8899
run: |
89100
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
90101
102+
- name: Check numpy legacy version
103+
if: matrix.numpy-version == 'legacy'
104+
run: |
105+
pip install "numpy<2" '.[dev,datasets,integrations]'
106+
91107
- name: Run the formatter
92108
run: |
93109
make format

cebra/integrations/sklearn/cebra.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import importlib.metadata
2525
import itertools
26+
import pickle
27+
import warnings
2628
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
2729
Union)
2830

@@ -1336,6 +1338,26 @@ def _get_state(self):
13361338
}
13371339
return state
13381340

1341+
def _get_state_dict(self):
1342+
backend = "sklearn"
1343+
return {
1344+
'args': self.get_params(),
1345+
'state': self._get_state(),
1346+
'state_dict': self.solver_.state_dict(),
1347+
'metadata': {
1348+
'backend':
1349+
backend,
1350+
'cebra_version':
1351+
cebra.__version__,
1352+
'torch_version':
1353+
torch.__version__,
1354+
'numpy_version':
1355+
np.__version__,
1356+
'sklearn_version':
1357+
importlib.metadata.distribution("scikit-learn").version
1358+
}
1359+
}
1360+
13391361
def save(self,
13401362
filename: str,
13411363
backend: Literal["torch", "sklearn"] = "sklearn"):
@@ -1384,28 +1406,16 @@ def save(self,
13841406
"""
13851407
if sklearn_utils.check_fitted(self):
13861408
if backend == "torch":
1409+
warnings.warn(
1410+
"Saving with backend='torch' is deprecated and will be removed in a future version. "
1411+
"Please use backend='sklearn' instead.",
1412+
DeprecationWarning,
1413+
stacklevel=2,
1414+
)
13871415
checkpoint = torch.save(self, filename)
13881416

13891417
elif backend == "sklearn":
1390-
checkpoint = torch.save(
1391-
{
1392-
'args': self.get_params(),
1393-
'state': self._get_state(),
1394-
'state_dict': self.solver_.state_dict(),
1395-
'metadata': {
1396-
'backend':
1397-
backend,
1398-
'cebra_version':
1399-
cebra.__version__,
1400-
'torch_version':
1401-
torch.__version__,
1402-
'numpy_version':
1403-
np.__version__,
1404-
'sklearn_version':
1405-
importlib.metadata.distribution("scikit-learn"
1406-
).version
1407-
}
1408-
}, filename)
1418+
checkpoint = torch.save(self._get_state_dict(), filename)
14091419
else:
14101420
raise NotImplementedError(f"Unsupported backend: {backend}")
14111421
else:
@@ -1457,15 +1467,52 @@ def load(cls,
14571467
>>> tmp_file.unlink()
14581468
"""
14591469
supported_backends = ["auto", "sklearn", "torch"]
1470+
14601471
if backend not in supported_backends:
14611472
raise NotImplementedError(
14621473
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14631474
)
14641475

1465-
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
1476+
if backend not in ["auto", "sklearn"]:
1477+
warnings.warn(
1478+
"From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
1479+
"the sklearn backend is now always used. Models saved with the torch backend can still be loaded.",
1480+
category=DeprecationWarning,
1481+
stacklevel=2,
1482+
)
14661483

1467-
if backend == "auto":
1468-
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
1484+
backend = "sklearn"
1485+
1486+
# NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
1487+
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
1488+
# introduced in torch 2.6.0.
1489+
try:
1490+
checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs)
1491+
except pickle.UnpicklingError as e:
1492+
if weights_only is not False:
1493+
if packaging.version.parse(
1494+
cebra.__version__) < packaging.version.parse("0.7"):
1495+
warnings.warn(
1496+
"Failed to unpickle checkpoint with weights_only=True. "
1497+
"Falling back to loading with weights_only=False. "
1498+
"This is unsafe and should only be done if you trust the source of the model file. "
1499+
"In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.",
1500+
category=UserWarning,
1501+
stacklevel=2,
1502+
)
1503+
else:
1504+
raise ValueError(
1505+
"Failed to unpickle checkpoint with weights_only=True. "
1506+
"This may be due to an incompatible model file format. "
1507+
"To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
1508+
"Example: CEBRA.load(filename, weights_only=False)."
1509+
) from e
1510+
1511+
checkpoint = _safe_torch_load(filename,
1512+
weights_only=False,
1513+
**kwargs)
1514+
checkpoint = _check_type_checkpoint(checkpoint)
1515+
checkpoint = checkpoint._get_state_dict()
14691516

14701517
if isinstance(checkpoint, dict) and backend == "torch":
14711518
raise RuntimeError(
@@ -1476,10 +1523,10 @@ def load(cls,
14761523
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
14771524
"Please try a different backend.")
14781525

1479-
if backend == "sklearn":
1480-
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
1481-
else:
1482-
cebra_ = _check_type_checkpoint(checkpoint)
1526+
if backend != "sklearn":
1527+
raise ValueError(f"Unsupported backend: {backend}")
1528+
1529+
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
14831530

14841531
n_features = cebra_.n_features_
14851532
cebra_.solver_.n_features = ([

cebra/registry.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from __future__ import annotations
4747

4848
import fnmatch
49+
import functools
4950
import itertools
5051
import sys
5152
import textwrap
@@ -214,14 +215,29 @@ def _zip_dict(d):
214215
yield dict(zip(keys, combination))
215216

216217
def _create_class(cls, **default_kwargs):
218+
class_name = pattern.format(**default_kwargs)
217219

218-
@register(pattern.format(**default_kwargs), base=pattern)
220+
@register(class_name, base=pattern)
219221
class _ParametrizedClass(cls):
220222

221223
def __init__(self, *args, **kwargs):
222224
default_kwargs.update(kwargs)
223225
super().__init__(*args, **default_kwargs)
224226

227+
# Make the class pickleable by copying metadata from the base class
228+
# and registering it in the module namespace
229+
functools.update_wrapper(_ParametrizedClass, cls, updated=[])
230+
231+
# Set a unique qualname so pickle finds this class, not the base class
232+
unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}"
233+
_ParametrizedClass.__qualname__ = unique_name
234+
_ParametrizedClass.__name__ = unique_name
235+
236+
# Register in module namespace so pickle can find it via getattr
237+
parent_module = sys.modules.get(cls.__module__)
238+
if parent_module is not None:
239+
setattr(parent_module, unique_name, _ParametrizedClass)
240+
225241
def _parametrize(cls):
226242
for _default_kwargs in kwargs:
227243
_create_class(cls, **_default_kwargs)

0 commit comments

Comments
 (0)