Skip to content

Commit 8b3e3a5

Browse files
mengxrnemccarthy
authored andcommitted
[SPARK-8051] [MLLIB] make StringIndexerModel silent if input column does not exist
This is just a workaround to a bigger problem. Some pipeline stages may not be effective during prediction, and they should not complain about missing required columns, e.g. `StringIndexerModel`. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#6595 from mengxr/SPARK-8051 and squashes the following commits: b6a36b9 [Xiangrui Meng] add doc f143fd4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-8051 8ee7c7e [Xiangrui Meng] use SparkFunSuite e112394 [Xiangrui Meng] make StringIndexerModel silent if input column does not exist
1 parent c58ac66 commit 8b3e3a5

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
8888
/**
8989
* :: Experimental ::
9090
* Model fitted by [[StringIndexer]].
91+
* NOTE: During transformation, if the input column does not exist,
92+
* [[StringIndexerModel.transform]] would return the input dataset unmodified.
93+
* This is a temporary fix for the case when target labels do not exist during prediction.
9194
*/
9295
@Experimental
9396
class StringIndexerModel private[ml] (
@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] (
112115
def setOutputCol(value: String): this.type = set(outputCol, value)
113116

114117
override def transform(dataset: DataFrame): DataFrame = {
118+
if (!dataset.schema.fieldNames.contains($(inputCol))) {
119+
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
120+
"Skip StringIndexerModel.")
121+
return dataset
122+
}
123+
115124
val indexer = udf { label: String =>
116125
if (labelToIndex.contains(label)) {
117126
labelToIndex(label)
@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] (
128137
}
129138

130139
override def transformSchema(schema: StructType): StructType = {
131-
validateAndTransformSchema(schema)
140+
if (schema.fieldNames.contains($(inputCol))) {
141+
validateAndTransformSchema(schema)
142+
} else {
143+
// If the input column does not exist during transformation, we skip StringIndexerModel.
144+
schema
145+
}
132146
}
133147
}

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
6060
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
6161
assert(output === expected)
6262
}
63+
64+
test("StringIndexerModel should keep silent if the input column does not exist.") {
65+
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
66+
.setInputCol("label")
67+
.setOutputCol("labelIndex")
68+
val df = sqlContext.range(0L, 10L)
69+
assert(indexerModel.transform(df).eq(df))
70+
}
6371
}

0 commit comments

Comments
 (0)