Skip to content

Commit feabc1f

Browse files
add nullability check
1 parent 6d5cade commit feabc1f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import org.apache.spark.sql.catalyst.expressions.Attribute
2021
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, PartialMerge}
2122
import org.apache.spark.sql.catalyst.rules.Rule
2223
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
@@ -68,6 +69,13 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] {
6869
}
6970
}
7071

72+
/**
73+
* Check if the nullability change is positive. It catches the case when the project output
74+
* attribute is not nullable, but the child output attribute is nullable.
75+
*/
76+
private def checkNullability(output: Seq[Attribute], childOutput: Seq[Attribute]): Boolean =
77+
output.zip(childOutput).forall { case (attr1, attr2) => attr1.nullable || !attr2.nullable }
78+
7179
private def isRedundant(
7280
project: ProjectExec,
7381
child: SparkPlan,
@@ -78,9 +86,11 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] {
7886
case d: DataSourceV2ScanExecBase if !d.supportsColumnar => false
7987
case _ =>
8088
if (requireOrdering) {
81-
project.output.map(_.exprId.id) == child.output.map(_.exprId.id)
89+
project.output.map(_.exprId.id) == child.output.map(_.exprId.id) &&
90+
checkNullability(project.output, child.output)
8291
} else {
83-
project.output.map(_.exprId.id).sorted == child.output.map(_.exprId.id).sorted
92+
project.output.map(_.exprId.id).sorted == child.output.map(_.exprId.id).sorted &&
93+
checkNullability(project.output, child.output)
8494
}
8595
}
8696
}

0 commit comments

Comments
 (0)