Skip to content

Commit b6b9f65

Browse files
committed
Fix
1 parent 5b2b272 commit b6b9f65

File tree

5 files changed

+70
-42
lines changed

5 files changed

+70
-42
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,30 +54,23 @@ object ResolveHints {
5454

5555
def resolver: Resolver = conf.resolver
5656

57-
private def namePartsWithDatabase(nameParts: Seq[String], database: String): Seq[String] = {
58-
if (nameParts.size == 1) {
59-
database +: nameParts
60-
} else {
61-
nameParts
62-
}
63-
}
64-
57+
// Name resolution in hints follows three rules below:
58+
//
59+
// 1. table name matches if the hint table name only has one part
60+
// 2. table name and database name both match if the hint table name has two parts
61+
// 3. no match happens if the hint table name has more than three parts
62+
//
63+
// This means, `SELECT /* BROADCAST(t) */ * FROM db1.t JOIN db2.t` will match both tables, and
64+
// `SELECT /* BROADCAST(default.t) */ * FROM t` match no table.
6565
private def matchedTableIdentifier(
6666
nameParts: Seq[String],
67-
tableIdent: IdentifierWithDatabase): Boolean = {
68-
tableIdent.database match {
69-
case Some(db) if resolver(catalog.globalTempViewManager.database, db) =>
70-
val identifierList = db :: tableIdent.identifier :: Nil
71-
namePartsWithDatabase(nameParts, catalog.globalTempViewManager.database)
72-
.corresponds(identifierList)(resolver)
73-
case None if catalog.getTempView(tableIdent.identifier).isDefined =>
74-
nameParts.size == 1 && resolver(nameParts.head, tableIdent.identifier)
75-
case _ =>
76-
val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase)
77-
val identifierList = db :: tableIdent.identifier :: Nil
78-
namePartsWithDatabase(nameParts, catalog.getCurrentDatabase)
79-
.corresponds(identifierList)(resolver)
80-
}
67+
tableIdent: IdentifierWithDatabase): Boolean = nameParts match {
68+
case Seq(tableName) =>
69+
resolver(tableIdent.identifier, tableName)
70+
case Seq(dbName, tableName) if tableIdent.database.isDefined =>
71+
resolver(tableIdent.database.get, dbName) && resolver(tableIdent.identifier, tableName)
72+
case _ =>
73+
false
8174
}
8275

8376
private def applyBroadcastHint(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class ResolveHintsSuite extends AnalysisTest {
156156
Seq(errMsgRepa))
157157
}
158158

159-
test("Supports multi-part table names for broadcast hint resolution") {
159+
test("supports multi-part table names for broadcast hint resolution") {
160160
// local temp table
161161
checkAnalysis(
162162
UnresolvedHint("MAPJOIN", Seq("table", "table2"),

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.catalyst.TableIdentifier
2121
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
22-
import org.apache.spark.sql.catalyst.plans.logical.Join
22+
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
2323
import org.apache.spark.sql.execution.FileSourceScanExec
24-
import org.apache.spark.sql.execution.datasources.LogicalRelation
2524
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
2625
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
2726
import org.apache.spark.sql.functions._
@@ -195,43 +194,78 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
195194
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1)
196195
}
197196

198-
test("SPARK-25121 Supports multi-part names for broadcast hint resolution") {
197+
test("SPARK-25121 supports multi-part names for broadcast hint resolution") {
199198
val (table1Name, table2Name) = ("t1", "t2")
199+
200200
withTempDatabase { dbName =>
201201
withTable(table1Name, table2Name) {
202202
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
203203
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
204204
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
205+
205206
// First, makes sure a join is not broadcastable
206207
val plan = sql(s"SELECT * FROM $dbName.$table1Name, $dbName.$table2Name " +
207208
s"WHERE $table1Name.id = $table2Name.id")
208209
.queryExecution.executedPlan
209-
assert(plan.collect { case p: BroadcastHashJoinExec => p }.size == 0)
210+
assert(plan.collect { case p: BroadcastHashJoinExec => p }.isEmpty)
210211

211-
// Uses multi-part table names for broadcast hints
212212
def checkIfHintApplied(tableName: String, hintTableName: String): Unit = {
213213
val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
214214
s"FROM $tableName, $dbName.$table2Name " +
215215
s"WHERE $tableName.id = $table2Name.id")
216216
.queryExecution.executedPlan
217-
val broadcastHashJoin = p.collect { case p: BroadcastHashJoinExec => p }
218-
assert(broadcastHashJoin.size == 1)
219-
val broadcastExchange = broadcastHashJoin.head.collect {
217+
val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p }
218+
assert(broadcastHashJoins.size == 1)
219+
val broadcastExchanges = broadcastHashJoins.head.collect {
220220
case p: BroadcastExchangeExec => p
221221
}
222-
assert(broadcastExchange.size == 1)
223-
val table = broadcastExchange.head.collect {
222+
assert(broadcastExchanges.size == 1)
223+
val tables = broadcastExchanges.head.collect {
224224
case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => tableIdent
225225
}
226-
assert(table.size == 1)
227-
assert(table.head === TableIdentifier(table1Name, Some(dbName)))
226+
assert(tables.size == 1)
227+
assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
228+
}
229+
230+
def checkIfHintNotApplied(tableName: String, hintTableName: String): Unit = {
231+
val p = sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
232+
s"FROM $tableName, $dbName.$table2Name " +
233+
s"WHERE $tableName.id = $table2Name.id")
234+
.queryExecution.executedPlan
235+
val broadcastHashJoins = p.collect { case p: BroadcastHashJoinExec => p }
236+
assert(broadcastHashJoins.isEmpty)
228237
}
229238

230239
sql(s"USE $dbName")
231240
checkIfHintApplied(table1Name, table1Name)
232241
checkIfHintApplied(s"$dbName.$table1Name", s"$dbName.$table1Name")
233-
checkIfHintApplied(table1Name, s"$dbName.$table1Name")
234242
checkIfHintApplied(s"$dbName.$table1Name", table1Name)
243+
checkIfHintNotApplied(table1Name, s"$dbName.$table1Name")
244+
checkIfHintNotApplied(s"$dbName.$table1Name", s"$dbName.$table1Name.id")
245+
}
246+
}
247+
}
248+
}
249+
250+
test("SPARK-25121 the same table name exists in two databases for broadcast hint resolution") {
251+
val (db1Name, db2Name) = ("db1", "db2")
252+
253+
withDatabase(db1Name, db2Name) {
254+
withTable("t") {
255+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
256+
sql(s"CREATE DATABASE $db1Name")
257+
sql(s"CREATE DATABASE $db2Name")
258+
spark.range(1).write.saveAsTable(s"$db1Name.t")
259+
spark.range(1).write.saveAsTable(s"$db2Name.t")
260+
261+
// Checks if a broadcast hint applied in both sides
262+
val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, $db2Name.t " +
263+
s"WHERE $db1Name.t.id = $db2Name.t.id"
264+
sql(statement).queryExecution.optimizedPlan match {
265+
case Join(_, _, _, _, JoinHint(Some(leftHint), Some(rightHint))) =>
266+
assert(leftHint.broadcast && rightHint.broadcast)
267+
case _ => fail("broadcast hint not found in both tables")
268+
}
235269
}
236270
}
237271
}

sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
2020
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
2121
import org.apache.spark.sql.catalog.Table
2222
import org.apache.spark.sql.catalyst.TableIdentifier
23-
import org.apache.spark.sql.catalyst.plans.logical.{Join, ResolvedHint}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
2424
import org.apache.spark.sql.internal.SQLConf
2525
import org.apache.spark.sql.test.SharedSQLContext
2626
import org.apache.spark.sql.types.StructType
@@ -169,9 +169,10 @@ class GlobalTempViewSuite extends QueryTest with SharedSQLContext {
169169
"SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id",
170170
"SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 WHERE v1.id = v2.id"
171171
).foreach { statement =>
172-
val plan = sql(statement).queryExecution.optimizedPlan
173-
assert(plan.asInstanceOf[Join].left.isInstanceOf[ResolvedHint])
174-
assert(!plan.asInstanceOf[Join].right.isInstanceOf[ResolvedHint])
172+
sql(statement).queryExecution.optimizedPlan match {
173+
case Join(_, _, _, _, JoinHint(Some(leftHint), None)) => assert(leftHint.broadcast)
174+
case _ => fail("broadcast hint not found in a left-side table")
175+
}
175176
}
176177
}
177178
}

sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,8 +738,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils {
738738
assert(broadcastData.head.identifier === "tv")
739739

740740
val sparkPlan = df.queryExecution.executedPlan
741-
val broadcastHashJoin = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
742-
assert(broadcastHashJoin.size == 1)
741+
val broadcastHashJoins = sparkPlan.collect { case p: BroadcastHashJoinExec => p }
742+
assert(broadcastHashJoins.size == 1)
743743
}
744744
}
745745
}

0 commit comments

Comments
 (0)