Skip to content

Commit 294f605

Browse files
imback82cloud-fan
authored andcommitted
[SPARK-31078][SQL] Respect aliases in output ordering
### What changes were proposed in this pull request? Currently, in the following scenario, an unnecessary `Sort` node is introduced: ```scala withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { val df = (0 until 20).toDF("i").as("df") df.repartition(8, df("i")).write.format("parquet") .bucketBy(8, "i").sortBy("i").saveAsTable("t") val t1 = spark.table("t") val t2 = t1.selectExpr("i as ii") t1.join(t2, t1("i") === t2("ii")).explain } ``` ``` == Physical Plan == *(3) SortMergeJoin [i#8], [ii#10], Inner :- *(1) Project [i#8] : +- *(1) Filter isnotnull(i#8) : +- *(1) ColumnarToRow : +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8 +- *(2) Sort [ii#10 ASC NULLS FIRST], false, 0 <==== UNNECESSARY +- *(2) Project [i#8 AS ii#10] +- *(2) Filter isnotnull(i#8) +- *(2) ColumnarToRow +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8 ``` Notice that `Sort [ii#10 ASC NULLS FIRST], false, 0` is introduced even though the underlying data is already sorted. This is because `outputOrdering` doesn't handle aliases correctly. This PR proposes to fix this issue. ### Why are the changes needed? To better handle aliases in `outputOrdering`. ### Does this PR introduce any user-facing change? Yes, now with the fix, the `explain` prints out the following: ``` == Physical Plan == *(3) SortMergeJoin [i#8], [ii#10], Inner :- *(1) Project [i#8] : +- *(1) Filter isnotnull(i#8) : +- *(1) ColumnarToRow : +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8 +- *(2) Project [i#8 AS ii#10] +- *(2) Filter isnotnull(i#8) +- *(2) ColumnarToRow +- FileScan parquet default.t[i#8] Batched: true, DataFilters: [isnotnull(i#8)], Format: Parquet, Location: InMemoryFileIndex[file:/..., PartitionFilters: [], PushedFilters: [IsNotNull(i)], ReadSchema: struct<i:int>, SelectedBucketsCount: 8 out of 8 ``` ### How was this patch tested? Tests added. Closes #27842 from imback82/alias_aware_sort_order. Authored-by: Terry Kim <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 15df2a3 commit 294f605

File tree

6 files changed

+87
-23
lines changed

6 files changed

+87
-23
lines changed
Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,37 @@
1616
*/
1717
package org.apache.spark.sql.execution
1818

19-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
19+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, SortOrder}
2020
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
2121

