diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index bd8dd6ea3fe0f..f3a71d5ab7ab7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -44,11 +44,14 @@ class ExperimentalMethods private[sql]() { */ @volatile var extraStrategies: Seq[Strategy] = Nil + @volatile var extraPreOptimizations: Seq[Rule[LogicalPlan]] = Nil + @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil override def clone(): ExperimentalMethods = { val result = new ExperimentalMethods result.extraStrategies = extraStrategies + result.extraPreOptimizations = extraPreOptimizations result.extraOptimizations = extraOptimizations result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..b6c3a091ef2f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -28,12 +28,18 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog) { - override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ + val experimentalPreOptimizations: Batch = Batch( + "User Provided Pre Optimizers", fixedPoint, experimentalMethods.extraPreOptimizations: _*) + + val experimentalPostOptimizations: Batch = Batch( + "User Provided Post Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + + override def batches: Seq[Batch] = + ((experimentalPreOptimizations +: preOptimizationBatches) ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ - postHocOptimizationBatches :+ - Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + postHocOptimizationBatches :+ experimentalPostOptimizations /** * Optimization batches that are executed before the regular optimization batches (also before diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index a1799829932b8..60896f2406dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -27,7 +27,11 @@ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructT @deprecated("This suite is deprecated to silent compiler deprecation warnings", "2.0.0") class SQLContextSuite extends SparkFunSuite with SharedSparkContext { - object DummyRule extends Rule[LogicalPlan] { + object DummyPostOptimizationRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + + object DummyPreOptimizationRule extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p } @@ -78,8 +82,14 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { test("Catalyst optimization passes are modifiable at runtime") { val sqlContext = SQLContext.getOrCreate(sc) - sqlContext.experimental.extraOptimizations = Seq(DummyRule) - assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) + sqlContext.experimental.extraOptimizations = Seq(DummyPostOptimizationRule) + sqlContext.experimental.extraPreOptimizations = Seq(DummyPreOptimizationRule) + + val firstBatch = sqlContext.sessionState.optimizer.batches.head + val lastBatch = sqlContext.sessionState.optimizer.batches.last + + assert(firstBatch.rules == Seq(DummyPreOptimizationRule)) + assert(lastBatch.rules == Seq(DummyPostOptimizationRule)) } test("get all tables") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index c01666770720c..f65799ea82073 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -96,7 +96,8 @@ class SessionStateSuite extends SparkFunSuite } test("fork new session and inherit experimental methods") { - val originalExtraOptimizations = activeSession.experimental.extraOptimizations + val originalExtraPostOptimizations = activeSession.experimental.extraOptimizations + val originalExtraPreOptimizations = activeSession.experimental.extraPreOptimizations val originalExtraStrategies = activeSession.experimental.extraStrategies try { object DummyRule1 extends Rule[LogicalPlan] { @@ -105,23 +106,35 @@ class SessionStateSuite extends SparkFunSuite object DummyRule2 extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p } - val optimizations = List(DummyRule1, DummyRule2) - activeSession.experimental.extraOptimizations = optimizations + object DummyRule3 extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + val preOptimizations = List(DummyRule3) + val postOptimizations = List(DummyRule1, DummyRule2) + activeSession.experimental.extraPreOptimizations = preOptimizations + activeSession.experimental.extraOptimizations = postOptimizations val forkedSession = activeSession.cloneSession() // inheritance assert(forkedSession ne activeSession) assert(forkedSession.experimental ne activeSession.experimental) + assert(forkedSession.experimental.extraPreOptimizations.toSet == + activeSession.experimental.extraPreOptimizations.toSet) assert(forkedSession.experimental.extraOptimizations.toSet == activeSession.experimental.extraOptimizations.toSet) // independence + forkedSession.experimental.extraPreOptimizations = List(DummyRule1) forkedSession.experimental.extraOptimizations = List(DummyRule2) - assert(activeSession.experimental.extraOptimizations == optimizations) + assert(activeSession.experimental.extraPreOptimizations == preOptimizations) + assert(activeSession.experimental.extraOptimizations == postOptimizations) + activeSession.experimental.extraPreOptimizations = List(DummyRule3) activeSession.experimental.extraOptimizations = List(DummyRule1) + assert(forkedSession.experimental.extraPreOptimizations == List(DummyRule1)) assert(forkedSession.experimental.extraOptimizations == List(DummyRule2)) } finally { - activeSession.experimental.extraOptimizations = originalExtraOptimizations + activeSession.experimental.extraPreOptimizations = originalExtraPreOptimizations + activeSession.experimental.extraOptimizations = originalExtraPostOptimizations activeSession.experimental.extraStrategies = originalExtraStrategies } }