Skip to content

Commit d710991

Browse files
committed
[SPARK-5769] Set params in constructors and in setParams in Python ML pipelines
This PR allow Python users to set params in constructors and in setParams, where we use decorator `keyword_only` to force keyword arguments. The trade-off is discussed in the design doc of SPARK-4586. Generated doc: ![screen shot 2015-02-12 at 3 06 58 am](https://cloud.githubusercontent.com/assets/829644/6166491/9cfcd06a-b265-11e4-99ea-473d866634fc.png) CC: davies rxin Author: Xiangrui Meng <[email protected]> Closes #4564 from mengxr/py-pipeline-kw and squashes the following commits: fedf720 [Xiangrui Meng] use toDF d565f2c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into py-pipeline-kw cbc15d3 [Xiangrui Meng] fix style 5032097 [Xiangrui Meng] update pipeline signature 950774e [Xiangrui Meng] simplify keyword_only and update constructor/setParams signatures fdde5fc [Xiangrui Meng] fix style c9384b8 [Xiangrui Meng] fix sphinx doc 8e59180 [Xiangrui Meng] add setParams and make constructors take params, where we force keyword args (cherry picked from commit cd4a153) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 4e099d7 commit d710991

File tree

7 files changed

+153
-53
lines changed

7 files changed

+153
-53
lines changed

examples/src/main/python/ml/simple_text_classification_pipeline.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,43 +36,33 @@
3636
sqlCtx = SQLContext(sc)
3737

3838
# Prepare training documents, which are labeled.
39-
LabeledDocument = Row('id', 'text', 'label')
40-
training = sqlCtx.inferSchema(
41-
sc.parallelize([(0L, "a b c d e spark", 1.0),
42-
(1L, "b d", 0.0),
43-
(2L, "spark f g h", 1.0),
44-
(3L, "hadoop mapreduce", 0.0)])
45-
.map(lambda x: LabeledDocument(*x)))
39+
LabeledDocument = Row("id", "text", "label")
40+
training = sc.parallelize([(0L, "a b c d e spark", 1.0),
41+
(1L, "b d", 0.0),
42+
(2L, "spark f g h", 1.0),
43+
(3L, "hadoop mapreduce", 0.0)]) \
44+
.map(lambda x: LabeledDocument(*x)).toDF()
4645

4746
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
48-
tokenizer = Tokenizer() \
49-
.setInputCol("text") \
50-
.setOutputCol("words")
51-
hashingTF = HashingTF() \
52-
.setInputCol(tokenizer.getOutputCol()) \
53-
.setOutputCol("features")
54-
lr = LogisticRegression() \
55-
.setMaxIter(10) \
56-
.setRegParam(0.01)
57-
pipeline = Pipeline() \
58-
.setStages([tokenizer, hashingTF, lr])
47+
tokenizer = Tokenizer(inputCol="text", outputCol="words")
48+
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
49+
lr = LogisticRegression(maxIter=10, regParam=0.01)
50+
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
5951

6052
# Fit the pipeline to training documents.
6153
model = pipeline.fit(training)
6254

6355
# Prepare test documents, which are unlabeled.
64-
Document = Row('id', 'text')
65-
test = sqlCtx.inferSchema(
66-
sc.parallelize([(4L, "spark i j k"),
67-
(5L, "l m n"),
68-
(6L, "mapreduce spark"),
69-
(7L, "apache hadoop")])
70-
.map(lambda x: Document(*x)))
56+
Document = Row("id", "text")
57+
test = sc.parallelize([(4L, "spark i j k"),
58+
(5L, "l m n"),
59+
(6L, "mapreduce spark"),
60+
(7L, "apache hadoop")]) \
61+
.map(lambda x: Document(*x)).toDF()
7162

7263
# Make predictions on test documents and print columns of interest.
7364
prediction = model.transform(test)
74-
prediction.registerTempTable("prediction")
75-
selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
65+
selected = prediction.select("id", "text", "prediction")
7666
for row in selected.collect():
7767
print row
7868

python/docs/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@
9797
# If true, keep warnings as "system message" paragraphs in the built documents.
9898
#keep_warnings = False
9999

100+
# -- Options for autodoc --------------------------------------------------
101+
102+
# Look at the first line of the docstring for function and method signatures.
103+
autodoc_docstring_signature = True
100104

101105
# -- Options for HTML output ----------------------------------------------
102106

python/pyspark/ml/classification.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from pyspark.ml.util import inherit_doc
18+
from pyspark.ml.util import inherit_doc, keyword_only
1919
from pyspark.ml.wrapper import JavaEstimator, JavaModel
2020
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
2121
HasRegParam
@@ -32,22 +32,46 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
3232
3333
>>> from pyspark.sql import Row
3434
>>> from pyspark.mllib.linalg import Vectors
35-
>>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
36-
Row(label=1.0, features=Vectors.dense(1.0)), \
37-
Row(label=0.0, features=Vectors.sparse(1, [], []))]))
38-
>>> lr = LogisticRegression() \
39-
.setMaxIter(5) \
40-
.setRegParam(0.01)
41-
>>> model = lr.fit(dataset)
42-
>>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
35+
>>> df = sc.parallelize([
36+
... Row(label=1.0, features=Vectors.dense(1.0)),
37+
... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
38+
>>> lr = LogisticRegression(maxIter=5, regParam=0.01)
39+
>>> model = lr.fit(df)
40+
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
4341
>>> print model.transform(test0).head().prediction
4442
0.0
45-
>>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
43+
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
4644
>>> print model.transform(test1).head().prediction
4745
1.0
46+
>>> lr.setParams("vector")
47+
Traceback (most recent call last):
48+
...
49+
TypeError: Method setParams forces keyword arguments.
4850
"""
4951
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
5052

53+
@keyword_only
54+
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
55+
maxIter=100, regParam=0.1):
56+
"""
57+
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
58+
maxIter=100, regParam=0.1)
59+
"""
60+
super(LogisticRegression, self).__init__()
61+
kwargs = self.__init__._input_kwargs
62+
self.setParams(**kwargs)
63+
64+
@keyword_only
65+
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
66+
maxIter=100, regParam=0.1):
67+
"""
68+
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
69+
maxIter=100, regParam=0.1)
70+
Sets params for logistic regression.
71+
"""
72+
kwargs = self.setParams._input_kwargs
73+
return self._set_params(**kwargs)
74+
5175
def _create_model(self, java_model):
5276
return LogisticRegressionModel(java_model)
5377

python/pyspark/ml/feature.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
19-
from pyspark.ml.util import inherit_doc
19+
from pyspark.ml.util import inherit_doc, keyword_only
2020
from pyspark.ml.wrapper import JavaTransformer
2121

2222
__all__ = ['Tokenizer', 'HashingTF']
@@ -29,18 +29,45 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
2929
splits it by white spaces.
3030
3131
>>> from pyspark.sql import Row
32-
>>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
33-
>>> tokenizer = Tokenizer() \
34-
.setInputCol("text") \
35-
.setOutputCol("words")
36-
>>> print tokenizer.transform(dataset).head()
32+
>>> df = sc.parallelize([Row(text="a b c")]).toDF()
33+
>>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
34+
>>> print tokenizer.transform(df).head()
3735
Row(text=u'a b c', words=[u'a', u'b', u'c'])
38-
>>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
36+
>>> # Change a parameter.
37+
>>> print tokenizer.setParams(outputCol="tokens").transform(df).head()
3938
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
39+
>>> # Temporarily modify a parameter.
40+
>>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
41+
Row(text=u'a b c', words=[u'a', u'b', u'c'])
42+
>>> print tokenizer.transform(df).head()
43+
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
44+
>>> # Must use keyword arguments to specify params.
45+
>>> tokenizer.setParams("text")
46+
Traceback (most recent call last):
47+
...
48+
TypeError: Method setParams forces keyword arguments.
4049
"""
4150

