Skip to content

Commit f4b66c9

Browse files
committed
Merge branch 'main' into pa/lazy
2 parents a828753 + d459cca commit f4b66c9

File tree

18 files changed

+786
-306
lines changed

18 files changed

+786
-306
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ jobs:
3030
- run: uv pip install --system -e .[test${{ matrix.extras == 'full' && ',full' || '' }}]
3131
- run: |
3232
coverage run -m pytest -m "not benchmark"
33-
coverage xml
3433
coverage report
34+
# https://github.com/codecov/codecov-cli/issues/648
35+
coverage xml
36+
rm test-data/.coverage
3537
- uses: codecov/codecov-action@v5
3638
with:
3739
fail_ci_if_error: true

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ repos:
66
- id: trailing-whitespace
77
- id: no-commit-to-branch
88
- repo: https://github.com/astral-sh/ruff-pre-commit
9-
rev: v0.9.6
9+
rev: v0.9.7
1010
hooks:
1111
- id: ruff
1212
args: [--fix, --exit-non-zero-on-fix]

docs/conf.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,20 @@
3232
"sphinx.ext.autodoc",
3333
"sphinx.ext.autosummary",
3434
"scanpydoc.elegant_typehints",
35+
"sphinx_autofixture",
3536
]
3637

3738
# API documentation when building
3839
nitpicky = True
3940
autosummary_generate = True
4041
autodoc_member_order = "bysource"
42+
autodoc_default_options = {
43+
"special-members": True,
44+
# everything except __call__ really, to avoid having to write autosummary templates
45+
"exclude-members": (
46+
"__setattr__,__delattr__,__repr__,__eq__,__or__,__ror__,__hash__,__weakref__,__init__,__new__"
47+
),
48+
}
4149
napoleon_google_docstring = False
4250
napoleon_numpy_docstring = True
4351
todo_include_todos = False
@@ -52,12 +60,15 @@
5260
)
5361
# Try overriding type paths
5462
qualname_overrides = autodoc_type_aliases = {
63+
"np.bool": ("py:data", "numpy.bool"),
5564
"np.dtype": "numpy.dtype",
5665
"np.number": "numpy.number",
5766
"np.integer": "numpy.integer",
67+
"np.random.Generator": "numpy.random.Generator",
5868
"ArrayLike": "numpy.typing.ArrayLike",
5969
"DTypeLike": "numpy.typing.DTypeLike",
6070
"NDArray": "numpy.typing.NDArray",
71+
"_pytest.fixtures.FixtureRequest": "pytest.FixtureRequest",
6172
**{
6273
k: v
6374
for k_plain, v in {
@@ -74,10 +85,18 @@
7485
# If that doesn’t work, ignore them
7586
nitpick_ignore = {
7687
("py:class", "fast_array_utils.types.T_co"),
88+
("py:class", "Arr"),
89+
("py:class", "testing.fast_array_utils._array_type.Arr"),
90+
("py:class", "testing.fast_array_utils._array_type.Inner"),
91+
("py:class", "_DTypeLikeFloat32"),
92+
("py:class", "_DTypeLikeFloat64"),
7793
# sphinx bugs, should be covered by `autodoc_type_aliases` above
94+
("py:class", "Array"),
7895
("py:class", "ArrayLike"),
7996
("py:class", "DTypeLike"),
8097
("py:class", "NDArray"),
98+
("py:class", "np.bool"),
99+
("py:class", "_pytest.fixtures.FixtureRequest"),
81100
}
82101

83102
# Options for HTML output

docs/index.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
``fast_array_utils``
22
====================
33

4+
.. toctree::
5+
:hidden:
6+
7+
fast-array-utils <self>
8+
testing
9+
410
.. automodule:: fast_array_utils
511
:members:
612

7-
813
``fast_array_utils.conv``
914
-------------------------
1015

docs/testing.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
``testing.fast_array_utils``
2+
============================
3+
4+
.. automodule:: testing.fast_array_utils
5+
:members:
6+
7+
``testing.fast_array_utils.pytest``
8+
-----------------------------------
9+
10+
.. automodule:: testing.fast_array_utils.pytest
11+
:members:

pyproject.toml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,33 @@ classifiers = [
2323
]
2424
dynamic = [ "description", "version" ]
2525
dependencies = [ "numba", "numpy" ]
26-
optional-dependencies.doc = [ "furo", "scanpydoc>=0.15.2", "sphinx>=8", "sphinx-autodoc-typehints" ]
26+
optional-dependencies.doc = [
27+
"furo",
28+
"pytest",
29+
"scanpydoc>=0.15.2",
30+
"sphinx>=8",
31+
"sphinx-autodoc-typehints",
32+
"sphinx-autofixture",
33+
]
2734
optional-dependencies.full = [ "dask", "fast-array-utils[sparse]", "h5py", "zarr" ]
2835
optional-dependencies.sparse = [ "scipy>=1.8" ]
2936
optional-dependencies.test = [ "coverage[toml]", "pytest", "pytest-codspeed" ]
3037
urls.'Documentation' = "https://icb-fast-array-utils.readthedocs-hosted.com/"
3138
urls.'Issue Tracker' = "https://github.com/scverse/fast-array-utils/issues"
3239
urls.'Source Code' = "https://github.com/scverse/fast-array-utils"
3340

34-
[tool.hatch.metadata.hooks.docstring-description]
41+
entry_points.pytest11.fast_array_utils = "testing.fast_array_utils.pytest"
3542

3643
[tool.hatch.version]
3744
source = "vcs"
3845
raw-options = { local_scheme = "no-local-version" } # be able to publish dev version
3946

47+
# TODO: support setting main package in the plugin
48+
# [tool.hatch.metadata.hooks.docstring-description]
49+
50+
[tool.hatch.build.targets.wheel]
51+
packages = [ "src/testing", "src/fast_array_utils" ]
52+
4053
[tool.hatch.envs.default]
4154
installer = "uv"
4255

@@ -47,7 +60,7 @@ scripts.clean = "git clean -fdX docs"
4760

4861
[tool.hatch.envs.hatch-test]
4962
features = [ "test" ]
50-
extra-dependencies = [ "ipykernel" ]
63+
extra-dependencies = [ "ipykernel", "ipycytoscape" ]
5164
env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
5265
overrides.matrix.extras.features = [
5366
{ if = [ "full" ], value = "full" },
@@ -86,20 +99,25 @@ lint.per-file-ignores."tests/**/test_*.py" = [
8699
"S101", # tests use `assert`
87100
]
88101
lint.allowed-confusables = [ "×", "" ]
102+
lint.flake8-bugbear.extend-immutable-calls = [ "testing.fast_array_utils.Flags" ]
103+
89104
lint.flake8-copyright.notice-rgx = "SPDX-License-Identifier: MPL-2\\.0"
90105
lint.flake8-type-checking.exempt-modules = [ ]
91106
lint.flake8-type-checking.strict = true
92107
lint.isort.known-first-party = [ "fast_array_utils" ]
93108
lint.isort.lines-after-imports = 2
94109
lint.isort.required-imports = [ "from __future__ import annotations" ]
110+
lint.pydocstyle.convention = "numpy"
95111

96112
[tool.pytest.ini_options]
97113
addopts = [
98114
"--import-mode=importlib",
99115
"--strict-markers",
116+
"--doctest-modules",
100117
"--pyargs",
101118
"-ptesting.fast_array_utils.pytest",
102119
]
120+
testpaths = [ "./tests", "fast_array_utils" ]
103121
filterwarnings = [
104122
"error",
105123
# codspeed seems to break this dtype added by h5py
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from numbers import Integral
5+
6+
7+
def validate_axis(axis: int | None) -> None:
8+
if axis is None:
9+
return
10+
if not isinstance(axis, Integral): # pragma: no cover
11+
msg = "axis must be integer or None."
12+
raise TypeError(msg)
13+
if axis not in (0, 1): # pragma: no cover
14+
msg = "We only support axis 0 and 1 at the moment"
15+
raise NotImplementedError(msg)

src/fast_array_utils/conv/_asarray.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
# SPDX-License-Identifier: MPL-2.0
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any, cast
55

66
import numpy as np
7+
from numpy.typing import NDArray
78

89
from .._import import lazy_singledispatch
910
from ..types import OutOfCoreDataset
1011

1112

1213
if TYPE_CHECKING:
13-
from typing import Any
14-
15-
from numpy.typing import ArrayLike, NDArray
14+
from numpy.typing import ArrayLike
1615

1716
from .. import types
1817

@@ -66,9 +65,9 @@ def _(x: types.OutOfCoreDataset[types.CSBase | NDArray[Any]]) -> NDArray[Any]:
6665

6766
@asarray.register("cupy:ndarray")
6867
def _(x: types.CupyArray) -> NDArray[Any]:
69-
return x.get() # type: ignore[no-any-return]
68+
return cast(NDArray[Any], x.get())
7069

7170

7271
@asarray.register("cupyx.scipy.sparse:spmatrix")
7372
def _(x: types.CupySparseMatrix) -> NDArray[Any]:
74-
return x.toarray().get() # type: ignore[no-any-return]
73+
return cast(NDArray[Any], x.toarray().get())

src/fast_array_utils/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from __future__ import annotations
55

6+
from ._is_constant import is_constant
67
from ._sum import sum
78

89

9-
__all__ = ["sum"]
10+
__all__ = ["is_constant", "sum"]
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from functools import partial, singledispatch
5+
from typing import TYPE_CHECKING, Any, cast, overload
6+
7+
import numba
8+
import numpy as np
9+
from numpy.typing import NDArray
10+
11+
from .. import types
12+
from .._validation import validate_axis
13+
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Callable
17+
from typing import Literal, TypeVar
18+
19+
C = TypeVar("C", bound=Callable[..., Any])
20+
21+
22+
@overload
23+
def is_constant(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ...
24+
@overload
25+
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: None = None) -> bool: ...
26+
@overload
27+
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> NDArray[np.bool]: ...
28+
29+
30+
def is_constant(
31+
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
32+
) -> bool | NDArray[np.bool] | types.DaskArray:
33+
"""Check whether values in array are constant.
34+
35+
Params
36+
------
37+
a
38+
Array to check
39+
axis
40+
Axis to reduce over.
41+
42+
Returns
43+
-------
44+
If ``axis`` is :data:`None`, return if all values were constant.
45+
Else returns a boolean array with :data:`True` representing constant columns/rows.
46+
47+
Example
48+
-------
49+
>>> a = np.array([[0, 1], [0, 0]])
50+
>>> a
51+
array([[0, 1],
52+
[0, 0]])
53+
>>> is_constant(a)
54+
False
55+
>>> is_constant(a, axis=0)
56+
array([ True, False])
57+
>>> is_constant(a, axis=1)
58+
array([False, True])
59+
60+
"""
61+
validate_axis(axis)
62+
return _is_constant(a, axis=axis)
63+
64+
65+
@singledispatch
66+
def _is_constant(
67+
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
68+
) -> bool | NDArray[np.bool]: # pragma: no cover
69+
raise NotImplementedError
70+
71+
72+
@_is_constant.register(np.ndarray)
73+
def _(a: NDArray[Any], /, *, axis: Literal[0, 1, None] = None) -> bool | NDArray[np.bool]:
74+
# Should eventually support nd, not now.
75+
match axis:
76+
case None:
77+
return bool((a == a.flat[0]).all())
78+
case 0:
79+
return _is_constant_rows(a.T)
80+
case 1:
81+
return _is_constant_rows(a)
82+
83+
84+
def _is_constant_rows(a: NDArray[Any]) -> NDArray[np.bool]:
85+
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
86+
return cast(NDArray[np.bool], (a == b).all(axis=1))
87+
88+
89+
@_is_constant.register(types.CSBase)
90+
def _(a: types.CSBase, /, *, axis: Literal[0, 1, None] = None) -> bool | NDArray[np.bool]:
91+
n_row, n_col = a.shape
92+
if axis is None:
93+
if len(a.data) == n_row * n_col:
94+
return is_constant(cast(NDArray[Any], a.data))
95+
return bool((a.data == 0).all())
96+
shape = (n_row, n_col) if axis == 1 else (n_col, n_row)
97+
match axis, a.format:
98+
case 0, "csr":
99+
a = a.T.tocsr()
100+
case 1, "csc":
101+
a = a.T.tocsc()
102+
return _is_constant_csr_rows(a.data, a.indptr, shape)
103+
104+
105+
@numba.njit(cache=True)
106+
def _is_constant_csr_rows(
107+
data: NDArray[np.number[Any]],
108+
indptr: NDArray[np.integer[Any]],
109+
shape: tuple[int, int],
110+
) -> NDArray[np.bool]:
111+
n = len(indptr) - 1
112+
result = np.ones(n, dtype=np.bool)
113+
for i in numba.prange(n):
114+
start = indptr[i]
115+
stop = indptr[i + 1]
116+
val = data[start] if stop - start == shape[1] else 0
117+
for j in range(start, stop):
118+
if data[j] != val:
119+
result[i] = False
120+
break
121+
return result
122+
123+
124+
@_is_constant.register(types.DaskArray)
125+
def _(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray:
126+
if TYPE_CHECKING:
127+
from dask.array.core import map_blocks
128+
from dask.array.overlap import map_overlap
129+
else:
130+
from dask.array import map_blocks, map_overlap
131+
132+
if axis is not None:
133+
return cast(
134+
types.DaskArray,
135+
map_blocks( # type: ignore[no-untyped-call]
136+
partial(is_constant, axis=axis), a, drop_axis=axis, meta=np.array([], dtype=np.bool)
137+
),
138+
)
139+
140+
rv = cast(
141+
types.DaskArray,
142+
(a == a[0, 0].compute()).all()
143+
if isinstance(a._meta, np.ndarray) # noqa: SLF001
144+
else map_overlap( # type: ignore[no-untyped-call]
145+
lambda a: np.array([[is_constant(a)]]),
146+
a,
147+
# use asymmetric overlaps to avoid unnecessary computation
148+
depth={d: (0, 1) for d in range(a.ndim)},
149+
trim=False,
150+
meta=np.array([], dtype=bool),
151+
).all(),
152+
)
153+
return cast(
154+
types.DaskArray,
155+
map_blocks(bool, rv, meta=np.array([], dtype=bool)), # type: ignore[no-untyped-call]
156+
)

0 commit comments

Comments
 (0)