Skip to content

Commit 589aa81

Browse files
committed
avoid clashing registrations
1 parent 1906fd9 commit 589aa81

1 file changed

Lines changed: 21 additions & 19 deletions

File tree

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
526526

527527
_obj: Any
528528
_is_iter: bool
529-
_registered: ClassVar[bool] = False
529+
_registered: ClassVar[set[JitLibrary]] = set()
530530
__slots__: tuple[str, ...] = ("_is_iter", "_obj")
531531

532532
def __init__(self, obj: T, jit_library: JitLibrary) -> None: # numpydoc ignore=GL08
@@ -549,24 +549,26 @@ def _register(cls, jit_library: JitLibrary) -> None: # numpydoc ignore=SS06,PR0
549549
Register upon first use instead of at import time, to avoid
550550
globally importing JAX.
551551
"""
552-
if not cls._registered:
553-
if jit_library is JitLibrary.jax:
554-
import jax
555-
556-
jax.tree_util.register_pytree_node(
557-
cls,
558-
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
559-
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
560-
)
561-
elif jit_library is JitLibrary.torch:
562-
import torch
563-
564-
torch.utils._pytree.register_pytree_node(
565-
cls,
566-
lambda instance: pickle_flatten(instance, torch.Tensor), # pyright: ignore[reportUnknownArgumentType]
567-
pickle_unflatten,
568-
)
569-
cls._registered = True
552+
if jit_library in cls._registered:
553+
return
554+
555+
if jit_library is JitLibrary.jax:
556+
import jax
557+
558+
jax.tree_util.register_pytree_node(
559+
cls,
560+
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
561+
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
562+
)
563+
elif jit_library is JitLibrary.torch:
564+
import torch
565+
566+
torch.utils._pytree.register_pytree_node(
567+
cls,
568+
lambda instance: pickle_flatten(instance, torch.Tensor), # pyright: ignore[reportUnknownArgumentType]
569+
pickle_unflatten,
570+
)
571+
cls._registered.add(jit_library)
570572

571573

572574
class JitLibrary(Enum):

0 commit comments

Comments
 (0)