Skip to content

Commit b877de7

Browse files
committed
address PR comments
1 parent 323d4a7 commit b877de7

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputPartitioning.scala

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.apache.spark.sql.execution
1818

19-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, NamedExpression}
19+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
2020
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
2121

2222
/**
@@ -27,14 +27,21 @@ trait AliasAwareOutputPartitioning extends UnaryExecNode {
2727
protected def outputExpressions: Seq[NamedExpression]
2828

2929
final override def outputPartitioning: Partitioning = {
30-
child.outputPartitioning match {
31-
case HashPartitioning(expressions, numPartitions) =>
32-
val newExpressions = expressions.map {
33-
case a: AttributeReference =>
34-
replaceAlias(a).getOrElse(a)
35-
case other => other
36-
}
37-
HashPartitioning(newExpressions, numPartitions)
30+
if (hasAlias) {
31+
child.outputPartitioning match {
32+
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
33+
case other => other
34+
}
35+
} else {
36+
child.outputPartitioning
37+
}
38+
}
39+
40+
private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined
41+
42+
private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
43+
exprs.map {
44+
case a: AttributeReference => replaceAlias(a).getOrElse(a)
3845
case other => other
3946
}
4047
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range, Repartition, Sort, Union}
2626
import org.apache.spark.sql.catalyst.plans.physical._
2727
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
28+
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
2829
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
2930
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
3031
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -990,11 +991,40 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
990991

991992
val agg1 = t1.groupBy("k1").agg(count(lit("1")).as("cnt1"))
992993
val agg2 = t2.groupBy("k2").agg(count(lit("1")).as("cnt2")).withColumnRenamed("k2", "k3")
994+
993995
val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan
996+
997+
assert(planned.collect { case h: HashAggregateExec => h }.nonEmpty)
998+
994999
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
9951000
assert(exchanges.size == 2)
9961001
}
9971002
}
1003+
1004+
test("aliases in the object hash/sort aggregate expressions should not introduce extra shuffle") {
1005+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
1006+
Seq(true, false).foreach { useObjectHashAgg =>
1007+
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> useObjectHashAgg.toString) {
1008+
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
1009+
val t2 = spark.range(10).selectExpr("floor(id/4) as k2")
1010+
1011+
val agg1 = t1.groupBy("k1").agg(collect_list("k1"))
1012+
val agg2 = t2.groupBy("k2").agg(collect_list("k2")).withColumnRenamed("k2", "k3")
1013+
1014+
val planned = agg1.join(agg2, $"k1" === $"k3").queryExecution.executedPlan
1015+
1016+
if (useObjectHashAgg) {
1017+
assert(planned.collect { case o: ObjectHashAggregateExec => o }.nonEmpty)
1018+
} else {
1019+
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
1020+
}
1021+
1022+
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
1023+
assert(exchanges.size == 2)
1024+
}
1025+
}
1026+
}
1027+
}
9981028
}
9991029

10001030
// Used for unit-testing EnsureRequirements

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -608,16 +608,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
608608
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
609609
withTable("t") {
610610
withView("v") {
611-
val df = (0 until 20).map(i => (i, i)).toDF("i", "j").as("df")
612-
df.write.format("parquet").bucketBy(8, "i").saveAsTable("t")
613-
611+
spark.range(20).selectExpr("id as i").write.bucketBy(8, "i").saveAsTable("t")
614612
sql("CREATE VIEW v AS SELECT * FROM t").collect()
615613

616-
val plan1 = sql("SELECT * FROM t a JOIN t b ON a.i = b.i").queryExecution.executedPlan
617-
assert(plan1.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
618-
619-
val plan2 = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan
620-
assert(plan2.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
614+
val plan = sql("SELECT * FROM t a JOIN v b ON a.i = b.i").queryExecution.executedPlan
615+
assert(plan.collect { case exchange: ShuffleExchangeExec => exchange }.isEmpty)
621616
}
622617
}
623618
}

0 commit comments

Comments
 (0)