Skip to content

Make lightgbm work with HyperbandSearchCV #838

@vecorro

Description

@vecorro

These libraries don't seem to work together. I think that supporting or claiming integration with any new ML library should include support for hyperparameter tuning, that's definitely an MVP.

Here a code and error dump to back up my point:

import dask
import dask.dataframe as dd
from distributed import Client
from dask_ml.model_selection import HyperbandSearchCV
from dask_ml import datasets
import lightgbm as lgb

client = Client('10.118.232.173:8786')

X, y = datasets.make_classification(chunks=50)

model = lgb.DaskLGBMRegressor(client=client)


param_space = {
    'n_estimators': range(100, 200, 50),
    'max_depth': range(3, 6, 2),
    'booster': ('gbtree', 'dart'),
}

search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05)
search.fit(X, y)

And the error message

/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=81. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=34. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(
/opt/conda/lib/python3.8/site-packages/sklearn/model_selection/_search.py:285: UserWarning: The total space of parameters 8 is smaller than n_iter=15. Running 8 iterations. For exhaustive searches, use GridSearchCV.
warnings.warn(

[CV, bracket=0] For training there are between 47 and 47 examples in each chunk
[CV, bracket=1] For training there are between 47 and 47 examples in each chunk


AttributeError Traceback (most recent call last)
in
10
11 search = HyperbandSearchCV(model, param_space, random_state=0, patience=True, verbose=True, test_size=0.05)
---> 12 search.fit(X, y)

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params)
715 client = default_client()
716 if not client.asynchronous:
--> 717 return client.sync(self._fit, X, y, **fit_params)
718 return self._fit(X, y, **fit_params)
719

/opt/conda/lib/python3.8/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
849 return future
850 else:
--> 851 return sync(
852 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
853 )

/opt/conda/lib/python3.8/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
352 if error[0]:
353 typ, exc, tb = error[0]
--> 354 raise exc.with_traceback(tb)
355 else:
356 return result[0]

/opt/conda/lib/python3.8/site-packages/distributed/utils.py in f()
335 if callback_timeout is not None:
336 future = asyncio.wait_for(future, callback_timeout)
--> 337 result[0] = yield future
338 except Exception as exc:
339 error[0] = sys.exc_info()

/opt/conda/lib/python3.8/site-packages/tornado/gen.py in run(self)
760
761 try:
--> 762 value = future.result()
763 except Exception:
764 exc_info = sys.exc_info()

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_hyperband.py in _fit(self, X, y, **fit_params)
399 _brackets_ids = list(reversed(sorted(SHAs)))
400
--> 401 _SHAs = await asyncio.gather(
402 *[SHAs[b]._fit(X, y, **fit_params) for b in _brackets_ids]
403 )

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params)
661
662 with context:
--> 663 results = await fit(
664 self.estimator,
665 self._get_params(),

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
475 A history of all models scores over time
476 """
--> 477 return await _fit(
478 model,
479 params,

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
266 # async for future, result in seq:
267 for _i in itertools.count():
--> 268 metas = await client.gather(new_scores)
269
270 if log_delay and _i % int(log_delay) == 0:

/opt/conda/lib/python3.8/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1846 exc = CancelledError(key)
1847 else:
-> 1848 raise exception.with_traceback(traceback)
1849 raise exc
1850 if errors == "skip":

/opt/conda/lib/python3.8/site-packages/dask_ml/model_selection/_incremental.py in _partial_fit()
101 if len(X):
102 model = deepcopy(model)
--> 103 model.partial_fit(X, y, **(fit_params or {}))
104
105 meta = dict(meta)

AttributeError: 'DaskLGBMRegressor' object has no attribute 'partial_fit'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions