@@ -25,46 +25,55 @@ import org.apache.spark.ml.util.SchemaUtils
2525import org .apache .spark .ml .{Estimator , Model }
2626import org .apache .spark .sql ._
2727import org .apache .spark .sql .functions ._
28- import org .apache .spark .sql .types .{DoubleType , StructType }
28+ import org .apache .spark .sql .types .{DoubleType , StructField , StructType }
2929
3030/**
3131 * :: AlphaComponent ::
3232 * `Bucketizer` maps a column of continuous features to a column of feature buckets.
3333 */
3434@ AlphaComponent
35- final class Bucketizer (override val parent : Estimator [Bucketizer ] = null )
35+ private [ml] final class Bucketizer (override val parent : Estimator [Bucketizer ])
3636 extends Model [Bucketizer ] with HasInputCol with HasOutputCol {
3737
38- /**
39- * The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
40- * way.
41- */
42- private def checkBuckets (buckets : Array [Double ]): Boolean = {
43- if (buckets.size == 0 ) false
44- else if (buckets.size == 1 ) true
45- else {
46- buckets.foldLeft((true , Double .MinValue )) { case ((validator, prevValue), currValue) =>
47- if (validator & prevValue <= currValue) {
48- (true , currValue)
49- } else {
50- (false , currValue)
51- }
52- }._1
53- }
54- }
38+ def this () = this (null )
5539
5640 /**
57- * Parameter for mapping continuous features into buckets.
41+ * Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42+ * A bucket defined by splits x,y holds values in the range (x,y].
5843 * @group param
5944 */
60- val buckets : Param [Array [Double ]] = new Param [Array [Double ]](this , " buckets" ,
61- " Split points for mapping continuous features into buckets." , checkBuckets)
45+ val splits : Param [Array [Double ]] = new Param [Array [Double ]](this , " splits" ,
46+ " Split points for mapping continuous features into buckets. With n splits, there are n+1" +
47+ " buckets. A bucket defined by splits x,y holds values in the range (x,y]." ,
48+ Bucketizer .checkSplits)
6249
6350 /** @group getParam */
64- def getBuckets : Array [Double ] = $(buckets )
51+ def getSplits : Array [Double ] = $(splits )
6552
6653 /** @group setParam */
67- def setBuckets (value : Array [Double ]): this .type = set(buckets, value)
54+ def setSplits (value : Array [Double ]): this .type = set(splits, value)
55+
56+ /** @group Param */
57+ val lowerInclusive : BooleanParam = new BooleanParam (this , " lowerInclusive" ,
58+ " An indicator of the inclusiveness of negative infinite." )
59+ setDefault(lowerInclusive -> true )
60+
61+ /** @group getParam */
62+ def getLowerInclusive : Boolean = $(lowerInclusive)
63+
64+ /** @group setParam */
65+ def setLowerInclusive (value : Boolean ): this .type = set(lowerInclusive, value)
66+
67+ /** @group Param */
68+ val upperInclusive : BooleanParam = new BooleanParam (this , " upperInclusive" ,
69+ " An indicator of the inclusiveness of positive infinite." )
70+ setDefault(upperInclusive -> true )
71+
72+ /** @group getParam */
73+ def getUpperInclusive : Boolean = $(upperInclusive)
74+
75+ /** @group setParam */
76+ def setUpperInclusive (value : Boolean ): this .type = set(upperInclusive, value)
6877
6978 /** @group setParam */
7079 def setInputCol (value : String ): this .type = set(inputCol, value)
@@ -74,45 +83,68 @@ final class Bucketizer(override val parent: Estimator[Bucketizer] = null)
7483
7584 override def transform (dataset : DataFrame ): DataFrame = {
7685 transformSchema(dataset.schema)
77- val bucketizer = udf { feature : Double => binarySearchForBuckets($(buckets), feature) }
78- val outputColName = $(outputCol)
79- val metadata = NominalAttribute .defaultAttr
80- .withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
81- dataset.select(col(" *" ), bucketizer(dataset($(inputCol))).as(outputColName, metadata))
86+ val wrappedSplits = Array (Double .MinValue ) ++ $(splits) ++ Array (Double .MaxValue )
87+ val bucketizer = udf { feature : Double =>
88+ Bucketizer .binarySearchForBuckets(wrappedSplits, feature) }
89+ val newCol = bucketizer(dataset($(inputCol)))
90+ val newField = prepOutputField(dataset.schema)
91+ dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
92+ }
93+
94+ private def prepOutputField (schema : StructType ): StructField = {
95+ val attr = new NominalAttribute (
96+ name = Some ($(outputCol)),
97+ isOrdinal = Some (true ),
98+ numValues = Some ($(splits).size),
99+ values = Some ($(splits).map(_.toString)))
100+
101+ attr.toStructField()
102+ }
103+
104+ override def transformSchema (schema : StructType ): StructType = {
105+ SchemaUtils .checkColumnType(schema, $(inputCol), DoubleType )
106+ require(schema.fields.forall(_.name != $(outputCol)),
107+ s " Output column ${$(outputCol)} already exists. " )
108+ StructType (schema.fields :+ prepOutputField(schema))
109+ }
110+ }
111+
112+ object Bucketizer {
113+ /**
114+ * The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
115+ * increasing way.
116+ */
117+ private def checkSplits (splits : Array [Double ]): Boolean = {
118+ if (splits.size == 0 ) false
119+ else if (splits.size == 1 ) true
120+ else {
121+ splits.foldLeft((true , Double .MinValue )) { case ((validator, prevValue), currValue) =>
122+ if (validator && prevValue < currValue) {
123+ (true , currValue)
124+ } else {
125+ (false , currValue)
126+ }
127+ }._1
128+ }
82129 }
83130
84131 /**
85132 * Binary searching in several buckets to place each data point.
86133 */
87- private def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
88- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
134+ private [feature] def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
89135 var left = 0
90- var right = wrappedSplits .length - 2
136+ var right = splits .length - 2
91137 while (left <= right) {
92138 val mid = left + (right - left) / 2
93- val split = wrappedSplits (mid)
94- if ((feature > split) && (feature <= wrappedSplits (mid + 1 ))) {
139+ val split = splits (mid)
140+ if ((feature > split) && (feature <= splits (mid + 1 ))) {
95141 return mid
96142 } else if (feature <= split) {
97143 right = mid - 1
98144 } else {
99145 left = mid + 1
100146 }
101147 }
102- - 1
103- }
104-
105- override def transformSchema (schema : StructType ): StructType = {
106- SchemaUtils .checkColumnType(schema, $(inputCol), DoubleType )
107-
108- val inputFields = schema.fields
109- val outputColName = $(outputCol)
110-
111- require(inputFields.forall(_.name != outputColName),
112- s " Output column $outputColName already exists. " )
113-
114- val attr = NominalAttribute .defaultAttr.withName(outputColName)
115- val outputFields = inputFields :+ attr.toStructField()
116- StructType (outputFields)
148+ throw new Exception (" Failed to find a bucket." )
117149 }
118150}
0 commit comments