@@ -19,9 +19,8 @@ package org.apache.spark.sql
1919
2020import org .apache .spark .sql .catalyst .TableIdentifier
2121import org .apache .spark .sql .catalyst .plans .{Inner , LeftOuter , RightOuter }
22- import org .apache .spark .sql .catalyst .plans .logical .Join
22+ import org .apache .spark .sql .catalyst .plans .logical .{ Join , JoinHint }
2323import org .apache .spark .sql .execution .FileSourceScanExec
24- import org .apache .spark .sql .execution .datasources .LogicalRelation
2524import org .apache .spark .sql .execution .exchange .BroadcastExchangeExec
2625import org .apache .spark .sql .execution .joins .BroadcastHashJoinExec
2726import org .apache .spark .sql .functions ._
@@ -195,43 +194,78 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
195194 assert(plan2.collect { case p : BroadcastHashJoinExec => p }.size == 1 )
196195 }
197196
198- test(" SPARK-25121 Supports multi-part names for broadcast hint resolution" ) {
197+ test(" SPARK-25121 supports multi-part names for broadcast hint resolution" ) {
199198 val (table1Name, table2Name) = (" t1" , " t2" )
199+
200200 withTempDatabase { dbName =>
201201 withTable(table1Name, table2Name) {
202202 withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
203203 spark.range(50 ).write.saveAsTable(s " $dbName. $table1Name" )
204204 spark.range(100 ).write.saveAsTable(s " $dbName. $table2Name" )
205+
205206 // First, makes sure a join is not broadcastable
206207 val plan = sql(s " SELECT * FROM $dbName. $table1Name, $dbName. $table2Name " +
207208 s " WHERE $table1Name.id = $table2Name.id " )
208209 .queryExecution.executedPlan
209- assert(plan.collect { case p : BroadcastHashJoinExec => p }.size == 0 )
210+ assert(plan.collect { case p : BroadcastHashJoinExec => p }.isEmpty )
210211
211- // Uses multi-part table names for broadcast hints
212212 def checkIfHintApplied (tableName : String , hintTableName : String ): Unit = {
213213 val p = sql(s " SELECT /*+ BROADCASTJOIN( $hintTableName) */ * " +
214214 s " FROM $tableName, $dbName. $table2Name " +
215215 s " WHERE $tableName.id = $table2Name.id " )
216216 .queryExecution.executedPlan
217- val broadcastHashJoin = p.collect { case p : BroadcastHashJoinExec => p }
218- assert(broadcastHashJoin .size == 1 )
219- val broadcastExchange = broadcastHashJoin .head.collect {
217+ val broadcastHashJoins = p.collect { case p : BroadcastHashJoinExec => p }
218+ assert(broadcastHashJoins .size == 1 )
219+ val broadcastExchanges = broadcastHashJoins .head.collect {
220220 case p : BroadcastExchangeExec => p
221221 }
222- assert(broadcastExchange .size == 1 )
223- val table = broadcastExchange .head.collect {
222+ assert(broadcastExchanges .size == 1 )
223+ val tables = broadcastExchanges .head.collect {
224224 case FileSourceScanExec (_, _, _, _, _, _, Some (tableIdent)) => tableIdent
225225 }
226- assert(table.size == 1 )
227- assert(table.head === TableIdentifier (table1Name, Some (dbName)))
226+ assert(tables.size == 1 )
227+ assert(tables.head === TableIdentifier (table1Name, Some (dbName)))
228+ }
229+
230+ def checkIfHintNotApplied (tableName : String , hintTableName : String ): Unit = {
231+ val p = sql(s " SELECT /*+ BROADCASTJOIN( $hintTableName) */ * " +
232+ s " FROM $tableName, $dbName. $table2Name " +
233+ s " WHERE $tableName.id = $table2Name.id " )
234+ .queryExecution.executedPlan
235+ val broadcastHashJoins = p.collect { case p : BroadcastHashJoinExec => p }
236+ assert(broadcastHashJoins.isEmpty)
228237 }
229238
230239 sql(s " USE $dbName" )
231240 checkIfHintApplied(table1Name, table1Name)
232241 checkIfHintApplied(s " $dbName. $table1Name" , s " $dbName. $table1Name" )
233- checkIfHintApplied(table1Name, s " $dbName. $table1Name" )
234242 checkIfHintApplied(s " $dbName. $table1Name" , table1Name)
243+ checkIfHintNotApplied(table1Name, s " $dbName. $table1Name" )
244+ checkIfHintNotApplied(s " $dbName. $table1Name" , s " $dbName. $table1Name.id " )
245+ }
246+ }
247+ }
248+ }
249+
250+ test(" SPARK-25121 the same table name exists in two databases for broadcast hint resolution" ) {
251+ val (db1Name, db2Name) = (" db1" , " db2" )
252+
253+ withDatabase(db1Name, db2Name) {
254+ withTable(" t" ) {
255+ withSQLConf(SQLConf .AUTO_BROADCASTJOIN_THRESHOLD .key -> " -1" ) {
256+ sql(s " CREATE DATABASE $db1Name" )
257+ sql(s " CREATE DATABASE $db2Name" )
258+ spark.range(1 ).write.saveAsTable(s " $db1Name.t " )
259+ spark.range(1 ).write.saveAsTable(s " $db2Name.t " )
260+
261+ // Checks if a broadcast hint applied in both sides
262+ val statement = s " SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " +
263+ s " WHERE $db1Name.t.id = $db2Name.t.id "
264+ sql(statement).queryExecution.optimizedPlan match {
265+ case Join (_, _, _, _, JoinHint (Some (leftHint), Some (rightHint))) =>
266+ assert(leftHint.broadcast && rightHint.broadcast)
267+ case _ => fail(" broadcast hint not found in both tables" )
268+ }
235269 }
236270 }
237271 }
0 commit comments