4251
_java_class = "org.apache.spark.ml.feature.Tokenizer"
4352

53+
@keyword_only
54+
def __init__(self, inputCol="input", outputCol="output"):
55+
"""
56+
__init__(self, inputCol="input", outputCol="output")
57+
"""
58+
super(Tokenizer, self).__init__()
59+
kwargs = self.__init__._input_kwargs
60+
self.setParams(**kwargs)
61+
62+
@keyword_only
63+
def setParams(self, inputCol="input", outputCol="output"):
64+
"""
65+
setParams(self, inputCol="input", outputCol="output")
66+
Sets params for this Tokenizer.
67+
"""
68+
kwargs = self.setParams._input_kwargs
69+
return self._set_params(**kwargs)
70+
4471

4572
@inherit_doc
4673
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
@@ -49,20 +76,37 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
4976
hashing trick.
5077
5178
>>> from pyspark.sql import Row
52-
>>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
53-
>>> hashingTF = HashingTF() \
54-
.setNumFeatures(10) \
55-
.setInputCol("words") \
56-
.setOutputCol("features")
57-
>>> print hashingTF.transform(dataset).head().features
79+
>>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF()
80+
>>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
81+
>>> print hashingTF.transform(df).head().features
82+
(10,[7,8,9],[1.0,1.0,1.0])
83+
>>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
5884
(10,[7,8,9],[1.0,1.0,1.0])
5985
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
60-
>>> print hashingTF.transform(dataset, params).head().vector
86+
>>> print hashingTF.transform(df, params).head().vector
6187
(5,[2,3,4],[1.0,1.0,1.0])
6288
"""
6389

6490
_java_class = "org.apache.spark.ml.feature.HashingTF"
6591

92+
@keyword_only
93+
def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
94+
"""
95+
__init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
96+
"""
97+
super(HashingTF, self).__init__()
98+
kwargs = self.__init__._input_kwargs
99+
self.setParams(**kwargs)
100+
101+
@keyword_only
102+
def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
103+
"""
104+
setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
105+
Sets params for this HashingTF.
106+
"""
107+
kwargs = self.setParams._input_kwargs
108+
return self._set_params(**kwargs)
109+
66110

67111
if __name__ == "__main__":
68112
import doctest

python/pyspark/ml/param/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,11 @@ def _dummy():
8080
dummy = Params()
8181
dummy.uid = "undefined"
8282
return dummy
83+
84+
def _set_params(self, **kwargs):
85+
"""
86+
Sets params.
87+
"""
88+
for param, value in kwargs.iteritems():
89+
self.paramMap[getattr(self, param)] = value
90+
return self

python/pyspark/ml/pipeline.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from abc import ABCMeta, abstractmethod
1919

2020
from pyspark.ml.param import Param, Params
21-
from pyspark.ml.util import inherit_doc
21+
from pyspark.ml.util import inherit_doc, keyword_only
2222

2323

2424
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
@@ -89,10 +89,16 @@ class Pipeline(Estimator):
8989
identity transformer.
9090
"""
9191

