Skip to content

Commit dd6b841

Browse files
committed
normalize outputPartitioning of Project to handle aliases after inner join
1 parent 84dc374 commit dd6b841

File tree

2 files changed

+80
-19
lines changed

2 files changed

+80
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.apache.spark.sql.execution
1818

19-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, SortOrder}
19+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, Expression, NamedExpression, SortOrder}
2020
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
2121

2222
/**
@@ -25,20 +25,14 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition
2525
trait AliasAwareOutputExpression extends UnaryExecNode {
2626
protected def outputExpressions: Seq[NamedExpression]
2727

28-
protected def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined
28+
lazy val aliasMap = AttributeMap(outputExpressions.collect {
29+
case a @ Alias(child: AttributeReference, _) => (child, a.toAttribute)
30+
})
2931

30-
protected def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
31-
exprs.map {
32-
case a: AttributeReference => replaceAlias(a).getOrElse(a)
33-
case other => other
34-
}
35-
}
32+
protected def hasAlias: Boolean = aliasMap.nonEmpty
3633

3734
protected def replaceAlias(attr: AttributeReference): Option[Attribute] = {
38-
outputExpressions.collectFirst {
39-
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
40-
a.toAttribute
41-
}
35+
aliasMap.get(attr)
4236
}
4337
}
4438

@@ -48,13 +42,13 @@ trait AliasAwareOutputExpression extends UnaryExecNode {
4842
*/
4943
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
5044
final override def outputPartitioning: Partitioning = {
51-
if (hasAlias) {
52-
child.outputPartitioning match {
53-
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
54-
case other => other
55-
}
56-
} else {
57-
child.outputPartitioning
45+
child.outputPartitioning match {
46+
case e: Expression if hasAlias =>
47+
val normalizedExp = e.transformDown {
48+
case attr: AttributeReference => replaceAlias(attr).getOrElse(attr)
49+
}
50+
normalizedExp.asInstanceOf[Partitioning]
51+
case other => other
5852
}
5953
}
6054
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,73 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
895895
}
896896
}
897897

898+
test("No extra exchanges in case of [Inner Join -> Project with aliases -> Inner join]") {
899+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
900+
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
901+
withTempView("t1", "t2", "t3") {
902+
spark.range(10).repartition($"id").createTempView("t1")
903+
spark.range(20).repartition($"id").createTempView("t2")
904+
spark.range(30).repartition($"id").createTempView("t3")
905+
val planned = sql(
906+
"""
907+
|SELECT t2id, t3.id as t3id
908+
|FROM (
909+
| SELECT t1.id as t1id, t2.id as t2id
910+
| FROM t1, t2
911+
| WHERE t1.id = t2.id
912+
|) t12, t3
913+
|WHERE t1id = t3.id
914+
""".stripMargin).queryExecution.executedPlan
915+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
916+
assert(exchanges.size == 3)
917+
}
918+
}
919+
}
920+
}
921+
922+
test("No extra exchanges in case of [LeftSemi Join -> Project with aliases -> Inner join]") {
923+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
924+
withTempView("t1", "t2", "t3") {
925+
spark.range(10).repartition($"id").createTempView("t1")
926+
spark.range(20).repartition($"id").createTempView("t2")
927+
spark.range(30).repartition($"id").createTempView("t3")
928+
val planned = sql(
929+
"""
930+
|SELECT t1id, t3.id as t3id
931+
|FROM (
932+
| SELECT t1.id as t1id
933+
| FROM t1 LEFT SEMI JOIN t2
934+
| ON t1.id = t2.id
935+
|) t12 INNER JOIN t3
936+
|WHERE t1id = t3.id
937+
""".stripMargin).queryExecution.executedPlan
938+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
939+
assert(exchanges.size == 3)
940+
}
941+
}
942+
}
943+
944+
test("No extra exchanges in case of [Inner Join -> Project with aliases -> HashAggregate]") {
945+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
946+
withTempView("t1", "t2") {
947+
spark.range(10).repartition($"id").createTempView("t1")
948+
spark.range(20).repartition($"id").createTempView("t2")
949+
val planned = sql(
950+
"""
951+
|SELECT t1id, t2id
952+
|FROM (
953+
| SELECT t1.id as t1id, t2.id as t2id
954+
| FROM t1 INNER JOIN t2
955+
| WHERE t1.id = t2.id
956+
|) t12
957+
|GROUP BY t1id, t2id
958+
""".stripMargin).queryExecution.executedPlan
959+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
960+
assert(exchanges.size == 2)
961+
}
962+
}
963+
}
964+
898965
test("aliases to expressions should not be replaced") {
899966
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
900967
withTempView("df1", "df2") {

0 commit comments

Comments
 (0)