@@ -72,6 +72,42 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
7272 }
7373 }
7474
75+ def checkJoin (join : DataFrame , spark : SparkSession ): Unit = {
76+ // Before Execution, there is one SortMergeJoin
77+ val smjBeforeExecution = join.queryExecution.executedPlan.collect {
78+ case smj : SortMergeJoinExec => smj
79+ }
80+ assert(smjBeforeExecution.length === 1 )
81+
82+ // Check the answer.
83+ val expectedAnswer =
84+ spark
85+ .range(0 , 1000 )
86+ .selectExpr(" id % 500 as key" , " id as value" )
87+ .union(spark.range(0 , 1000 ).selectExpr(" id % 500 as key" , " id as value" ))
88+ checkAnswer(
89+ join,
90+ expectedAnswer.collect())
91+
92+ // During execution, the SortMergeJoin is changed to BroadcastHashJoinExec
93+ val smjAfterExecution = join.queryExecution.executedPlan.collect {
94+ case smj : SortMergeJoinExec => smj
95+ }
96+ assert(smjAfterExecution.length === 0 )
97+
98+ val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
99+ case smj : BroadcastHashJoinExec => smj
100+ }.length
101+ assert(numBhjAfterExecution === 1 )
102+
103+ // Both shuffle should be local shuffle
104+ val queryStageInputs = join.queryExecution.executedPlan.collect {
105+ case q : ShuffleQueryStageInput => q
106+ }
107+ assert(queryStageInputs.length === 2 )
108+ assert(queryStageInputs.forall(_.isLocalShuffle) === true )
109+ }
110+
75111 test(" 1 sort merge join to broadcast join" ) {
76112 withSparkSession(defaultSparkSession) { spark : SparkSession =>
77113 val df1 =
@@ -83,39 +119,12 @@ class QueryStageSuite extends SparkFunSuite with BeforeAndAfterAll {
83119 .range(0 , 1000 , 1 , numInputPartitions)
84120 .selectExpr(" id % 500 as key2" , " id as value2" )
85121
86- val join = df1.join(df2, col(" key1" ) === col(" key2" )).select(col(" key1" ), col(" value2" ))
87-
88- // Before Execution, there is one SortMergeJoin
89- val smjBeforeExecution = join.queryExecution.executedPlan.collect {
90- case smj : SortMergeJoinExec => smj
91- }
92- assert(smjBeforeExecution.length === 1 )
93-
94- // Check the answer.
95- val expectedAnswer =
96- spark
97- .range(0 , 1000 )
98- .selectExpr(" id % 500 as key" , " id as value" )
99- .union(spark.range(0 , 1000 ).selectExpr(" id % 500 as key" , " id as value" ))
100- checkAnswer(
101- join,
102- expectedAnswer.collect())
103-
104- // During execution, the SortMergeJoin is changed to BroadcastHashJoinExec
105- val smjAfterExecution = join.queryExecution.executedPlan.collect {
106- case smj : SortMergeJoinExec => smj
107- }
108- assert(smjAfterExecution.length === 0 )
122+ val innerJoin = df1.join(df2, col(" key1" ) === col(" key2" )).select(col(" key1" ), col(" value2" ))
123+ checkJoin(innerJoin, spark)
109124
110- val numBhjAfterExecution = join.queryExecution.executedPlan.collect {
111- case smj : BroadcastHashJoinExec => smj
112- }.length
113- assert(numBhjAfterExecution === 1 )
114-
115- val queryStageInputs = join.queryExecution.executedPlan.collect {
116- case q : QueryStageInput => q
117- }
118- assert(queryStageInputs.length === 2 )
125+ val leftJoin =
126+ df1.join(df2, col(" key1" ) === col(" key2" ), " left" ).select(col(" key1" ), col(" value1" ))
127+ checkJoin(leftJoin, spark)
119128 }
120129 }
121130
0 commit comments