diff --git a/decent_array/__init__.py b/decent_array/__init__.py index 9893247..e9bd6b4 100644 --- a/decent_array/__init__.py +++ b/decent_array/__init__.py @@ -1,8 +1,73 @@ from decent_array import interoperability, types from decent_array._array import Array +from decent_array._constants import e, inf, nan, pi +from decent_array.types._dtypes import ( + bfloat16, + bool_, + bytes_, + complex64, + complex128, + complex256, + float16, + float32, + float64, + float128, + int8, + int16, + int32, + int64, + object_, + qint8, + qint16, + qint32, + quint8, + quint16, + uint8, + uint16, + uint32, + uint64, + unicode_, + void, +) + +__all_docs__ = [ + "Array", + "interoperability", + "types", +] __all__ = [ "Array", + "bfloat16", + "bool_", + "bytes_", + "complex64", + "complex128", + "complex256", + "e", + "float16", + "float32", + "float64", + "float128", + "inf", + "int8", + "int16", + "int32", + "int64", "interoperability", + "nan", + "object_", + "pi", + "qint8", + "qint16", + "qint32", + "quint8", + "quint16", "types", + "uint8", + "uint16", + "uint32", + "uint64", + "unicode_", + "void", ] diff --git a/decent_array/_array.py b/decent_array/_array.py index 4a0224a..cc73f1c 100644 --- a/decent_array/_array.py +++ b/decent_array/_array.py @@ -29,13 +29,13 @@ from typing import TYPE_CHECKING, Any, Self from decent_array.interoperability._backend_manager import register_backend_listener -from decent_array.types import _STRING_TO_DTYPE +from decent_array.types._dtypes import _BACKEND_DTYPE_TO_DTYPE if TYPE_CHECKING: from numpy.typing import NDArray from decent_array.interoperability._abstracts import Backend - from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices + from decent_array.types import ArrayKey, ArrayTypes, Devices, dtype _BACKEND_INSTANCE: Backend | None = None @@ -49,7 +49,7 @@ def _update_backend(backend: Backend | None) -> None: register_backend_listener(_update_backend) -class Array: # noqa: PLR0904 +class Array: """ Wrapper around a single backend-native array. @@ -60,7 +60,7 @@ class Array: # noqa: PLR0904 __slots__ = ("_backend", "value") - def __init__(self, value: SupportedArrayTypes) -> None: + def __init__(self, value: ArrayTypes) -> None: """ Wrap ``value`` in an :class:`Array`. @@ -401,21 +401,12 @@ def ndim(self) -> int: return self._backend.ndim(self) @property - def dtype(self) -> DTypes: - """ - Return dtype of the Array as item of DTypes enum. - - Raises: - ValueError: for dtypes that are not supported by all decent-array functions - - """ - # get framework-native dtype as string - # split takes care of types with names like "torch.float32" - dtype_name = str(self.value.dtype).split(".")[-1] + def dtype(self) -> dtype: + """Return dtype of the Array.""" + dtype = _BACKEND_DTYPE_TO_DTYPE.get(self.value.dtype) - dtype = _STRING_TO_DTYPE.get(dtype_name) if dtype is None: - raise ValueError(f"dtype {self.value.dtype} is not supported by all decent-array functions.") + raise ValueError(f"dtype {self.value.dtype} is not supported.") return dtype @@ -445,7 +436,7 @@ def all(self) -> bool: return self._backend.all(self) @property - def device(self) -> SupportedDevices: + def device(self) -> Devices: """Return the device of the array.""" return self._backend.device_of(self) diff --git a/decent_array/_constants.py b/decent_array/_constants.py new file mode 100644 index 0000000..a05e32b --- /dev/null +++ b/decent_array/_constants.py @@ -0,0 +1,10 @@ +"""Numerical constants.""" + +import math + +_CONSTANTS = ["e", "inf", "nan", "pi"] + +e = math.e +inf = math.inf +nan = math.nan +pi = math.pi diff --git a/decent_array/_utils.py b/decent_array/_utils.py new file mode 100644 index 0000000..559e3d5 --- /dev/null +++ b/decent_array/_utils.py @@ -0,0 +1,14 @@ +from typing import Any + +from decent_array._array import Array + + +def unwrap(x: Any) -> Any: # noqa: ANN401 + """ + Return the underlying value of an :class:`Array`, or pass ``x`` through. + + Typed as ``Any`` because operator dunders may pass either an :class:`Array` or a + Python scalar; the strict abstract signature would force a ``cast`` at every call + site without runtime benefit. + """ + return x.value if type(x) is Array else x diff --git a/decent_array/interoperability/__init__.py b/decent_array/interoperability/__init__.py index 8465bca..2fcdd50 100644 --- a/decent_array/interoperability/__init__.py +++ b/decent_array/interoperability/__init__.py @@ -21,7 +21,7 @@ bitwise_right_shift, bitwise_xor, ) -from ._iop.comparasion import equal, greater, greater_equal, less, less_equal, not_equal +from ._iop.comparison import equal, greater, greater_equal, less, less_equal, not_equal from ._iop.creation import eye, ones, ones_like, zeros, zeros_like from ._iop.linalg import dot, matmul, norm, vecdot, vector_norm from ._iop.manipulations import ( diff --git a/decent_array/interoperability/_abstracts/backend.py b/decent_array/interoperability/_abstracts/backend.py index 0d3789b..1912a02 100644 --- a/decent_array/interoperability/_abstracts/backend.py +++ b/decent_array/interoperability/_abstracts/backend.py @@ -18,25 +18,27 @@ from collections.abc import Sequence from typing import TYPE_CHECKING, Any -from decent_array.types import SupportedDevices +from decent_array.types._types import Devices if TYPE_CHECKING: from numpy.typing import NDArray from decent_array import Array - from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes + from decent_array.types import ArrayKey, ArrayTypes + from decent_array.types._dtypes import dtype -class Backend(ABC): # noqa: PLR0904 +class Backend(ABC): """ Abstract base class for a backend. - Concrete backends are bound to a single :class:`SupportedDevices` at construction + Concrete backends are bound to a single :class:`Devices` at construction time; that device is the default for all new arrays produced by this backend. """ - def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: - self.device: SupportedDevices = device + def __init__(self, device: Devices = Devices.CPU, name: str = "") -> None: + self.device: Devices = device + self.name: str = name # Array creation ------------------------------------------------------ @@ -61,12 +63,12 @@ def eye(self, n: int) -> Array: """Create an ``n x n`` identity matrix.""" @abstractmethod - def device_to_native(self, device: SupportedDevices) -> Any: # noqa: ANN401 - """Convert :class:`SupportedDevices` to the backend's native device representation.""" + def device_to_native(self, device: Devices) -> Any: # noqa: ANN401 + """Convert :class:`Devices` to the backend's native device representation.""" @abstractmethod - def device_of(self, x: Array) -> SupportedDevices: - """Return the :class:`SupportedDevices` of the given array.""" + def device_of(self, x: Array) -> Devices: + """Return the :class:`Devices` of the given array.""" # Array manipulation -------------------------------------------------- @@ -75,7 +77,7 @@ def copy(self, x: Array) -> Array: """Return a copy of ``x``.""" @abstractmethod - def to_numpy(self, x: SupportedArrayTypes | Array) -> NDArray[Any]: + def to_numpy(self, x: ArrayTypes | Array) -> NDArray[Any]: """Convert ``x`` to a NumPy array on the CPU.""" @abstractmethod @@ -135,7 +137,7 @@ def diagonal(self, x: Array, offset: int = 0) -> Array: """Extract the diagonal entries from a 2-D matrix at the given ``offset``.""" @abstractmethod - def astype(self, x: Array, dtype: DTypes) -> Array: + def astype(self, x: Array, dtype: dtype) -> Array: """Cast ``x`` to a different dtype.""" # Linalg -------------------------------------------------------------- @@ -393,3 +395,157 @@ def uniform_like(self, x: Array, low: float = 0.0, high: float = 1.0) -> Array: @abstractmethod def choice(self, x: Array, size: int, replace: bool = True) -> Array: """Sample ``size`` elements from ``x``.""" + + # DTYPES -------------------------------------------------------------- + + @property + @abstractmethod + def bool_(self) -> Any: # noqa: ANN401 + """Bool dtype.""" + + @property + @abstractmethod + def uint8(self) -> Any: # noqa: ANN401 + """Unsigned 8-bit integer dtype.""" + + @property + def uint16(self) -> Any | None: # noqa: ANN401 + """Unsigned 16-bit integer dtype (optional).""" + return None + + @property + def uint32(self) -> Any | None: # noqa: ANN401 + """Unsigned 32-bit integer dtype (optional).""" + return None + + @property + def uint64(self) -> Any | None: # noqa: ANN401 + """Unsigned 64-bit integer dtype (optional).""" + return None + + @property + @abstractmethod + def int8(self) -> Any: # noqa: ANN401 + """Signed 8-bit integer dtype.""" + + @property + @abstractmethod + def int16(self) -> Any: # noqa: ANN401 + """Signed 16-bit integer dtype.""" + + @property + @abstractmethod + def int32(self) -> Any: # noqa: ANN401 + """Signed 32-bit integer dtype.""" + + @property + def int64(self) -> Any: # noqa: ANN401 + """Signed 64-bit integer dtype.""" + return None + + @property + @abstractmethod + def float16(self) -> Any: # noqa: ANN401 + """16-bit floating-point dtype.""" + + @property + @abstractmethod + def float32(self) -> Any: # noqa: ANN401 + """32-bit floating-point dtype.""" + + @property + def float64(self) -> Any: # noqa: ANN401 + """64-bit floating-point dtype.""" + return None + + @property + def complex64(self) -> Any: # noqa: ANN401 + """64-bit complex dtype.""" + return None + + @property + def complex128(self) -> Any: # noqa: ANN401 + """128-bit complex dtype.""" + return None + + @property + def float128(self) -> Any | None: # noqa: ANN401 + """128-bit floating-point dtype (optional).""" + return None + + @property + def complex256(self) -> Any | None: # noqa: ANN401 + """256-bit complex dtype (optional).""" + return None + + @property + def qint8(self) -> Any | None: # noqa: ANN401 + """Quantized 8-bit integer dtype (optional).""" + return None + + @property + def qint16(self) -> Any | None: # noqa: ANN401 + """Quantized 16-bit integer dtype (optional).""" + return None + + @property + def qint32(self) -> Any | None: # noqa: ANN401 + """Quantized 32-bit integer dtype (optional).""" + return None + + @property + def quint8(self) -> Any | None: # noqa: ANN401 + """Quantized unsigned 8-bit integer dtype (optional).""" + return None + + @property + def quint16(self) -> Any | None: # noqa: ANN401 + """Quantized unsigned 16-bit integer dtype (optional).""" + return None + + @property + def bfloat16(self) -> Any | None: # noqa: ANN401 + """Brain floating point 16-bit dtype (optional).""" + return None + + @property + def unicode_(self) -> Any | None: # noqa: ANN401 + """Unicode string dtype (optional).""" + return None + + @property + def bytes_(self) -> Any | None: # noqa: ANN401 + """Byte string dtype (optional).""" + return None + + @property + def object_(self) -> Any | None: # noqa: ANN401 + """Python object dtype (optional).""" + return None + + @property + def void(self) -> Any | None: # noqa: ANN401 + """Raw/void dtype (optional).""" + return None + + # CONSTANTS ----------------------------------------------------------- + + @property + @abstractmethod + def e(self) -> Any: # noqa: ANN401 + """e = 2.71828...""" # noqa: D403 + + @property + @abstractmethod + def inf(self) -> Any: # noqa: ANN401 + """Infinity.""" + + @property + @abstractmethod + def nan(self) -> Any: # noqa: ANN401 + """Not-a-number.""" + + @property + @abstractmethod + def pi(self) -> Any: # noqa: ANN401 + """pi = 3.14159...""" # noqa: D403 diff --git a/decent_array/interoperability/_backend_manager.py b/decent_array/interoperability/_backend_manager.py index a359a67..c897fea 100644 --- a/decent_array/interoperability/_backend_manager.py +++ b/decent_array/interoperability/_backend_manager.py @@ -4,13 +4,15 @@ from collections.abc import Callable from contextvars import ContextVar -from decent_array.types import SupportedDevices, SupportedFrameworks +import decent_array._constants as constants +import decent_array.types._dtypes as dtypes +from decent_array.types._types import Devices, Frameworks from ._abstracts import Backend -_BACKEND_REGISTRY: dict[SupportedFrameworks, type[Backend]] = {} -_BACKEND_INSTANCES: dict[SupportedFrameworks, Backend] = {} -_ACTIVE_BACKEND: ContextVar[SupportedFrameworks | None] = ContextVar( +_BACKEND_REGISTRY: dict[Frameworks, type[Backend]] = {} +_BACKEND_INSTANCES: dict[Frameworks, Backend] = {} +_ACTIVE_BACKEND: ContextVar[Frameworks | None] = ContextVar( "decent_array.interoperability.active_backend", default=None ) _BACKEND_LISTENERS: list[Callable[[Backend | None], None]] = [] @@ -18,8 +20,8 @@ def set_backend( - backend: SupportedFrameworks | str, - device: SupportedDevices | str = SupportedDevices.CPU, + backend: Frameworks | str, + device: Devices | str = Devices.CPU, ) -> None: """ Set the active backend (and target device) for the current execution context. @@ -33,9 +35,9 @@ def set_backend( Backend modules are auto-imported on demand. Args: - backend: A :class:`~decent_array.types.SupportedFrameworks` value, its canonical string (e.g. + backend: A :class:`~decent_array.types.Frameworks` value, its canonical string (e.g. ``"numpy"``, ``"pytorch"``). - device: Target accelerator. Accepts a :class:`~decent_array.types.SupportedDevices` value or its + device: Target accelerator. Accepts a :class:`~decent_array.types.Devices` value or its string equivalent (``"cpu"``, ``"gpu"``, ``"mps"``). Defaults to CPU. The backend's array-creation methods produce arrays on this device by default. @@ -44,9 +46,9 @@ def set_backend( is already active in this context. ImportError: If the backend module cannot be imported (e.g. due to a missing optional dependency). - """ # noqa: DOC502 + """ requested = _normalize(backend) - requested_device = device if isinstance(device, SupportedDevices) else SupportedDevices(device) + requested_device = device if isinstance(device, Devices) else Devices(device) current = _ACTIVE_BACKEND.get() if current is not None and current != requested: @@ -69,6 +71,9 @@ def set_backend( for listener in _BACKEND_LISTENERS: listener(_BACKEND_INSTANCE) + _bind_dtypes(_BACKEND_INSTANCE) + _bind_constants(_BACKEND_INSTANCE) + def register_backend_listener(listener: Callable[[Backend | None], None]) -> None: """ @@ -87,11 +92,11 @@ def register_backend_listener(listener: Callable[[Backend | None], None]) -> Non def register_backend( - backend: SupportedFrameworks, + backend: Frameworks, cls: type[Backend], ) -> None: """ - Register a backend class under a :class:`SupportedFrameworks` value. + Register a backend class under a :class:`Frameworks` value. Called once per backend module *after* the class definition. Backends are instantiated lazily on first use. Re-registering replaces the @@ -127,7 +132,7 @@ def reset_backends() -> None: listener(None) -def default_device() -> SupportedDevices: +def default_device() -> Devices: """ Return the default device for the active backend. @@ -143,24 +148,24 @@ def default_device() -> SupportedDevices: return backend.device -def _normalize(backend: SupportedFrameworks | str) -> SupportedFrameworks: +def _normalize(backend: Frameworks | str) -> Frameworks: """ - Convert a backend identifier to its canonical :class:`SupportedFrameworks` value. + Convert a backend identifier to its canonical :class:`Frameworks` value. Raises: KeyError: If the input is not a valid backend identifier. """ - if isinstance(backend, SupportedFrameworks): + if isinstance(backend, Frameworks): return backend try: - return SupportedFrameworks(backend) + return Frameworks(backend) except ValueError as exc: - valid = ", ".join(f.value for f in SupportedFrameworks) + valid = ", ".join(f.value for f in Frameworks) raise KeyError(f"Unknown backend '{backend}'. Valid backends: {valid}.") from exc -def _instantiate(backend: SupportedFrameworks, device: SupportedDevices) -> Backend: +def _instantiate(backend: Frameworks, device: Devices) -> Backend: """ Get or create a backend instance for the given backend and device. @@ -185,7 +190,7 @@ def _instantiate(backend: SupportedFrameworks, device: SupportedDevices) -> Back return instance -def _auto_import(backend: SupportedFrameworks) -> None: +def _auto_import(backend: Frameworks) -> None: """ Import the backend's package so its registration side-effect runs. @@ -202,3 +207,25 @@ def _auto_import(backend: SupportedFrameworks) -> None: f"Failed to import the backend module for '{backend.value}'. Ensure the " "corresponding backend package is installed and importable." ) from exc + + +def _bind_dtypes(backend: Backend | None) -> None: + """Bind dtype objects to the corresponding backend dtypes (if available).""" + if backend is None: + return + for name in dtypes._SUPPORTED: # noqa: SLF001 + dt = getattr(dtypes, name) + backend_dt = getattr(backend, name, None) + dt._available = backend_dt is not None # noqa: SLF001 + dt._backend_dtype = backend_dt # noqa: SLF001 + + +def _bind_constants(backend: Backend | None) -> None: + """Bind constants to the corresponding backend constants.""" + if backend is None: + return + for name in constants._CONSTANTS: # noqa: SLF001 + backend_c = getattr(backend, name, None) + if backend_c is None: + return + setattr(constants, name, backend_c) diff --git a/decent_array/interoperability/_iop/comparasion.py b/decent_array/interoperability/_iop/comparison.py similarity index 100% rename from decent_array/interoperability/_iop/comparasion.py rename to decent_array/interoperability/_iop/comparison.py diff --git a/decent_array/interoperability/_iop/manipulations.py b/decent_array/interoperability/_iop/manipulations.py index c2450e2..629a2f8 100644 --- a/decent_array/interoperability/_iop/manipulations.py +++ b/decent_array/interoperability/_iop/manipulations.py @@ -23,7 +23,8 @@ from decent_array import Array from decent_array.interoperability._abstracts import Backend - from decent_array.types import DTypes, SupportedArrayTypes + from decent_array.types import ArrayTypes + from decent_array.types._dtypes import dtype _BACKEND_INSTANCE: Backend | None = None _error = RuntimeError("No backend active: call 'set_backend' with a supported framework to activate one.") @@ -44,7 +45,7 @@ def copy(x: Array) -> Array: return _BACKEND_INSTANCE.copy(x) -def to_numpy(x: SupportedArrayTypes | Array) -> NDArray[Any]: +def to_numpy(x: ArrayTypes | Array) -> NDArray[Any]: """Convert ``x`` to a NumPy array on CPU.""" if _BACKEND_INSTANCE is None: raise _error @@ -156,7 +157,7 @@ def diagonal(x: Array, offset: int = 0) -> Array: return _BACKEND_INSTANCE.diagonal(x, offset) -def astype(x: Array, dtype: DTypes) -> Array: +def astype(x: Array, dtype: dtype) -> Array: """Cast ``x`` to a different dtype.""" if _BACKEND_INSTANCE is None: raise _error diff --git a/decent_array/interoperability/_iop/rng.py b/decent_array/interoperability/_iop/rng.py index 6968ddb..3df90ca 100644 --- a/decent_array/interoperability/_iop/rng.py +++ b/decent_array/interoperability/_iop/rng.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, cast from decent_array.interoperability._backend_manager import _instantiate, register_backend_listener -from decent_array.types import SupportedDevices, SupportedFrameworks +from decent_array.types import Devices, Frameworks if TYPE_CHECKING: import numpy @@ -118,7 +118,7 @@ def set_rng_state(self, state: dict[str, Any]) -> None: active.set_rng_state(state) def numpy_backend(self) -> NumpyBackend: - return cast("NumpyBackend", _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU)) + return cast("NumpyBackend", _instantiate(Frameworks.NUMPY, Devices.CPU)) _COORDINATOR = _RngCoordinator() diff --git a/decent_array/interoperability/_iop/utils.py b/decent_array/interoperability/_iop/utils.py index e1b737a..ab69d22 100644 --- a/decent_array/interoperability/_iop/utils.py +++ b/decent_array/interoperability/_iop/utils.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from decent_array import Array from decent_array.interoperability._abstracts import Backend - from decent_array.types import ArrayKey, SupportedDevices + from decent_array.types import ArrayKey, Devices _BACKEND_INSTANCE: Backend | None = None _error = RuntimeError("No backend active: call 'set_backend' with a supported framework to activate one.") @@ -34,15 +34,15 @@ def _update_backend(backend: Backend | None) -> None: register_backend_listener(_update_backend) -def device_to_native(device: SupportedDevices) -> Any: # noqa: ANN401 - """Convert :class:`~decent_array.types.SupportedDevices` to the active backend's native device.""" +def device_to_native(device: Devices) -> Any: # noqa: ANN401 + """Convert :class:`~decent_array.types.Devices` to the active backend's native device.""" if _BACKEND_INSTANCE is None: raise _error return _BACKEND_INSTANCE.device_to_native(device) -def device_of(x: Array) -> SupportedDevices: - """Return the :class:`~decent_array.types.SupportedDevices` of ``x``.""" +def device_of(x: Array) -> Devices: + """Return the :class:`~decent_array.types.Devices` of ``x``.""" if _BACKEND_INSTANCE is None: raise _error return _BACKEND_INSTANCE.device_of(x) diff --git a/decent_array/interoperability/_jax/jax_backend.py b/decent_array/interoperability/_jax/jax_backend.py index 7173b92..744979c 100644 --- a/decent_array/interoperability/_jax/jax_backend.py +++ b/decent_array/interoperability/_jax/jax_backend.py @@ -10,7 +10,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence from time import time_ns from typing import Any, cast @@ -20,38 +20,18 @@ from numpy.typing import NDArray from decent_array import Array +from decent_array._utils import unwrap from decent_array.interoperability._abstracts import Backend from decent_array.interoperability._backend_manager import register_backend -from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices, SupportedFrameworks +from decent_array.types import ArrayKey, ArrayTypes, Devices, Frameworks +from decent_array.types._dtypes import _ALL_DTYPES, dtype -def _unwrap(x: Any) -> Any: # noqa: ANN401 - """Return the underlying value of an :class:`Array`, or pass ``x`` through.""" - return x.value if type(x) is Array else x - - -_DTYPE_MAP = { - DTypes.BOOL: jnp.bool_, - DTypes.UINT8: jnp.uint8, - DTypes.UINT16: jnp.uint16, - DTypes.UINT32: jnp.uint32, - DTypes.UINT64: jnp.uint64, - DTypes.INT8: jnp.int8, - DTypes.INT16: jnp.int16, - DTypes.INT32: jnp.int32, - DTypes.INT64: jnp.int64, - DTypes.FLOAT32: jnp.float32, - DTypes.FLOAT64: jnp.float64, - DTypes.COMPLEX64: jnp.complex64, - DTypes.COMPLEX128: jnp.complex128, -} - - -class JaxBackend(Backend): # noqa: PLR0904 +class JaxBackend(Backend): """JAX implementation of :class:`Backend`.""" - def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: - super().__init__(device) + def __init__(self, device: Devices = Devices.CPU) -> None: + super().__init__(device, name=Frameworks.JAX.value) self._native_device: jax.Device = self.device_to_native(device) self._key: jax.Array = jax.random.key(time_ns()) @@ -72,19 +52,19 @@ def ones_like(self, x: Array) -> Array: def eye(self, n: int) -> Array: return Array(jnp.eye(n, device=self._native_device)) - def device_to_native(self, device: SupportedDevices) -> jax.Device: - if device == SupportedDevices.CPU: + def device_to_native(self, device: Devices) -> jax.Device: + if device == Devices.CPU: return jax.devices("cpu")[0] - if device == SupportedDevices.GPU: + if device == Devices.GPU: return jax.devices("gpu")[0] raise ValueError(f"Unsupported device for JAX: {device}") - def device_of(self, x: Array) -> SupportedDevices: + def device_of(self, x: Array) -> Devices: platform = x.value.device.platform if platform == "gpu": - return SupportedDevices.GPU + return Devices.GPU if platform == "cpu": - return SupportedDevices.CPU + return Devices.CPU raise TypeError(f"Unsupported JAX platform: {platform}") # Array manipulation @@ -92,7 +72,7 @@ def device_of(self, x: Array) -> SupportedDevices: def copy(self, x: Array) -> Array: return Array(jnp.array(x.value, copy=True)) - def to_numpy(self, x: SupportedArrayTypes | Array) -> NDArray[Any]: + def to_numpy(self, x: ArrayTypes | Array) -> NDArray[Any]: return np.array(x.value if type(x) is Array else x) def from_numpy(self, x: NDArray[Any]) -> Array: @@ -147,10 +127,10 @@ def diagonal(self, x: Array, offset: int = 0) -> Array: raise ValueError(f"diagonal requires a 2-D array, got {x.value.ndim}-D") return Array(jnp.diagonal(x.value, offset=offset)) - def astype(self, x: Array, dtype: DTypes) -> Array: - if dtype not in _DTYPE_MAP: - raise ValueError(f"Unsupported dtype '{dtype.value}' for NumPy backend.") - return Array(jnp.asarray(x.value, dtype=_DTYPE_MAP[dtype])) + def astype(self, x: Array, dtype: dtype) -> Array: + if dtype not in _ALL_DTYPES.values(): + raise ValueError(f"Unsupported dtype '{dtype}' for JAX backend.") + return Array(jnp.asarray(x.value, dtype=dtype.backend_dtype)) # Linalg @@ -198,52 +178,52 @@ def all(self, x: Array, axis: int | tuple[int, ...] | None = None, keepdims: boo # covers both because PEP 484's numeric tower implicitly admits ``int``. def add(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.add(_unwrap(x1), _unwrap(x2))) + return Array(jnp.add(unwrap(x1), unwrap(x2))) def iadd[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = jnp.add(x1.value, _unwrap(x2)) + x1.value = jnp.add(x1.value, unwrap(x2)) return x1 def subtract(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.subtract(_unwrap(x1), _unwrap(x2))) + return Array(jnp.subtract(unwrap(x1), unwrap(x2))) def isubtract[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = jnp.subtract(x1.value, _unwrap(x2)) + x1.value = jnp.subtract(x1.value, unwrap(x2)) return x1 def multiply(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.multiply(_unwrap(x1), _unwrap(x2))) + return Array(jnp.multiply(unwrap(x1), unwrap(x2))) def imultiply[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = jnp.multiply(x1.value, _unwrap(x2)) + x1.value = jnp.multiply(x1.value, unwrap(x2)) return x1 def divide(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.divide(_unwrap(x1), _unwrap(x2))) + return Array(jnp.divide(unwrap(x1), unwrap(x2))) def idivide[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = jnp.divide(x1.value, _unwrap(x2)) + x1.value = jnp.divide(x1.value, unwrap(x2)) return x1 def floor_divide(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(jnp.floor_divide(_unwrap(x1), _unwrap(x2))) + return Array(jnp.floor_divide(unwrap(x1), unwrap(x2))) def ifloordiv[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value = jnp.floor_divide(x1.value, _unwrap(x2)) + x1.value = jnp.floor_divide(x1.value, unwrap(x2)) return x1 def remainder(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(jnp.remainder(_unwrap(x1), _unwrap(x2))) + return Array(jnp.remainder(unwrap(x1), unwrap(x2))) def imod[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value = jnp.remainder(x1.value, _unwrap(x2)) + x1.value = jnp.remainder(x1.value, unwrap(x2)) return x1 def pow(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.power(_unwrap(x1), _unwrap(x2))) + return Array(jnp.power(unwrap(x1), unwrap(x2))) def ipow[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = jnp.power(x1.value, _unwrap(x2)) + x1.value = jnp.power(x1.value, unwrap(x2)) return x1 def negative(self, x: Array) -> Array: @@ -258,61 +238,61 @@ def sqrt(self, x: Array) -> Array: # Comparisons def equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.equal(_unwrap(x1), _unwrap(x2))) + return Array(jnp.equal(unwrap(x1), unwrap(x2))) def not_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.not_equal(_unwrap(x1), _unwrap(x2))) + return Array(jnp.not_equal(unwrap(x1), unwrap(x2))) def less(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.less(_unwrap(x1), _unwrap(x2))) + return Array(jnp.less(unwrap(x1), unwrap(x2))) def less_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.less_equal(_unwrap(x1), _unwrap(x2))) + return Array(jnp.less_equal(unwrap(x1), unwrap(x2))) def greater(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.greater(_unwrap(x1), _unwrap(x2))) + return Array(jnp.greater(unwrap(x1), unwrap(x2))) def greater_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.greater_equal(_unwrap(x1), _unwrap(x2))) + return Array(jnp.greater_equal(unwrap(x1), unwrap(x2))) # Bitwise def bitwise_and(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(jnp.bitwise_and(_unwrap(x1), _unwrap(x2))) + return Array(jnp.bitwise_and(unwrap(x1), unwrap(x2))) def iand[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value = jnp.bitwise_and(x1.value, _unwrap(x2)) + x1.value = jnp.bitwise_and(x1.value, unwrap(x2)) return x1 def bitwise_invert(self, x: Array) -> Array: return Array(jnp.bitwise_not(x.value)) def bitwise_or(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(jnp.bitwise_or(_unwrap(x1), _unwrap(x2))) + return Array(jnp.bitwise_or(unwrap(x1), unwrap(x2))) def ior[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value = jnp.bitwise_or(x1.value, _unwrap(x2)) + x1.value = jnp.bitwise_or(x1.value, unwrap(x2)) return x1 def bitwise_xor(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(jnp.bitwise_xor(_unwrap(x1), _unwrap(x2))) + return Array(jnp.bitwise_xor(unwrap(x1), unwrap(x2))) def ixor[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value = jnp.bitwise_xor(x1.value, _unwrap(x2)) + x1.value = jnp.bitwise_xor(x1.value, unwrap(x2)) return x1 def bitwise_left_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(jnp.left_shift(_unwrap(x1), _unwrap(x2))) + return Array(jnp.left_shift(unwrap(x1), unwrap(x2))) def ilshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value = jnp.left_shift(x1.value, _unwrap(x2)) + x1.value = jnp.left_shift(x1.value, unwrap(x2)) return x1 def bitwise_right_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(jnp.right_shift(_unwrap(x1), _unwrap(x2))) + return Array(jnp.right_shift(unwrap(x1), unwrap(x2))) def irshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value = jnp.right_shift(x1.value, _unwrap(x2)) + x1.value = jnp.right_shift(x1.value, unwrap(x2)) return x1 # Operators @@ -321,7 +301,7 @@ def sign(self, x: Array) -> Array: return Array(jnp.sign(x.value)) def maximum(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(jnp.maximum(_unwrap(x1), _unwrap(x2))) + return Array(jnp.maximum(unwrap(x1), unwrap(x2))) def argmax(self, x: Array, axis: int | None = None, keepdims: bool = False) -> Array: return Array(jnp.argmax(x.value, axis=axis, keepdims=keepdims)) @@ -331,7 +311,7 @@ def argmin(self, x: Array, axis: int | None = None, keepdims: bool = False) -> A def set_item(self, x: Array, key: ArrayKey, value: bool | int | float | complex | Array) -> None: # JAX arrays are immutable; rebind the wrapper to a new array with `key` updated. - x.value = x.value.at[key].set(_unwrap(value)) + x.value = x.value.at[key].set(unwrap(value)) def get_item(self, x: Array, key: ArrayKey) -> Array: return Array(x.value[key]) @@ -381,5 +361,104 @@ def _next_key(self) -> jax.Array: self._key, sub = jax.random.split(self._key) return cast("jax.Array", sub) + # Dtypes + + @property + def bool_(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.bool_) + + @property + def uint8(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.uint8) + + @property + def uint16(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.uint16) + + @property + def uint32(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.uint32) + + @property + def uint64(self) -> Any: # noqa: ANN401 + if _x64_enabled(): + return np.dtype(jnp.uint64) + return None + + @property + def int8(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.int8) + + @property + def int16(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.int16) + + @property + def int32(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.int32) + + @property + def int64(self) -> Any: # noqa: ANN401 + if _x64_enabled(): + return np.dtype(jnp.int64) + return None + + @property + def float16(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.float16) + + @property + def float32(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.float32) + + @property + def float64(self) -> Any: # noqa: ANN401 + if _x64_enabled(): + return np.dtype(jnp.float64) + return None + + @property + def complex64(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.complex64) + + @property + def complex128(self) -> Any: # noqa: ANN401 + if _x64_enabled(): + return np.dtype(jnp.complex128) + return None + + @property + def bfloat16(self) -> Any: # noqa: ANN401 + return np.dtype(jnp.bfloat16) + + # Constants + + @property + def e(self) -> Any: # noqa: ANN401 + """e = 2.71828...""" # noqa: D403 + return jnp.e + + @property + def inf(self) -> Any: # noqa: ANN401 + """Infinity.""" + return jnp.inf + + @property + def nan(self) -> Any: # noqa: ANN401 + """Not-a-number.""" + return jnp.nan + + @property + def pi(self) -> Any: # noqa: ANN401 + """pi = 3.14159...""" # noqa: D403 + return jnp.pi + + +_JAX_CONFIG_READ = cast("Callable[[str], Any]", jax.config.read) + + +def _x64_enabled() -> bool: + return bool(_JAX_CONFIG_READ("jax_enable_x64")) + -register_backend(SupportedFrameworks.JAX, JaxBackend) +register_backend(Frameworks.JAX, JaxBackend) diff --git a/decent_array/interoperability/_numpy/numpy_backend.py b/decent_array/interoperability/_numpy/numpy_backend.py index 1c5f149..3fe66d0 100644 --- a/decent_array/interoperability/_numpy/numpy_backend.py +++ b/decent_array/interoperability/_numpy/numpy_backend.py @@ -15,46 +15,20 @@ from numpy.typing import NDArray from decent_array import Array +from decent_array._utils import unwrap from decent_array.interoperability._abstracts import Backend from decent_array.interoperability._backend_manager import register_backend -from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices, SupportedFrameworks - - -def _unwrap(x: Any) -> Any: # noqa: ANN401 - """ - Return the underlying value of an :class:`Array`, or pass ``x`` through. - - Typed as ``Any`` because operator dunders may pass either an :class:`Array` or a - Python scalar; the strict abstract signature would force a ``cast`` at every call - site without runtime benefit. - """ - return x.value if type(x) is Array else x - - -_DTYPE_MAP = { - DTypes.BOOL: np.bool_, - DTypes.UINT8: np.uint8, - DTypes.UINT16: np.uint16, - DTypes.UINT32: np.uint32, - DTypes.UINT64: np.uint64, - DTypes.INT8: np.int8, - DTypes.INT16: np.int16, - DTypes.INT32: np.int32, - DTypes.INT64: np.int64, - DTypes.FLOAT32: np.float32, - DTypes.FLOAT64: np.float64, - DTypes.COMPLEX64: np.complex64, - DTypes.COMPLEX128: np.complex128, -} - - -class NumpyBackend(Backend): # noqa: PLR0904 +from decent_array.types import ArrayKey, ArrayTypes, Devices, Frameworks +from decent_array.types._dtypes import _ALL_DTYPES, dtype + + +class NumpyBackend(Backend): """NumPy implementation of :class:`Backend`.""" - def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: - if device != SupportedDevices.CPU: + def __init__(self, device: Devices = Devices.CPU) -> None: + if device != Devices.CPU: raise ValueError(f"NumPy backend only supports CPU, got '{device.value}'.") - super().__init__(device) + super().__init__(device, name=Frameworks.NUMPY.value) self._rng: np.random.Generator = np.random.default_rng() # Array creation @@ -74,12 +48,12 @@ def ones_like(self, x: Array) -> Array: def eye(self, n: int) -> Array: return Array(np.eye(n)) - def device_to_native(self, device: SupportedDevices) -> Any: # noqa: ANN401 + def device_to_native(self, device: Devices) -> Any: # noqa: ANN401 # NumPy has no explicit device management; surface the request unchanged. return device - def device_of(self, x: Array) -> SupportedDevices: # noqa: ARG002 - return SupportedDevices.CPU + def device_of(self, x: Array) -> Devices: # noqa: ARG002 + return Devices.CPU # Array manipulation @@ -89,7 +63,7 @@ def copy(self, x: Array) -> Array: return Array(np.copy(v)) return Array(deepcopy(v)) - def to_numpy(self, x: SupportedArrayTypes | Array) -> NDArray[Any]: + def to_numpy(self, x: ArrayTypes | Array) -> NDArray[Any]: """Return the value of an :class:`Array` as a NumPy array.""" v = x.value if type(x) is Array else x if isinstance(v, np.ndarray): @@ -148,10 +122,10 @@ def diagonal(self, x: Array, offset: int = 0) -> Array: raise ValueError(f"diagonal requires a 2-D array, got {x.value.ndim}-D") return Array(np.diagonal(x.value, offset=offset)) - def astype(self, x: Array, dtype: DTypes) -> Array: - if dtype not in _DTYPE_MAP: - raise ValueError(f"Unsupported dtype '{dtype.value}' for NumPy backend.") - return Array(np.asarray(x.value, dtype=_DTYPE_MAP[dtype])) + def astype(self, x: Array, dtype: dtype) -> Array: + if dtype not in _ALL_DTYPES.values(): + raise ValueError(f"Unsupported dtype '{dtype}' for NumPy backend.") + return Array(np.asarray(x.value, dtype=dtype.backend_dtype)) # Linalg @@ -198,52 +172,52 @@ def all(self, x: Array, axis: int | tuple[int, ...] | None = None, keepdims: boo # ``Array | float`` covers both: PEP 484's numeric tower implicitly admits ``int``. def add(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.add(_unwrap(x1), _unwrap(x2))) + return Array(np.add(unwrap(x1), unwrap(x2))) def iadd[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value += _unwrap(x2) + x1.value += unwrap(x2) return x1 def subtract(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.subtract(_unwrap(x1), _unwrap(x2))) + return Array(np.subtract(unwrap(x1), unwrap(x2))) def isubtract[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value -= _unwrap(x2) + x1.value -= unwrap(x2) return x1 def multiply(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.multiply(_unwrap(x1), _unwrap(x2))) + return Array(np.multiply(unwrap(x1), unwrap(x2))) def imultiply[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value *= _unwrap(x2) + x1.value *= unwrap(x2) return x1 def divide(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.divide(_unwrap(x1), _unwrap(x2))) + return Array(np.divide(unwrap(x1), unwrap(x2))) def idivide[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value /= _unwrap(x2) + x1.value /= unwrap(x2) return x1 def floor_divide(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(np.floor_divide(_unwrap(x1), _unwrap(x2))) + return Array(np.floor_divide(unwrap(x1), unwrap(x2))) def ifloordiv[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value //= _unwrap(x2) + x1.value //= unwrap(x2) return x1 def remainder(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(np.remainder(_unwrap(x1), _unwrap(x2))) + return Array(np.remainder(unwrap(x1), unwrap(x2))) def imod[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value %= _unwrap(x2) + x1.value %= unwrap(x2) return x1 def pow(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.power(_unwrap(x1), _unwrap(x2))) + return Array(np.power(unwrap(x1), unwrap(x2))) def ipow[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value **= _unwrap(x2) + x1.value **= unwrap(x2) return x1 def negative(self, x: Array) -> Array: @@ -258,61 +232,61 @@ def sqrt(self, x: Array) -> Array: # Comparisons def equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.equal(_unwrap(x1), _unwrap(x2))) + return Array(np.equal(unwrap(x1), unwrap(x2))) def not_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.not_equal(_unwrap(x1), _unwrap(x2))) + return Array(np.not_equal(unwrap(x1), unwrap(x2))) def less(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.less(_unwrap(x1), _unwrap(x2))) + return Array(np.less(unwrap(x1), unwrap(x2))) def less_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.less_equal(_unwrap(x1), _unwrap(x2))) + return Array(np.less_equal(unwrap(x1), unwrap(x2))) def greater(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.greater(_unwrap(x1), _unwrap(x2))) + return Array(np.greater(unwrap(x1), unwrap(x2))) def greater_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.greater_equal(_unwrap(x1), _unwrap(x2))) + return Array(np.greater_equal(unwrap(x1), unwrap(x2))) # Bitwise def bitwise_and(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(np.bitwise_and(_unwrap(x1), _unwrap(x2))) + return Array(np.bitwise_and(unwrap(x1), unwrap(x2))) def iand[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value &= _unwrap(x2) + x1.value &= unwrap(x2) return x1 def bitwise_invert(self, x: Array) -> Array: return Array(np.bitwise_not(x.value)) def bitwise_or(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(np.bitwise_or(_unwrap(x1), _unwrap(x2))) + return Array(np.bitwise_or(unwrap(x1), unwrap(x2))) def ior[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value |= _unwrap(x2) + x1.value |= unwrap(x2) return x1 def bitwise_xor(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(np.bitwise_xor(_unwrap(x1), _unwrap(x2))) + return Array(np.bitwise_xor(unwrap(x1), unwrap(x2))) def ixor[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value ^= _unwrap(x2) + x1.value ^= unwrap(x2) return x1 def bitwise_left_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(np.left_shift(_unwrap(x1), _unwrap(x2))) + return Array(np.left_shift(unwrap(x1), unwrap(x2))) def ilshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value <<= _unwrap(x2) + x1.value <<= unwrap(x2) return x1 def bitwise_right_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(np.right_shift(_unwrap(x1), _unwrap(x2))) + return Array(np.right_shift(unwrap(x1), unwrap(x2))) def irshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value >>= _unwrap(x2) + x1.value >>= unwrap(x2) return x1 # Operators @@ -321,7 +295,7 @@ def sign(self, x: Array) -> Array: return Array(np.sign(x.value)) def maximum(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(np.maximum(_unwrap(x1), _unwrap(x2))) + return Array(np.maximum(unwrap(x1), unwrap(x2))) def argmax(self, x: Array, axis: int | None = None, keepdims: bool = False) -> Array: return Array(np.argmax(x.value, axis=axis, keepdims=keepdims)) @@ -330,7 +304,7 @@ def argmin(self, x: Array, axis: int | None = None, keepdims: bool = False) -> A return Array(np.argmin(x.value, axis=axis, keepdims=keepdims)) def set_item(self, x: Array, key: ArrayKey, value: bool | int | float | complex | Array) -> None: - x.value[key] = _unwrap(value) + x.value[key] = unwrap(value) def get_item(self, x: Array, key: ArrayKey) -> Array: return Array(x.value[key]) @@ -375,5 +349,111 @@ def uniform_like(self, x: Array, low: float = 0.0, high: float = 1.0) -> Array: def choice(self, x: Array, size: int, replace: bool = True) -> Array: return Array(self._rng.choice(x.value, size=size, replace=replace)) + # Dtypes + + @property + def bool_(self) -> Any: # noqa: ANN401 + return np.dtype(np.bool_) + + @property + def uint8(self) -> Any: # noqa: ANN401 + return np.dtype(np.uint8) + + @property + def uint16(self) -> Any: # noqa: ANN401 + return np.dtype(np.uint16) + + @property + def uint32(self) -> Any: # noqa: ANN401 + return np.dtype(np.uint32) + + @property + def uint64(self) -> Any: # noqa: ANN401 + return np.dtype(np.uint64) + + @property + def int8(self) -> Any: # noqa: ANN401 + return np.dtype(np.int8) + + @property + def int16(self) -> Any: # noqa: ANN401 + return np.dtype(np.int16) + + @property + def int32(self) -> Any: # noqa: ANN401 + return np.dtype(np.int32) + + @property + def int64(self) -> Any: # noqa: ANN401 + return np.dtype(np.int64) + + @property + def float16(self) -> Any: # noqa: ANN401 + return np.dtype(np.float16) + + @property + def float32(self) -> Any: # noqa: ANN401 + return np.dtype(np.float32) + + @property + def float64(self) -> Any: # noqa: ANN401 + return np.dtype(np.float64) + + @property + def complex64(self) -> Any: # noqa: ANN401 + return np.dtype(np.complex64) + + @property + def complex128(self) -> Any: # noqa: ANN401 + return np.dtype(np.complex128) + + @property + def float128(self) -> Any | None: # noqa: ANN401 + float128 = getattr(np, "float128", None) + return np.dtype(float128) if float128 is not None else None + + @property + def complex256(self) -> Any | None: # noqa: ANN401 + complex256 = getattr(np, "complex256", None) + return np.dtype(complex256) if complex256 is not None else None + + @property + def unicode_(self) -> Any: # noqa: ANN401 + return np.dtype(np.str_) + + @property + def bytes_(self) -> Any: # noqa: ANN401 + return np.dtype(np.bytes_) + + @property + def object_(self) -> Any: # noqa: ANN401 + return np.dtype(np.object_) + + @property + def void(self) -> Any: # noqa: ANN401 + return np.dtype(np.void) + + # Constants + + @property + def e(self) -> Any: # noqa: ANN401 + """e = 2.71828...""" # noqa: D403 + return np.e + + @property + def inf(self) -> Any: # noqa: ANN401 + """Infinity.""" + return np.inf + + @property + def nan(self) -> Any: # noqa: ANN401 + """Not-a-number.""" + return np.nan + + @property + def pi(self) -> Any: # noqa: ANN401 + """pi = 3.14159...""" # noqa: D403 + return np.pi + -register_backend(SupportedFrameworks.NUMPY, NumpyBackend) +register_backend(Frameworks.NUMPY, NumpyBackend) diff --git a/decent_array/interoperability/_pytorch/pytorch_backend.py b/decent_array/interoperability/_pytorch/pytorch_backend.py index a720d2d..36e1789 100644 --- a/decent_array/interoperability/_pytorch/pytorch_backend.py +++ b/decent_array/interoperability/_pytorch/pytorch_backend.py @@ -15,38 +15,18 @@ from numpy.typing import NDArray from decent_array import Array +from decent_array._utils import unwrap from decent_array.interoperability._abstracts import Backend from decent_array.interoperability._backend_manager import register_backend -from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices, SupportedFrameworks +from decent_array.types import ArrayKey, ArrayTypes, Devices, Frameworks +from decent_array.types._dtypes import _ALL_DTYPES, dtype -def _unwrap(x: Any) -> Any: # noqa: ANN401 - """Return the underlying value of an :class:`Array`, or pass ``x`` through.""" - return x.value if type(x) is Array else x - - -_DTYPE_MAP = { - DTypes.BOOL: torch.bool, - DTypes.UINT8: torch.uint8, - DTypes.UINT16: torch.uint16, - DTypes.UINT32: torch.uint32, - DTypes.UINT64: torch.uint64, - DTypes.INT8: torch.int8, - DTypes.INT16: torch.int16, - DTypes.INT32: torch.int32, - DTypes.INT64: torch.int64, - DTypes.FLOAT32: torch.float32, - DTypes.FLOAT64: torch.float64, - DTypes.COMPLEX64: torch.complex64, - DTypes.COMPLEX128: torch.complex128, -} - - -class PyTorchBackend(Backend): # noqa: PLR0904 +class PyTorchBackend(Backend): """PyTorch implementation of :class:`Backend`.""" - def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: - super().__init__(device) + def __init__(self, device: Devices = Devices.CPU) -> None: + super().__init__(device, name=Frameworks.PYTORCH.value) self._native_device: str = self.device_to_native(device) self._generator: torch.Generator = torch.Generator(device=self._native_device) @@ -67,23 +47,23 @@ def ones_like(self, x: Array) -> Array: def eye(self, n: int) -> Array: return Array(torch.eye(n, device=self._native_device)) - def device_to_native(self, device: SupportedDevices) -> str: - if device == SupportedDevices.CPU: + def device_to_native(self, device: Devices) -> str: + if device == Devices.CPU: return "cpu" - if device == SupportedDevices.GPU: + if device == Devices.GPU: return "cuda" - if device == SupportedDevices.MPS: + if device == Devices.MPS: return "mps" raise ValueError(f"Unsupported device: {device}") - def device_of(self, x: Array) -> SupportedDevices: + def device_of(self, x: Array) -> Devices: kind = x.value.device.type if kind == "cpu": - return SupportedDevices.CPU + return Devices.CPU if kind == "cuda": - return SupportedDevices.GPU + return Devices.GPU if kind == "mps": - return SupportedDevices.MPS + return Devices.MPS raise TypeError(f"Unsupported PyTorch device type: {kind}") # Array manipulation @@ -91,7 +71,7 @@ def device_of(self, x: Array) -> SupportedDevices: def copy(self, x: Array) -> Array: return Array(x.value.detach().clone()) - def to_numpy(self, x: SupportedArrayTypes | Array) -> NDArray[Any]: + def to_numpy(self, x: ArrayTypes | Array) -> NDArray[Any]: """Return the value of an :class:`Array` as a NumPy array.""" v = x.value if type(x) is Array else x if isinstance(v, torch.Tensor): @@ -157,10 +137,10 @@ def diagonal(self, x: Array, offset: int = 0) -> Array: raise ValueError(f"diagonal requires a 2-D array, got {x.value.ndim}-D") return Array(torch.diagonal(x.value, offset=offset)) - def astype(self, x: Array, dtype: DTypes) -> Array: - if dtype not in _DTYPE_MAP: - raise ValueError(f"Unsupported dtype '{dtype.value}' for PyTorch backend.") - return Array(x.value.to(dtype=_DTYPE_MAP[dtype])) + def astype(self, x: Array, dtype: dtype) -> Array: + if dtype not in _ALL_DTYPES.values(): + raise ValueError(f"Unsupported dtype '{dtype}' for PyTorch backend.") + return Array(x.value.to(dtype=dtype.backend_dtype)) # Linalg @@ -219,52 +199,52 @@ def all(self, x: Array, axis: int | tuple[int, ...] | None = None, keepdims: boo # ``Array | float`` covers both: PEP 484's numeric tower implicitly admits ``int``. def add(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.add(_unwrap(x1), _unwrap(x2))) + return Array(torch.add(unwrap(x1), unwrap(x2))) def iadd[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value.add_(_unwrap(x2)) + x1.value.add_(unwrap(x2)) return x1 def subtract(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.sub(_unwrap(x1), _unwrap(x2))) + return Array(torch.sub(unwrap(x1), unwrap(x2))) def isubtract[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value.sub_(_unwrap(x2)) + x1.value.sub_(unwrap(x2)) return x1 def multiply(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.mul(_unwrap(x1), _unwrap(x2))) + return Array(torch.mul(unwrap(x1), unwrap(x2))) def imultiply[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value.mul_(_unwrap(x2)) + x1.value.mul_(unwrap(x2)) return x1 def divide(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.div(_unwrap(x1), _unwrap(x2))) + return Array(torch.div(unwrap(x1), unwrap(x2))) def idivide[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value.div_(_unwrap(x2)) + x1.value.div_(unwrap(x2)) return x1 def floor_divide(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(torch.floor_divide(_unwrap(x1), _unwrap(x2))) + return Array(torch.floor_divide(unwrap(x1), unwrap(x2))) def ifloordiv[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value.floor_divide_(_unwrap(x2)) + x1.value.floor_divide_(unwrap(x2)) return x1 def remainder(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(torch.remainder(_unwrap(x1), _unwrap(x2))) + return Array(torch.remainder(unwrap(x1), unwrap(x2))) def imod[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value.remainder_(_unwrap(x2)) + x1.value.remainder_(unwrap(x2)) return x1 def pow(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.pow(_unwrap(x1), _unwrap(x2))) + return Array(torch.pow(unwrap(x1), unwrap(x2))) def ipow[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value.pow_(_unwrap(x2)) + x1.value.pow_(unwrap(x2)) return x1 def negative(self, x: Array) -> Array: @@ -279,61 +259,61 @@ def sqrt(self, x: Array) -> Array: # Comparisons def equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.eq(_unwrap(x1), _unwrap(x2))) + return Array(torch.eq(unwrap(x1), unwrap(x2))) def not_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.ne(_unwrap(x1), _unwrap(x2))) + return Array(torch.ne(unwrap(x1), unwrap(x2))) def less(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.lt(_unwrap(x1), _unwrap(x2))) + return Array(torch.lt(unwrap(x1), unwrap(x2))) def less_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.le(_unwrap(x1), _unwrap(x2))) + return Array(torch.le(unwrap(x1), unwrap(x2))) def greater(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.gt(_unwrap(x1), _unwrap(x2))) + return Array(torch.gt(unwrap(x1), unwrap(x2))) def greater_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(torch.ge(_unwrap(x1), _unwrap(x2))) + return Array(torch.ge(unwrap(x1), unwrap(x2))) # Bitwise def bitwise_and(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(torch.bitwise_and(_unwrap(x1), _unwrap(x2))) + return Array(torch.bitwise_and(unwrap(x1), unwrap(x2))) def iand[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value.bitwise_and_(_unwrap(x2)) + x1.value.bitwise_and_(unwrap(x2)) return x1 def bitwise_invert(self, x: Array) -> Array: return Array(torch.bitwise_not(x.value)) def bitwise_or(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(torch.bitwise_or(_unwrap(x1), _unwrap(x2))) + return Array(torch.bitwise_or(unwrap(x1), unwrap(x2))) def ior[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value.bitwise_or_(_unwrap(x2)) + x1.value.bitwise_or_(unwrap(x2)) return x1 def bitwise_xor(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(torch.bitwise_xor(_unwrap(x1), _unwrap(x2))) + return Array(torch.bitwise_xor(unwrap(x1), unwrap(x2))) def ixor[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value.bitwise_xor_(_unwrap(x2)) + x1.value.bitwise_xor_(unwrap(x2)) return x1 def bitwise_left_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(torch.bitwise_left_shift(_unwrap(x1), _unwrap(x2))) + return Array(torch.bitwise_left_shift(unwrap(x1), unwrap(x2))) def ilshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value.bitwise_left_shift_(_unwrap(x2)) + x1.value.bitwise_left_shift_(unwrap(x2)) return x1 def bitwise_right_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(torch.bitwise_right_shift(_unwrap(x1), _unwrap(x2))) + return Array(torch.bitwise_right_shift(unwrap(x1), unwrap(x2))) def irshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value.bitwise_right_shift_(_unwrap(x2)) + x1.value.bitwise_right_shift_(unwrap(x2)) return x1 # Operators @@ -342,7 +322,7 @@ def sign(self, x: Array) -> Array: return Array(torch.sign(x.value)) def maximum(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - a, b = _unwrap(x1), _unwrap(x2) + a, b = unwrap(x1), unwrap(x2) # torch.maximum requires both operands to be Tensors; lift Python scalars to # match the dtype/device of the tensor operand so the contract matches numpy. if not isinstance(a, torch.Tensor): @@ -359,7 +339,7 @@ def argmin(self, x: Array, axis: int | None = None, keepdims: bool = False) -> A return Array(torch.argmin(x.value, dim=axis, keepdim=keepdims)) def set_item(self, x: Array, key: ArrayKey, value: bool | int | float | complex | Array) -> None: - x.value[key] = _unwrap(value) + x.value[key] = unwrap(value) def get_item(self, x: Array, key: ArrayKey) -> Array: return Array(x.value[key]) @@ -422,5 +402,89 @@ def choice(self, x: Array, size: int, replace: bool = True) -> Array: indices = weights.multinomial(num_samples=size, replacement=replace, generator=self._generator) return Array(v[indices]) + # Dtypes + + @property + def bool_(self) -> torch.dtype: + return torch.bool + + @property + def uint8(self) -> torch.dtype: + return torch.uint8 + + @property + def int8(self) -> torch.dtype: + return torch.int8 + + @property + def int16(self) -> torch.dtype: + return torch.int16 + + @property + def int32(self) -> torch.dtype: + return torch.int32 + + @property + def int64(self) -> torch.dtype: + return torch.int64 + + @property + def float16(self) -> torch.dtype: + return torch.float16 + + @property + def float32(self) -> torch.dtype: + return torch.float32 + + @property + def float64(self) -> torch.dtype: + return torch.float64 + + @property + def complex64(self) -> torch.dtype: + return torch.complex64 + + @property + def complex128(self) -> torch.dtype: + return torch.complex128 + + @property + def qint8(self) -> Any: # noqa: ANN401 + return torch.qint8 + + @property + def qint32(self) -> Any: # noqa: ANN401 + return torch.qint32 + + @property + def quint8(self) -> Any: # noqa: ANN401 + return torch.quint8 + + @property + def bfloat16(self) -> torch.dtype: + return torch.bfloat16 + + # Constants + + @property + def e(self) -> Any: # noqa: ANN401 + """e = 2.71828...""" # noqa: D403 + return torch.e + + @property + def inf(self) -> Any: # noqa: ANN401 + """Infinity.""" + return torch.inf + + @property + def nan(self) -> Any: # noqa: ANN401 + """Not-a-number.""" + return torch.nan + + @property + def pi(self) -> Any: # noqa: ANN401 + """pi = 3.14159...""" # noqa: D403 + return torch.pi + -register_backend(SupportedFrameworks.PYTORCH, PyTorchBackend) +register_backend(Frameworks.PYTORCH, PyTorchBackend) diff --git a/decent_array/interoperability/_tensorflow/tensorflow_backend.py b/decent_array/interoperability/_tensorflow/tensorflow_backend.py index 4203b86..fe5b150 100644 --- a/decent_array/interoperability/_tensorflow/tensorflow_backend.py +++ b/decent_array/interoperability/_tensorflow/tensorflow_backend.py @@ -18,39 +18,18 @@ from numpy.typing import NDArray from decent_array import Array +from decent_array._utils import unwrap from decent_array.interoperability._abstracts import Backend from decent_array.interoperability._backend_manager import register_backend -from decent_array.types import ArrayKey, DTypes, SupportedArrayTypes, SupportedDevices, SupportedFrameworks +from decent_array.types import ArrayKey, ArrayTypes, Devices, Frameworks +from decent_array.types._dtypes import _ALL_DTYPES, dtype -def _unwrap(x: Any) -> Any: # noqa: ANN401 - """Return the underlying value of an :class:`Array`, or pass ``x`` through.""" - return x.value if type(x) is Array else x - - -_DTYPE_MAP = { - DTypes.BOOL: tf.bool, - DTypes.UINT8: tf.uint8, - DTypes.INT8: tf.int8, - DTypes.UINT16: tf.uint16, - DTypes.INT16: tf.int16, - DTypes.UINT32: tf.uint32, - DTypes.INT32: tf.int32, - DTypes.UINT64: tf.uint64, - DTypes.INT64: tf.int64, - DTypes.FLOAT16: tf.float16, - DTypes.FLOAT32: tf.float32, - DTypes.FLOAT64: tf.float64, - DTypes.COMPLEX64: tf.complex64, - DTypes.COMPLEX128: tf.complex128, -} - - -class TensorflowBackend(Backend): # noqa: PLR0904 +class TensorflowBackend(Backend): """TensorFlow implementation of :class:`Backend`.""" - def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: - super().__init__(device) + def __init__(self, device: Devices = Devices.CPU) -> None: + super().__init__(device, name=Frameworks.TENSORFLOW.value) self._native_device: str = self.device_to_native(device) self._generator: tf.random.Generator = tf.random.Generator.from_non_deterministic_state(alg="philox") @@ -74,23 +53,23 @@ def eye(self, n: int) -> Array: with tf.device(self._native_device): return Array(tf.eye(n)) - def device_to_native(self, device: SupportedDevices) -> str: - if device in {SupportedDevices.CPU, SupportedDevices.GPU}: + def device_to_native(self, device: Devices) -> str: + if device in {Devices.CPU, Devices.GPU}: return f"/{device.value}:0" raise ValueError(f"Unsupported device for TensorFlow: {device}") - def device_of(self, x: Array) -> SupportedDevices: + def device_of(self, x: Array) -> Devices: device_str = x.value.device.lower() if "gpu" in device_str or "cuda" in device_str: - return SupportedDevices.GPU - return SupportedDevices.CPU + return Devices.GPU + return Devices.CPU # Array manipulation def copy(self, x: Array) -> Array: return Array(tf.identity(x.value)) - def to_numpy(self, x: SupportedArrayTypes | Array) -> NDArray[Any]: + def to_numpy(self, x: ArrayTypes | Array) -> NDArray[Any]: """Return the value of an :class:`Array` as a NumPy array.""" v = x.value if type(x) is Array else x if isinstance(v, tf.Tensor): @@ -163,10 +142,10 @@ def diagonal(self, x: Array, offset: int = 0) -> Array: raise ValueError(f"diagonal requires a 2-D tensor, got rank {rank}") return Array(tf.linalg.diag_part(v, k=offset)) - def astype(self, x: Array, dtype: DTypes) -> Array: - if dtype not in _DTYPE_MAP: - raise ValueError(f"Unsupported dtype '{dtype.value}' for TensorFlow backend.") - return Array(tf.cast(x.value, dtype=_DTYPE_MAP[dtype])) + def astype(self, x: Array, dtype: dtype) -> Array: + if dtype not in _ALL_DTYPES.values(): + raise ValueError(f"Unsupported dtype '{dtype}' for TensorFlow backend.") + return Array(tf.cast(x.value, dtype=dtype.backend_dtype)) # Linalg @@ -227,52 +206,52 @@ def all(self, x: Array, axis: int | tuple[int, ...] | None = None, keepdims: boo # covers both because PEP 484's numeric tower implicitly admits ``int``. def add(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.add(_unwrap(x1), _unwrap(x2))) + return Array(tf.add(unwrap(x1), unwrap(x2))) def iadd[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = tf.add(x1.value, _unwrap(x2)) + x1.value = tf.add(x1.value, unwrap(x2)) return x1 def subtract(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.subtract(_unwrap(x1), _unwrap(x2))) + return Array(tf.subtract(unwrap(x1), unwrap(x2))) def isubtract[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = tf.subtract(x1.value, _unwrap(x2)) + x1.value = tf.subtract(x1.value, unwrap(x2)) return x1 def multiply(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.multiply(_unwrap(x1), _unwrap(x2))) + return Array(tf.multiply(unwrap(x1), unwrap(x2))) def imultiply[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = tf.multiply(x1.value, _unwrap(x2)) + x1.value = tf.multiply(x1.value, unwrap(x2)) return x1 def divide(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.divide(_unwrap(x1), _unwrap(x2))) + return Array(tf.divide(unwrap(x1), unwrap(x2))) def idivide[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = tf.divide(x1.value, _unwrap(x2)) + x1.value = tf.divide(x1.value, unwrap(x2)) return x1 def floor_divide(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(tf.math.floordiv(_unwrap(x1), _unwrap(x2))) + return Array(tf.math.floordiv(unwrap(x1), unwrap(x2))) def ifloordiv[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value = tf.math.floordiv(x1.value, _unwrap(x2)) + x1.value = tf.math.floordiv(x1.value, unwrap(x2)) return x1 def remainder(self, x1: int | float | Array, x2: int | float | Array) -> Array: - return Array(tf.math.floormod(_unwrap(x1), _unwrap(x2))) + return Array(tf.math.floormod(unwrap(x1), unwrap(x2))) def imod[T: Array](self, x1: T, x2: int | float | Array) -> T: - x1.value = tf.math.floormod(x1.value, _unwrap(x2)) + x1.value = tf.math.floormod(x1.value, unwrap(x2)) return x1 def pow(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.pow(_unwrap(x1), _unwrap(x2))) + return Array(tf.pow(unwrap(x1), unwrap(x2))) def ipow[T: Array](self, x1: T, x2: int | float | complex | Array) -> T: - x1.value = tf.pow(x1.value, _unwrap(x2)) + x1.value = tf.pow(x1.value, unwrap(x2)) return x1 def negative(self, x: Array) -> Array: @@ -287,22 +266,22 @@ def sqrt(self, x: Array) -> Array: # Comparisons def equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.equal(_unwrap(x1), _unwrap(x2))) + return Array(tf.equal(unwrap(x1), unwrap(x2))) def not_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.not_equal(_unwrap(x1), _unwrap(x2))) + return Array(tf.not_equal(unwrap(x1), unwrap(x2))) def less(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.less(_unwrap(x1), _unwrap(x2))) + return Array(tf.less(unwrap(x1), unwrap(x2))) def less_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.less_equal(_unwrap(x1), _unwrap(x2))) + return Array(tf.less_equal(unwrap(x1), unwrap(x2))) def greater(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.greater(_unwrap(x1), _unwrap(x2))) + return Array(tf.greater(unwrap(x1), unwrap(x2))) def greater_equal(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.greater_equal(_unwrap(x1), _unwrap(x2))) + return Array(tf.greater_equal(unwrap(x1), unwrap(x2))) # Bitwise — TF's native ``&`` dispatches to ``tf.math.logical_and`` for bool # tensors and ``tf.bitwise.bitwise_and`` for int tensors, matching numpy/torch/jax @@ -310,41 +289,41 @@ def greater_equal(self, x1: int | float | complex | Array, x2: int | float | com # us to one dtype family. def bitwise_and(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(_unwrap(x1) & _unwrap(x2)) + return Array(unwrap(x1) & unwrap(x2)) def iand[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value &= _unwrap(x2) + x1.value &= unwrap(x2) return x1 def bitwise_invert(self, x: Array) -> Array: return Array(~x.value) def bitwise_or(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(_unwrap(x1) | _unwrap(x2)) + return Array(unwrap(x1) | unwrap(x2)) def ior[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value |= _unwrap(x2) + x1.value |= unwrap(x2) return x1 def bitwise_xor(self, x1: bool | int | Array, x2: bool | int | Array) -> Array: - return Array(_unwrap(x1) ^ _unwrap(x2)) + return Array(unwrap(x1) ^ unwrap(x2)) def ixor[T: Array](self, x1: T, x2: bool | int | Array) -> T: - x1.value ^= _unwrap(x2) + x1.value ^= unwrap(x2) return x1 def bitwise_left_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(tf.bitwise.left_shift(_unwrap(x1), _unwrap(x2))) + return Array(tf.bitwise.left_shift(unwrap(x1), unwrap(x2))) def ilshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value = tf.bitwise.left_shift(x1.value, _unwrap(x2)) + x1.value = tf.bitwise.left_shift(x1.value, unwrap(x2)) return x1 def bitwise_right_shift(self, x1: int | Array, x2: int | Array) -> Array: - return Array(tf.bitwise.right_shift(_unwrap(x1), _unwrap(x2))) + return Array(tf.bitwise.right_shift(unwrap(x1), unwrap(x2))) def irshift[T: Array](self, x1: T, x2: int | Array) -> T: - x1.value = tf.bitwise.right_shift(x1.value, _unwrap(x2)) + x1.value = tf.bitwise.right_shift(x1.value, unwrap(x2)) return x1 # Operators @@ -353,7 +332,7 @@ def sign(self, x: Array) -> Array: return Array(tf.sign(x.value)) def maximum(self, x1: int | float | complex | Array, x2: int | float | complex | Array) -> Array: - return Array(tf.maximum(_unwrap(x1), _unwrap(x2))) + return Array(tf.maximum(unwrap(x1), unwrap(x2))) def argmax(self, x: Array, axis: int | None = None, keepdims: bool = False) -> Array: v = x.value @@ -388,7 +367,7 @@ def set_item(self, x: Array, key: ArrayKey, value: bool | int | float | complex # that hammer set_item in tight loops should consider numpy or pytorch. original = x.value np_array = original.numpy().copy() - np_array[key] = np.asarray(_unwrap(value)) + np_array[key] = np.asarray(unwrap(value)) with tf.device(self._native_device): x.value = tf.convert_to_tensor(np_array, dtype=original.dtype) @@ -434,5 +413,109 @@ def choice(self, x: Array, size: int, replace: bool = True) -> Array: indices = tf.cast(tf.math.top_k(scores, k=size).indices, tf.int32) return Array(tf.gather(v, indices)) + # Dtypes + + @property + def bool_(self) -> tf.dtypes.DType: + return tf.bool + + @property + def uint8(self) -> tf.dtypes.DType: + return tf.uint8 + + @property + def uint16(self) -> tf.dtypes.DType: + return tf.uint16 + + @property + def uint32(self) -> tf.dtypes.DType: + return tf.uint32 + + @property + def uint64(self) -> tf.dtypes.DType: + return tf.uint64 + + @property + def int8(self) -> tf.dtypes.DType: + return tf.int8 + + @property + def int16(self) -> tf.dtypes.DType: + return tf.int16 + + @property + def int32(self) -> tf.dtypes.DType: + return tf.int32 + + @property + def int64(self) -> tf.dtypes.DType: + return tf.int64 + + @property + def float16(self) -> tf.dtypes.DType: + return tf.float16 + + @property + def float32(self) -> tf.dtypes.DType: + return tf.float32 + + @property + def float64(self) -> tf.dtypes.DType: + return tf.float64 + + @property + def complex64(self) -> tf.dtypes.DType: + return tf.complex64 + + @property + def complex128(self) -> tf.dtypes.DType: + return tf.complex128 + + @property + def qint8(self) -> tf.dtypes.DType: + return tf.qint8 + + @property + def qint16(self) -> tf.dtypes.DType: + return tf.qint16 + + @property + def qint32(self) -> tf.dtypes.DType: + return tf.qint32 + + @property + def quint8(self) -> tf.dtypes.DType: + return tf.quint8 + + @property + def quint16(self) -> tf.dtypes.DType: + return tf.quint16 + + @property + def bfloat16(self) -> tf.dtypes.DType: + return tf.bfloat16 + + # Constants + + @property + def e(self) -> Any: # noqa: ANN401 + """e = 2.71828...""" # noqa: D403 + return tf.experimental.numpy.e + + @property + def inf(self) -> Any: # noqa: ANN401 + """Infinity.""" + return tf.experimental.numpy.inf + + @property + def nan(self) -> Any: # noqa: ANN401 + """Not-a-number.""" + return tf.experimental.numpy.nan + + @property + def pi(self) -> Any: # noqa: ANN401 + """pi = 3.14159...""" # noqa: D403 + return tf.experimental.numpy.pi + -register_backend(SupportedFrameworks.TENSORFLOW, TensorflowBackend) +register_backend(Frameworks.TENSORFLOW, TensorflowBackend) diff --git a/decent_array/types/__init__.py b/decent_array/types/__init__.py new file mode 100644 index 0000000..198067c --- /dev/null +++ b/decent_array/types/__init__.py @@ -0,0 +1,18 @@ +from decent_array.types._dtypes import dtype, dtypes +from decent_array.types._types import ( + ArrayKey, + ArrayLike, + ArrayTypes, + Devices, + Frameworks, +) + +__all__ = [ + "ArrayKey", + "ArrayLike", + "ArrayTypes", + "Devices", + "Frameworks", + "dtype", + "dtypes", +] diff --git a/decent_array/types/_dtypes.py b/decent_array/types/_dtypes.py new file mode 100644 index 0000000..7feccd2 --- /dev/null +++ b/decent_array/types/_dtypes.py @@ -0,0 +1,178 @@ +"""Data type definitions.""" + +from __future__ import annotations + +from typing import Any + +from decent_array.interoperability import _backend_manager + +_SUPPORTED = { + "bfloat16", + "bool_", + "bytes_", + "complex64", + "complex128", + "complex256", + "float16", + "float32", + "float64", + "float128", + "int8", + "int16", + "int32", + "int64", + "object_", + "qint8", + "qint16", + "qint32", + "quint8", + "quint16", + "uint8", + "uint16", + "uint32", + "uint64", + "unicode_", + "void", +} + + +class dtype: # noqa: N801 + """Base class for dtypes.""" + + def __init__(self, name: str): + # name doesn't map to any dtype + if name not in _SUPPORTED: + raise ValueError(f"dtype {name} is not supported. Supported dtypes: {', '.join(_SUPPORTED)}") + + # initialize with backend dtype; if backend is not initialized, it sets placeholder values and set_backend + # will then bind the backend dtypes + backend_instance = getattr(_backend_manager, "_BACKEND_INSTANCE", None) + self._name = name + self._backend_dtype: Any = None if backend_instance is None else getattr(backend_instance, name, None) + self._available = self._backend_dtype is not None + + @property + def name(self) -> str: + """Name of the dtype.""" + return self._name + + @property + def available(self) -> bool: + """Availability of the dtype (dependent on backend, device, backend settings, OS).""" + return self._available + + @property + def backend_dtype(self) -> Any: # noqa: ANN401 + """The corresponding backend dtype object.""" + return self._backend_dtype + + def __str__(self) -> str: + """Name of the dtype.""" + return self.name + + def __eq__(self, other: object) -> bool: + """Check equivalence by ``name`` attributes.""" + if not isinstance(other, dtype): + return NotImplemented + return self.name == other.name and self.available and other.available + + def __hash__(self) -> int: + """Hash of the dtype.""" + return hash(self.name) + + +# instantiate all the supported dtypes +bool_ = dtype("bool_") +_BOOL_DTYPES = {"bool_": bool_} + +int8 = dtype("int8") +int16 = dtype("int16") +int32 = dtype("int32") +int64 = dtype("int64") +_SIGNED_INT_DTYPES = {"int8": int8, "int16": int16, "int32": int32, "int64": int64} + +uint8 = dtype("uint8") +uint16 = dtype("uint16") +uint32 = dtype("uint32") +uint64 = dtype("uint64") +_UNSIGNED_INT_DTYPES = {"uint8": uint8, "uint16": uint16, "uint32": uint32, "uint64": uint64} + +float16 = dtype("float16") +bfloat16 = dtype("bfloat16") +float32 = dtype("float32") +float64 = dtype("float64") +float128 = dtype("float128") +_REAL_FLOATING_DTYPES = { + "float16": float16, + "bfloat16": bfloat16, + "float32": float32, + "float64": float64, + "float128": float128, +} + +complex64 = dtype("complex64") +complex128 = dtype("complex128") +complex256 = dtype("complex256") +_COMPLEX_FLOATING_DTYPES = {"complex64": complex64, "complex128": complex128, "complex256": complex256} + +qint8 = dtype("qint8") +qint16 = dtype("qint16") +qint32 = dtype("qint32") +_QUANTIZED_SIGNED_INT_DTYPES = {"qint8": qint8, "qint16": qint16, "qint32": qint32} + +quint8 = dtype("quint8") +quint16 = dtype("quint16") +_QUANTIZED_UNSIGNED_INT_DTYPES = {"quint8": quint8, "quint16": quint16} + +unicode_ = dtype("unicode_") +bytes_ = dtype("bytes_") +object_ = dtype("object_") +void = dtype("void") +_MISCELLANEOUS_DTYPES = {"unicode_": unicode_, "bytes_": bytes_, "object_": object_, "void": void} + +_INTEGRAL_DTYPES = ( + _SIGNED_INT_DTYPES | _UNSIGNED_INT_DTYPES | _QUANTIZED_SIGNED_INT_DTYPES | _QUANTIZED_UNSIGNED_INT_DTYPES +) +_NUMERIC_DTYPES = _INTEGRAL_DTYPES | _REAL_FLOATING_DTYPES | _COMPLEX_FLOATING_DTYPES +_ALL_DTYPES = _BOOL_DTYPES | _NUMERIC_DTYPES | _MISCELLANEOUS_DTYPES + + +_BACKEND_DTYPE_TO_DTYPE: dict[Any, dtype] = {} +for dt in _ALL_DTYPES.values(): + if dt.available: + _BACKEND_DTYPE_TO_DTYPE[dt._backend_dtype] = dt # noqa: SLF001 + + +_ALIASES = { + "bool": _BOOL_DTYPES, + "signed integer": _SIGNED_INT_DTYPES, + "unsigned integer": _UNSIGNED_INT_DTYPES, + "integral": _INTEGRAL_DTYPES, + "real floating": _REAL_FLOATING_DTYPES, + "complex floating": _COMPLEX_FLOATING_DTYPES, + "numeric": _NUMERIC_DTYPES, +} + + +def dtypes(*, kind: str | tuple[str, ...] | None = None) -> dict[str, dtype]: + """ + Return a dictionary of available dtypes. + + Args: + kind: kind of dtypes to be returned, either one string or a tuple of strings; available kinds are: `bool`, + `signed integer`, `unsigned integer`, `integral`, `real floating`, `complex floating`, `numeric`. If kind is + None, all available dtypes are included. If `kind` does not match any of the supported strings, an empty + dictionary is returned. + + Returns: + A dictionary of available dtypes, keyed by name, and filtered by `kind`. + + """ + if kind is None: + dtypes = _ALL_DTYPES + elif isinstance(kind, str): + dtypes = _ALIASES.get(kind, {}) + else: + dtypes = {k: v for alias in kind if alias in _ALIASES for k, v in _ALIASES[alias].items()} + + return {name: dt for name, dt in dtypes.items() if dt.available} diff --git a/decent_array/types.py b/decent_array/types/_types.py similarity index 67% rename from decent_array/types.py rename to decent_array/types/_types.py index b90fba0..ff5a081 100644 --- a/decent_array/types.py +++ b/decent_array/types/_types.py @@ -1,4 +1,4 @@ -"""Type definitions for optimization variables.""" +"""Type definitions.""" from __future__ import annotations @@ -13,13 +13,14 @@ from decent_array._array import Array + ArrayLike: TypeAlias = Union["numpy.ndarray", "torch.Tensor", "tf.Tensor", "jax.Array"] # noqa: UP040 """ Type alias for array-like types supported in decent-array, including NumPy arrays, PyTorch tensors, TensorFlow tensors, and JAX arrays. """ -SupportedArrayTypes: TypeAlias = bool | int | float | complex | ArrayLike # noqa: UP040 +ArrayTypes: TypeAlias = bool | int | float | complex | ArrayLike # noqa: UP040 """ Type alias for supported types for optimization variables in decent-array, including array-like types and scalars. @@ -36,7 +37,7 @@ # Its important that the enum values correspond to the folder names of the backends, # since those are used for dynamic imports in _backend_manager.py -class SupportedFrameworks(Enum): +class Frameworks(Enum): """Enum for supported frameworks in decent-array.""" NUMPY = "numpy" @@ -45,31 +46,9 @@ class SupportedFrameworks(Enum): JAX = "jax" -class SupportedDevices(Enum): +class Devices(Enum): """Enum for supported devices in decent-array.""" CPU = "cpu" GPU = "gpu" MPS = "mps" - - -class DTypes(Enum): - """Enum for supported dtypes in decent-array.""" - - BOOL = "bool" - UINT8 = "uint8" - UINT16 = "uint16" - UINT32 = "uint32" - UINT64 = "uint64" - INT8 = "int8" - INT16 = "int16" - INT32 = "int32" - INT64 = "int64" - FLOAT16 = "float16" - FLOAT32 = "float32" - FLOAT64 = "float64" - COMPLEX64 = "complex64" - COMPLEX128 = "complex128" - - -_STRING_TO_DTYPE = {dt.value: dt for dt in DTypes} diff --git a/docs/source/api/decent_array.types.rst b/docs/source/api/decent_array.types.rst index 6a29483..09cdf13 100644 --- a/docs/source/api/decent_array.types.rst +++ b/docs/source/api/decent_array.types.rst @@ -4,4 +4,7 @@ decent\_array.types .. automodule:: decent_array.types :members: :show-inheritance: - :undoc-members: \ No newline at end of file + :undoc-members: + +.. autodata:: decent_array.types.ArrayTypes + :annotation: \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 4c33da5..4d475a0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,6 +28,10 @@ "sphinx.ext.viewcode", # View source code ] +autodoc_type_aliases = { + "ArrayTypes": "decent_array.types.ArrayTypes", +} + nitpicky = True nitpick_ignore = [ ("py:class", "numpy.float64"), diff --git a/docs/source/user.rst b/docs/source/user.rst index 9c08874..235744e 100644 --- a/docs/source/user.rst +++ b/docs/source/user.rst @@ -8,3 +8,286 @@ Requires `Python 3.13+ `_ .. code-block:: bash pip install decent-array + + +Constants +--------- + +decent-array exposes `e`, `inf`, `nan`, `pi` constants (`from decent_array import e`), which are bound to the +corresponding framework-native constants. If the constants are not bound, they fall back on the `math` constants. + + +dtypes +------ + +decent-array exposes a number of dtypes which are bound to the corresponding framework-native dtypes. +decent-array exposes dtypes as instances of :class:`~decent_array.types.dtype`, which can be accessed as +`from decent_array import float32`. Additionally, dtypes can be accessed by `dtype("float32")`. + +dtypes have attributes: :attr:`~decent_array.types.dtype.name`, :attr:`~decent_array.types.dtype.available` (see +discussion below), :attr:`~decent_array.types.dtype.backend_dtype` (which binds the dtype to the corresponding +framework-native dtype). + +The following lists the dtypes exposed by decent-array. + +- Booleans + - `bool_` + +- Unsigned integers + - `uint8` + - `uint16` + - `uint32` + - `uint64` + +- Signed integers + - `int8` + - `int16` + - `int32` + - `int64` + +- Floating point + - `float16` + - `bfloat16` + - `float32` + - `float64` + - `float128` + +- Complex + - `complex64` + - `complex128` + - `complex256` + +- Quantized integers + - `qint8` + - `quint8` + - `qint16` + - `quint16` + - `qint32` + +- Miscellaneous + - `unicode_` + - `bytes_` + - `object_` + - `void` + + + +Availability +~~~~~~~~~~~~ + +Not all dtypes are available in all configurations. The following factors affect dtype availability: +framework, device, OS, framework settings; additionally, some framework-native operations only support a subset of +dtypes. + +There is no reliable way to check which dtypes are available in the current setting, and this is also subject to change +as frameworks develop. In decent-array, we make a best effort to determine whether a dtype is available and, if not, +set :attr:`~decent_array.types.dtype.available`. to `False`. Currently, we mark dtypes as un/available based on the +table below. Additionally, dtypes are marked as unavailable (and :attr:`~decent_array.types.dtype.backend_dtype` is +`None`) if the backend has not been initialized via `set_backend`. + +However, :attr:`~decent_array.types.dtype.available` is generally not a reliable indicator of availability. The most +reliable way is to try operations that involve the dtype and observe if the framework raises an error. + + +.. list-table:: dtype support across frameworks + :header-rows: 1 + :widths: 22 10 10 10 12 36 + + * - dtype + - NumPy + - JAX + - PyTorch + - TensorFlow + - Notes + * - ``bool_`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - **Integers** + - + - + - + - + - + * - ``int8`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``int16`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``int32`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``int64`` + - ✓ + - ⚠️ + - ✓ + - ✓ + - JAX requires ``jax_enable_x64=True`` + * - **Unsigned integers** + - + - + - + - + - + * - ``uint8`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``uint16`` + - ✓ + - ⚠️ + - ⚠️ + - ✓ + - JAX requires ``jax_enable_x64=True``; PyTorch support is limited/experimental + * - ``uint32`` + - ✓ + - ⚠️ + - ⚠️ + - ✓ + - JAX requires ``jax_enable_x64=True``; PyTorch support is limited/experimental + * - ``uint64`` + - ✓ + - ⚠️ + - ⚠️ + - ✓ + - JAX requires ``jax_enable_x64=True``; PyTorch support is limited/experimental + * - **Floating point** + - + - + - + - + - + * - ``float16`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``bfloat16`` + - ✗ + - ✓ + - ✓ + - ✓ + - + * - ``float32`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``float64`` + - ✓ + - ⚠️ + - ✓ + - ✓ + - JAX requires ``jax_enable_x64=True`` + * - ``float128`` + - ⚠️ + - ✗ + - ✗ + - ✗ + - NumPy support is platform-dependent + * - **Complex** + - + - + - + - + - + * - ``complex64`` + - ✓ + - ✓ + - ✓ + - ✓ + - + * - ``complex128`` + - ✓ + - ⚠️ + - ✓ + - ✓ + - JAX requires ``jax_enable_x64=True`` + * - ``complex256`` + - ⚠️ + - ✗ + - ✗ + - ✗ + - NumPy support is platform-dependent + * - **Quantized** + - + - + - + - + - + * - ``qint8`` + - ✗ + - ✗ + - ✓ + - ✓ + - + * - ``quint8`` + - ✗ + - ✗ + - ✓ + - ✓ + - + * - ``qint16`` + - ✗ + - ✗ + - ✗ + - ✓ + - + * - ``quint16`` + - ✗ + - ✗ + - ✗ + - ✓ + - + * - ``qint32`` + - ✗ + - ✗ + - ✓ + - ✓ + - + * - **Miscellaneous** + - + - + - + - + - + * - ``unicode_`` + - ✓ + - ✗ + - ✗ + - ✗ + - Equivalent to ``np.str_`` + * - ``bytes_`` + - ✓ + - ✗ + - ✗ + - ✓ + - Equivalent to ``np.bytes_`` and ``tf.string`` + * - ``object_`` + - ✓ + - ✗ + - ✗ + - ✗ + - + * - ``void`` + - ✓ + - ✗ + - ✗ + - ✗ + - \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 398c054..41daa86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -194,5 +194,5 @@ lint.ignore = [ "ICN001", # unconventional-import-alias, complains about common import aliases like `import numpy as np` but it doesn't work for most libraries "PYI041", # Removes warning for redundant types in cases like int | float | complex where complex already covers int and float, but we want to keep the union for readability and explicitness about supported types. ] -preview = true +preview = false line-length = 120 diff --git a/tests/conftest.py b/tests/conftest.py index 1ca38bf..8e83362 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Shared fixtures: parametrize tests across every (framework, device) combination. Each test using the ``backend`` fixture runs once per (framework, device) pair from -:class:`SupportedFrameworks` x :class:`SupportedDevices`. Combinations whose backend +:class:`Frameworks` x :class:`Devices`. Combinations whose backend package is missing or whose device is not present on the current host are marked ``skip`` so the test report stays interpretable on machines with partial accelerator support. @@ -15,50 +15,50 @@ import pytest from decent_array.interoperability._backend_manager import reset_backends -from decent_array.types import SupportedDevices, SupportedFrameworks +from decent_array.types import Devices, Frameworks if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest -def _framework_importable(framework: SupportedFrameworks) -> bool: +def _framework_importable(framework: Frameworks) -> bool: try: - if framework == SupportedFrameworks.NUMPY: + if framework == Frameworks.NUMPY: import numpy # noqa: F401, PLC0415 - elif framework == SupportedFrameworks.PYTORCH: + elif framework == Frameworks.PYTORCH: import torch # noqa: F401, PLC0415 - elif framework == SupportedFrameworks.JAX: + elif framework == Frameworks.JAX: import jax # noqa: F401, PLC0415 - elif framework == SupportedFrameworks.TENSORFLOW: + elif framework == Frameworks.TENSORFLOW: import tensorflow # noqa: F401, PLC0415 except ImportError: return False return True -def _device_available(framework: SupportedFrameworks, device: SupportedDevices) -> bool: +def _device_available(framework: Frameworks, device: Devices) -> bool: """Return True iff this (framework, device) pair can run on the current host.""" if not _framework_importable(framework): return False - if framework == SupportedFrameworks.NUMPY: - return device == SupportedDevices.CPU - if framework == SupportedFrameworks.PYTORCH: + if framework == Frameworks.NUMPY: + return device == Devices.CPU + if framework == Frameworks.PYTORCH: import torch # noqa: PLC0415 - if device == SupportedDevices.CPU: + if device == Devices.CPU: return True - if device == SupportedDevices.GPU: + if device == Devices.GPU: try: return bool(torch.cuda.is_available()) except Exception: return False - if device == SupportedDevices.MPS: + if device == Devices.MPS: try: return bool(torch.backends.mps.is_available()) except Exception: return False - if framework == SupportedFrameworks.JAX: - if device == SupportedDevices.MPS: + if framework == Frameworks.JAX: + if device == Devices.MPS: return False import jax # noqa: PLC0415 @@ -67,14 +67,14 @@ def _device_available(framework: SupportedFrameworks, device: SupportedDevices) except Exception: return False return True - if framework == SupportedFrameworks.TENSORFLOW: - if device == SupportedDevices.MPS: + if framework == Frameworks.TENSORFLOW: + if device == Devices.MPS: return False import tensorflow as tf # noqa: PLC0415 - if device == SupportedDevices.CPU: + if device == Devices.CPU: return True - if device == SupportedDevices.GPU: + if device == Devices.GPU: try: return len(tf.config.list_physical_devices("GPU")) > 0 except Exception: @@ -84,8 +84,8 @@ def _device_available(framework: SupportedFrameworks, device: SupportedDevices) def _backend_params() -> list[pytest.ParameterSet]: params: list[pytest.ParameterSet] = [] - for framework in SupportedFrameworks: - for device in SupportedDevices: + for framework in Frameworks: + for device in Devices: test_id = f"{framework.value}-{device.value}" if _device_available(framework, device): params.append(pytest.param((framework, device), id=test_id)) @@ -104,7 +104,7 @@ def _backend_params() -> list[pytest.ParameterSet]: @pytest.fixture(params=BACKEND_PARAMS) -def backend(request: FixtureRequest) -> Iterator[tuple[SupportedFrameworks, SupportedDevices]]: +def backend(request: FixtureRequest) -> Iterator[tuple[Frameworks, Devices]]: """Activate the (framework, device) backend for this test, then reset on teardown.""" from decent_array.interoperability import set_backend # noqa: PLC0415 diff --git a/tests/test_backend_manager.py b/tests/test_backend_manager.py index a0eaf4c..8e890c6 100644 --- a/tests/test_backend_manager.py +++ b/tests/test_backend_manager.py @@ -4,8 +4,10 @@ from typing import TYPE_CHECKING +import numpy as np import pytest +import decent_array._constants as constants from decent_array.interoperability import _backend_manager as backend_manager from decent_array.interoperability._abstracts import Backend from decent_array.interoperability._backend_manager import ( @@ -17,7 +19,8 @@ reset_backends, set_backend, ) -from decent_array.types import SupportedDevices, SupportedFrameworks +from decent_array.types import Devices, Frameworks +from decent_array.types._dtypes import dtype, float32 if TYPE_CHECKING: from collections.abc import Iterator @@ -40,11 +43,11 @@ def _isolate_listeners_and_backends() -> Iterator[None]: def test_normalize_accepts_enum() -> None: - assert _normalize(SupportedFrameworks.NUMPY) == SupportedFrameworks.NUMPY + assert _normalize(Frameworks.NUMPY) == Frameworks.NUMPY def test_normalize_accepts_string() -> None: - assert _normalize("numpy") == SupportedFrameworks.NUMPY + assert _normalize("numpy") == Frameworks.NUMPY def test_normalize_unknown_raises() -> None: @@ -57,19 +60,19 @@ def test_normalize_unknown_raises() -> None: def test_set_backend_with_string() -> None: set_backend("numpy") - assert backend_manager._ACTIVE_BACKEND.get() == SupportedFrameworks.NUMPY + assert backend_manager._ACTIVE_BACKEND.get() == Frameworks.NUMPY def test_set_backend_with_enum() -> None: - set_backend(SupportedFrameworks.NUMPY) - assert backend_manager._ACTIVE_BACKEND.get() == SupportedFrameworks.NUMPY + set_backend(Frameworks.NUMPY) + assert backend_manager._ACTIVE_BACKEND.get() == Frameworks.NUMPY def test_set_backend_idempotent_same_backend() -> None: set_backend("numpy") # Re-activating with the same backend+device must be a no-op (no exception). set_backend("numpy") - assert backend_manager._ACTIVE_BACKEND.get() == SupportedFrameworks.NUMPY + assert backend_manager._ACTIVE_BACKEND.get() == Frameworks.NUMPY def test_set_backend_different_backend_raises() -> None: @@ -80,8 +83,8 @@ def test_set_backend_different_backend_raises() -> None: def test_set_backend_with_string_device() -> None: set_backend("numpy", "cpu") - instance = _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) - assert instance.device == SupportedDevices.CPU + instance = _instantiate(Frameworks.NUMPY, Devices.CPU) + assert instance.device == Devices.CPU def test_set_backend_invalid_name_raises() -> None: @@ -89,6 +92,22 @@ def test_set_backend_invalid_name_raises() -> None: set_backend("not-a-backend") +def test_set_backend_instantiates_dtypes() -> None: + dt1 = dtype("float32") # backend not set, dtype has placeholder values + set_backend("numpy", "cpu") + dt2 = float32 # global dtype bound during set_dtype + dt3 = dtype("float32") # backend is set, so this is bound to backend dtype + + assert dt1 != dt2 # dtypes are not equal if not avaiable, and dt1 is not available + assert dt2 == dt3 # available because bound to backend, and equal + + dt4 = dtype("int16") + assert dt3 != dt4 + + assert dt2.backend_dtype == np.dtype("float32") # check global dtype is correctly bound to backend + assert dt3.backend_dtype == np.dtype("float32") # check dtypes instantiated after set_backend are correctly bound to backend + + # register_backend ------------------------------------------------------- @@ -97,21 +116,21 @@ class NotABackend: pass with pytest.raises(TypeError, match=r"subclass of Backend"): - register_backend(SupportedFrameworks.NUMPY, NotABackend) # type: ignore[arg-type] + register_backend(Frameworks.NUMPY, NotABackend) # type: ignore[arg-type] def test_register_backend_replaces_cached_instance() -> None: # First import registers the real backend; instantiate to populate cache. set_backend("numpy") - cached = backend_manager._BACKEND_INSTANCES.get(SupportedFrameworks.NUMPY) + cached = backend_manager._BACKEND_INSTANCES.get(Frameworks.NUMPY) assert cached is not None # Re-register the same class — cache should be cleared so next instantiate is fresh. from decent_array.interoperability._numpy.numpy_backend import NumpyBackend # noqa: PLC0415 reset_backends() - register_backend(SupportedFrameworks.NUMPY, NumpyBackend) - assert SupportedFrameworks.NUMPY not in backend_manager._BACKEND_INSTANCES + register_backend(Frameworks.NUMPY, NumpyBackend) + assert Frameworks.NUMPY not in backend_manager._BACKEND_INSTANCES # register_backend_listener --------------------------------------------- @@ -167,22 +186,22 @@ def test_reset_backends_clears_active() -> None: def test_reset_backends_clears_instance_cache() -> None: set_backend("numpy") - assert SupportedFrameworks.NUMPY in backend_manager._BACKEND_INSTANCES + assert Frameworks.NUMPY in backend_manager._BACKEND_INSTANCES reset_backends() - assert SupportedFrameworks.NUMPY not in backend_manager._BACKEND_INSTANCES + assert Frameworks.NUMPY not in backend_manager._BACKEND_INSTANCES # _instantiate ---------------------------------------------------------- def test_instantiate_caches_instance() -> None: - a = _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) - b = _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) + a = _instantiate(Frameworks.NUMPY, Devices.CPU) + b = _instantiate(Frameworks.NUMPY, Devices.CPU) assert a is b def test_set_backend_device_mismatch_raises() -> None: - set_backend("numpy", SupportedDevices.CPU) + set_backend("numpy", Devices.CPU) # NumPy backend rejects non-CPU devices at construction; check behavior via the # configured-mismatch path: re-set with a different device after first activation. with pytest.raises((RuntimeError, ValueError)): @@ -194,8 +213,8 @@ def test_set_backend_device_mismatch_raises() -> None: def test_default_device_returns_active_device() -> None: - set_backend("numpy", SupportedDevices.CPU) - assert default_device() == SupportedDevices.CPU + set_backend("numpy", Devices.CPU) + assert default_device() == Devices.CPU def test_default_device_raises_when_no_backend() -> None: diff --git a/tests/test_iop_functions.py b/tests/test_iop_functions.py index e4c46ee..c115504 100644 --- a/tests/test_iop_functions.py +++ b/tests/test_iop_functions.py @@ -6,11 +6,11 @@ import pytest import decent_array.interoperability as iop +import decent_array as da from decent_array import Array from decent_array.interoperability._backend_manager import reset_backends from decent_array.interoperability._iop.math import iadd, idivide, imultiply, isubtract from decent_array.interoperability._iop.utils import device_to_native, get_item, set_item -from decent_array.types import DTypes def _np(arr: Array) -> np.ndarray: @@ -219,7 +219,7 @@ def test_diagonal_with_offset(backend: tuple) -> None: def test_astype_to_float(backend: tuple) -> None: arr = iop.asarray(3.0) - out = iop.astype(arr, DTypes.FLOAT32) + out = iop.astype(arr, da.float32) np_out = _np(out) assert np_out.dtype == np.float32 assert np_out == pytest.approx(3.0) @@ -227,7 +227,7 @@ def test_astype_to_float(backend: tuple) -> None: def test_astype_to_int(backend: tuple) -> None: arr = iop.asarray(3.0) - out = iop.astype(arr, DTypes.INT32) + out = iop.astype(arr, da.int32) np_out = _np(out) assert np_out.dtype == np.int32 assert int(np_out) == 3 @@ -235,7 +235,7 @@ def test_astype_to_int(backend: tuple) -> None: def test_astype_to_bool(backend: tuple) -> None: arr = iop.asarray(1.0) - out = iop.astype(arr, DTypes.BOOL) + out = iop.astype(arr, da.bool_) np_out = _np(out) assert np_out.dtype == np.bool_ assert bool(np_out) is True @@ -697,7 +697,7 @@ def test_function_raises_when_no_backend() -> None: def test_to_array_round_trip_with_bool(backend: tuple) -> None: arr = iop.asarray(True) - out = iop.astype(arr, DTypes.BOOL) + out = iop.astype(arr, da.bool_) np_out = _np(out) assert np_out.dtype == np.bool_ assert bool(np_out) is True