@@ -526,7 +526,7 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
526526
527527 _obj : Any
528528 _is_iter : bool
529- _registered : ClassVar [bool ] = False
529+ _registered : ClassVar [set [ JitLibrary ]] = set ()
530530 __slots__ : tuple [str , ...] = ("_is_iter" , "_obj" )
531531
532532 def __init__ (self , obj : T , jit_library : JitLibrary ) -> None : # numpydoc ignore=GL08
@@ -549,24 +549,26 @@ def _register(cls, jit_library: JitLibrary) -> None: # numpydoc ignore=SS06,PR0
549549 Register upon first use instead of at import time, to avoid
550550 globally importing JAX.
551551 """
552- if not cls ._registered :
553- if jit_library is JitLibrary .jax :
554- import jax
555-
556- jax .tree_util .register_pytree_node (
557- cls ,
558- lambda instance : pickle_flatten (instance , jax .Array ), # pyright: ignore[reportUnknownArgumentType]
559- lambda aux_data , children : pickle_unflatten (children , aux_data ), # pyright: ignore[reportUnknownArgumentType]
560- )
561- elif jit_library is JitLibrary .torch :
562- import torch
563-
564- torch .utils ._pytree .register_pytree_node (
565- cls ,
566- lambda instance : pickle_flatten (instance , torch .Tensor ), # pyright: ignore[reportUnknownArgumentType]
567- pickle_unflatten ,
568- )
569- cls ._registered = True
552+ if jit_library in cls ._registered :
553+ return
554+
555+ if jit_library is JitLibrary .jax :
556+ import jax
557+
558+ jax .tree_util .register_pytree_node (
559+ cls ,
560+ lambda instance : pickle_flatten (instance , jax .Array ), # pyright: ignore[reportUnknownArgumentType]
561+ lambda aux_data , children : pickle_unflatten (children , aux_data ), # pyright: ignore[reportUnknownArgumentType]
562+ )
563+ elif jit_library is JitLibrary .torch :
564+ import torch
565+
566+ torch .utils ._pytree .register_pytree_node (
567+ cls ,
568+ lambda instance : pickle_flatten (instance , torch .Tensor ), # pyright: ignore[reportUnknownArgumentType]
569+ pickle_unflatten ,
570+ )
571+ cls ._registered .add (jit_library )
570572
571573
572574class JitLibrary (Enum ):
0 commit comments