Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,11 @@ dask-core = ">=2026.3.0" # No distributed, tornado, etc.
minversion = "6.0"
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
xfail_strict = true
filterwarnings = ["error"]
filterwarnings = [
"error",
'ignore:.*torch.jit.script_method.*',
'ignore:.*accumulated_recompile_limit reached.*',
]
log_cli_level = "INFO"
testpaths = ["tests"]
markers = [
Expand Down
60 changes: 45 additions & 15 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

import functools
import io
import math
import pickle
import types
from collections.abc import Callable, Generator, Iterable, Iterator
from enum import Enum, auto
from functools import wraps
from types import ModuleType
from typing import (
Expand Down Expand Up @@ -47,12 +49,13 @@ def override(func):


__all__ = [
"JitLibrary",
"asarrays",
"autojit",
"capabilities",
"eager_shape",
"in1d",
"is_python_scalar",
"jax_autojit",
"mean",
"meta_namespace",
"pickle_flatten",
Expand Down Expand Up @@ -515,20 +518,20 @@ def persistent_load(self, pid: Literal[0, 1]) -> object: # numpydoc ignore=GL08

class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
"""
Helper of :func:`jax_autojit`.
Helper of :func:`autojit`.

Wrap arbitrary inputs and outputs of the jitted function and
convert them to/from PyTrees.
"""

_obj: Any
_is_iter: bool
_registered: ClassVar[bool] = False
_registered: ClassVar[set[JitLibrary]] = set()
__slots__: tuple[str, ...] = ("_is_iter", "_obj")

def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
self._register()
if isinstance(obj, Iterator):
def __init__(self, obj: T, jit_library: JitLibrary) -> None: # numpydoc ignore=GL08
self._register(jit_library)
if jit_library is JitLibrary.jax and isinstance(obj, Iterator):
self._obj = list(obj)
self._is_iter = True
Comment on lines +534 to 536
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude reckoned that, unlike JAX, we needn't treat iterables with this special case for torch.compile

else:
Expand All @@ -541,24 +544,44 @@ def obj(self) -> T: # numpydoc ignore=RT01
return iter(self._obj) if self._is_iter else self._obj

@classmethod
def _register(cls) -> None: # numpydoc ignore=SS06
def _register(cls, jit_library: JitLibrary) -> None: # numpydoc ignore=SS06,PR01
"""
Register upon first use instead of at import time, to avoid
globally importing JAX.
"""
if not cls._registered:
if jit_library in cls._registered:
return

if jit_library is JitLibrary.jax:
import jax

jax.tree_util.register_pytree_node(
cls,
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
)
cls._registered = True
elif jit_library is JitLibrary.torch:
import torch

torch.utils._pytree.register_pytree_node(
cls,
lambda instance: pickle_flatten(instance, torch.Tensor), # pyright: ignore[reportUnknownArgumentType]
pickle_unflatten,
)
cls._registered.add(jit_library)


def jax_autojit(
func: Callable[P, T],
class JitLibrary(Enum):
"""
Enum for JIT libraries compatible with `autojit`.
"""

jax = auto()
torch = auto()


def autojit(
func: Callable[P, T], jit_library: JitLibrary
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
"""
Wrap `func` with ``jax.jit``, with the following differences:
Expand Down Expand Up @@ -601,19 +624,26 @@ def f(x: Array, y: float, plus: bool) -> Array:
``j1``, but on the flip side it means that it will be re-traced for every different
value of ``y``, which likely makes it not fit for purpose in production.
"""
import jax
if jit_library is JitLibrary.jax:
import jax

jit_decorator = jax.jit
elif jit_library is JitLibrary.torch:
import torch

jit_decorator = functools.partial(torch.compile, fullgraph=True)

@jax.jit # type: ignore[untyped-decorator] # pyright: ignore[reportUntypedFunctionDecorator]
@jit_decorator # type: ignore[untyped-decorator] # pyright: ignore[reportUntypedFunctionDecorator]
def inner( # numpydoc ignore=GL08
wargs: _AutoJITWrapper[Any],
) -> _AutoJITWrapper[T]:
args, kwargs = wargs.obj
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
return _AutoJITWrapper(res)
return _AutoJITWrapper(res, jit_library)

@wraps(func)
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
wargs = _AutoJITWrapper((args, kwargs))
wargs = _AutoJITWrapper((args, kwargs), jit_library)
Comment thread
lucascolley marked this conversation as resolved.
return inner(wargs).obj

return outer
20 changes: 17 additions & 3 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from types import FunctionType, ModuleType
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast

from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
from ._lib._utils._helpers import jax_autojit, pickle_flatten, pickle_unflatten
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace, is_torch_namespace
from ._lib._utils._helpers import JitLibrary, autojit, pickle_flatten, pickle_unflatten

__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]

Expand Down Expand Up @@ -67,6 +67,7 @@ def lazy_xp_function(
*,
allow_dask_compute: bool | int = False,
jax_jit: bool = True,
torch_compile: bool = True,
static_argnums: Deprecated = DEPRECATED,
static_argnames: Deprecated = DEPRECATED,
) -> None: # numpydoc ignore=GL07
Expand Down Expand Up @@ -222,6 +223,7 @@ def test_myfunc(xp):
tags: dict[str, bool | int | type] = {
"allow_dask_compute": allow_dask_compute,
"jax_jit": jax_jit,
"torch_compile": torch_compile,
}

if isinstance(func, tuple):
Expand Down Expand Up @@ -419,7 +421,19 @@ def iter_tagged() -> Iterator[
elif is_jax_namespace(xp):
for target, name, attr, func, tags in iter_tagged():
if tags["jax_jit"]:
wrapped = jax_autojit(func)
wrapped = autojit(func, JitLibrary.jax)
# If we're dealing with a staticmethod or classmethod, make
# sure things stay that way.
if isinstance(attr, staticmethod):
wrapped = staticmethod(wrapped)
elif isinstance(attr, classmethod):
wrapped = classmethod(wrapped)
temp_setattr(target, name, wrapped)

elif is_torch_namespace(xp):
for target, name, attr, func, tags in iter_tagged():
if tags["torch_compile"]:
wrapped = autojit(func, JitLibrary.torch)
Comment on lines +424 to +436
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor: could reduce some LoC perhaps)

# If we're dealing with a staticmethod or classmethod, make
# sure things stay that way.
if isinstance(attr, staticmethod):
Expand Down
85 changes: 59 additions & 26 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Iterator
import functools
from collections.abc import Callable, Iterator
from types import ModuleType
from typing import TYPE_CHECKING, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Generic, ParamSpec, Protocol, TypeVar, cast

import numpy as np
import pytest
Expand All @@ -10,11 +11,12 @@
from array_api_extra._lib._utils._compat import array_namespace
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import (
JitLibrary,
asarrays,
autojit,
capabilities,
eager_shape,
in1d,
jax_autojit,
meta_namespace,
ndindex,
pickle_flatten,
Expand All @@ -34,6 +36,7 @@ def override(func):
return func


P = ParamSpec("P")
T = TypeVar("T")

# FIXME calls xp.unique_values without size
Expand Down Expand Up @@ -359,41 +362,48 @@ def test_recursion(self):
assert obj2[1] is obj2


class TestJAXAutoJIT:
def test_basic(self, jnp: ModuleType):
@jax_autojit
class AutoJitFunc(Protocol):
def __call__(
self,
func: Callable[P, T],
) -> Callable[P, T]: ...


class CheckAutoJIT:
def test_basic(self, autojit_func: AutoJitFunc, xp: ModuleType):
@autojit_func
def f(x: Array, k: object = False) -> Array:
return x + 1 if k else x - 1

# Basic recognition of static_argnames
xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1]))
xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1]))
xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3]))
xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3]))
xp_assert_equal(f(xp.asarray([1, 2])), xp.asarray([0, 1]))
xp_assert_equal(f(xp.asarray([1, 2]), False), xp.asarray([0, 1]))
xp_assert_equal(f(xp.asarray([1, 2]), True), xp.asarray([2, 3]))
xp_assert_equal(f(xp.asarray([1, 2]), 1), xp.asarray([2, 3]))

