Skip to content

Commit 87ff9c7

Browse files
committed
Add initial array_api framework
1 parent 7f7f995 commit 87ff9c7

File tree

5 files changed

+381
-0
lines changed

5 files changed

+381
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
__array_api_version__ = '2022.12'
16+
17+
from jax.experimental.array_api._constants import (
18+
e as e,
19+
inf as inf,
20+
nan as nan,
21+
newaxis as newaxis,
22+
pi as pi,
23+
)
24+
25+
from ._creation_functions import (
26+
arange as arange,
27+
asarray as asarray,
28+
empty as empty,
29+
empty_like as empty_like,
30+
eye as eye,
31+
from_dlpack as from_dlpack,
32+
full as full,
33+
full_like as full_like,
34+
linspace as linspace,
35+
meshgrid as meshgrid,
36+
ones as ones,
37+
ones_like as ones_like,
38+
tril as tril,
39+
triu as triu,
40+
zeros as zeros,
41+
zeros_like as zeros_like,
42+
)
43+
44+
from ._data_type_functions import (
45+
astype as astype,
46+
can_cast as can_cast,
47+
finfo as finfo,
48+
iinfo as iinfo,
49+
isdtype as isdtype,
50+
result_type as result_type,
51+
)
52+
53+
from ._dtypes import (
54+
bool as bool,
55+
int8 as int8,
56+
int16 as int16,
57+
int32 as int32,
58+
int64 as int64,
59+
uint8 as uint8,
60+
uint16 as uint16,
61+
uint32 as uint32,
62+
uint64 as uint64,
63+
float32 as float32,
64+
float64 as float64,
65+
complex64 as complex64,
66+
complex128 as complex128,
67+
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import numpy as np
2+
3+
e = np.e
4+
inf = np.inf
5+
nan = np.nan
6+
newaxis = np.newaxis
7+
pi = np.pi
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
import jax.numpy as jnp
17+
18+
19+
def arange(start, /, stop=None, step=1, *, dtype=None, device=None):
20+
return jax.device_put(jnp.arange(start, stop, step, dtype=dtype), device=device)
21+
22+
def asarray(obj, /, *, dtype=None, device=None, copy=None):
23+
return jax.device_put(jnp.array(obj, dtype=dtype, copy=copy), device=device)
24+
25+
def empty(shape, *, dtype=None, device=None):
26+
return jax.device_put(jnp.empty(shape, dtype=dtype), device=device)
27+
28+
def empty_like(x, /, *, dtype=None, device=None):
29+
return jax.device_put(jnp.empty_like(x, dtype=dtype), device=device)
30+
31+
def eye(n_rows, n_cols=None, /, *, k=0, dtype=None, device=None):
32+
return jax.device_put(jnp.eye(n_rows, n_cols, k=k, dtype=dtype), device=device)
33+
34+
def from_dlpack(x, /):
35+
return jnp.from_dlpack(x)
36+
37+
def full(shape, fill_value, *, dtype=None, device=None):
38+
return jax.device_put(jnp.full(shape, fill_value, dtype=dtype), device=device)
39+
40+
def full_like(x, /, fill_value, *, dtype=None, device=None):
41+
return jax.device_put(jnp.full_like(x, fill_value=fill_value, dtype=dtype), device=device)
42+
43+
def linspace(start, stop, /, num, *, dtype=None, device=None, endpoint=True):
44+
return jax.device_put(jnp.linspace(start, stop, num=num, dtype=dtype, endpoint=endpoint), device=device)
45+
46+
def meshgrid(*arrays, indexing='xy'):
47+
return jnp.meshgrid(*arrays, indexing=indexing)
48+
49+
def ones(shape, *, dtype=None, device=None):
50+
return jax.device_put(jnp.ones(shape, dtype=dtype), device=device)
51+
52+
def ones_like(x, /, *, dtype=None, device=None):
53+
return jax.device_put(jnp.ones_like(x, dtype=dtype), device=device)
54+
55+
def tril(x, /, *, k=0):
56+
return jnp.tril(x, k=k)
57+
58+
def triu(x, /, *, k=0):
59+
return jnp.triu(x, k=k)
60+
61+
def zeros(shape, *, dtype=None, device=None):
62+
return jax.device_put(jnp.zeros(shape, dtype=dtype), device=device)
63+
64+
def zeros_like(x, /, *, dtype=None, device=None):
65+
return jax.device_put(jnp.zeros_like(x, dtype=dtype), device=device)
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import functools
17+
from typing import NamedTuple
18+
import jax
19+
import jax.numpy as jnp
20+
21+
22+
from jax.experimental.array_api._dtypes import (
23+
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
24+
float32, float64, complex64, complex128
25+
)
26+
27+
_valid_dtypes = {
28+
bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64,
29+
float32, float64, complex64, complex128
30+
}
31+
32+
_promotion_table = {
33+
(bool, bool): bool,
34+
(int8, int8): int8,
35+
(int8, int16): int16,
36+
(int8, int32): int32,
37+
(int8, int64): int64,
38+
(int8, uint8): int16,
39+
(int8, uint16): int32,
40+
(int8, uint32): int64,
41+
(int16, int8): int16,
42+
(int16, int16): int16,
43+
(int16, int32): int32,
44+
(int16, int64): int64,
45+
(int16, uint8): int32,
46+
(int16, uint16): int32,
47+
(int16, uint32): int64,
48+
(int32, int8): int32,
49+
(int32, int16): int32,
50+
(int32, int32): int32,
51+
(int32, int64): int64,
52+
(int32, uint8): int64,
53+
(int32, uint16): int64,
54+
(int32, uint32): int64,
55+
(int64, int8): int64,
56+
(int64, int16): int64,
57+
(int64, int32): int64,
58+
(int64, int64): int64,
59+
(int64, uint8): int64,
60+
(int64, uint16): int64,
61+
(int64, uint32): int64,
62+
(uint8, int8): int16,
63+
(uint8, int16): int32,
64+
(uint8, int32): int64,
65+
(uint8, uint8): uint8,
66+
(uint8, uint16): uint16,
67+
(uint8, uint32): uint32,
68+
(uint8, uint64): uint64,
69+
(uint16, int8): int32,
70+
(uint16, int16): int32,
71+
(uint16, int32): int64,
72+
(uint16, uint8): uint16,
73+
(uint16, uint16): uint16,
74+
(uint16, uint32): uint32,
75+
(uint16, uint64): uint64,
76+
(uint32, int8): int64,
77+
(uint32, int16): int64,
78+
(uint32, int32): int64,
79+
(uint32, uint8): uint32,
80+
(uint32, uint16): uint32,
81+
(uint32, uint32): uint32,
82+
(uint32, uint64): uint64,
83+
(uint64, uint8): uint64,
84+
(uint64, uint16): uint64,
85+
(uint64, uint32): uint64,
86+
(uint64, uint64): uint64,
87+
(float32, float32): float32,
88+
(float32, float64): float64,
89+
(float32, complex64): complex64,
90+
(float32, complex128): complex128,
91+
(float64, float32): float64,
92+
(float64, float64): float64,
93+
(float64, complex64): complex128,
94+
(float64, complex128): complex128,
95+
(complex64, float32): complex64,
96+
(complex64, float64): complex128,
97+
(complex64, complex64): complex64,
98+
(complex64, complex128): complex128,
99+
(complex128, float32): complex128,
100+
(complex128, float64): complex128,
101+
(complex128, complex64): complex128,
102+
(complex128, complex128): complex128,
103+
}
104+
105+
106+
def _is_valid_dtype(t):
107+
try:
108+
return t in _valid_dtypes
109+
except TypeError:
110+
return False
111+
112+
113+
def _promote_types(t1, t2):
114+
if not _is_valid_dtype(t1):
115+
raise ValueError(f"{t1} is not a valid dtype")
116+
if not _is_valid_dtype(t2):
117+
raise ValueError(f"{t2} is not a valid dtype")
118+
if result := _promotion_table.get((t1, t2), None):
119+
return result
120+
else:
121+
raise ValueError("No promotion path for {t1} & {t2}")
122+
123+
124+
def astype(x, dtype, /, *, copy=True):
125+
return jnp.asarray(x, dtype=dtype, copy=copy)
126+
127+
128+
def can_cast(from_, to, /):
129+
if not _is_valid_dtype(from_):
130+
raise ValueError(f"{from_} is not a valid dtype")
131+
if not _is_valid_dtype(to):
132+
raise ValueError(f"{to} is not a valid dtype")
133+
try:
134+
result = _promote_types(from_, to)
135+
except ValueError:
136+
return False
137+
else:
138+
return result == to
139+
140+
141+
class FInfo(NamedTuple):
142+
bits: int
143+
eps: float
144+
max: float
145+
min: float
146+
smallest_normal: float
147+
148+
149+
class IInfo(NamedTuple):
150+
bits: int
151+
max: int
152+
min: int
153+
154+
155+
def finfo(type, /) -> FInfo:
156+
info = jnp.finfo(type)
157+
return FInfo(
158+
bits=info.bits,
159+
eps=float(info.eps),
160+
max=float(info.max),
161+
min=float(info.min),
162+
smallest_normal=float(info.smallest_normal),
163+
)
164+
165+
166+
def iinfo(type, /) -> IInfo:
167+
info = jnp.iinfo(type)
168+
return IInfo(bits=info.bits, max=info.max, min=info.min)
169+
170+
171+
_dtype_kinds = {
172+
'bool': {bool},
173+
'signed integer': {int8, int16, int32, int64},
174+
'unsigned integer': {uint8, uint16, uint32, uint64},
175+
'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64},
176+
'real floating': {float32, float64},
177+
'complex floating': {complex64, complex128},
178+
'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64,
179+
float32, float64, complex64, complex128},
180+
}
181+
182+
def isdtype(dtype, kind):
183+
if not _is_valid_dtype(dtype):
184+
raise ValueError(f"{dtype} is not a valid dtype.")
185+
if isinstance(kind, tuple):
186+
return any(_isdtype(dtype, k) for k in kind)
187+
return _isdtype(dtype, kind)
188+
189+
def _isdtype(dtype, kind):
190+
if isinstance(kind, jnp.dtype):
191+
return dtype == kind
192+
elif isinstance(kind, str):
193+
if kind not in _dtype_kinds:
194+
raise ValueError(f"Unrecognized {kind=!r}")
195+
return dtype in _dtype_kinds[kind]
196+
else:
197+
raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.")
198+
199+
200+
def result_type(*arrays_and_dtypes):
201+
dtypes = []
202+
for val in arrays_and_dtypes:
203+
if isinstance(val, jax.Array):
204+
val = val.dtype
205+
if _is_valid_dtype(val):
206+
dtypes.append(val)
207+
else:
208+
raise ValueError(f"{val} is not a valid dtype")
209+
if len(dtypes) == 0:
210+
raise ValueError("result_type requires at least one argument")
211+
if len(dtypes) == 1:
212+
return dtypes[0]
213+
return functools.reduce(_promote_types, dtypes)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
bool = np.dtype('bool')
18+
int8 = np.dtype('int8')
19+
int16 = np.dtype('int16')
20+
int32 = np.dtype('int32')
21+
int64 = np.dtype('int64')
22+
uint8 = np.dtype('uint8')
23+
uint16 = np.dtype('uint16')
24+
uint32 = np.dtype('uint32')
25+
uint64 = np.dtype('uint64')
26+
float32 = np.dtype('float32')
27+
float64 = np.dtype('float64')
28+
complex64 = np.dtype('complex64')
29+
complex128 = np.dtype('complex128')

0 commit comments

Comments
 (0)