Skip to content

Commit 94ace2d

Browse files
committed
Rework group by map type to fix bind reference exception
1 parent af70aaf commit 94ace2d

File tree

4 files changed

+119
-49
lines changed

4 files changed

+119
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,72 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedLambdaVariable}
21-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, CreateNamedStruct, Expression, GetStructField, If, IsNull, LambdaFunction, Literal, MapFromArrays, MapKeys, MapSort, MapValues, NamedExpression, NamedLambdaVariable}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
2224
import org.apache.spark.sql.catalyst.rules.Rule
23-
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
25+
import org.apache.spark.sql.catalyst.trees.TreePattern
2426
import org.apache.spark.sql.types.{ArrayType, MapType, StructType}
2527
import org.apache.spark.util.ArrayImplicits.SparkArrayOps
2628

2729
/**
28-
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
30+
* Adds [[MapSort]] to group expressions containing map columns, as the key/value paris need to be
2931
* in the correct order before grouping:
30-
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
31-
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
32+
*
33+
* SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
34+
* SELECT _groupingmapsort as map_column, COUNT(*) FROM (
35+
* SELECT map_sort(map_column) as _groupingmapsort FROM TABLE
36+
* ) GROUP BY _groupingmapsort
3237
*/
3338
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
34-
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
35-
_.containsPattern(AGGREGATE), ruleId) {
36-
case a @ Aggregate(groupingExpr, _, _) =>
37-
val newGrouping = groupingExpr.map { expr =>
38-
if (!expr.exists(_.isInstanceOf[MapSort])
39-
&& expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
40-
insertMapSortRecursively(expr)
41-
} else {
42-
expr
39+
private def shouldAddMapSort(expr: Expression): Boolean = {
40+
expr.dataType.existsRecursively(_.isInstanceOf[MapType])
41+
}
42+
43+
override def apply(plan: LogicalPlan): LogicalPlan = {
44+
if (!plan.containsPattern(TreePattern.AGGREGATE)) {
45+
return plan
46+
}
47+
val shouldRewrite = plan.exists {
48+
case agg: Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
49+
case _ => false
50+
}
51+
if (!shouldRewrite) {
52+
return plan
53+
}
54+
55+
plan transformUpWithNewOutput {
56+
case agg @ Aggregate(groupingExprs, aggregateExpressions, child)
57+
if agg.groupingExpressions.exists(shouldAddMapSort) =>
58+
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
59+
val newGroupingKeys = groupingExprs.map { expr =>
60+
val inserted = insertMapSortRecursively(expr)
61+
if (expr.ne(inserted)) {
62+
exprToMapSort.getOrElseUpdate(
63+
expr.canonicalized, Alias(inserted, "_groupingmapsort")())
64+
.toAttribute
65+
} else {
66+
expr
67+
}
4368
}
44-
}
45-
a.copy(groupingExpressions = newGrouping)
69+
val newAggregateExprs = aggregateExpressions.map {
70+
case named if exprToMapSort.contains(named.canonicalized) =>
71+
// If we replace the top-level named expr, then should add back the original name
72+
exprToMapSort(named.canonicalized).toAttribute.withName(named.name)
73+
case other =>
74+
other.transformUp {
75+
case e => exprToMapSort.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
76+
}.asInstanceOf[NamedExpression]
77+
}
78+
val newChild = Project(child.output ++ exprToMapSort.values, child)
79+
val newAgg = Aggregate(newGroupingKeys, newAggregateExprs, newChild)
80+
newAgg -> agg.output.zip(newAgg.output)
81+
}
4682
}
4783

48-
/*
49-
Inserts MapSort recursively taking into account when
50-
it is nested inside a struct or array.
84+
/**
85+
* Inserts MapSort recursively taking into account when it is nested inside a struct or array.
5186
*/
5287
private def insertMapSortRecursively(e: Expression): Expression = {
5388
e.dataType match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
150150
}
151151

152152
val batches = (
153-
Batch("Finish Analysis", Once, FinishAnalysis) ::
153+
Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
154154
// We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression
155155
// may produce `With` expressions that need to be rewritten.
156156
Batch("Rewrite With expression", Once, RewriteWithExpression) ::
@@ -246,8 +246,6 @@ abstract class Optimizer(catalogManager: CatalogManager)
246246
CollapseProject,
247247
RemoveRedundantAliases,
248248
RemoveNoopOperators) :+
249-
Batch("InsertMapSortInGroupingExpressions", Once,
250-
InsertMapSortInGroupingExpressions) :+
251249
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
252250
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
253251
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
@@ -297,6 +295,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
297295
ReplaceExpressions,
298296
RewriteNonCorrelatedExists,
299297
PullOutGroupingExpressions,
298+
InsertMapSortInGroupingExpressions,
300299
ComputeCurrentTime,
301300
ReplaceCurrentLike(catalogManager),
302301
SpecialDatetimeValues,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ object RuleIdCollection {
127127
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
128128
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
129129
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
130-
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
131130
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
132131
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
133132
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,8 +2162,9 @@ class DataFrameAggregateSuite extends QueryTest
21622162
)
21632163
}
21642164

2165-
private def assertAggregateOnDataframe(df: DataFrame,
2166-
expected: Int, aggregateColumn: String): Unit = {
2165+
private def assertAggregateOnDataframe(
2166+
df: => DataFrame,
2167+
expected: Int): Unit = {
21672168
val configurations = Seq(
21682169
Seq.empty[(String, String)], // hash aggregate is used by default
21692170
Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN",
@@ -2175,32 +2176,64 @@ class DataFrameAggregateSuite extends QueryTest
21752176
Seq("spark.sql.test.forceApplySortAggregate" -> "true")
21762177
)
21772178

2178-
for (conf <- configurations) {
2179-
withSQLConf(conf: _*) {
2180-
assert(createAggregate(df).count() == expected)
2179+
// Make tests faster
2180+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
2181+
for (conf <- configurations) {
2182+
withSQLConf(conf: _*) {
2183+
assert(df.count() == expected, df.queryExecution.simpleString)
2184+
}
21812185
}
21822186
}
2183-
2184-
def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*"))
21852187
}
21862188

21872189
test("SPARK-47430 Support GROUP BY MapType") {
2188-
val numRows = 50
2189-
2190-
val dfSameInt = (0 until numRows)
2191-
.map(_ => Tuple1(Map(1 -> 1)))
2192-
.toDF("m0")
2193-
assertAggregateOnDataframe(dfSameInt, 1, "m0")
2194-
2195-
val dfSameFloat = (0 until numRows)
2196-
.map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 )))
2197-
.toDF("m0")
2198-
assertAggregateOnDataframe(dfSameFloat, 1, "m0")
2199-
2200-
val dfDifferent = (0 until numRows)
2201-
.map(i => Tuple1(Map(i -> i)))
2202-
.toDF("m0")
2203-
assertAggregateOnDataframe(dfDifferent, numRows, "m0")
2190+
def genMapData(dataType: String): String = {
2191+
s"""
2192+
|case when id % 4 == 0 then map()
2193+
|when id % 4 == 1 then map(cast(0 as $dataType), cast(0 as $dataType))
2194+
|when id % 4 == 2 then map(cast(0 as $dataType), cast(0 as $dataType),
2195+
| cast(1 as $dataType), cast(1 as $dataType))
2196+
|else map(cast(1 as $dataType), cast(1 as $dataType),
2197+
| cast(0 as $dataType), cast(0 as $dataType))
2198+
|end
2199+
|""".stripMargin
2200+
}
2201+
Seq("int", "long", "float", "double", "decimal(10, 2)", "string", "varchar(6)").foreach { dt =>
2202+
withTempView("v") {
2203+
spark.range(20)
2204+
.selectExpr(
2205+
s"cast(1 as $dt) as c1",
2206+
s"${genMapData(dt)} as c2",
2207+
"map(c1, null) as c3",
2208+
s"cast(null as map<$dt, $dt>) as c4")
2209+
.createOrReplaceTempView("v")
2210+
2211+
assertAggregateOnDataframe(
2212+
spark.sql("SELECT count(*) FROM v GROUP BY c2"),
2213+
3)
2214+
assertAggregateOnDataframe(
2215+
spark.sql("SELECT c2, count(*) FROM v GROUP BY c2"),
2216+
3)
2217+
assertAggregateOnDataframe(
2218+
spark.sql("SELECT c1, c2, count(*) FROM v GROUP BY c1, c2"),
2219+
3)
2220+
assertAggregateOnDataframe(
2221+
spark.sql("SELECT map(c1, c1) FROM v GROUP BY map(c1, c1)"),
2222+
1)
2223+
assertAggregateOnDataframe(
2224+
spark.sql("SELECT map(c1, c1), count(*) FROM v GROUP BY map(c1, c1)"),
2225+
1)
2226+
assertAggregateOnDataframe(
2227+
spark.sql("SELECT c3, count(*) FROM v GROUP BY c3"),
2228+
1)
2229+
assertAggregateOnDataframe(
2230+
spark.sql("SELECT c4, count(*) FROM v GROUP BY c4"),
2231+
1)
2232+
assertAggregateOnDataframe(
2233+
spark.sql("SELECT c1, c2, c3, c4, count(*) FROM v GROUP BY c1, c2, c3, c4"),
2234+
3)
2235+
}
2236+
}
22042237
}
22052238

22062239
test("SPARK-46536 Support GROUP BY CalendarIntervalType") {
@@ -2209,12 +2242,16 @@ class DataFrameAggregateSuite extends QueryTest
22092242
val dfSame = (0 until numRows)
22102243
.map(_ => Tuple1(new CalendarInterval(1, 2, 3)))
22112244
.toDF("c0")
2212-
assertAggregateOnDataframe(dfSame, 1, "c0")
2245+
.groupBy($"c0")
2246+
.count()
2247+
assertAggregateOnDataframe(dfSame, 1)
22132248

22142249
val dfDifferent = (0 until numRows)
22152250
.map(i => Tuple1(new CalendarInterval(i, i, i)))
22162251
.toDF("c0")
2217-
assertAggregateOnDataframe(dfDifferent, numRows, "c0")
2252+
.groupBy($"c0")
2253+
.count()
2254+
assertAggregateOnDataframe(dfDifferent, numRows)
22182255
}
22192256

22202257
test("SPARK-46779: Group by subquery with a cached relation") {

0 commit comments

Comments
 (0)