1717
1818package org .apache .spark .sql .catalyst .optimizer
1919
20- import org .apache .spark .sql .catalyst .expressions .{ArrayTransform , CreateNamedStruct , Expression , GetStructField , If , IsNull , LambdaFunction , Literal , MapFromArrays , MapKeys , MapSort , MapValues , NamedLambdaVariable }
21- import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan }
20+ import scala .collection .mutable
21+
22+ import org .apache .spark .sql .catalyst .expressions .{Alias , ArrayTransform , CreateNamedStruct , Expression , GetStructField , If , IsNull , LambdaFunction , Literal , MapFromArrays , MapKeys , MapSort , MapValues , NamedExpression , NamedLambdaVariable }
23+ import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan , Project }
2224import org .apache .spark .sql .catalyst .rules .Rule
23- import org .apache .spark .sql .catalyst .trees .TreePattern . AGGREGATE
25+ import org .apache .spark .sql .catalyst .trees .TreePattern
2426import org .apache .spark .sql .types .{ArrayType , MapType , StructType }
2527import org .apache .spark .util .ArrayImplicits .SparkArrayOps
2628
2729/**
28- * Adds MapSort to group expressions containing map columns, as the key/value paris need to be
30+ * Adds [[ MapSort ]] to group expressions containing map columns, as the key/value paris need to be
2931 * in the correct order before grouping:
30- * SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
31- * SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
32+ *
33+ * SELECT map_column, COUNT(*) FROM TABLE GROUP BY map_column =>
34+ * SELECT _groupingexpression as map_column, COUNT(*) FROM (
35+ * SELECT map_sort(map_column) as _groupingexpression FROM TABLE
36+ * ) GROUP BY _groupingexpression
3237 */
3338object InsertMapSortInGroupingExpressions extends Rule [LogicalPlan ] {
34- override def apply (plan : LogicalPlan ): LogicalPlan = plan.transformWithPruning(
35- _.containsPattern(AGGREGATE ), ruleId) {
36- case a @ Aggregate (groupingExpr, _, _) =>
37- val newGrouping = groupingExpr.map { expr =>
38- if (! expr.exists(_.isInstanceOf [MapSort ])
39- && expr.dataType.existsRecursively(_.isInstanceOf [MapType ])) {
40- insertMapSortRecursively(expr)
41- } else {
42- expr
39+ private def shouldAddMapSort (expr : Expression ): Boolean = {
40+ expr.dataType.existsRecursively(_.isInstanceOf [MapType ])
41+ }
42+
43+ override def apply (plan : LogicalPlan ): LogicalPlan = {
44+ if (! plan.containsPattern(TreePattern .AGGREGATE )) {
45+ return plan
46+ }
47+ val shouldRewrite = plan.exists {
48+ case agg : Aggregate if agg.groupingExpressions.exists(shouldAddMapSort) => true
49+ case _ => false
50+ }
51+ if (! shouldRewrite) {
52+ return plan
53+ }
54+
55+ plan transformUpWithNewOutput {
56+ case agg @ Aggregate (groupingExpr, aggregateExpressions, child)
57+ if agg.groupingExpressions.exists(shouldAddMapSort) =>
58+ val exprToMapSort = new mutable.HashMap [Expression , NamedExpression ]
59+ val newGroupingKeys = groupingExpr.map(replaceWithMapSortRecursively(_, exprToMapSort))
60+ val newAggregateExprs = aggregateExpressions.map {
61+ case named if exprToMapSort.contains(named.canonicalized) =>
62+ // If we replace the top-level named expr, then should add back the original name
63+ exprToMapSort(named.canonicalized).toAttribute.withName(named.name)
64+ case other =>
65+ other.transformUp {
66+ case e => exprToMapSort.get(e.canonicalized).map(_.toAttribute).getOrElse(e)
67+ }.asInstanceOf [NamedExpression ]
4368 }
44- }
45- a.copy(groupingExpressions = newGrouping)
69+ val newChild = Project (child.output ++ exprToMapSort.values, child)
70+ val newAgg = Aggregate (newGroupingKeys, newAggregateExprs, newChild)
71+ newAgg -> agg.output.zip(newAgg.output)
72+ }
4673 }
4774
48- /*
49- Inserts MapSort recursively taking into account when
50- it is nested inside a struct or array.
75+ /**
76+ * Inserts MapSort recursively taking into account when it is nested inside a struct or array.
5177 */
52- private def insertMapSortRecursively (e : Expression ): Expression = {
78+ private def replaceWithMapSortRecursively (
79+ e : Expression ,
80+ exprToMapSort : mutable.HashMap [Expression , NamedExpression ]): Expression = {
5381 e.dataType match {
5482 case m : MapType =>
5583 // Check if value type of MapType contains MapType (possibly nested)
5684 // and special handle this case.
5785 val mapSortExpr = if (m.valueType.existsRecursively(_.isInstanceOf [MapType ])) {
58- MapFromArrays (MapKeys (e), insertMapSortRecursively (MapValues (e)))
86+ MapFromArrays (MapKeys (e), replaceWithMapSortRecursively (MapValues (e), exprToMapSort ))
5987 } else {
6088 e
6189 }
62-
63- MapSort (mapSortExpr)
90+ exprToMapSort.getOrElseUpdate(
91+ e.canonicalized, Alias (MapSort (mapSortExpr), " _groupingexpression" )())
92+ .toAttribute
6493
6594 case StructType (fields)
6695 if fields.exists(_.dataType.existsRecursively(_.isInstanceOf [MapType ])) =>
6796 val struct = CreateNamedStruct (fields.zipWithIndex.flatMap { case (f, i) =>
68- Seq (Literal (f.name), insertMapSortRecursively (
69- GetStructField (e, i, Some (f.name))))
97+ Seq (Literal (f.name), replaceWithMapSortRecursively (
98+ GetStructField (e, i, Some (f.name)), exprToMapSort ))
7099 }.toImmutableArraySeq)
71100 if (struct.valExprs.forall(_.isInstanceOf [GetStructField ])) {
72101 // No field needs MapSort processing, just return the original expression.
@@ -79,12 +108,10 @@ object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
79108
80109 case ArrayType (et, containsNull) if et.existsRecursively(_.isInstanceOf [MapType ]) =>
81110 val param = NamedLambdaVariable (" x" , et, containsNull)
82- val funcBody = insertMapSortRecursively(param)
83-
111+ val funcBody = replaceWithMapSortRecursively(param, exprToMapSort)
84112 ArrayTransform (e, LambdaFunction (funcBody, Seq (param)))
85113
86114 case _ => e
87115 }
88116 }
89-
90117}
0 commit comments