@@ -895,6 +895,73 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
895895 }
896896 }
897897
898+ test(" No extra exchanges in case of [Inner Join -> Project with aliases -> Inner join]" ) {
899+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
900+ withSQLConf(SQLConf .CONSTRAINT_PROPAGATION_ENABLED .key -> " false" ) {
901+ withTempView(" t1" , " t2" , " t3" ) {
902+ spark.range(10 ).repartition($" id" ).createTempView(" t1" )
903+ spark.range(20 ).repartition($" id" ).createTempView(" t2" )
904+ spark.range(30 ).repartition($" id" ).createTempView(" t3" )
905+ val planned = sql(
906+ """
907+ |SELECT t2id, t3.id as t3id
908+ |FROM (
909+ | SELECT t1.id as t1id, t2.id as t2id
910+ | FROM t1, t2
911+ | WHERE t1.id = t2.id
912+ |) t12, t3
913+ |WHERE t1id = t3.id
914+ """ .stripMargin).queryExecution.executedPlan
915+ val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
916+ assert(exchanges.size == 3 )
917+ }
918+ }
919+ }
920+ }
921+
922+ test(" No extra exchanges in case of [LeftSemi Join -> Project with aliases -> Inner join]" ) {
923+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
924+ withTempView(" t1" , " t2" , " t3" ) {
925+ spark.range(10 ).repartition($" id" ).createTempView(" t1" )
926+ spark.range(20 ).repartition($" id" ).createTempView(" t2" )
927+ spark.range(30 ).repartition($" id" ).createTempView(" t3" )
928+ val planned = sql(
929+ """
930+ |SELECT t1id, t3.id as t3id
931+ |FROM (
932+ | SELECT t1.id as t1id
933+ | FROM t1 LEFT SEMI JOIN t2
934+ | ON t1.id = t2.id
935+ |) t12 INNER JOIN t3
936+ |WHERE t1id = t3.id
937+ """ .stripMargin).queryExecution.executedPlan
938+ val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
939+ assert(exchanges.size == 3 )
940+ }
941+ }
942+ }
943+
944+ test(" No extra exchanges in case of [Inner Join -> Project with aliases -> HashAggregate]" ) {
945+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
946+ withTempView(" t1" , " t2" ) {
947+ spark.range(10 ).repartition($" id" ).createTempView(" t1" )
948+ spark.range(20 ).repartition($" id" ).createTempView(" t2" )
949+ val planned = sql(
950+ """
951+ |SELECT t1id, t2id
952+ |FROM (
953+ | SELECT t1.id as t1id, t2.id as t2id
954+ | FROM t1 INNER JOIN t2
955+ | WHERE t1.id = t2.id
956+ |) t12
957+ |GROUP BY t1id, t2id
958+ """ .stripMargin).queryExecution.executedPlan
959+ val exchanges = planned.collect { case s : ShuffleExchangeExec => s }
960+ assert(exchanges.size == 2 )
961+ }
962+ }
963+ }
964+
898965 test(" aliases to expressions should not be replaced" ) {
899966 withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
900967 withTempView(" df1" , " df2" ) {
0 commit comments