Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ object ResolveHints {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
applyBroadcastHint(h.child, h.parameters.toSet)
if (h.parameters.isEmpty) {
// If there is no table alias specified, turn the entire subtree into a BroadcastHint.
BroadcastHint(h.child)
} else {
// Otherwise, find within the subtree query plans that should be broadcasted.
applyBroadcastHint(h.child, h.parameters.toSet)
}
}
}

Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,22 @@ class Dataset[T] private[sql](
*/
def apply(colName: String): Column = col(colName)

/**
* Specifies some hint on the current Dataset. As an example, the following code specifies
* that one of the plan can be broadcasted:
*
* {{{
* df1.join(df2.hint("broadcast"))
* }}}
*
* @group basic
* @since 2.2.0
*/
@scala.annotation.varargs
def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan {
Hint(name, parameters, logicalPlan)
}

/**
* Selects column based on the column name and return it as a [[Column]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)
}

test("broadcast join hint") {
test("broadcast join hint using broadcast function") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value")

Expand All @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
}
}

test("broadcast join hint using Dataset.hint") {
// make sure a giant join is not broadcastable
val plan1 =
spark.range(10e10.toLong)
.join(spark.range(10e10.toLong), "id")
.queryExecution.executedPlan
assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0)

// now with a hint it should be broadcasted
val plan2 =
spark.range(10e10.toLong)
.join(spark.range(10e10.toLong).hint("broadcast"), "id")
.queryExecution.executedPlan
assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1)
}

test("join - outer join conversion") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b")
Expand Down