Skip to content

Commit a38d623

Browse files
author
Davies Liu
committed
embed condition into SMJ and BroadcastHashJoin
1 parent 34dbc8a commit a38d623

File tree

5 files changed

+88
-72
lines changed

5 files changed

+88
-72
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
2021
import org.apache.spark.sql.{execution, Strategy}
2122
import org.apache.spark.sql.catalyst.InternalRow
2223
import org.apache.spark.sql.catalyst.expressions._
@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
7778
*/
7879
object EquiJoinSelection extends Strategy with PredicateHelper {
7980

80-
private[this] def makeBroadcastHashJoin(
81-
leftKeys: Seq[Expression],
82-
rightKeys: Seq[Expression],
83-
left: LogicalPlan,
84-
right: LogicalPlan,
85-
condition: Option[Expression],
86-
side: joins.BuildSide): Seq[SparkPlan] = {
87-
val broadcastHashJoin = execution.joins.BroadcastHashJoin(
88-
leftKeys, rightKeys, side, planLater(left), planLater(right))
89-
condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
90-
}
91-
9281
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
9382

9483
// --- Inner joins --------------------------------------------------------------------------
9584

9685
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
97-
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
86+
joins.BroadcastHashJoin(
87+
leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
9888

9989
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
100-
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
90+
joins.BroadcastHashJoin(
91+
leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
10192

10293
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
10394
if RowOrdering.isOrderable(leftKeys) =>
104-
val mergeJoin =
105-
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
106-
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
95+
joins.SortMergeJoin(
96+
leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
10797

10898
// --- Outer joins --------------------------------------------------------------------------
10999

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
3939
leftKeys: Seq[Expression],
4040
rightKeys: Seq[Expression],
4141
buildSide: BuildSide,
42+
condition: Option[Expression],
4243
left: SparkPlan,
4344
right: SparkPlan)
4445
extends BinaryNode with HashJoin {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20+
import java.util.NoSuchElementException
21+
2022
import org.apache.spark.sql.catalyst.InternalRow
2123
import org.apache.spark.sql.catalyst.expressions._
2224
import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +31,7 @@ trait HashJoin {
2931
val leftKeys: Seq[Expression]
3032
val rightKeys: Seq[Expression]
3133
val buildSide: BuildSide
34+
val condition: Option[Expression]
3235
val left: SparkPlan
3336
val right: SparkPlan
3437

@@ -50,6 +53,9 @@ trait HashJoin {
5053
protected def streamSideKeyGenerator: Projection =
5154
UnsafeProjection.create(streamedKeys, streamedPlan.output)
5255

56+
@transient private[this] lazy val boundCondition =
57+
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
58+
5359
protected def hashJoin(
5460
streamIter: Iterator[InternalRow],
5561
numStreamRows: LongSQLMetric,
@@ -68,44 +74,50 @@ trait HashJoin {
6874

6975
private[this] val joinKeys = streamSideKeyGenerator
7076

71-
override final def hasNext: Boolean =
72-
(currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
73-
(streamIter.hasNext && fetchNext())
74-
75-
override final def next(): InternalRow = {
76-
val ret = buildSide match {
77-
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
78-
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
79-
}
80-
currentMatchPosition += 1
81-
numOutputRows += 1
82-
resultProjection(ret)
83-
}
77+
hasNext // find the initial match
78+
79+
override final def hasNext: Boolean = {
80+
while (currentMatchPosition >= 0) {
81+
82+
// check if it's end of current matches
83+
if (currentMatchPosition == currentHashMatches.length) {
84+
currentHashMatches = null
85+
currentMatchPosition = -1
86+
87+
while (currentHashMatches == null && streamIter.hasNext) {
88+
currentStreamedRow = streamIter.next()
89+
numStreamRows += 1
90+
val key = joinKeys(currentStreamedRow)
91+
if (!key.anyNull) {
92+
currentHashMatches = hashedRelation.get(key)
93+
}
94+
}
95+
if (currentHashMatches == null) {
96+
return false
97+
}
98+
currentMatchPosition = 0
99+
}
84100

85-
/**
86-
* Searches the streamed iterator for the next row that has at least one match in hashtable.
87-
*
88-
* @return true if the search is successful, and false if the streamed iterator runs out of
89-
* tuples.
90-
*/
91-
private final def fetchNext(): Boolean = {
92-
currentHashMatches = null
93-
currentMatchPosition = -1
94-
95-
while (currentHashMatches == null && streamIter.hasNext) {
96-
currentStreamedRow = streamIter.next()
97-
numStreamRows += 1
98-
val key = joinKeys(currentStreamedRow)
99-
if (!key.anyNull) {
100-
currentHashMatches = hashedRelation.get(key)
101+
// found some matches
102+
buildSide match {
103+
case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
104+
case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
105+
}
106+
if (boundCondition(joinRow)) {
107+
return true
101108
}
102109
}
110+
false
111+
}
103112

104-
if (currentHashMatches == null) {
105-
false
113+
override final def next(): InternalRow = {
114+
// next() could be called without calling hasNext()
115+
if (hasNext) {
116+
currentMatchPosition += 1
117+
numOutputRows += 1
118+
resultProjection(joinRow)
106119
} else {
107-
currentMatchPosition = 0
108-
true
120+
throw new NoSuchElementException
109121
}
110122
}
111123
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
3232
case class SortMergeJoin(
3333
leftKeys: Seq[Expression],
3434
rightKeys: Seq[Expression],
35+
condition: Option[Expression],
3536
left: SparkPlan,
3637
right: SparkPlan) extends BinaryNode {
3738

@@ -64,6 +65,13 @@ case class SortMergeJoin(
6465
val numOutputRows = longMetric("numOutputRows")
6566

6667
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
68+
val boundCondition: (InternalRow) => Boolean = {
69+
condition.map { cond =>
70+
newPredicate(cond, left.output ++ right.output)
71+
}.getOrElse {
72+
(r: InternalRow) => true
73+
}
74+
}
6775
new RowIterator {
6876
// The projection used to extract keys from input rows of the left child.
6977
private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
@@ -89,26 +97,34 @@ case class SortMergeJoin(
8997
private[this] val resultProjection: (InternalRow) => InternalRow =
9098
UnsafeProjection.create(schema)
9199

100+
if (smjScanner.findNextInnerJoinRows()) {
101+
currentRightMatches = smjScanner.getBufferedMatches
102+
currentLeftRow = smjScanner.getStreamedRow
103+
currentMatchIdx = 0
104+
}
105+
92106
override def advanceNext(): Boolean = {
93-
if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
94-
if (smjScanner.findNextInnerJoinRows()) {
95-
currentRightMatches = smjScanner.getBufferedMatches
96-
currentLeftRow = smjScanner.getStreamedRow
97-
currentMatchIdx = 0
98-
} else {
99-
currentRightMatches = null
100-
currentLeftRow = null
101-
currentMatchIdx = -1
107+
while (currentMatchIdx >= 0) {
108+
if (currentMatchIdx == currentRightMatches.length) {
109+
if (smjScanner.findNextInnerJoinRows()) {
110+
currentRightMatches = smjScanner.getBufferedMatches
111+
currentLeftRow = smjScanner.getStreamedRow
112+
currentMatchIdx = 0
113+
} else {
114+
currentRightMatches = null
115+
currentLeftRow = null
116+
currentMatchIdx = -1
117+
return false
118+
}
102119
}
103-
}
104-
if (currentLeftRow != null) {
105120
joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
106121
currentMatchIdx += 1
107-
numOutputRows += 1
108-
true
109-
} else {
110-
false
122+
if (boundCondition(joinRow)) {
123+
numOutputRows += 1
124+
return true
125+
}
111126
}
127+
false
112128
}
113129

114130
override def getRow: InternalRow = resultProjection(joinRow)

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20-
import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
2120
import org.apache.spark.sql.catalyst.expressions.Expression
2221
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2322
import org.apache.spark.sql.catalyst.plans.Inner
2423
import org.apache.spark.sql.catalyst.plans.logical.Join
2524
import org.apache.spark.sql.execution._
2625
import org.apache.spark.sql.test.SharedSQLContext
2726
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
27+
import org.apache.spark.sql.{DataFrame, Row, SQLConf}
2828

2929
class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
3030
import testImplicits.localSeqToDataFrameHolder
@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
8888
leftPlan: SparkPlan,
8989
rightPlan: SparkPlan,
9090
side: BuildSide) = {
91-
val broadcastHashJoin =
92-
execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
93-
boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
91+
joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
9492
}
9593

9694
def makeSortMergeJoin(
@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
10098
leftPlan: SparkPlan,
10199
rightPlan: SparkPlan) = {
102100
val sortMergeJoin =
103-
execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
104-
val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
105-
EnsureRequirements(sqlContext).apply(filteredJoin)
101+
joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
102+
EnsureRequirements(sqlContext).apply(sortMergeJoin)
106103
}
107104

108105
test(s"$testName using BroadcastHashJoin (build=left)") {

0 commit comments

Comments
 (0)