Skip to content

Commit 7eca60d

Browse files
EnricoMicloud-fan
authored andcommitted
[SPARK-41162][SQL][3.3] Fix anti- and semi-join for self-join with aggregations
### What changes were proposed in this pull request? Backport #39131 to branch-3.3. Rule `PushDownLeftSemiAntiJoin` should not push an anti-join below an `Aggregate` when the join condition references an attribute that exists in its right plan and its left plan's child. This usually happens when the anti-join / semi-join is a self-join while `DeduplicateRelations` cannot deduplicate those attributes (in this example due to the projection of `value` to `id`). This behaviour already exists for `Project` and `Union`, but `Aggregate` lacks this safety guard. ### Why are the changes needed? Without this change, the optimizer creates an incorrect plan. This example fails with `distinct()` (an aggregation), and succeeds without `distinct()`, but both queries are identical: ```scala val ids = Seq(1, 2, 3).toDF("id").distinct() val result = ids.withColumn("id", $"id" + 1).join(ids, Seq("id"), "left_anti").collect() assert(result.length == 1) ``` With `distinct()`, rule `PushDownLeftSemiAntiJoin` creates a join condition `(value#907 + 1) = value#907`, which can never be true. This effectively removes the anti-join. **Before this PR:** The anti-join is fully removed from the plan. ``` == Physical Plan == AdaptiveSparkPlan (16) +- == Final Plan == LocalTableScan (1) (16) AdaptiveSparkPlan Output [1]: [id#900] Arguments: isFinalPlan=true ``` This is caused by `PushDownLeftSemiAntiJoin` adding join condition `(value#907 + 1) = value#907`, which is wrong as because `id#910` in `(id#910 + 1) AS id#912` exists in the right child of the join as well as in the left grandchild: ``` === Applying Rule org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin === !Join LeftAnti, (id#912 = id#910) Aggregate [id#910], [(id#910 + 1) AS id#912] !:- Aggregate [id#910], [(id#910 + 1) AS id#912] +- Project [value#907 AS id#910] !: +- Project [value#907 AS id#910] +- Join LeftAnti, ((value#907 + 1) = value#907) !: +- LocalRelation [value#907] :- LocalRelation [value#907] !+- Aggregate [id#910], [id#910] +- Aggregate [id#910], [id#910] ! +- Project [value#914 AS id#910] +- Project [value#914 AS id#910] ! +- LocalRelation [value#914] +- LocalRelation [value#914] ``` The right child of the join and in the left grandchild would become the children of the pushed-down join, which creates an invalid join condition. **After this PR:** Join condition `(id#910 + 1) AS id#912` is understood to become ambiguous as both sides of the prospect join contain `id#910`. Hence, the join is not pushed down. The rule is then not applied any more. The final plan contains the anti-join: ``` == Physical Plan == AdaptiveSparkPlan (24) +- == Final Plan == * BroadcastHashJoin LeftSemi BuildRight (14) :- * HashAggregate (7) : +- AQEShuffleRead (6) : +- ShuffleQueryStage (5), Statistics(sizeInBytes=48.0 B, rowCount=3) : +- Exchange (4) : +- * HashAggregate (3) : +- * Project (2) : +- * LocalTableScan (1) +- BroadcastQueryStage (13), Statistics(sizeInBytes=1024.0 KiB, rowCount=3) +- BroadcastExchange (12) +- * HashAggregate (11) +- AQEShuffleRead (10) +- ShuffleQueryStage (9), Statistics(sizeInBytes=48.0 B, rowCount=3) +- ReusedExchange (8) (8) ReusedExchange [Reuses operator id: 4] Output [1]: [id#898] (24) AdaptiveSparkPlan Output [1]: [id#900] Arguments: isFinalPlan=true ``` ### Does this PR introduce _any_ user-facing change? It fixes correctness. ### How was this patch tested? Unit tests in `DataFrameJoinSuite` and `LeftSemiAntiJoinPushDownSuite`. Closes #39409 from EnricoMi/branch-antijoin-selfjoin-fix-3.3. Authored-by: Enrico Minack <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit b97f79d) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0f5e231 commit 7eca60d

File tree

3 files changed

+63
-25
lines changed

3 files changed

+63
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
5656
}
5757

5858
// LeftSemi/LeftAnti over Aggregate, only push down if join can be planned as broadcast join.
59-
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
59+
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _)
6060
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
6161
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
62+
canPushThroughCondition(agg.children, joinCond, rightOp) &&
6263
canPlanAsBroadcastHashJoin(join, conf) =>
6364
val aliasMap = getAliasMap(agg)
6465
val canPushDownPredicate = (predicate: Expression) => {
@@ -105,11 +106,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
105106
}
106107

