-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_numpy1p0.pyi
More file actions
81 lines (60 loc) · 2.46 KB
/
test_numpy1p0.pyi
File metadata and controls
81 lines (60 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# mypy: disable-error-code="no-redef"
from types import ModuleType
from typing import Any, assert_type
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
from numpy import dtype
import array_api_typing as xpt
# Define NDArrays against which we can test the protocols
# Note that `np.array_api` doesn't support boolean arrays.
nparr = np.eye(2)
nparr_i32 = np.asarray([1], dtype=np.int32)
nparr_f32 = np.asarray([1.0], dtype=np.float32)
# =========================================================
# `xpt.HasArrayNamespace`
_: xpt.HasArrayNamespace[ModuleType] = nparr
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()
# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
# =========================================================
# `xpt.HasDType`
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
# type annotate specific dtypes like `np.float32` or `np.int32`.
_: xpt.HasDType[dtype[Any]] = nparr
_: xpt.HasDType[dtype[Any]] = nparr_i32
_: xpt.HasDType[dtype[Any]] = nparr_f32
# =========================================================
# `xpt.Array`
# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, Any, ModuleType] = nparr
# Check DTypeT_co assignment
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
# type annotate specific dtypes like `np.float32` or `np.int32`.
_: xpt.Array[dtype[Any]] = nparr
x_f32: xpt.Array[dtype[Any]] = nparr_f32
x_i32: xpt.Array[dtype[Any]] = nparr_i32
# Check Attribute `.dtype`
assert_type(x_f32.dtype, dtype[Any])
assert_type(x_i32.dtype, dtype[Any])
# Check Attribute `.device`
assert_type(x_f32.device, object)
assert_type(x_i32.device, object)
# Check Attribute `.mT`
assert_type(x_f32.mT, xpt.Array[dtype[Any]])
assert_type(x_i32.mT, xpt.Array[dtype[Any]])
# Check Attribute `.ndim`
assert_type(x_f32.ndim, int)
assert_type(x_i32.ndim, int)
# Check Attribute `.shape`
assert_type(x_f32.shape, tuple[int | None, ...])
assert_type(x_i32.shape, tuple[int | None, ...])
# Check Attribute `.size`
assert_type(x_f32.size, int | None)
assert_type(x_i32.size, int | None)
# Check Attribute `.T`
assert_type(x_f32.T, xpt.Array[dtype[Any]])
assert_type(x_i32.T, xpt.Array[dtype[Any]])