Skip to content

Commit c19680b

Browse files
committed
[SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data
## What changes were proposed in this pull request? This PR is to maintain API parity with changes made in SPARK-17498 to support a new option 'keep' in StringIndexer to handle unseen labels or NULL values with PySpark. Note: This is updated version of apache#17237 , the primary author of this PR is VinceShieh . ## How was this patch tested? Unit tests. Author: VinceShieh <[email protected]> Author: Yanbo Liang <[email protected]> Closes apache#18453 from yanboliang/spark-19852.
1 parent c605fee commit c19680b

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

python/pyspark/ml/feature.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
21322132
"frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
21332133
typeConverter=TypeConverters.toString)
21342134

2135+
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
2136+
"labels or NULL values). Options are 'skip' (filter out rows with " +
2137+
"invalid data), error (throw an error), or 'keep' (put invalid data " +
2138+
"in a special additional bucket, at index numLabels).",
2139+
typeConverter=TypeConverters.toString)
2140+
21352141
@keyword_only
21362142
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
21372143
stringOrderType="frequencyDesc"):

python/pyspark/ml/tests.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self):
551551
for i in range(0, len(expected)):
552552
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
553553

554+
def test_string_indexer_handle_invalid(self):
555+
df = self.spark.createDataFrame([
556+
(0, "a"),
557+
(1, "d"),
558+
(2, None)], ["id", "label"])
559+
560+
si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
561+
stringOrderType="alphabetAsc")
562+
model1 = si1.fit(df)
563+
td1 = model1.transform(df)
564+
actual1 = td1.select("id", "indexed").collect()
565+
expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
566+
self.assertEqual(actual1, expected1)
567+
568+
si2 = si1.setHandleInvalid("skip")
569+
model2 = si2.fit(df)
570+
td2 = model2.transform(df)
571+
actual2 = td2.select("id", "indexed").collect()
572+
expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
573+
self.assertEqual(actual2, expected2)
574+
554575

555576
class HasInducedError(Params):
556577

0 commit comments

Comments
 (0)