Skip to content

Commit 1906fd9

Browse files
committed
ENH: testing.lazy_xp_function: torch.compile support
1 parent 05752f3 commit 1906fd9

4 files changed

Lines changed: 127 additions & 49 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,10 @@ dask-core = ">=2026.3.0" # No distributed, tornado, etc.
241241
minversion = "6.0"
242242
addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"]
243243
xfail_strict = true
244-
filterwarnings = ["error"]
244+
filterwarnings = [
245+
"error",
246+
'ignore:.*torch.jit.script_method.*',
247+
]
245248
log_cli_level = "INFO"
246249
testpaths = ["tests"]
247250
markers = [

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import io
67
import math
78
import pickle
89
import types
910
from collections.abc import Callable, Generator, Iterable, Iterator
11+
from enum import Enum, auto
1012
from functools import wraps
1113
from types import ModuleType
1214
from typing import (
@@ -47,12 +49,13 @@ def override(func):
4749

4850

4951
__all__ = [
52+
"JitLibrary",
5053
"asarrays",
54+
"autojit",
5155
"capabilities",
5256
"eager_shape",
5357
"in1d",
5458
"is_python_scalar",
55-
"jax_autojit",
5659
"mean",
5760
"meta_namespace",
5861
"pickle_flatten",
@@ -515,7 +518,7 @@ def persistent_load(self, pid: Literal[0, 1]) -> object: # numpydoc ignore=GL08
515518

516519
class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
517520
"""
518-
Helper of :func:`jax_autojit`.
521+
Helper of :func:`autojit`.
519522
520523
Wrap arbitrary inputs and outputs of the jitted function and
521524
convert them to/from PyTrees.
@@ -526,9 +529,9 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
526529
_registered: ClassVar[bool] = False
527530
__slots__: tuple[str, ...] = ("_is_iter", "_obj")
528531

529-
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
530-
self._register()
531-
if isinstance(obj, Iterator):
532+
def __init__(self, obj: T, jit_library: JitLibrary) -> None: # numpydoc ignore=GL08
533+
self._register(jit_library)
534+
if jit_library is JitLibrary.jax and isinstance(obj, Iterator):
532535
self._obj = list(obj)
533536
self._is_iter = True
534537
else:
@@ -541,24 +544,42 @@ def obj(self) -> T: # numpydoc ignore=RT01
541544
return iter(self._obj) if self._is_iter else self._obj
542545

543546
@classmethod
544-
def _register(cls) -> None: # numpydoc ignore=SS06
547+
def _register(cls, jit_library: JitLibrary) -> None: # numpydoc ignore=SS06,PR01
545548
"""
546549
Register upon first use instead of at import time, to avoid
547550
globally importing JAX.
548551
"""
549552
if not cls._registered:
550-
import jax
551-
552-
jax.tree_util.register_pytree_node(
553-
cls,
554-
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
555-
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
556-
)
553+
if jit_library is JitLibrary.jax:
554+
import jax
555+
556+
jax.tree_util.register_pytree_node(
557+
cls,
558+
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
559+
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
560+
)
561+
elif jit_library is JitLibrary.torch:
562+
import torch
563+
564+
torch.utils._pytree.register_pytree_node(
565+
cls,
566+
lambda instance: pickle_flatten(instance, torch.Tensor), # pyright: ignore[reportUnknownArgumentType]
567+
pickle_unflatten,
568+
)
557569
cls._registered = True
558570

559571

560-
def jax_autojit(
561-
func: Callable[P, T],
572+
class JitLibrary(Enum):
573+
"""
574+
Enum for JIT libraries compatible with `autojit`.
575+
"""
576+
577+
jax = auto()
578+
torch = auto()
579+
580+
581+
def autojit(
582+
func: Callable[P, T], jit_library: JitLibrary
562583
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
563584
"""
564585
Wrap `func` with ``jax.jit``, with the following differences:
@@ -601,19 +622,26 @@ def f(x: Array, y: float, plus: bool) -> Array:
601622
``j1``, but on the flip side it means that it will be re-traced for every different
602623
value of ``y``, which likely makes it not fit for purpose in production.
603624
"""
604-
import jax
625+
if jit_library is JitLibrary.jax:
626+
import jax
627+
628+
jit_decorator = jax.jit
629+
elif jit_library is JitLibrary.torch:
630+
import torch
631+
632+
jit_decorator = functools.partial(torch.compile, fullgraph=True)
605633

606-
@jax.jit # type: ignore[untyped-decorator] # pyright: ignore[reportUntypedFunctionDecorator]
634+
@jit_decorator # type: ignore[untyped-decorator] # pyright: ignore[reportUntypedFunctionDecorator]
607635
def inner( # numpydoc ignore=GL08
608636
wargs: _AutoJITWrapper[Any],
609637
) -> _AutoJITWrapper[T]:
610638
args, kwargs = wargs.obj
611639
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
612-
return _AutoJITWrapper(res)
640+
return _AutoJITWrapper(res, jit_library)
613641

614642
@wraps(func)
615643
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
616-
wargs = _AutoJITWrapper((args, kwargs))
644+
wargs = _AutoJITWrapper((args, kwargs), jit_library)
617645
return inner(wargs).obj
618646

619647
return outer

src/array_api_extra/testing.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from types import FunctionType, ModuleType
1616
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
1717

18-
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
19-
from ._lib._utils._helpers import jax_autojit, pickle_flatten, pickle_unflatten
18+
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace, is_torch_namespace
19+
from ._lib._utils._helpers import JitLibrary, autojit, pickle_flatten, pickle_unflatten
2020

2121
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
2222

@@ -67,6 +67,7 @@ def lazy_xp_function(
6767
*,
6868
allow_dask_compute: bool | int = False,
6969
jax_jit: bool = True,
70+
torch_compile: bool = True,
7071
static_argnums: Deprecated = DEPRECATED,
7172
static_argnames: Deprecated = DEPRECATED,
7273
) -> None: # numpydoc ignore=GL07
@@ -222,6 +223,7 @@ def test_myfunc(xp):
222223
tags: dict[str, bool | int | type] = {
223224
"allow_dask_compute": allow_dask_compute,
224225
"jax_jit": jax_jit,
226+
"torch_compile": torch_compile,
225227
}
226228

227229
if isinstance(func, tuple):
@@ -419,7 +421,19 @@ def iter_tagged() -> Iterator[
419421
elif is_jax_namespace(xp):
420422
for target, name, attr, func, tags in iter_tagged():
421423
if tags["jax_jit"]:
422-
wrapped = jax_autojit(func)
424+
wrapped = autojit(func, JitLibrary.jax)
425+
# If we're dealing with a staticmethod or classmethod, make
426+
# sure things stay that way.
427+
if isinstance(attr, staticmethod):
428+
wrapped = staticmethod(wrapped)
429+
elif isinstance(attr, classmethod):
430+
wrapped = classmethod(wrapped)
431+
temp_setattr(target, name, wrapped)
432+
433+
elif is_torch_namespace(xp):
434+
for target, name, attr, func, tags in iter_tagged():
435+
if tags["torch_compile"]:
436+
wrapped = autojit(func, JitLibrary.torch)
423437
# If we're dealing with a staticmethod or classmethod, make
424438
# sure things stay that way.
425439
if isinstance(attr, staticmethod):

tests/test_helpers.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import Iterator
1+
import functools
2+
from collections.abc import Callable, Iterator
23
from types import ModuleType
3-
from typing import TYPE_CHECKING, Generic, TypeVar, cast
4+
from typing import TYPE_CHECKING, Generic, ParamSpec, Protocol, TypeVar, cast
45

56
import numpy as np
67
import pytest
@@ -10,11 +11,12 @@
1011
from array_api_extra._lib._utils._compat import array_namespace
1112
from array_api_extra._lib._utils._compat import device as get_device
1213
from array_api_extra._lib._utils._helpers import (
14+
JitLibrary,
1315
asarrays,
16+
autojit,
1417
capabilities,
1518
eager_shape,
1619
in1d,
17-
jax_autojit,
1820
meta_namespace,
1921
ndindex,
2022
pickle_flatten,
@@ -34,6 +36,7 @@ def override(func):
3436
return func
3537

3638

39+
P = ParamSpec("P")
3740
T = TypeVar("T")
3841

3942
# FIXME calls xp.unique_values without size
@@ -359,41 +362,48 @@ def test_recursion(self):
359362
assert obj2[1] is obj2
360363

361364

362-
class TestJAXAutoJIT:
363-
def test_basic(self, jnp: ModuleType):
364-
@jax_autojit
365+
class AutoJitFunc(Protocol):
366+
def __call__(
367+
self,
368+
func: Callable[P, T],
369+
) -> Callable[P, T]: ...
370+
371+
372+
class CheckAutoJIT:
373+
def test_basic(self, autojit_func: AutoJitFunc, xp: ModuleType):
374+
@autojit_func
365375
def f(x: Array, k: object = False) -> Array:
366376
return x + 1 if k else x - 1
367377

368378
# Basic recognition of static_argnames
369-
xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1]))
370-
xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1]))
371-
xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3]))
372-
xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3]))
379+
xp_assert_equal(f(xp.asarray([1, 2])), xp.asarray([0, 1]))
380+
xp_assert_equal(f(xp.asarray([1, 2]), False), xp.asarray([0, 1]))
381+
xp_assert_equal(f(xp.asarray([1, 2]), True), xp.asarray([2, 3]))
382+
xp_assert_equal(f(xp.asarray([1, 2]), 1), xp.asarray([2, 3]))
373383

374384
# static argument is not an ArrayLike
375-
xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3]))
385+
xp_assert_equal(f(xp.asarray([1, 2]), "foo"), xp.asarray([2, 3]))
376386

377387
# static argument is not hashable, but serializable
378-
xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3]))
388+
xp_assert_equal(f(xp.asarray([1, 2]), ["foo"]), xp.asarray([2, 3]))
379389

380-
def test_wrapper(self, jnp: ModuleType):
381-
@jax_autojit
390+
def test_wrapper(self, autojit_func: AutoJitFunc, xp: ModuleType):
391+
@autojit_func
382392
def f(w: Wrapper[Array]) -> Wrapper[Array]:
383393
return Wrapper(w.x + 1)
384394

385-
inp = Wrapper(jnp.asarray([1, 2]))
395+
inp = Wrapper(xp.asarray([1, 2]))
386396
out = f(inp).x
387-
xp_assert_equal(out, jnp.asarray([2, 3]))
397+
xp_assert_equal(out, xp.asarray([2, 3]))
388398

389-
def test_static_hashable(self, jnp: ModuleType):
399+
def test_static_hashable(self, autojit_func: AutoJitFunc, xp: ModuleType):
390400
"""Static argument/return value is hashable, but not serializable"""
391401

392402
class C:
393403
def __reduce__(self) -> object: # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
394404
raise Exception()
395405

396-
@jax_autojit
406+
@autojit_func
397407
def f(x: object) -> object:
398408
return x
399409

@@ -402,17 +412,20 @@ def f(x: object) -> object:
402412
assert out is inp
403413

404414
# Serializable opaque input contains non-serializable object plus array
405-
winp = Wrapper((C(), jnp.asarray([1, 2])))
415+
winp = Wrapper((C(), xp.asarray([1, 2])))
406416
out = f(winp)
407417
assert isinstance(out, Wrapper)
408418
assert out.x[0] is winp.x[0]
409419
assert out.x[1] is not winp.x[1]
410420
xp_assert_equal(out.x[1], winp.x[1])
411421

412-
def test_arraylikes_are_static(self):
422+
def test_arraylikes_are_static(
423+
self,
424+
autojit_func: AutoJitFunc,
425+
):
413426
pytest.importorskip("jax")
414427

415-
@jax_autojit
428+
@autojit_func
416429
def f(x: list[int]) -> list[int]:
417430
assert isinstance(x, list)
418431
assert x == [1, 2]
@@ -422,15 +435,35 @@ def f(x: list[int]) -> list[int]:
422435
assert isinstance(out, list)
423436
assert out == [3, 4]
424437

425-
def test_iterators(self, jnp: ModuleType):
426-
@jax_autojit
438+
def test_iterators(self, autojit_func: AutoJitFunc, xp: ModuleType):
439+
@autojit_func
427440
def f(x: Array) -> Iterator[Array]:
428441
return (x + i for i in range(2))
429442

430-
inp = jnp.asarray([1, 2])
443+
inp = xp.asarray([1, 2])
431444
out = f(inp)
432445
assert isinstance(out, Iterator)
433-
xp_assert_equal(next(out), jnp.asarray([1, 2]))
434-
xp_assert_equal(next(out), jnp.asarray([2, 3]))
446+
xp_assert_equal(next(out), xp.asarray([1, 2]))
447+
xp_assert_equal(next(out), xp.asarray([2, 3]))
435448
with pytest.raises(StopIteration):
436449
_ = next(out)
450+
451+
452+
class TestJAXAutoJit(CheckAutoJIT):
453+
@pytest.fixture
454+
def xp(self, jnp: ModuleType) -> ModuleType:
455+
return jnp
456+
457+
@pytest.fixture
458+
def autojit_func(self) -> AutoJitFunc:
459+
return functools.partial(autojit, jit_library=JitLibrary.jax)
460+
461+
462+
class TestTorchAutoJit(CheckAutoJIT):
463+
@pytest.fixture
464+
def xp(self, torch: ModuleType) -> ModuleType:
465+
return torch
466+
467+
@pytest.fixture
468+
def autojit_func(self) -> AutoJitFunc:
469+
return functools.partial(autojit, jit_library=JitLibrary.torch)

0 commit comments

Comments
 (0)