Skip to content

Commit 676f1ac

Browse files
committed
Add configurable maximum number of pivot values when none are given to prevent unintended OOM errors.
1 parent 12a8270 commit 676f1ac

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,23 @@ class GroupedData protected[sql](
309309
s"The values of a pivot must be literals, found $other")
310310
}
311311
} else {
312+
// This is to prevent unintended OOM errors when the number of distinct values is large
313+
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
312314
// Get the distinct values of the column and sort them so its consistent
313-
df.select(pivotColumn)
315+
val values = df.select(pivotColumn)
314316
.distinct()
315317
.sort(pivotColumn)
316318
.map(_.get(0))
317-
.collect()
319+
.take(maxValues + 1)
318320
.map(Literal(_)).toSeq
321+
if (values.length > maxValues) {
322+
throw new RuntimeException(
323+
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
324+
"this could indicate an error. " +
325+
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
326+
s"to at least the number of distinct values of the pivot column.")
327+
}
328+
values
319329
}
320330
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
321331
case _ =>

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,13 @@ private[spark] object SQLConf {
453453
defaultValue = Some(true),
454454
isPublic = false)
455455

456+
val DATAFRAME_PIVOT_MAX_VALUES = intConf(
457+
"spark.sql.pivotMaxValues",
458+
defaultValue = Some(10000),
459+
doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
460+
"number of (distinct) values that will be collected without error."
461+
)
462+
456463
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
457464
defaultValue = Some(true),
458465
isPublic = false,

sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,13 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
7575
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
7676
)
7777
}
78+
79+
test("pivot max values inforced") {
80+
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
81+
intercept[RuntimeException](
82+
courseSales.groupBy($"year").pivot($"course")
83+
)
84+
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
85+
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
86+
}
7887
}

0 commit comments

Comments
 (0)