Skip to content

Commit ff66add

Browse files
allisonwang-dbcloud-fan
authored andcommitted
[SPARK-40107][SQL][FOLLOW-UP] Update empty2null check
### What changes were proposed in this pull request? This PR is a follow-up for SPARK-40107. It updates the way we check the `empty2null` expression in a V1 write query plan. Previously, we only search for this expression in Project. But optimizer can change the position of this expression, for example collapsing projects with aggregates. As a result, we need to search the entire plan to see if `empty2null` has been added by `V1Writes`. ### Why are the changes needed? To prevent unnecessary `empty2null` projections from being added in FileFormatWriter. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests. Closes #37856 from allisonwang-db/spark-40107-followup. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f92b494 commit ff66add

File tree

4 files changed

+50
-41
lines changed

4 files changed

+50
-41
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,7 @@ object FileFormatWriter extends Logging {
103103
.map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation))
104104
val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains)
105105

106-
val hasEmpty2Null = plan.find {
107-
case p: ProjectExec => V1WritesUtils.hasEmptyToNull(p.projectList)
108-
case _ => false
109-
}.isDefined
106+
val hasEmpty2Null = plan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions))
110107
val empty2NullPlan = if (hasEmpty2Null) {
111108
plan
112109
} else {
@@ -150,14 +147,7 @@ object FileFormatWriter extends Logging {
150147
// the sort order doesn't matter
151148
// Use the output ordering from the original plan before adding the empty2null projection.
152149
val actualOrdering = plan.outputOrdering.map(_.child)
153-
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
154-
false
155-
} else {
156-
requiredOrdering.zip(actualOrdering).forall {
157-
case (requiredOrder, childOutputOrder) =>
158-
requiredOrder.semanticEquals(childOutputOrder)
159-
}
160-
}
150+
val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering)
161151

