Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Linting
on: [pull_request, push]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: pre-commit/[email protected]
44 changes: 44 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Tests

on: [push, pull_request]

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
# os: ["windows-latest", "ubuntu-latest", "macos-latest"]
os: ["ubuntu-latest"]
python-version: ["3.7", "3.8", "3.9"]

env:
PYTHON_VERSION: ${{ matrix.python-version }}
PARALLEL: "true"
COVERAGE: "true"

steps:
- name: Checkout source
uses: actions/checkout@v2
with:
fetch-depth: 0 # Needed by codecov.io

- name: Setup Conda Environment
uses: conda-incubator/setup-miniconda@v2
with:
miniforge-variant: Mambaforge
miniforge-version: latest
use-mamba: true
channel-priority: strict
python-version: ${{ matrix.python-version }}
environment-file: ci/environment-${{ matrix.python-version }}.yaml
activate-environment: test-environment
auto-activate-base: false

- name: Install
shell: bash -l {0}
run: source ci/install.sh

- name: Run tests
shell: bash -l {0}
run: pytest -v
12 changes: 0 additions & 12 deletions azure-pipelines.yml

This file was deleted.

18 changes: 0 additions & 18 deletions ci/code_checks.sh

This file was deleted.

32 changes: 0 additions & 32 deletions ci/environment-3.6.yaml

This file was deleted.

24 changes: 6 additions & 18 deletions ci/environment-3.7.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
name: dask-ml-test
name: dask-ml-3.7
channels:
- conda-forge
- defaults
dependencies:
- black==19.10b0
- coverage
- codecov
# dask 2021.3.0 introduced a regression which causes tests to fail.
# The issue has been resolved upstream in dask and will be included
# in the next release. We temporarily apply a dask version contraint
# to allow CI to pass
- dask !=2021.3.0
- dask-glm >=0.2.0
- flake8
- isort==4.3.21
- dask
- dask-glm
- multipledispatch >=0.4.9
- mypy
- numba
- numpy >=1.16.3
- numpy
- numpydoc
- packaging
- pandas
Expand All @@ -26,10 +17,7 @@ dependencies:
- pytest-cov
- pytest-mock
- python=3.7.*
- scikit-learn>=0.23.0
- scikit-learn>=1.0.0
- scipy
- sparse
- toolz
- pip
- pip:
- pytest-azurepipelines
- toolz
24 changes: 6 additions & 18 deletions ci/environment-3.8.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
name: dask-ml-test
name: dask-ml-3.8
channels:
- conda-forge
- defaults
dependencies:
- black==19.10b0
- coverage
- codecov
# dask 2021.3.0 introduced a regression which causes tests to fail.
# The issue has been resolved upstream in dask and will be included
# in the next release. We temporarily apply a dask version contraint
# to allow CI to pass
- dask !=2021.3.0
- dask-glm >=0.2.0
- flake8
- isort==4.3.21
- dask
- dask-glm
- multipledispatch >=0.4.9
- mypy
- numba
- numpy >=1.16.3
- numpy
- numpydoc
- packaging
- pandas
Expand All @@ -26,10 +17,7 @@ dependencies:
- pytest-cov
- pytest-mock
- python=3.8.*
- scikit-learn>=0.23.0
- scikit-learn>=1.0.0
- scipy
- sparse
- toolz
- pip
- pip:
- pytest-azurepipelines
- toolz
23 changes: 23 additions & 0 deletions ci/environment-3.9.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: dask-ml-3.9
channels:
- conda-forge
- defaults
dependencies:
- dask
- dask-glm
- multipledispatch >=0.4.9
- mypy
- numba
- numpy
- numpydoc
- packaging
- pandas
- psutil
- pytest
- pytest-cov
- pytest-mock
- python=3.9.*
- scikit-learn>=1.0.0
- scipy
- sparse
- toolz
12 changes: 3 additions & 9 deletions ci/environment-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ channels:
dependencies:
- black
- coverage
- flake8
- graphviz
- heapdict
- ipykernel
- ipython
- isort==4.3.21
- multipledispatch
- mypy
- nbsphinx
Expand All @@ -21,18 +19,14 @@ dependencies:
- numpydoc
- pandas
- psutil
- pytest
- pytest-cov
- pytest-mock
- python=3.7
- python=3.8
- sortedcontainers
- scikit-learn>=0.23.1
- scikit-learn>=1.0.0
- scipy
- sparse
- sphinx==1.8.5
- sphinx
- sphinx_rtd_theme
- sphinx-gallery
- testpath<0.4
- tornado
- toolz
- xgboost
Expand Down
4 changes: 4 additions & 0 deletions ci/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python -m pip install --quiet --no-deps -e .

echo mamba list
mamba list
12 changes: 4 additions & 8 deletions dask_ml/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
PANDAS_VERSION = packaging.version.parse(pandas.__version__)
DISTRIBUTED_VERSION = packaging.version.parse(distributed.__version__)

SK_0_23_2 = SK_VERSION >= packaging.version.parse("0.23.2")
SK_024 = SK_VERSION >= packaging.version.parse("0.24.0.dev0")
DASK_240 = DASK_VERSION >= packaging.version.parse("2.4.0")
DASK_2130 = DASK_VERSION >= packaging.version.parse("2.13.0")
DASK_2_20_0 = DASK_VERSION >= packaging.version.parse("2.20.0")
Expand Down Expand Up @@ -49,9 +47,7 @@ def _check_multimetric_scoring(estimator, scoring=None):
from sklearn.metrics._scorer import _check_multimetric_scoring
from sklearn.metrics import check_scoring

if SK_024:
if callable(scoring) or isinstance(scoring, (type(None), str)):
scorers = {"score": check_scoring(estimator, scoring=scoring)}
return scorers, False
return _check_multimetric_scoring(estimator, scoring), True
return _check_multimetric_scoring(estimator, scoring)
if callable(scoring) or isinstance(scoring, (type(None), str)):
scorers = {"score": check_scoring(estimator, scoring=scoring)}
return scorers, False
return _check_multimetric_scoring(estimator, scoring), True
28 changes: 19 additions & 9 deletions dask_ml/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _blocks_and_name(obj):


def _predict(model, x):
return model.predict(x)[:, None]
return model.predict(x)


def predict(model, x):
Expand All @@ -173,15 +173,25 @@ def predict(model, x):

See docstring for ``da.learn.fit``
"""
if not hasattr(x, "chunks") and hasattr(x, "to_dask_array"):
x = x.to_dask_array()
assert x.ndim == 2
if len(x.chunks[1]) > 1:
x = x.rechunk(chunks=(x.chunks[0], sum(x.chunks[1])))
func = partial(_predict, model)
xx = np.zeros((1, x.shape[1]), dtype=x.dtype)
dt = model.predict(xx).dtype
return x.map_blocks(func, chunks=(x.chunks[0], (1,)), dtype=dt).squeeze()

if getattr(model, "feature_names_in_", None) is not None:
meta = model.predict(x._meta_nonempty)
return x.map_partitions(func, meta=meta)
else:
if len(x.chunks[1]) > 1:
x = x.rechunk(chunks=(x.chunks[0], sum(x.chunks[1])))

xx = np.zeros((1, x.shape[1]), dtype=x.dtype)
meta = model.predict(xx)

if meta.ndim > 1:
chunks = (x.chunks[0], (1,))
drop_axis = None
else:
chunks = (x.chunks[0],)
drop_axis = 1
return x.map_blocks(func, chunks=chunks, meta=meta, drop_axis=drop_axis)


def _copy_partial_doc(cls):
Expand Down
Loading