Skip to content

Commit f6be730

Browse files
committed
add feature discretizer
1 parent 271c4c6 commit f6be730

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.annotation.AlphaComponent
23+
import org.apache.spark.ml.Transformer
24+
import org.apache.spark.ml.attribute.NominalAttribute
25+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
26+
import org.apache.spark.ml.param.{IntParam, ParamMap}
27+
import org.apache.spark.ml.util.SchemaUtils
28+
import org.apache.spark.rdd.RDD
29+
import org.apache.spark.sql.{DataFrame, Row}
30+
import org.apache.spark.sql.functions._
31+
import org.apache.spark.sql.types.{DoubleType, StructType}
32+
import org.apache.spark.util.random.XORShiftRandom
33+
34+
/**
35+
* :: AlphaComponent ::
36+
* `FeatureDiscretizer` takes a column with continuous features and outputs a column with binned
37+
* categorical features.
38+
*/
39+
@AlphaComponent
40+
class FeatureDiscretizer extends Transformer with HasInputCol with HasOutputCol {
41+
42+
/**
43+
* Number of bins to collect data points, which should be a positive integer.
44+
* @group param
45+
*/
46+
val numBins = new IntParam(this, "numBins",
47+
"Number of bins to collect data points, which should be a positive integer.")
48+
setDefault(numBins -> 1)
49+
50+
/** @group getParam */
51+
def getNumBins: Int = getOrDefault(numBins)
52+
53+
/** @group setParam */
54+
def setNumBins(value: Int): this.type = set(numBins, value)
55+
56+
/** @group setParam */
57+
def setInputCol(value: String): this.type = set(inputCol, value)
58+
59+
/** @group setParam */
60+
def setOutputCol(value: String): this.type = set(outputCol, value)
61+
62+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
63+
val map = extractParamMap(paramMap)
64+
assert(map(numBins) >= 1, "Number of bins should be a positive integer.")
65+
SchemaUtils.checkColumnType(schema, map(inputCol), DoubleType)
66+
val inputFields = schema.fields
67+
val outputColName = map(outputCol)
68+
require(inputFields.forall(_.name != outputColName),
69+
s"Output column $outputColName already exists.")
70+
val attr = NominalAttribute.defaultAttr.withName(outputColName)
71+
val outputFields = inputFields :+ attr.toStructField()
72+
StructType(outputFields)
73+
}
74+
75+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
76+
transformSchema(dataset.schema, paramMap)
77+
val map = extractParamMap(paramMap)
78+
val input = dataset.select(map(inputCol)).map { case Row(feature: Double) => feature }
79+
val samples = getSampledInput(input, map(numBins))
80+
val splits = findSplits(samples, map(numBins) - 1)
81+
val discretizer = udf { feature: Double => binarySearchForBins(splits, feature) }
82+
val outputColName = map(outputCol)
83+
val metadata = NominalAttribute.defaultAttr
84+
.withName(outputColName).withValues(splits.map(_.toString)).toMetadata()
85+
dataset.select(col("*"),
86+
discretizer(dataset(map(inputCol))).as(outputColName, metadata))
87+
}
88+
89+
/**
90+
* Binary searching in several bins to place each data point.
91+
*/
92+
private def binarySearchForBins(splits: Array[Double], feature: Double): Double = {
93+
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
94+
var left = 0
95+
var right = wrappedSplits.length - 2
96+
while (left <= right) {
97+
val mid = left + (right - left) / 2
98+
val split = wrappedSplits(mid)
99+
if ((feature > split) && (feature <= wrappedSplits(mid + 1))) {
100+
return mid
101+
} else if (feature <= split) {
102+
right = mid - 1
103+
} else {
104+
left = mid + 1
105+
}
106+
}
107+
-1
108+
}
109+
110+
/**
111+
* Sampling from the given dataset to collect quantile statistics.
112+
*/
113+
private def getSampledInput(dataset: RDD[Double], numBins: Int): Array[Double] = {
114+
val totalSamples = dataset.count()
115+
assert(totalSamples > 0)
116+
val requiredSamples = math.max(numBins * numBins, 10000)
117+
val fraction = math.min(requiredSamples / dataset.count(), 1.0)
118+
dataset.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
119+
}
120+
121+
/**
122+
* Compute split points with respect to the sample distribution.
123+
*/
124+
private def findSplits(samples: Array[Double], numSplits: Int): Array[Double] = {
125+
val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
126+
m + ((x, m.getOrElse(x, 0) + 1))
127+
}
128+
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
129+
val possibleSplits = valueCounts.length
130+
if (possibleSplits <= numSplits) {
131+
valueCounts.map(_._1)
132+
} else {
133+
val stride: Double = samples.length.toDouble / (numSplits + 1)
134+
val splitsBuilder = mutable.ArrayBuilder.make[Double]
135+
var index = 1
136+
// currentCount: sum of counts of values that have been visited
137+
var currentCount = valueCounts(0)._2
138+
// targetCount: target value for `currentCount`.
139+
// If `currentCount` is closest value to `targetCount`,
140+
// then current value is a split threshold.
141+
// After finding a split threshold, `targetCount` is added by stride.
142+
var targetCount = stride
143+
while (index < valueCounts.length) {
144+
val previousCount = currentCount
145+
currentCount += valueCounts(index)._2
146+
val previousGap = math.abs(previousCount - targetCount)
147+
val currentGap = math.abs(currentCount - targetCount)
148+
// If adding count of current value to currentCount
149+
// makes the gap between currentCount and targetCount smaller,
150+
// previous value is a split threshold.
151+
if (previousGap < currentGap) {
152+
splitsBuilder += valueCounts(index - 1)._1
153+
targetCount += stride
154+
}
155+
index += 1
156+
}
157+
splitsBuilder.result()
158+
}
159+
}
160+
}
161+
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.util.Random
21+
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
25+
import org.apache.spark.mllib.util.MLlibTestSparkContext
26+
import org.apache.spark.sql.{Row, SQLContext}
27+
28+
class FeatureDiscretizerSuite extends FunSuite with MLlibTestSparkContext {
29+
30+
test("Test feature discretizer") {
31+
val sqlContext = new SQLContext(sc)
32+
import sqlContext.implicits._
33+
34+
val random = new Random(47)
35+
val data = Array.fill[Double](10)(random.nextDouble())
36+
val result = Array[Double](2, 1, 0, 0, 1, 1, 1, 0, 2, 2)
37+
38+
val df = sc.parallelize(data.zip(result)).toDF("data", "expected")
39+
40+
val discretizer = new FeatureDiscretizer()
41+
.setInputCol("data")
42+
.setOutputCol("result")
43+
.setNumBins(3)
44+
45+
val res = discretizer.transform(df)
46+
res.select("expected", "result").collect().foreach {
47+
case Row(expected: Double, result: Double) => assert(expected == result)
48+
}
49+
50+
val attr = Attribute.fromStructField(res.schema("result")).asInstanceOf[NominalAttribute]
51+
assert(attr.values.get === Array("0.18847866977771732","0.5309454508634242"))
52+
}
53+
}

0 commit comments

Comments
 (0)