Skip to content

Commit 4024cf1

Browse files
committed
add test suite
1 parent 5fe190e commit 4024cf1

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.Transformer
22-
import org.apache.spark.ml.attribute.{NominalAttribute, BinaryAttribute}
22+
import org.apache.spark.ml.attribute.NominalAttribute
2323
import org.apache.spark.ml.param._
2424
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
2525
import org.apache.spark.ml.util.SchemaUtils
@@ -29,18 +29,17 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
2929

3030
/**
3131
* :: AlphaComponent ::
32-
* Binarize a column of continuous features given a threshold.
32+
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
3333
*/
3434
@AlphaComponent
3535
final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
3636

3737
/**
38-
* Param for threshold used to binarize continuous features.
39-
* The features greater than the threshold, will be binarized to 1.0.
40-
* The features equal to or less than the threshold, will be binarized to 0.0.
38+
* Parameter for mapping continuous features into buckets.
4139
* @group param
4240
*/
43-
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets", "")
41+
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets",
42+
"Map continuous features into buckets.")
4443

4544
/** @group getParam */
4645
def getBuckets: Array[Double] = $(buckets)
@@ -64,7 +63,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
6463
}
6564

6665
/**
67-
* Binary searching in several bins to place each data point.
66+
* Binary searching in several buckets to place each data point.
6867
*/
6968
private def binarySearchForBins(splits: Array[Double], feature: Double): Double = {
7069
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
22+
import org.scalatest.FunSuite
23+
24+
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
25+
26+
test("Bucket continuous features with setter") {
27+
val sqlContext = new SQLContext(sc)
28+
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
29+
val buckets = Array(-0.5, 0.0, 0.5)
30+
val bucketizedData = Array(2.0, 0.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0)
31+
val dataFrame: DataFrame = sqlContext.createDataFrame(
32+
data.zip(bucketizedData)).toDF("feature", "expected")
33+
34+
val bucketizer: Bucketizer = new Bucketizer()
35+
.setInputCol("feature")
36+
.setOutputCol("result")
37+
.setBuckets(buckets)
38+
39+
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
40+
case Row(x: Double, y: Double) =>
41+
assert(x === y, "The feature value is not correct after bucketing.")
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)