-
-
Notifications
You must be signed in to change notification settings - Fork 260
Description
These libraries don't seem to work together. I think that supporting or claiming integration with any new ML library should include support for hyperparameter tuning, that's definitely an MVP.
Here a code and error dump to back up my point:
import dask
import dask.dataframe as dd
from distributed import Client
from dask_ml.model_selection import HyperbandSearchCV
from dask_ml import datasets
import lightgbm as lgb
client = Client('10.118.232.173:8786')
X, y = datasets.make_classification(chunks=50)
model = lgb.DaskLGBMRegressor(client=client)
param_space = {
'n_estimators': range(100, 200, 50),
'max_depth': range(3, 6, 2),
'booster': ('gbtree', 'dart'),
}
search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05)
search.fit(X, y)
And the error message
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=81. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=34. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=15. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
[CV, bracket=0] For training there are between 47 and 47 examples in each chunk
[CV, bracket=1] For training there are between 47 and 47 examples in each chunk
AttributeError Traceback (most recent call last)
in
10
11 search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05)
---> 12 search.fit(X, y)
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params)
715 client = default_client()
716 if not client.asynchronous:
--> 717 return client.sync(self._fit, X, y, **fit_params)
718 return self._fit(X, y, **fit_params)
719
/opt/conda/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
849 return future
850 else:
--> 851 return sync(
852 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
853 )
/opt/conda/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
352 if error[0]:
353 typ, exc, tb = error[0]
--> 354 raise exc.with_traceback(tb)
355 else:
356 return result[0]
/opt/conda/lib/python3.8/site-packages/distributed/utils.py in f()
335 if callback_timeout is not None:
336 future = asyncio.wait_for(future, callback_timeout)
--> 337 result[0] = yield future
338 except Exception as exc:
339 error[0] = sys.exc_info()
/opt/conda/lib/python3.8/site-packages/tornado/gen.py in run(self)
760
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_hyperband.py in _fit(self, X, y, **fit_params)
399 _brackets_ids = list(reversed(sorted(SHAs)))
400
--> 401 _SHAs = await asyncio.gather(
402 *[SHAs[b]._fit(X, y, **fit_params) for b in _brackets_ids]
403 )
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params)
661
662 with context:
--> 663 results = await fit(
664 self.estimator,
665 self._get_params(),
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
475 A history of all models scores over time
476 """
--> 477 return await _fit(
478 model,
479 params,
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
266 # async for future, result in seq:
267 for _i in itertools.count():
--> 268 metas = await client.gather(new_scores)
269
270 if log_delay and _i % int(log_delay) == 0:
/opt/conda/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1846 exc = CancelledError(key)
1847 else:
-> 1848 raise exception.with_traceback(traceback)
1849 raise exc
1850 if errors == "skip":
/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _partial_fit()
101 if len(X):
102 model = deepcopy(model)
--> 103 model.partial_fit(X, y, **(fit_params or {}))
104
105 meta = dict(meta)
AttributeError: 'DaskLGBMRegressor' object has no attribute 'partial_fit'