Skip to content

Commit bbb9cb3

Browse files
authored
Add sklearn contrib (#44)
1 parent 0a7d476 commit bbb9cb3

File tree

7 files changed

+122
-0
lines changed

7 files changed

+122
-0
lines changed

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@
232232
'torch': ('https://pytorch.org/docs/stable', None),
233233
'numpy': ('https://numpy.org/doc/stable', None),
234234
'optuna': ('https://optuna.readthedocs.io/en/latest', None),
235+
'sklearn': ('https://scikit-learn.org/stable', None),
235236
}
236237

237238
autoclass_content = 'both'

docs/source/contrib/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ External Package Integrations
99
torch
1010
numpy
1111
optuna
12+
sklearn

docs/source/contrib/sklearn.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Scikit-learn
2+
============
3+
.. automodule:: class_resolver.contrib.sklearn
4+
:members:

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ optuna =
8282
optuna
8383
numpy =
8484
numpy
85+
sklearn =
86+
scikit-learn
8587

8688
[options.entry_points]
8789
class_resolver_demo =
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
Scikit-learn is a generic machine learning package with implementations of
3+
algorithms for classification, regression, dimensionality reduction, clustering,
4+
as well as other generic tooling.
5+
6+
The ``class-resolver`` provides several class resolvers for instantiating various
7+
implementations, such as those of linear models.
8+
""" # noqa:D205,D400
9+
10+
from sklearn.base import BaseEstimator
11+
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
12+
from sklearn.linear_model import (
13+
LogisticRegression,
14+
LogisticRegressionCV,
15+
PassiveAggressiveClassifier,
16+
Perceptron,
17+
RidgeClassifier,
18+
RidgeClassifierCV,
19+
SGDClassifier,
20+
)
21+
from sklearn.tree import DecisionTreeClassifier
22+
23+
from ..api import ClassResolver
24+
25+
__all__ = [
26+
"classifier_resolver",
27+
]
28+
29+
classifier_resolver: ClassResolver[BaseEstimator] = ClassResolver(
30+
[
31+
LogisticRegression,
32+
LogisticRegressionCV,
33+
PassiveAggressiveClassifier,
34+
Perceptron,
35+
RidgeClassifier,
36+
RidgeClassifierCV,
37+
SGDClassifier,
38+
DecisionTreeClassifier,
39+
RandomForestClassifier,
40+
GradientBoostingClassifier,
41+
],
42+
base=BaseEstimator,
43+
base_as_suffix=False,
44+
default=LogisticRegression,
45+
)
46+
"""A resolver for classifiers.
47+
48+
The default value is :class:`sklearn.linear_model.LogisticRegression`.
49+
This resolver can be used like in the following:
50+
51+
.. code-block:: python
52+
53+
from sklearn import datasets
54+
from sklearn.model_selection import train_test_split
55+
56+
from class_resolver.contrib.sklearn import classifier_resolver
57+
58+
# Prepare a dataset
59+
x, y = datasets.load_iris(return_X_y=True)
60+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)
61+
62+
# Lookup with a string
63+
classifier = classifier_resolver.make("LogisticRegression")
64+
classifier.fit(x_train, y_train)
65+
assert 0.7 < classifier.score(x_test, y_test)
66+
67+
# Default lookup gives logistic regression
68+
classifier = classifier_resolver.make(None)
69+
classifier.fit(x_train, y_train)
70+
assert 0.7 < classifier.score(x_test, y_test)
71+
72+
.. seealso:: https://scikit-learn.org/stable/modules/classes.html#linear-classifiers
73+
"""

tests/test_contrib/test_sklearn.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""Tests for the scikit-learn contribution module."""
4+
5+
import unittest
6+
7+
try:
8+
import sklearn
9+
except ImportError: # pragma: no cover
10+
sklearn = None # pragma: no cover
11+
12+
13+
@unittest.skipUnless(sklearn, "Can not test sklearn contrib without ``pip install scikit-learn``.")
14+
class TestSklearn(unittest.TestCase):
15+
"""Test for the scikit-learn contribution module."""
16+
17+
def test_classifier_resolver(self):
18+
"""Tests for the classifier resolver."""
19+
from sklearn import datasets
20+
from sklearn.model_selection import train_test_split
21+
22+
from class_resolver.contrib.sklearn import classifier_resolver
23+
24+
x, y = datasets.load_iris(return_X_y=True)
25+
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)
26+
27+
classifier = classifier_resolver.make("LogisticRegression")
28+
classifier.fit(x_train, y_train)
29+
accuracy = classifier.score(x_test, y_test)
30+
self.assertLessEqual(0.0, accuracy)
31+
self.assertGreaterEqual(1.0, accuracy)
32+
33+
for name, cls in classifier_resolver.lookup_dict.items():
34+
with self.subTest(name=name):
35+
classifier = cls()
36+
classifier.fit(x_train, y_train)
37+
accuracy = classifier.score(x_test, y_test)
38+
self.assertLessEqual(0.0, accuracy)
39+
self.assertGreaterEqual(1.0, accuracy)

tox.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ extras =
3636
numpy
3737
optuna
3838
docdata
39+
sklearn
3940

4041
[testenv:coverage-clean]
4142
deps = coverage
@@ -114,6 +115,7 @@ extras =
114115
torch
115116
numpy
116117
optuna
118+
sklearn
117119
commands =
118120
python -m sphinx -W -b html -d docs/build/doctrees docs/source docs/build/html
119121

0 commit comments

Comments
 (0)