162152
SQLExecution.checkSQLExecutionId(sparkSession)
163153

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ trait V1WriteCommand extends DataWritingCommand {
4747
* A rule that adds logical sorts to V1 data writing commands.
4848
*/
4949
object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
50+
51+
import V1WritesUtils._
52+
5053
override def apply(plan: LogicalPlan): LogicalPlan = {
5154
if (conf.plannedWriteEnabled) {
5255
plan.transformDown {
@@ -65,10 +68,11 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
6568
}
6669

6770
private def prepareQuery(write: V1WriteCommand, query: LogicalPlan): LogicalPlan = {
68-
val empty2NullPlan = if (hasEmptyToNull(query)) {
71+
val hasEmpty2Null = query.exists(p => hasEmptyToNull(p.expressions))
72+
val empty2NullPlan = if (hasEmpty2Null) {
6973
query
7074
} else {
71-
val projectList = V1WritesUtils.convertEmptyToNull(query.output, write.partitionColumns)
75+
val projectList = convertEmptyToNull(query.output, write.partitionColumns)
7276
if (projectList.isEmpty) query else Project(projectList, query)
7377
}
7478
assert(empty2NullPlan.output.length == query.output.length)
@@ -80,26 +84,13 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
8084
}.asInstanceOf[SortOrder])
8185
val outputOrdering = query.outputOrdering
8286
// Check if the ordering is already matched to ensure the idempotency of the rule.
83-
val orderingMatched = if (requiredOrdering.length > outputOrdering.length) {
84-
false
85-
} else {
86-
requiredOrdering.zip(outputOrdering).forall {
87-
case (requiredOrder, outputOrder) => requiredOrder.semanticEquals(outputOrder)
88-
}
89-
}
87+
val orderingMatched = isOrderingMatched(requiredOrdering, outputOrdering)
9088
if (orderingMatched) {
9189
empty2NullPlan
9290
} else {
9391
Sort(requiredOrdering, global = false, empty2NullPlan)
9492
}
9593
}
96-
97-
private def hasEmptyToNull(plan: LogicalPlan): Boolean = {
98-
plan.find {
99-
case p: Project => V1WritesUtils.hasEmptyToNull(p.projectList)
100-
case _ => false
101-
}.isDefined
102-
}
10394
}
10495

10596
object V1WritesUtils {
@@ -209,4 +200,16 @@ object V1WritesUtils {
209200
def hasEmptyToNull(expressions: Seq[Expression]): Boolean = {
210201
expressions.exists(_.exists(_.isInstanceOf[Empty2Null]))
211202
}
203+
204+
def isOrderingMatched(
205+
requiredOrdering: Seq[Expression],
206+
outputOrdering: Seq[Expression]): Boolean = {
207+
if (requiredOrdering.length > outputOrdering.length) {
208+
false
209+
} else {
210+
requiredOrdering.zip(outputOrdering).forall {
211+
case (requiredOrder, outputOrder) => requiredOrder.semanticEquals(outputOrder)
212+
}
213+
}
214+
}
212215
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,32 @@
1818
package org.apache.spark.sql.execution.datasources
1919

2020
import org.apache.spark.sql.{QueryTest, Row}
21-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort}
21+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort}
2222
import org.apache.spark.sql.execution.QueryExecution
2323
import org.apache.spark.sql.internal.SQLConf
2424
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
2525
import org.apache.spark.sql.util.QueryExecutionListener
2626

27-
abstract class V1WriteCommandSuiteBase extends QueryTest with SQLTestUtils {
27+
trait V1WriteCommandSuiteBase extends SQLTestUtils {
2828

2929
import testImplicits._
3030

3131
setupTestData()
3232

33-
protected override def beforeAll(): Unit = {
33+
override def beforeAll(): Unit = {
3434
super.beforeAll()
3535
(0 to 20).map(i => (i, i % 5, (i % 10).toString))
3636
.toDF("i", "j", "k")
3737
.write
3838
.saveAsTable("t0")
3939
}
4040

41-
protected override def afterAll(): Unit = {
41+
override def afterAll(): Unit = {
4242
sql("drop table if exists t0")
4343
super.afterAll()
4444
}
4545

46-
protected def withPlannedWrite(testFunc: Boolean => Any): Unit = {
46+
def withPlannedWrite(testFunc: Boolean => Any): Unit = {
4747
Seq(true, false).foreach { enabled =>
4848
withSQLConf(SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) {
4949
testFunc(enabled)
@@ -87,19 +87,16 @@ abstract class V1WriteCommandSuiteBase extends QueryTest with SQLTestUtils {
8787
s"Expect hasLogicalSort: $hasLogicalSort, Actual: ${optimizedPlan.isInstanceOf[Sort]}")
8888

8989
// Check empty2null conversion.
90-
val projection = optimizedPlan.collectFirst {
91-
case p: Project
92-
if p.projectList.exists(_.exists(_.isInstanceOf[V1WritesUtils.Empty2Null])) => p
93-
}
94-
assert(projection.isDefined == hasEmpty2Null,
95-
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: ${projection.isDefined}")
90+
val empty2nullExpr = optimizedPlan.exists(p => V1WritesUtils.hasEmptyToNull(p.expressions))
91+
assert(empty2nullExpr == hasEmpty2Null,
92+
s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan")
9693
}
9794

9895
spark.listenerManager.unregister(listener)
9996
}
10097
}
10198

102-
class V1WriteCommandSuite extends V1WriteCommandSuiteBase with SharedSparkSession {
99+
class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1WriteCommandSuiteBase {
103100

104101
import testImplicits._
105102

@@ -277,4 +274,21 @@ class V1WriteCommandSuite extends V1WriteCommandSuiteBase with SharedSparkSessio
277274
}
278275
}
279276
}
277+
278+
test("v1 write with empty2null in aggregate") {
279+
withPlannedWrite { enabled =>
280+
withTable("t") {
281+
executeAndCheckOrdering(
282+
hasLogicalSort = enabled, orderingMatched = enabled, hasEmpty2Null = enabled) {
283+
sql(
284+
"""
285+
|CREATE TABLE t USING PARQUET
286+
|PARTITIONED BY (k) AS
287+
|SELECT SUM(i) AS i, SUM(j) AS j, k
288+
|FROM t0 WHERE i > 0 GROUP BY k
289+
|""".stripMargin)
290+
}
291+
}
292+
}
293+
}
280294
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql.hive.execution.command
1919

20+
import org.apache.spark.sql.QueryTest
2021
import org.apache.spark.sql.execution.datasources.V1WriteCommandSuiteBase
2122
import org.apache.spark.sql.hive.test.TestHiveSingleton
2223

23-
class V1WriteHiveCommandSuite extends V1WriteCommandSuiteBase with TestHiveSingleton {
24+
class V1WriteHiveCommandSuite
25+
extends QueryTest with TestHiveSingleton with V1WriteCommandSuiteBase {
2426

2527
test("create hive table as select - no partition column") {
2628
withPlannedWrite { enabled =>

0 commit comments

Comments
 (0)