# static argument is not an ArrayLike
xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3]))
xp_assert_equal(f(xp.asarray([1, 2]), "foo"), xp.asarray([2, 3]))

# static argument is not hashable, but serializable
xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3]))
xp_assert_equal(f(xp.asarray([1, 2]), ["foo"]), xp.asarray([2, 3]))

def test_wrapper(self, jnp: ModuleType):
@jax_autojit
def test_wrapper(self, autojit_func: AutoJitFunc, xp: ModuleType):
@autojit_func
def f(w: Wrapper[Array]) -> Wrapper[Array]:
return Wrapper(w.x + 1)

inp = Wrapper(jnp.asarray([1, 2]))
inp = Wrapper(xp.asarray([1, 2]))
out = f(inp).x
xp_assert_equal(out, jnp.asarray([2, 3]))
xp_assert_equal(out, xp.asarray([2, 3]))

def test_static_hashable(self, jnp: ModuleType):
def test_static_hashable(self, autojit_func: AutoJitFunc, xp: ModuleType):
"""Static argument/return value is hashable, but not serializable"""

class C:
def __reduce__(self) -> object: # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
raise Exception()

@jax_autojit
@autojit_func
def f(x: object) -> object:
return x

Expand All @@ -402,17 +412,20 @@ def f(x: object) -> object:
assert out is inp

