Skip to content

Commit 4626614

Browse files
committed
init
init py nit
1 parent f5e3302 commit 4626614

File tree

3 files changed

+148
-120
lines changed

3 files changed

+148
-120
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
3939
* The imputation strategy. Currently only "mean" and "median" are supported.
4040
* If "mean", then replace missing values using the mean value of the feature.
4141
* If "median", then replace missing values using the approximate median value of the feature.
42+
* If "mode", then replace missing using the most frequent value of the feature.
4243
* Default: mean
4344
*
4445
* @group param
4546
*/
4647
final val strategy: Param[String] = new Param(this, "strategy", s"strategy for imputation. " +
4748
s"If ${Imputer.mean}, then replace missing values using the mean value of the feature. " +
48-
s"If ${Imputer.median}, then replace missing values using the median value of the feature.",
49-
ParamValidators.inArray[String](Array(Imputer.mean, Imputer.median)))
49+
s"If ${Imputer.median}, then replace missing values using the median value of the feature. " +
50+
s"If ${Imputer.mode}, then replace missing values using the most frequent value of " +
51+
s"the feature.", ParamValidators.inArray[String](Imputer.supportedStrategies))
5052

5153
/** @group getParam */
5254
def getStrategy: String = $(strategy)
@@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
104106
* For example, if the input column is IntegerType (1, 2, 4, null),
105107
* the output will be IntegerType (1, 2, 4, 2) after mean imputation.
106108
*
107-
* Note that the mean/median value is computed after filtering out missing values.
109+
* Note that the mean/median/mode value is computed after filtering out missing values.
108110
* All Null values in the input columns are treated as missing, and so are also imputed. For
109111
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
110112
*/
@@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
132134
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
133135

134136
/**
135-
* Imputation strategy. Available options are ["mean", "median"].
137+
* Imputation strategy. Available options are ["mean", "median", "mode"].
136138
* @group setParam
137139
*/
138140
@Since("2.2.0")
@@ -151,39 +153,47 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
151153
val spark = dataset.sparkSession
152154

153155
val (inputColumns, _) = getInOutCols()
154-
155156
val cols = inputColumns.map { inputCol =>
156157
when(col(inputCol).equalTo($(missingValue)), null)
157158
.when(col(inputCol).isNaN, null)
158159
.otherwise(col(inputCol))
159-
.cast("double")
160+
.cast(DoubleType)
160161
.as(inputCol)
161162
}
163+
val numCols = cols.length
162164

163165
val results = $(strategy) match {
164166
case Imputer.mean =>
165167
// Function avg will ignore null automatically.
166168
// For a column only containing null, avg will return null.
167169
val row = dataset.select(cols.map(avg): _*).head()
168-
Array.range(0, inputColumns.length).map { i =>
169-
if (row.isNullAt(i)) {
170-
Double.NaN
171-
} else {
172-
row.getDouble(i)
173-
}
174-
}
170+
Array.tabulate(numCols)(i => if (row.isNullAt(i)) Double.NaN else row.getDouble(i))
175171

176172
case Imputer.median =>
177173
// Function approxQuantile will ignore null automatically.
178174
// For a column only containing null, approxQuantile will return an empty array.
179175
dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError))
180-
.map { array =>
181-
if (array.isEmpty) {
182-
Double.NaN
183-
} else {
184-
array.head
185-
}
176+
.map(_.headOption.getOrElse(Double.NaN))
177+
178+
case Imputer.mode =>
179+
val modes = dataset.select(cols: _*).rdd.flatMap { row =>
180+
Iterator.range(0, numCols).flatMap { i =>
181+
// Ignore null.
182+
if (row.isNullAt(i)) Iterator.empty else Iterator.single((i, row.getDouble(i)), 1L)
186183
}
184+
}.reduceByKey(_ + _).map { case ((i, v), c) => (i, (v, c))
185+
}.reduceByKey { case ((v1, c1), (v2, c2)) =>
186+
if (c1 > c2) {
187+
(v1, c1)
188+
} else if (c1 < c2) {
189+
(v2, c2)
190+
} else {
191+
// Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode).
192+
// If there is more than one mode, choose the smallest one.
193+
(math.min(v1, v2), c1)
194+
}
195+
}.mapValues(_._1).collectAsMap()
196+
Array.tabulate(numCols)(i => modes.getOrElse(i, Double.NaN))
187197
}
188198

189199
val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1)
@@ -212,6 +222,10 @@ object Imputer extends DefaultParamsReadable[Imputer] {
212222
/** strategy names that Imputer currently supports. */
213223
private[feature] val mean = "mean"
214224
private[feature] val median = "median"
225+
private[feature] val mode = "mode"
226+
227+
/* Set of strategies that Imputer supports */
228+
private[feature] val supportedStrategies = Array(mean, median, mode)
215229

216230
@Since("2.2.0")
217231
override def load(path: String): Imputer = super.load(path)

0 commit comments

Comments
 (0)