|
| 1 | +# Copyright 2023 The JAX Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +import functools |
| 17 | +from typing import NamedTuple |
| 18 | +import jax |
| 19 | +import jax.numpy as jnp |
| 20 | + |
| 21 | + |
| 22 | +from jax.experimental.array_api._dtypes import ( |
| 23 | + bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, |
| 24 | + float32, float64, complex64, complex128 |
| 25 | +) |
| 26 | + |
| 27 | +_valid_dtypes = { |
| 28 | + bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, |
| 29 | + float32, float64, complex64, complex128 |
| 30 | +} |
| 31 | + |
| 32 | +_promotion_table = { |
| 33 | + (bool, bool): bool, |
| 34 | + (int8, int8): int8, |
| 35 | + (int8, int16): int16, |
| 36 | + (int8, int32): int32, |
| 37 | + (int8, int64): int64, |
| 38 | + (int8, uint8): int16, |
| 39 | + (int8, uint16): int32, |
| 40 | + (int8, uint32): int64, |
| 41 | + (int16, int8): int16, |
| 42 | + (int16, int16): int16, |
| 43 | + (int16, int32): int32, |
| 44 | + (int16, int64): int64, |
| 45 | + (int16, uint8): int32, |
| 46 | + (int16, uint16): int32, |
| 47 | + (int16, uint32): int64, |
| 48 | + (int32, int8): int32, |
| 49 | + (int32, int16): int32, |
| 50 | + (int32, int32): int32, |
| 51 | + (int32, int64): int64, |
| 52 | + (int32, uint8): int64, |
| 53 | + (int32, uint16): int64, |
| 54 | + (int32, uint32): int64, |
| 55 | + (int64, int8): int64, |
| 56 | + (int64, int16): int64, |
| 57 | + (int64, int32): int64, |
| 58 | + (int64, int64): int64, |
| 59 | + (int64, uint8): int64, |
| 60 | + (int64, uint16): int64, |
| 61 | + (int64, uint32): int64, |
| 62 | + (uint8, int8): int16, |
| 63 | + (uint8, int16): int32, |
| 64 | + (uint8, int32): int64, |
| 65 | + (uint8, uint8): uint8, |
| 66 | + (uint8, uint16): uint16, |
| 67 | + (uint8, uint32): uint32, |
| 68 | + (uint8, uint64): uint64, |
| 69 | + (uint16, int8): int32, |
| 70 | + (uint16, int16): int32, |
| 71 | + (uint16, int32): int64, |
| 72 | + (uint16, uint8): uint16, |
| 73 | + (uint16, uint16): uint16, |
| 74 | + (uint16, uint32): uint32, |
| 75 | + (uint16, uint64): uint64, |
| 76 | + (uint32, int8): int64, |
| 77 | + (uint32, int16): int64, |
| 78 | + (uint32, int32): int64, |
| 79 | + (uint32, uint8): uint32, |
| 80 | + (uint32, uint16): uint32, |
| 81 | + (uint32, uint32): uint32, |
| 82 | + (uint32, uint64): uint64, |
| 83 | + (uint64, uint8): uint64, |
| 84 | + (uint64, uint16): uint64, |
| 85 | + (uint64, uint32): uint64, |
| 86 | + (uint64, uint64): uint64, |
| 87 | + (float32, float32): float32, |
| 88 | + (float32, float64): float64, |
| 89 | + (float32, complex64): complex64, |
| 90 | + (float32, complex128): complex128, |
| 91 | + (float64, float32): float64, |
| 92 | + (float64, float64): float64, |
| 93 | + (float64, complex64): complex128, |
| 94 | + (float64, complex128): complex128, |
| 95 | + (complex64, float32): complex64, |
| 96 | + (complex64, float64): complex128, |
| 97 | + (complex64, complex64): complex64, |
| 98 | + (complex64, complex128): complex128, |
| 99 | + (complex128, float32): complex128, |
| 100 | + (complex128, float64): complex128, |
| 101 | + (complex128, complex64): complex128, |
| 102 | + (complex128, complex128): complex128, |
| 103 | +} |
| 104 | + |
| 105 | + |
| 106 | +def _is_valid_dtype(t): |
| 107 | + try: |
| 108 | + return t in _valid_dtypes |
| 109 | + except TypeError: |
| 110 | + return False |
| 111 | + |
| 112 | + |
| 113 | +def _promote_types(t1, t2): |
| 114 | + if not _is_valid_dtype(t1): |
| 115 | + raise ValueError(f"{t1} is not a valid dtype") |
| 116 | + if not _is_valid_dtype(t2): |
| 117 | + raise ValueError(f"{t2} is not a valid dtype") |
| 118 | + if result := _promotion_table.get((t1, t2), None): |
| 119 | + return result |
| 120 | + else: |
| 121 | + raise ValueError("No promotion path for {t1} & {t2}") |
| 122 | + |
| 123 | + |
| 124 | +def astype(x, dtype, /, *, copy=True): |
| 125 | + return jnp.asarray(x, dtype=dtype, copy=copy) |
| 126 | + |
| 127 | + |
| 128 | +def can_cast(from_, to, /): |
| 129 | + if not _is_valid_dtype(from_): |
| 130 | + raise ValueError(f"{from_} is not a valid dtype") |
| 131 | + if not _is_valid_dtype(to): |
| 132 | + raise ValueError(f"{to} is not a valid dtype") |
| 133 | + try: |
| 134 | + result = _promote_types(from_, to) |
| 135 | + except ValueError: |
| 136 | + return False |
| 137 | + else: |
| 138 | + return result == to |
| 139 | + |
| 140 | + |
| 141 | +class FInfo(NamedTuple): |
| 142 | + bits: int |
| 143 | + eps: float |
| 144 | + max: float |
| 145 | + min: float |
| 146 | + smallest_normal: float |
| 147 | + |
| 148 | + |
| 149 | +class IInfo(NamedTuple): |
| 150 | + bits: int |
| 151 | + max: int |
| 152 | + min: int |
| 153 | + |
| 154 | + |
| 155 | +def finfo(type, /) -> FInfo: |
| 156 | + info = jnp.finfo(type) |
| 157 | + return FInfo( |
| 158 | + bits=info.bits, |
| 159 | + eps=float(info.eps), |
| 160 | + max=float(info.max), |
| 161 | + min=float(info.min), |
| 162 | + smallest_normal=float(info.smallest_normal), |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +def iinfo(type, /) -> IInfo: |
| 167 | + info = jnp.iinfo(type) |
| 168 | + return IInfo(bits=info.bits, max=info.max, min=info.min) |
| 169 | + |
| 170 | + |
| 171 | +_dtype_kinds = { |
| 172 | + 'bool': {bool}, |
| 173 | + 'signed integer': {int8, int16, int32, int64}, |
| 174 | + 'unsigned integer': {uint8, uint16, uint32, uint64}, |
| 175 | + 'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64}, |
| 176 | + 'real floating': {float32, float64}, |
| 177 | + 'complex floating': {complex64, complex128}, |
| 178 | + 'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64, |
| 179 | + float32, float64, complex64, complex128}, |
| 180 | +} |
| 181 | + |
| 182 | +def isdtype(dtype, kind): |
| 183 | + if not _is_valid_dtype(dtype): |
| 184 | + raise ValueError(f"{dtype} is not a valid dtype.") |
| 185 | + if isinstance(kind, tuple): |
| 186 | + return any(_isdtype(dtype, k) for k in kind) |
| 187 | + return _isdtype(dtype, kind) |
| 188 | + |
| 189 | +def _isdtype(dtype, kind): |
| 190 | + if isinstance(kind, jnp.dtype): |
| 191 | + return dtype == kind |
| 192 | + elif isinstance(kind, str): |
| 193 | + if kind not in _dtype_kinds: |
| 194 | + raise ValueError(f"Unrecognized {kind=!r}") |
| 195 | + return dtype in _dtype_kinds[kind] |
| 196 | + else: |
| 197 | + raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.") |
| 198 | + |
| 199 | + |
| 200 | +def result_type(*arrays_and_dtypes): |
| 201 | + dtypes = [] |
| 202 | + for val in arrays_and_dtypes: |
| 203 | + if isinstance(val, jax.Array): |
| 204 | + val = val.dtype |
| 205 | + if _is_valid_dtype(val): |
| 206 | + dtypes.append(val) |
| 207 | + else: |
| 208 | + raise ValueError(f"{val} is not a valid dtype") |
| 209 | + if len(dtypes) == 0: |
| 210 | + raise ValueError("result_type requires at least one argument") |
| 211 | + if len(dtypes) == 1: |
| 212 | + return dtypes[0] |
| 213 | + return functools.reduce(_promote_types, dtypes) |
0 commit comments