92-
def __init__(self):
92+
@keyword_only
93+
def __init__(self, stages=[]):
94+
"""
95+
__init__(self, stages=[])
96+
"""
9397
super(Pipeline, self).__init__()
9498
#: Param for pipeline stages.
9599
self.stages = Param(self, "stages", "pipeline stages")
100+
kwargs = self.__init__._input_kwargs
101+
self.setParams(**kwargs)
96102

97103
def setStages(self, value):
98104
"""
@@ -110,6 +116,15 @@ def getStages(self):
110116
if self.stages in self.paramMap:
111117
return self.paramMap[self.stages]
112118

119+
@keyword_only
120+
def setParams(self, stages=[]):
121+
"""
122+
setParams(self, stages=[])
123+
Sets params for Pipeline.
124+
"""
125+
kwargs = self.setParams._input_kwargs
126+
return self._set_params(**kwargs)
127+
113128
def fit(self, dataset, params={}):
114129
paramMap = self._merge_params(params)
115130
stages = paramMap[self.stages]

python/pyspark/ml/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
from functools import wraps
1819
import uuid
1920

2021

@@ -32,6 +33,20 @@ def inherit_doc(cls):
3233
return cls
3334

3435

36+
def keyword_only(func):
37+
"""
38+
A decorator that forces keyword arguments in the wrapped method
39+
and saves actual input keyword arguments in `_input_kwargs`.
40+
"""
41+
@wraps(func)
42+
def wrapper(*args, **kwargs):
43+
if len(args) > 1:
44+
raise TypeError("Method %s forces keyword arguments." % func.__name__)
45+
wrapper._input_kwargs = kwargs
46+
return func(*args, **kwargs)
47+
return wrapper
48+
49+
3550
class Identifiable(object):
3651
"""
3752
Object with a unique ID.

0 commit comments

Comments
 (0)