Skip to content

Commit d393b5c

Browse files
committed
fixed bug in Pipeline (typo from last commit). updated examples for CV and Params for spark.ml
1 parent c38469c commit d393b5c

File tree

7 files changed

+19
-19
lines changed

7 files changed

+19
-19
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ public static void main(String[] args) {
107107

108108
// Run cross-validation, and choose the best set of parameters.
109109
CrossValidatorModel cvModel = crossval.fit(training);
110-
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
111-
Model lrModel = cvModel.bestModel();
112110

113111
// Prepare test documents, which are unlabeled.
114112
List<Document> localTest = Lists.newArrayList(

examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ public static void main(String[] args) {
4747
JavaSQLContext jsql = new JavaSQLContext(jsc);
4848

4949
// Prepare training data.
50-
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
51-
// into SchemaRDDs, where it uses the case class metadata to infer the schema.
50+
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
51+
// into SchemaRDDs, where it uses the bean metadata to infer the schema.
5252
List<LabeledPoint> localTraining = Lists.newArrayList(
5353
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
5454
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
@@ -75,13 +75,13 @@ public static void main(String[] args) {
7575

7676
// We may alternatively specify parameters using a ParamMap.
7777
ParamMap paramMap = new ParamMap();
78-
paramMap.put(lr.maxIter(), 20); // Specify 1 Param.
78+
paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
7979
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
80-
paramMap.put(lr.regParam(), 0.1);
80+
paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
8181

8282
// One can also combine ParamMaps.
8383
ParamMap paramMap2 = new ParamMap();
84-
paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name.
84+
paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
8585
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
8686

8787
// Now learn a new model using the paramMapCombined parameters.

examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ object CrossValidatorExample {
9191

9292
// Run cross-validation, and choose the best set of parameters.
9393
val cvModel = crossval.fit(training)
94-
// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
95-
val lrModel = cvModel.bestModel
9694

9795
// Prepare test documents, which are unlabeled.
9896
val test = sparkContext.parallelize(Seq(

examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.examples.ml
1919

2020
import org.apache.spark.{SparkConf, SparkContext}
21+
import org.apache.spark.SparkContext._
2122
import org.apache.spark.ml.classification.LogisticRegression
2223
import org.apache.spark.ml.param.ParamMap
2324
import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -40,8 +41,8 @@ object SimpleParamsExample {
4041
import sqlContext._
4142

4243
// Prepare training data.
43-
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
44-
// into SchemaRDDs, where it uses the case class metadata to infer the schema.
44+
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
45+
// into SchemaRDDs, where it uses the bean metadata to infer the schema.
4546
val training = sparkContext.parallelize(Seq(
4647
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
4748
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
@@ -69,10 +70,10 @@ object SimpleParamsExample {
6970
// which supports several methods for specifying parameters.
7071
val paramMap = ParamMap(lr.maxIter -> 20)
7172
paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
72-
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params.
73+
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
7374

7475
// One can also combine ParamMaps.
75-
val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name.
76+
val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
7677
val paramMapCombined = paramMap ++ paramMap2
7778

7879
// Now learn a new model using the paramMapCombined parameters.

examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml
2020
import scala.beans.BeanInfo
2121

2222
import org.apache.spark.{SparkConf, SparkContext}
23+
import org.apache.spark.SparkContext._
2324
import org.apache.spark.ml.Pipeline
2425
import org.apache.spark.ml.classification.LogisticRegression
2526
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,14 @@ class PipelineModel private[ml] (
163163

164164
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
165165
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
166-
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
166+
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
167167
transformSchema(dataset.schema, map, logging = true)
168168
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
169169
}
170170

171171
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
172172
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
173-
val map = (fittingParamMap ++ this.paramMap) ++ fittingParamMap
173+
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
174174
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
175175
}
176176
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
package org.apache.spark.ml.param
1919

20-
import java.lang.reflect.Modifier
21-
22-
import org.apache.spark.annotation.AlphaComponent
23-
2420
import scala.annotation.varargs
2521
import scala.collection.mutable
2622

23+
import java.lang.reflect.Modifier
24+
25+
import org.apache.spark.annotation.AlphaComponent
2726
import org.apache.spark.ml.Identifiable
2827

2928
/**
@@ -223,6 +222,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
223222
* Puts a list of param pairs (overwrites if the input params exists).
224223
* Not usable from Java
225224
*/
225+
@varargs
226226
def put(paramPairs: ParamPair[_]*): this.type = {
227227
paramPairs.foreach { p =>
228228
put(p.param.asInstanceOf[Param[Any]], p.value)
@@ -283,6 +283,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
283283
* where the latter overwrites this if there exists conflicts.
284284
*/
285285
def ++(other: ParamMap): ParamMap = {
286+
// TODO: Provide a better method name for Java users.
286287
new ParamMap(this.map ++ other.map)
287288
}
288289

@@ -291,6 +292,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
291292
* Adds all parameters from the input param map into this param map.
292293
*/
293294
def ++=(other: ParamMap): this.type = {
295+
// TODO: Provide a better method name for Java users.
294296
this.map ++= other.map
295297
this
296298
}

0 commit comments

Comments
 (0)