@@ -39,14 +39,16 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
3939   * The imputation strategy. Currently only "mean" and "median" are supported. 
4040   * If "mean", then replace missing values using the mean value of the feature. 
4141   * If "median", then replace missing values using the approximate median value of the feature. 
42+    * If "mode", then replace missing using the most frequent value of the feature. 
4243   * Default: mean 
4344   * 
4445   * @group  param 
4546   */  
4647  final  val  strategy :  Param [String ] =  new  Param (this , " strategy" s " strategy for imputation.  "  + 
4748    s " If  ${Imputer .mean}, then replace missing values using the mean value of the feature.  "  + 
48-     s " If  ${Imputer .median}, then replace missing values using the median value of the feature. " ,
49-     ParamValidators .inArray[String ](Array (Imputer .mean, Imputer .median)))
49+     s " If  ${Imputer .median}, then replace missing values using the median value of the feature.  "  + 
50+     s " If  ${Imputer .mode}, then replace missing values using the most frequent value of  "  + 
51+     s " the feature. " , ParamValidators .inArray[String ](Imputer .supportedStrategies))
5052
5153  /**  @group  getParam */  
5254  def  getStrategy :  String  =  $(strategy)
@@ -104,7 +106,7 @@ private[feature] trait ImputerParams extends Params with HasInputCol with HasInp
104106 * For example, if the input column is IntegerType (1, 2, 4, null), 
105107 * the output will be IntegerType (1, 2, 4, 2) after mean imputation. 
106108 * 
107-  * Note that the mean/median value is computed after filtering out missing values. 
109+  * Note that the mean/median/mode  value is computed after filtering out missing values. 
108110 * All Null values in the input columns are treated as missing, and so are also imputed. For 
109111 * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. 
110112 */  
@@ -132,7 +134,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
132134  def  setOutputCols (value : Array [String ]):  this .type  =  set(outputCols, value)
133135
134136  /**  
135-    * Imputation strategy. Available options are ["mean", "median"]. 
137+    * Imputation strategy. Available options are ["mean", "median", "mode" ]. 
136138   * @group  setParam 
137139   */  
138140  @ Since (" 2.2.0" 
@@ -151,39 +153,47 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
151153    val  spark  =  dataset.sparkSession
152154
153155    val  (inputColumns, _) =  getInOutCols()
154- 
155156    val  cols  =  inputColumns.map { inputCol => 
156157      when(col(inputCol).equalTo($(missingValue)), null )
157158        .when(col(inputCol).isNaN, null )
158159        .otherwise(col(inputCol))
159-         .cast(" double " 
160+         .cast(DoubleType )
160161        .as(inputCol)
161162    }
163+     val  numCols  =  cols.length
162164
163165    val  results  =  $(strategy) match  {
164166      case  Imputer .mean => 
165167        //  Function avg will ignore null automatically.
166168        //  For a column only containing null, avg will return null.
167169        val  row  =  dataset.select(cols.map(avg):  _* ).head()
168-         Array .range(0 , inputColumns.length).map { i => 
169-           if  (row.isNullAt(i)) {
170-             Double .NaN 
171-           } else  {
172-             row.getDouble(i)
173-           }
174-         }
170+         Array .tabulate(numCols)(i =>  if  (row.isNullAt(i)) Double .NaN  else  row.getDouble(i))
175171
176172      case  Imputer .median => 
177173        //  Function approxQuantile will ignore null automatically.
178174        //  For a column only containing null, approxQuantile will return an empty array.
179175        dataset.select(cols : _* ).stat.approxQuantile(inputColumns, Array (0.5 ), $(relativeError))
180-           .map { array => 
181-             if  (array.isEmpty) {
182-               Double .NaN 
183-             } else  {
184-               array.head
185-             }
176+           .map(_.headOption.getOrElse(Double .NaN ))
177+ 
178+       case  Imputer .mode => 
179+         val  modes  =  dataset.select(cols : _* ).rdd.flatMap { row => 
180+           Iterator .range(0 , numCols).flatMap { i => 
181+             //  Ignore null.
182+             if  (row.isNullAt(i)) Iterator .empty else  Iterator .single((i, row.getDouble(i)), 1L )
186183          }
184+         }.reduceByKey(_ +  _).map { case  ((i, v), c) =>  (i, (v, c))
185+         }.reduceByKey { case  ((v1, c1), (v2, c2)) => 
186+           if  (c1 >  c2) {
187+             (v1, c1)
188+           } else  if  (c1 <  c2) {
189+             (v2, c2)
190+           } else  {
191+             //  Keep in line with sklearn.impute.SimpleImputer (using scipy.stats.mode).
192+             //  If there is more than one mode, choose the smallest one.
193+             (math.min(v1, v2), c1)
194+           }
195+         }.mapValues(_._1).collectAsMap()
196+         Array .tabulate(numCols)(i =>  modes.getOrElse(i, Double .NaN ))
187197    }
188198
189199    val  emptyCols  =  inputColumns.zip(results).filter(_._2.isNaN).map(_._1)
@@ -212,6 +222,10 @@ object Imputer extends DefaultParamsReadable[Imputer] {
212222  /**  strategy names that Imputer currently supports. */  
213223  private [feature] val  mean  =  " mean" 
214224  private [feature] val  median  =  " median" 
225+   private [feature] val  mode  =  " mode" 
226+ 
227+   /*  Set of strategies that Imputer supports */ 
228+   private [feature] val  supportedStrategies  =  Array (mean, median, mode)
215229
216230  @ Since (" 2.2.0" 
217231  override  def  load (path : String ):  Imputer  =  super .load(path)
0 commit comments