diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 538690725fa01..8b8f0d82732a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -213,7 +213,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// This batch pushes filters and projections into scan nodes. Before this batch, the logical
// plan may contain nodes that do not report stats. Anything that uses stats must run after
// this batch.
- Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*) :+
+ Batch("Early Scan Push-Down", Once, earlyScanPushDownRules: _*) :+
Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats) :+
// Since join costs in AQP can change between multiple runs, there is no reason that we have an
// idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index a4ec48142cfe2..e5be5be27923d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
*
Check Analysis Rules.
* Optimizer Rules.
* Pre CBO Rules.
+ * Early Scan Push-Down
* Planning Strategies.
* Customized Parser.
* (External) Catalog listeners.
@@ -226,6 +227,24 @@ class SparkSessionExtensions {
preCBORules += builder
}
+ private[this] val earlyScanPushDownRules = mutable.Buffer.empty[RuleBuilder]
+
+ private[sql] def buildEarlyScanPushDownRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
+ earlyScanPushDownRules.map(_.apply(session)).toSeq
+ }
+
+ /**
+ * Inject an optimizer `Rule` builder that rewrites logical plans into the [[SparkSession]].
+ * The injected rules will be executed once after the operator optimization batch and
+ * after any push down optimization rules.
+ * 'Pre CBO Rules' and 'Early Scan Push-Down' are executed before and after
+ * `V2ScanRelationPushDown`. So the user can apply the custom rules related to pushdown
+ * after `V2ScanRelationPushDown` fails.
+ */
+ def injectEarlyScanPushDownRule(builder: RuleBuilder): Unit = {
+ earlyScanPushDownRules += builder
+ }
+
private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]
private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 1e60cb8b1db2a..9d959c257de79 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -270,7 +270,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `optimizer` function.
*/
- protected def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil
+ protected def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildEarlyScanPushDownRules(session)
+ }
/**
* Custom rules for rewriting plans after operator optimization and before CBO.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 4994968fdd6ba..e8fe80c92e4f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -95,6 +95,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}
+ test("SPARK-37518: inject a early scan push down rule") {
+ withSession(Seq(_.injectEarlyScanPushDownRule(MyRule))) { session =>
+ assert(session.sessionState.optimizer.earlyScanPushDownRules.contains(MyRule(session)))
+ }
+ }
+
test("inject spark planner strategy") {
withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session =>
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))