@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
2525import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan , Range , Repartition , Sort , Union }
2626import org .apache .spark .sql .catalyst .plans .physical ._
2727import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
28+ import org .apache .spark .sql .execution .aggregate .{HashAggregateExec , ObjectHashAggregateExec , SortAggregateExec }
2829import org .apache .spark .sql .execution .columnar .{InMemoryRelation , InMemoryTableScanExec }
2930import org .apache .spark .sql .execution .exchange .{EnsureRequirements , ReusedExchangeExec , ReuseExchange , ShuffleExchangeExec }
3031import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , SortMergeJoinExec }
@@ -990,11 +991,40 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
990991
991992 val agg1 = t1.groupBy(" k1" ).agg(count(lit(" 1" )).as(" cnt1" ))
992993 val agg2 = t2.groupBy(" k2" ).agg(count(lit(" 1" )).as(" cnt2" )).withColumnRenamed(" k2" , " k3" )
994+
993995 val planned = agg1.join(agg2, $" k1" === $" k3" ).queryExecution.executedPlan
996+
997+ assert(planned.collect { case h : HashAggregateExec => h }.nonEmpty)
998+
994999 val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
9951000 assert(exchanges.size == 2 )
9961001 }
9971002 }
1003+
1004+ test(" aliases in the object hash/sort aggregate expressions should not introduce extra shuffle" ) {
1005+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
1006+ Seq (true , false ).foreach { useObjectHashAgg =>
1007+ withSQLConf(SQLConf .USE_OBJECT_HASH_AGG .key -> useObjectHashAgg.toString) {
1008+ val t1 = spark.range(10 ).selectExpr(" floor(id/4) as k1" )
1009+ val t2 = spark.range(10 ).selectExpr(" floor(id/4) as k2" )
1010+
1011+ val agg1 = t1.groupBy(" k1" ).agg(collect_list(" k1" ))
1012+ val agg2 = t2.groupBy(" k2" ).agg(collect_list(" k2" )).withColumnRenamed(" k2" , " k3" )
1013+
1014+ val planned = agg1.join(agg2, $" k1" === $" k3" ).queryExecution.executedPlan
1015+
1016+ if (useObjectHashAgg) {
1017+ assert(planned.collect { case o : ObjectHashAggregateExec => o }.nonEmpty)
1018+ } else {
1019+ assert(planned.collect { case s : SortAggregateExec => s }.nonEmpty)
1020+ }
1021+
1022+ val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
1023+ assert(exchanges.size == 2 )
1024+ }
1025+ }
1026+ }
1027+ }
9981028}
9991029
10001030// Used for unit-testing EnsureRequirements
0 commit comments