Skip to content

Commit 4a9034b

Browse files
VinceShiehjkbradley
authored andcommitted
[SPARK-17498][ML] StringIndexer enhancement for handling unseen labels
## What changes were proposed in this pull request? This PR is an enhancement to ML StringIndexer. Before this PR, String Indexer only supports "skip"/"error" options to deal with unseen records. But those unseen records might still be useful and user would like to keep the unseen labels in certain use cases, This PR enables StringIndexer to support keeping unseen labels as indices [numLabels]. '''Before StringIndexer().setHandleInvalid("skip") StringIndexer().setHandleInvalid("error") '''After support the third option "keep" StringIndexer().setHandleInvalid("keep") ## How was this patch tested? Test added in StringIndexerSuite Signed-off-by: VinceShieh <vincent.xieintel.com> (Please fill in changes proposed in this fix) Author: VinceShieh <[email protected]> Closes #16883 from VinceShieh/spark-17498.
1 parent c05baab commit 4a9034b

File tree

4 files changed

+95
-30
lines changed

4 files changed

+95
-30
lines changed

docs/ml-features.md

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ for more details on the API.
503503

504504
`StringIndexer` encodes a string column of labels to a column of label indices.
505505
The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`.
506+
The unseen labels will be put at index numLabels if user chooses to keep them.
506507
If the input column is numeric, we cast it to string and index the string
507508
values. When downstream pipeline components such as `Estimator` or
508509
`Transformer` make use of this string-indexed label, you must set the input
@@ -542,12 +543,13 @@ column, we should get the following:
542543
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
543544
index `2`.
544545

545-
Additionally, there are two strategies regarding how `StringIndexer` will handle
546+
Additionally, there are three strategies regarding how `StringIndexer` will handle
546547
unseen labels when you have fit a `StringIndexer` on one dataset and then use it
547548
to transform another:
548549

549550
- throw an exception (which is the default)
550551
- skip the row containing the unseen label entirely
552+
- put unseen labels in a special additional bucket, at index numLabels
551553

552554
**Examples**
553555

@@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined
561563
1 | b
562564
2 | c
563565
3 | d
566+
4 | e
564567
~~~~
565568

566569
If you've not set how `StringIndexer` handles unseen labels or set it to
@@ -576,7 +579,22 @@ will be generated:
576579
2 | c | 1.0
577580
~~~~
578581

579-
Notice that the row containing "d" does not appear.
582+
Notice that the rows containing "d" or "e" do not appear.
583+
584+
If you call `setHandleInvalid("keep")`, the following dataset
585+
will be generated:
586+
587+
~~~~
588+
id | category | categoryIndex
589+
----|----------|---------------
590+
0 | a | 0.0
591+
1 | b | 2.0
592+
2 | c | 1.0
593+
3 | d | 3.0
594+
4 | e | 3.0
595+
~~~~
596+
597+
Notice that the rows containing "d" or "e" are mapped to index "3.0"
580598

581599
<div class="codetabs">
582600

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import scala.language.existentials
21+
2022
import org.apache.hadoop.fs.Path
2123

2224
import org.apache.spark.SparkException
2325
import org.apache.spark.annotation.Since
2426
import org.apache.spark.ml.{Estimator, Model, Transformer}
2527
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
2628
import org.apache.spark.ml.param._
27-
import org.apache.spark.ml.param.shared._
29+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
2830
import org.apache.spark.ml.util._
2931
import org.apache.spark.sql.{DataFrame, Dataset}
3032
import org.apache.spark.sql.functions._
@@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap
3436
/**
3537
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
3638
*/
37-
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
38-
with HasHandleInvalid {
39+
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
40+
41+
/**
42+
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with
43+
* unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
44+
* bucket, at index numLabels.
45+
* Default: "error"
46+
* @group param
47+
*/
48+
@Since("1.6.0")
49+
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
50+
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
51+
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
52+
"at index numLabels).",
53+
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
54+
55+
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
56+
57+
/** @group getParam */
58+
@Since("1.6.0")
59+
def getHandleInvalid: String = $(handleInvalid)
3960

4061
/** Validates and transforms the input schema. */
4162
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") (
7394
/** @group setParam */
7495
@Since("1.6.0")
7596
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
76-
setDefault(handleInvalid, "error")
7797

7898
/** @group setParam */
7999
@Since("1.4.0")
@@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") (
105125

106126
@Since("1.6.0")
107127
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
128+
private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
129+
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
130+
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
131+
private[feature] val supportedHandleInvalids: Array[String] =
132+
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
108133

109134
@Since("1.6.0")
110135
override def load(path: String): StringIndexer = super.load(path)
@@ -144,7 +169,6 @@ class StringIndexerModel (
144169
/** @group setParam */
145170
@Since("1.6.0")
146171
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
147-
setDefault(handleInvalid, "error")
148172

149173
/** @group setParam */
150174
@Since("1.4.0")
@@ -163,25 +187,34 @@ class StringIndexerModel (
163187
}
164188
transformSchema(dataset.schema, logging = true)
165189

166-
val indexer = udf { label: String =>
167-
if (labelToIndex.contains(label)) {
168-
labelToIndex(label)
169-
} else {
170-
throw new SparkException(s"Unseen label: $label.")
171-
}
190+
val filteredLabels = getHandleInvalid match {
191+
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
192+
case _ => labels
172193
}
173194

174195
val metadata = NominalAttribute.defaultAttr
175-
.withName($(outputCol)).withValues(labels).toMetadata()
196+
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
176197
// If we are skipping invalid records, filter them out.
177-
val filteredDataset = getHandleInvalid match {
178-
case "skip" =>
198+
val (filteredDataset, keepInvalid) = getHandleInvalid match {
199+
case StringIndexer.SKIP_UNSEEN_LABEL =>
179200
val filterer = udf { label: String =>
180201
labelToIndex.contains(label)
181202
}
182-
dataset.where(filterer(dataset($(inputCol))))
183-
case _ => dataset
203+
(dataset.where(filterer(dataset($(inputCol)))), false)
204+
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
184205
}
206+
207+
val indexer = udf { label: String =>
208+
if (labelToIndex.contains(label)) {
209+
labelToIndex(label)
210+
} else if (keepInvalid) {
211+
labels.length
212+
} else {
213+
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
214+
s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
215+
}
216+
}
217+
185218
filteredDataset.select(col("*"),
186219
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
187220
}

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class StringIndexerSuite
6464

6565
test("StringIndexerUnseen") {
6666
val data = Seq((0, "a"), (1, "b"), (4, "b"))
67-
val data2 = Seq((0, "a"), (1, "b"), (2, "c"))
67+
val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))
6868
val df = data.toDF("id", "label")
6969
val df2 = data2.toDF("id", "label")
7070
val indexer = new StringIndexer()
@@ -75,22 +75,32 @@ class StringIndexerSuite
7575
intercept[SparkException] {
7676
indexer.transform(df2).collect()
7777
}
78-
val indexerSkipInvalid = new StringIndexer()
79-
.setInputCol("label")
80-
.setOutputCol("labelIndex")
81-
.setHandleInvalid("skip")
82-
.fit(df)
78+
79+
indexer.setHandleInvalid("skip")
8380
// Verify that we skip the c record
84-
val transformed = indexerSkipInvalid.transform(df2)
85-
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
81+
val transformedSkip = indexer.transform(df2)
82+
val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex"))
8683
.asInstanceOf[NominalAttribute]
87-
assert(attr.values.get === Array("b", "a"))
88-
val output = transformed.select("id", "labelIndex").rdd.map { r =>
84+
assert(attrSkip.values.get === Array("b", "a"))
85+
val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
8986
(r.getInt(0), r.getDouble(1))
9087
}.collect().toSet
9188
// a -> 1, b -> 0
92-
val expected = Set((0, 1.0), (1, 0.0))
93-
assert(output === expected)
89+
val expectedSkip = Set((0, 1.0), (1, 0.0))
90+
assert(outputSkip === expectedSkip)
91+
92+
indexer.setHandleInvalid("keep")
93+
// Verify that we keep the unseen records
94+
val transformedKeep = indexer.transform(df2)
95+
val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex"))
96+
.asInstanceOf[NominalAttribute]
97+
assert(attrKeep.values.get === Array("b", "a", "__unknown"))
98+
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
99+
(r.getInt(0), r.getDouble(1))
100+
}.collect().toSet
101+
// a -> 1, b -> 0, c -> 2, d -> 3
102+
val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0))
103+
assert(outputKeep === expectedKeep)
94104
}
95105

96106
test("StringIndexer with a numeric input column") {

project/MimaExcludes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,10 @@ object MimaExcludes {
914914
) ++ Seq(
915915
// [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
916916
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
917+
) ++ Seq(
918+
// [SPARK-17498] StringIndexer enhancement for handling unseen labels
919+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"),
920+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel")
917921
) ++ Seq(
918922
// [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
919923
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")

0 commit comments

Comments
 (0)