2222
/**
23-
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning`
24-
* that satisfies output distribution requirements.
23+
* A trait that provides functionality to handle aliases in the `outputExpressions`.
2524
*/
26-
trait AliasAwareOutputPartitioning extends UnaryExecNode {
25+
trait AliasAwareOutputExpression extends UnaryExecNode {
2726
protected def outputExpressions: Seq[NamedExpression]
2827

28+
protected def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined
29+
30+
protected def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
31+
exprs.map {
32+
case a: AttributeReference => replaceAlias(a).getOrElse(a)
33+
case other => other
34+
}
35+
}
36+
37+
protected def replaceAlias(attr: AttributeReference): Option[Attribute] = {
38+
outputExpressions.collectFirst {
39+
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
40+
a.toAttribute
41+
}
42+
}
43+
}
44+
45+
/**
46+
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that
47+
* satisfies distribution requirements.
48+
*/
49+
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
2950
final override def outputPartitioning: Partitioning = {
3051
if (hasAlias) {
3152
child.outputPartitioning match {
@@ -36,20 +57,25 @@ trait AliasAwareOutputPartitioning extends UnaryExecNode {
3657
child.outputPartitioning
3758
}
3859
}
60+
}
3961

40-
private def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined
41-
42-
private def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
43-
exprs.map {
44-
case a: AttributeReference => replaceAlias(a).getOrElse(a)
45-
case other => other
46-
}
47-
}
62+
/**
63+
* A trait that handles aliases in the `orderingExpressions` to produce `outputOrdering` that
64+
* satisfies ordering requirements.
65+
*/
66+
trait AliasAwareOutputOrdering extends AliasAwareOutputExpression {
67+
protected def orderingExpressions: Seq[SortOrder]
4868

49-
private def replaceAlias(attr: AttributeReference): Option[Attribute] = {
50-
outputExpressions.collectFirst {
51-
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
52-
a.toAttribute
69+
final override def outputOrdering: Seq[SortOrder] = {
70+
if (hasAlias) {
71+
orderingExpressions.map { s =>
72+
s.child match {
73+
case a: AttributeReference => s.copy(child = replaceAlias(a).getOrElse(a))
74+
case _ => s
75+
}
76+
}
77+
} else {
78+
orderingExpressions
5379
}
5480
}
5581
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ case class HashAggregateExec(
5353
initialInputBufferOffset: Int,
5454
resultExpressions: Seq[NamedExpression],
5555
child: SparkPlan)
56-
extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning {
56+
extends BaseAggregateExec
57+
with BlockingOperatorWithCodegen
58+
with AliasAwareOutputPartitioning {
5759

5860
private[this] val aggregateBufferAttributes = {
5961
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate._
2525
import org.apache.spark.sql.catalyst.plans.physical._
2626
import org.apache.spark.sql.catalyst.util.truncatedString
27-
import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan}
27+
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, AliasAwareOutputPartitioning, SparkPlan}
2828
import org.apache.spark.sql.execution.metric.SQLMetrics
2929

3030
/**
@@ -38,7 +38,9 @@ case class SortAggregateExec(
3838
initialInputBufferOffset: Int,
3939
resultExpressions: Seq[NamedExpression],
4040
child: SparkPlan)
41-
extends BaseAggregateExec with AliasAwareOutputPartitioning {
41+
extends BaseAggregateExec
42+
with AliasAwareOutputPartitioning
43+
with AliasAwareOutputOrdering {
4244

4345
private[this] val aggregateBufferAttributes = {
4446
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -68,7 +70,7 @@ case class SortAggregateExec(
6870

6971
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
7072

71-
override def outputOrdering: Seq[SortOrder] = {
73+
override protected def orderingExpressions: Seq[SortOrder] = {
7274
groupingExpressions.map(SortOrder(_, Ascending))
7375
}
7476

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
3939

4040
/** Physical plan for Project. */
4141
case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
42-
extends UnaryExecNode with CodegenSupport with AliasAwareOutputPartitioning {
42+
extends UnaryExecNode
43+
with CodegenSupport
44+
with AliasAwareOutputPartitioning
45+
with AliasAwareOutputOrdering {
4346

4447
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
4548

@@ -80,10 +83,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
8083
}
8184
}
8285

83-
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
84-
8586
override protected def outputExpressions: Seq[NamedExpression] = projectList
8687

88+
override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering
89+
8790
override def verboseStringWithOperatorId(): String = {
8891
s"""
8992
|(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,25 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
975975
}
976976
}
977977
}
978+
979+
test("aliases in the sort aggregate expressions should not introduce extra sort") {
980+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
981+
withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
982+
val t1 = spark.range(10).selectExpr("floor(id/4) as k1")
983+
val t2 = spark.range(20).selectExpr("floor(id/4) as k2")
984+
985+
val agg1 = t1.groupBy("k1").agg(collect_list("k1")).withColumnRenamed("k1", "k3")
986+
val agg2 = t2.groupBy("k2").agg(collect_list("k2"))
987+
988+
val planned = agg1.join(agg2, $"k3" === $"k2").queryExecution.executedPlan
989+
assert(planned.collect { case s: SortAggregateExec => s }.nonEmpty)
990+
991+
// We expect two SortExec nodes on each side of join.
992+
val sorts = planned.collect { case s: SortExec => s }
993+
assert(sorts.size == 4)
994+
}
995+
}
996+
}
978997
}
979998

980999
// Used for unit-testing EnsureRequirements

sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,18 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
604604
}
605605
}
606606

607+
test("sort should not be introduced when aliases are used") {
608+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
609+
withTable("t") {
610+
df1.repartition(1).write.format("parquet").bucketBy(8, "i").sortBy("i").saveAsTable("t")
611+
val t1 = spark.table("t")
612+
val t2 = t1.selectExpr("i as ii")
613+
val plan = t1.join(t2, t1("i") === t2("ii")).queryExecution.executedPlan
614+
assert(plan.collect { case sort: SortExec => sort }.isEmpty)
615+
}
616+
}
617+
}
618+
607619
test("bucket join should work with SubqueryAlias plan") {
608620
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
609621
withTable("t") {

0 commit comments

Comments
 (0)