Skip to content

Support DataFrame in IncrementalSearchCV #628

@TomAugspurger

Description

@TomAugspurger

https://stackoverflow.com/questions/60846247/this-estimator-does-not-support-dask-dataframes

import pandas as pd
import dask.dataframe as dd
from distributed import Client
from sklearn.linear_model import SGDClassifier
from dask_ml.model_selection import IncrementalSearchCV

df = pd.DataFrame({"A": [1, 2, 3, 4, 5], "B": [0, 0, 1, 1, 1]})
ddf = dd.from_pandas(df, 2)
X = df[["A"]]
y = df['B']

model = SGDClassifier()
params = dict(alpha=[0.1, 1])
search = IncrementalSearchCV(model, params)

client = Client()

search.fit(X, y)

If we accept dask dataframes

diff --git a/dask_ml/model_selection/_incremental.py b/dask_ml/model_selection/_incremental.py
index 48f43be9..3219f8c1 100644
--- a/dask_ml/model_selection/_incremental.py
+++ b/dask_ml/model_selection/_incremental.py
@@ -447,8 +447,8 @@ class BaseIncrementalSearchCV(ParallelPostFit):
                 )
             )
 
-        X = self._check_array(X)
-        y = self._check_array(y, ensure_2d=False)
+        X = self._check_array(X, accept_dask_dataframe=True)
+        y = self._check_array(y, ensure_2d=False, accept_dask_dataframe=True)
         scorer = check_scoring(self.estimator, scoring=self.scoring)
         return X, y, scorer
 

We fail at

~/sandbox/dask-ml/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params)
    556     def _fit(self, X, y, **fit_params):
    557         X, y, scorer = self._validate_parameters(X, y)
--> 558         X_train, X_test, y_train, y_test = self._get_train_test_split(X, y)
    559 
    560         results = yield fit(

~/sandbox/dask-ml/dask_ml/model_selection/_incremental.py in _get_train_test_split(self, X, y, **kwargs)
    484         """
    485         if self.test_size is None:
--> 486             test_size = min(0.2, 1 / X.npartitions)
    487         else:
    488             test_size = self.test_size

AttributeError: 'numpy.ndarray' object has no attribute 'npartitions'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions