forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_delegation.py
More file actions
319 lines (273 loc) · 10.9 KB
/
_delegation.py
File metadata and controls
319 lines (273 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""Delegation to existing implementations for Public API Functions."""
from collections.abc import Sequence
from types import ModuleType
from typing import Literal
from ._lib import _funcs
from ._lib._utils._compat import (
array_namespace,
is_cupy_namespace,
is_dask_namespace,
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._lib._utils._compat import device as get_device
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType
__all__ = ["isclose", "one_hot", "pad", "quantile"]
def isclose(
a: Array | complex,
b: Array | complex,
*,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
xp: ModuleType | None = None,
) -> Array:
"""
Return a boolean array where two arrays are element-wise equal within a tolerance.
The tolerance values are positive, typically very small numbers. The relative
difference ``(rtol * abs(b))`` and the absolute difference `atol` are added together
to compare against the absolute difference between `a` and `b`.
NaNs are treated as equal if they are in the same place and if ``equal_nan=True``.
Infs are treated as equal if they are in the same place and of the same sign in both
arrays.
Parameters
----------
a, b : Array | int | float | complex | bool
Input objects to compare. At least one must be an array.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
The absolute tolerance parameter (see Notes).
equal_nan : bool, optional
Whether to compare NaN's as equal. If True, NaN's in `a` will be considered
equal to NaN's in `b` in the output array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.
Returns
-------
Array
A boolean array of shape broadcasted from `a` and `b`, containing ``True`` where
`a` is close to `b`, and ``False`` otherwise.
Warnings
--------
The default `atol` is not appropriate for comparing numbers with magnitudes much
smaller than one (see notes).
See Also
--------
math.isclose : Similar function in stdlib for Python scalars.
Notes
-----
For finite values, `isclose` uses the following equation to test whether two
floating point values are equivalent::
absolute(a - b) <= (atol + rtol * absolute(b))
Unlike the built-in `math.isclose`,
the above equation is not symmetric in `a` and `b`,
so that ``isclose(a, b)`` might be different from ``isclose(b, a)`` in some rare
cases.
The default value of `atol` is not appropriate when the reference value `b` has
magnitude smaller than one. For example, it is unlikely that ``a = 1e-9`` and
``b = 2e-9`` should be considered "close", yet ``isclose(1e-9, 2e-9)`` is ``True``
with default settings. Be sure to select `atol` for the use case at hand, especially
for defining the threshold below which a non-zero value in `a` will be considered
"close" to a very small or zero value in `b`.
The comparison of `a` and `b` uses standard broadcasting, which means that `a` and
`b` need not have the same shape in order for ``isclose(a, b)`` to evaluate to
``True``.
`isclose` is not defined for non-numeric data types.
``bool`` is considered a numeric data-type for this purpose.
"""
xp = array_namespace(a, b) if xp is None else xp
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
if is_torch_namespace(xp):
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
def one_hot(
x: Array,
/,
num_classes: int,
*,
dtype: DType | None = None,
axis: int = -1,
xp: ModuleType | None = None,
) -> Array:
"""
One-hot encode the given indices.
Each index in the input `x` is encoded as a vector of zeros of length `num_classes`
with the element at the given index set to one.
Parameters
----------
x : array
An array with integral dtype whose values are between `0` and `num_classes - 1`.
num_classes : int
Number of classes in the one-hot dimension.
dtype : DType, optional
The dtype of the return value. Defaults to the default float dtype (usually
float64).
axis : int, optional
Position in the expanded axes where the new axis is placed. Default: -1.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
An array having the same shape as `x` except for a new axis at the position
given by `axis` having size `num_classes`. If `axis` is unspecified, it
defaults to -1, which appends a new axis.
If ``x < 0`` or ``x >= num_classes``, then the result is undefined, may raise
an exception, or may even cause a bad state. `x` is not checked.
Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.one_hot(xp.asarray([1, 2, 0]), 3)
Array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]], dtype=array_api_strict.float64)
"""
# Validate inputs.
if xp is None:
xp = array_namespace(x)
if not xp.isdtype(x.dtype, "integral"):
msg = "x must have an integral dtype."
raise TypeError(msg)
if dtype is None:
dtype = _funcs.default_dtype(xp, device=get_device(x))
# Delegate where possible.
if is_jax_namespace(xp):
from jax.nn import one_hot as jax_one_hot
return jax_one_hot(x, num_classes, dtype=dtype, axis=axis)
if is_torch_namespace(xp):
from torch.nn.functional import one_hot as torch_one_hot
x = xp.astype(x, xp.int64) # PyTorch only supports int64 here.
try:
out = torch_one_hot(x, num_classes)
except RuntimeError as e:
raise IndexError from e
else:
out = _funcs.one_hot(x, num_classes, xp=xp)
out = xp.astype(out, dtype, copy=False)
if axis != -1:
out = xp.moveaxis(out, -1, axis)
return out
def pad(
x: Array,
pad_width: int | tuple[int, int] | Sequence[tuple[int, int]],
mode: Literal["constant"] = "constant",
*,
constant_values: complex = 0,
xp: ModuleType | None = None,
) -> Array:
"""
Pad the input array.
Parameters
----------
x : array
Input array.
pad_width : int or tuple of ints or sequence of pairs of ints
Pad the input array with this many elements from each side.
If a sequence of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``,
each pair applies to the corresponding axis of ``x``.
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim``
copies of this tuple.
mode : str, optional
Only "constant" mode is currently supported, which pads with
the value passed to `constant_values`.
constant_values : python scalar, optional
Use this value to pad the input. Default is zero.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
The input array,
padded with ``pad_width`` elements equal to ``constant_values``.
"""
xp = array_namespace(x) if xp is None else xp
if mode != "constant":
msg = "Only `'constant'` mode is currently supported"
raise NotImplementedError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_pydata_sparse_namespace(xp)
):
return xp.pad(x, pad_width, mode, constant_values=constant_values)
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
if is_torch_namespace(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
def quantile(
x: Array,
q: Array | float,
/,
*,
axis: int | None = None,
keepdims: bool = False,
method: str = "linear",
xp: ModuleType | None = None,
) -> Array:
"""
Compute the q-th quantile(s) of the data along the specified axis.
Parameters
----------
x : array of real numbers
Data array.
q : array of float
Probability or sequence of probabilities of the quantiles to compute.
Values must be between 0 and 1 (inclusive). Must have length 1 along
`axis` unless ``keepdims=True``.
axis : int or None, default: None
Axis along which the quantiles are computed. ``None`` ravels both `x`
and `q` before performing the calculation.
keepdims : bool, optional
If this is set to True, the axes which are reduced are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the original array `x`.
method : str, default: 'linear'
The method to use for estimating the quantile. The available options are:
'inverted_cdf', 'averaged_inverted_cdf', 'closest_observation',
'interpolated_inverted_cdf', 'hazen', 'weibull', 'linear' (default),
'median_unbiased', 'normal_unbiased', 'harrell-davis'.
xp : array_namespace, optional
The standard-compatible namespace for `x` and `q`. Default: infer.
Returns
-------
array
An array with the quantiles of the data.
Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([[10, 8, 7, 5, 4], [0, 1, 2, 3, 5]])
>>> xpx.quantile(x, 0.5, axis=-1)
Array([7., 2.], dtype=array_api_strict.float64)
>>> xpx.quantile(x, [0.25, 0.75], axis=-1)
Array([[5., 8.],
[1., 3.]], dtype=array_api_strict.float64)
"""
xp = array_namespace(x, q) if xp is None else xp
try:
import scipy
from packaging import version
# The quantile function in scipy 1.16 supports array API directly, no need
# to delegate
if version.parse(scipy.__version__) >= version.parse("1.16"):
from scipy.stats import quantile as scipy_quantile
return scipy_quantile(x, p=q, axis=axis, keepdims=keepdims, method=method)
except (ImportError, AttributeError):
pass
return _funcs.quantile(x, q, axis=axis, keepdims=keepdims, method=method, xp=xp)