107108
/**
108-
* Check if we can safely push a join through a project or union by making sure that attributes
109-
* referred in join condition do not contain the same attributes as the plan they are moved
110-
* into. This can happen when both sides of join refers to the same source (self join). This
111-
* function makes sure that the join condition refers to attributes that are not ambiguous (i.e
112-
* present in both the legs of the join) or else the resultant plan will be invalid.
109+
* Check if we can safely push a join through a project, aggregate, or union by making sure that
110+
* attributes referred in join condition do not contain the same attributes as the plan they are
111+
* moved into. This can happen when both sides of join refers to the same source (self join).
112+
* This function makes sure that the join condition refers to attributes that are not ambiguous
113+
* (i.e present in both the legs of the join) or else the resultant plan will be invalid.
113114
*/
114115
private def canPushThroughCondition(
115116
plans: Seq[LogicalPlan],

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._
2727
import org.apache.spark.sql.internal.SQLConf
2828
import org.apache.spark.sql.types.IntegerType
2929

30-
class LeftSemiPushdownSuite extends PlanTest {
30+
class LeftSemiAntiJoinPushDownSuite extends PlanTest {
3131

3232
object Optimize extends RuleExecutor[LogicalPlan] {
3333
val batches =
@@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest {
4646
val testRelation1 = LocalRelation('d.int)
4747
val testRelation2 = LocalRelation('e.int)
4848

49-
test("Project: LeftSemiAnti join pushdown") {
49+
test("Project: LeftSemi join pushdown") {
5050
val originalQuery = testRelation
5151
.select(star())
5252
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest {
5959
comparePlans(optimized, correctAnswer)
6060
}
6161

62-
test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") {
62+
test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") {
6363
val originalQuery = testRelation
6464
.select(Rand(1), 'b, 'c)
6565
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest {
6868
comparePlans(optimized, originalQuery.analyze)
6969
}
7070

71-
test("Project: LeftSemiAnti join non correlated scalar subq") {
71+
test("Project: LeftSemi join pushdown - non-correlated scalar subq") {
7272
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
7373
val originalQuery = testRelation
7474
.select(subq.as("sum"))
@@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest {
8383
comparePlans(optimized, correctAnswer)
8484
}
8585

86-
test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") {
86+
test("Project: LeftSemi join no pushdown - correlated scalar subq in projection list") {
8787
val testRelation2 = LocalRelation('e.int, 'f.int)
8888
val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a)
8989
val subqExpr = ScalarSubquery(subqPlan)
@@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest {
9595
comparePlans(optimized, originalQuery.analyze)
9696
}
9797

98-
test("Aggregate: LeftSemiAnti join pushdown") {
98+
test("Aggregate: LeftSemi join pushdown") {
9999
val originalQuery = testRelation
100100
.groupBy('b)('b, sum('c))
101101
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest {
109109
comparePlans(optimized, correctAnswer)
110110
}
111111

112-
test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") {
112+
test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr expressions") {
113113
val originalQuery = testRelation
114114
.groupBy('b)('b, Rand(10).as('c))
115115
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest {
142142
comparePlans(optimized, originalQuery.analyze)
143143
}
144144

145-
test("LeftSemiAnti join over aggregate - no pushdown") {
145+
test("Aggregate: LeftSemi join no pushdown") {
146146
val originalQuery = testRelation
147147
.groupBy('b)('b, sum('c).as('sum))
148148
.join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd))
@@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest {
151151
comparePlans(optimized, originalQuery.analyze)
152152
}
153153

154-
test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
154+
test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr exprs") {
155155
val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
156156
val originalQuery = testRelation
157157
.groupBy('a) ('a, subq.as("sum"))
@@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest {
166166
comparePlans(optimized, correctAnswer)
167167
}
168168

169-
test("LeftSemiAnti join over Window") {
169+
test("Window: LeftSemi join pushdown") {
170170
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
171171

172172
val originalQuery = testRelation
@@ -184,7 +184,7 @@ class LeftSemiPushdownSuite extends PlanTest {
184184
comparePlans(optimized, correctAnswer)
185185
}
186186

187-
test("Window: LeftSemi partial pushdown") {
187+
test("Window: LeftSemi join partial pushdown") {
188188
// Attributes from join condition which does not refer to the window partition spec
189189
// are kept up in the plan as a Filter operator above Window.
190190
val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame))
@@ -224,7 +224,7 @@ class LeftSemiPushdownSuite extends PlanTest {
224224
comparePlans(optimized, correctAnswer)
225225
}
226226

227-
test("Union: LeftSemiAnti join pushdown") {
227+
test("Union: LeftSemi join pushdown") {
228228
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
229229

230230
val originalQuery = Union(Seq(testRelation, testRelation2))
@@ -240,7 +240,7 @@ class LeftSemiPushdownSuite extends PlanTest {
240240
comparePlans(optimized, correctAnswer)
241241
}
242242

243-
test("Union: LeftSemiAnti join pushdown in self join scenario") {
243+
test("Union: LeftSemi join pushdown in self join scenario") {
244244
val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
245245
val attrX = testRelation2.output.head
246246

@@ -259,7 +259,7 @@ class LeftSemiPushdownSuite extends PlanTest {
259259
comparePlans(optimized, correctAnswer)
260260
}
261261

262-
test("Unary: LeftSemiAnti join pushdown") {
262+
test("Unary: LeftSemi join pushdown") {
263263
val originalQuery = testRelation
264264
.select(star())
265265
.repartition(1)
@@ -274,7 +274,7 @@ class LeftSemiPushdownSuite extends PlanTest {
274274
comparePlans(optimized, correctAnswer)
275275
}
276276

277-
test("Unary: LeftSemiAnti join pushdown - empty join condition") {
277+
test("Unary: LeftSemi join pushdown - empty join condition") {
278278
val originalQuery = testRelation
279279
.select(star())
280280
.repartition(1)
@@ -289,7 +289,7 @@ class LeftSemiPushdownSuite extends PlanTest {
289289
comparePlans(optimized, correctAnswer)
290290
}
291291

292-
test("Unary: LeftSemi join pushdown - partial pushdown") {
292+
test("Unary: LeftSemi join partial pushdown") {
293293
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
294294
val originalQuery = testRelationWithArrayType
295295
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
@@ -305,7 +305,7 @@ class LeftSemiPushdownSuite extends PlanTest {
305305
comparePlans(optimized, correctAnswer)
306306
}
307307

308-
test("Unary: LeftAnti join pushdown - no pushdown") {
308+
test("Unary: LeftAnti join no pushdown") {
309309
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
310310
val originalQuery = testRelationWithArrayType
311311
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
@@ -315,7 +315,7 @@ class LeftSemiPushdownSuite extends PlanTest {
315315
comparePlans(optimized, originalQuery.analyze)
316316
}
317317

318-
test("Unary: LeftSemiAnti join pushdown - no pushdown") {
318+
test("Unary: LeftSemi join - no pushdown") {
319319
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
320320
val originalQuery = testRelationWithArrayType
321321
.generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col"))
@@ -325,7 +325,7 @@ class LeftSemiPushdownSuite extends PlanTest {
325325
comparePlans(optimized, originalQuery.analyze)
326326
}
327327

328-
test("Unary: LeftSemi join push down through Expand") {
328+
test("Unary: LeftSemi join pushdown through Expand") {
329329
val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)),
330330
Seq('a, 'b, 'c), testRelation)
331331
val originalQuery = expand
@@ -431,6 +431,25 @@ class LeftSemiPushdownSuite extends PlanTest {
431431
}
432432
}
433433

434+
Seq(LeftSemi, LeftAnti).foreach { case jt =>
435+
test(s"Aggregate: $jt join no pushdown - join condition refers left leg and right leg child") {
436+
val aggregation = testRelation
437+
.select('b.as("id"), 'c)
438+
.groupBy('id)('id, sum('c).as("sum"))
439+
440+
// reference "b" exists in left leg, and the children of the right leg of the join
441+
val originalQuery = aggregation.select(('id + 1).as("id_plus_1"), 'sum)
442+
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
443+
val optimized = Optimize.execute(originalQuery.analyze)
444+
val correctAnswer = testRelation
445+
.select('b.as("id"), 'c)
446+
.groupBy('id)(('id + 1).as("id_plus_1"), sum('c).as("sum"))
447+
.join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
448+
.analyze
449+
comparePlans(optimized, correctAnswer)
450+
}
451+
}
452+
434453
Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
435454
Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
436455
test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest
288288
}
289289
}
290290

291+
Seq("left_semi", "left_anti").foreach { joinType =>
292+
test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
293+
// aggregated dataframe
294+
val ids = Seq(1, 2, 3).toDF("id").distinct()
295+
296+
// self-joined via joinType
297+
val result = ids.withColumn("id", $"id" + 1)
298+
.join(ids, usingColumns = Seq("id"), joinType = joinType).collect()
299+
300+
val expected = joinType match {
301+
case "left_semi" => 2
302+
case "left_anti" => 1
303+
case _ => -1 // unsupported test type, test will always fail
304+
}
305+
assert(result.length == expected)
306+
}
307+
}
308+
291309
def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan match {
292310
case j @ Join(left, right, _: InnerLike, _, _) => right +: extractLeftDeepInnerJoins(left)
293311
case Filter(_, child) => extractLeftDeepInnerJoins(child)

0 commit comments

Comments
 (0)