# Serializable opaque input contains non-serializable object plus array
winp = Wrapper((C(), jnp.asarray([1, 2])))
winp = Wrapper((C(), xp.asarray([1, 2])))
out = f(winp)
assert isinstance(out, Wrapper)
assert out.x[0] is winp.x[0]
assert out.x[1] is not winp.x[1]
xp_assert_equal(out.x[1], winp.x[1])

def test_arraylikes_are_static(self):
def test_arraylikes_are_static(
self,
autojit_func: AutoJitFunc,
):
pytest.importorskip("jax")

@jax_autojit
@autojit_func
def f(x: list[int]) -> list[int]:
assert isinstance(x, list)
assert x == [1, 2]
Expand All @@ -422,15 +435,35 @@ def f(x: list[int]) -> list[int]:
assert isinstance(out, list)
assert out == [3, 4]

def test_iterators(self, jnp: ModuleType):
@jax_autojit
def test_iterators(self, autojit_func: AutoJitFunc, xp: ModuleType):
@autojit_func
def f(x: Array) -> Iterator[Array]:
return (x + i for i in range(2))

inp = jnp.asarray([1, 2])
inp = xp.asarray([1, 2])
out = f(inp)
assert isinstance(out, Iterator)
xp_assert_equal(next(out), jnp.asarray([1, 2]))
xp_assert_equal(next(out), jnp.asarray([2, 3]))
xp_assert_equal(next(out), xp.asarray([1, 2]))
xp_assert_equal(next(out), xp.asarray([2, 3]))
with pytest.raises(StopIteration):
_ = next(out)


class TestJAXAutoJit(CheckAutoJIT):
@pytest.fixture
def xp(self, jnp: ModuleType) -> ModuleType:
return jnp

@pytest.fixture
def autojit_func(self) -> AutoJitFunc:
return functools.partial(autojit, jit_library=JitLibrary.jax)


class TestTorchAutoJit(CheckAutoJIT):
@pytest.fixture
def xp(self, torch: ModuleType) -> ModuleType:
return torch

@pytest.fixture
def autojit_func(self) -> AutoJitFunc:
return functools.partial(autojit, jit_library=JitLibrary.torch)
Loading