diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 67c57ec2787c3..7ef98083eb589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -122,6 +122,14 @@ object ExternalCatalogUtils { } } + def getPartitionSpecString(value: String): String = { + if (value == null || value.isEmpty) { + null + } else { + value + } + } + def getPartitionValueString(value: String): String = { if (value == null || value.isEmpty) { DEFAULT_PARTITION_NAME diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index de25c19a26eb8..633fea1ab7553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -420,6 +420,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ELIMINATE_DYNAMIC_PARTITION_WRITES = + buildConf("spark.sql.optimizer.eliminateDynamicPartitionWrites") + .internal() + .doc("When set to true, Spark optimizer will infer if the partition column is static and " + + "convert it to static partition.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed") .doc("When set to true Spark SQL will automatically select a compression codec for each " + "column based on statistics of the data.") @@ -4684,6 +4693,9 @@ class SQLConf extends Serializable with Logging { def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED) + def eliminateDynamicPartitionWrites: Boolean = + getConf(SQLConf.ELIMINATE_DYNAMIC_PARTITION_WRITES) + def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT) def legacyInferArrayTypeFromFirstElement: Boolean = getConf( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 72bdab409a9e6..c38a0f0575a5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} +import org.apache.spark.sql.execution.datasources.{EliminateV1DynamicPartitionWrites, PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -38,6 +38,7 @@ class SparkOptimizer( // TODO: move SchemaPruning into catalyst Seq(SchemaPruning) :+ GroupBasedRowLevelOperationScanPlanning :+ + EliminateV1DynamicPartitionWrites :+ V1Writes :+ V2ScanRelationPushDown :+ V2ScanPartitioningAndOrdering :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 6cc356488393c..2aeb685f39561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -142,7 +142,8 @@ case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, query: LogicalPlan, - outputColumnNames: Seq[String]) + outputColumnNames: Seq[String], + staticPartitions: Map[String, String] = Map.empty) extends V1WriteCommand { override lazy val partitionColumns: Seq[Attribute] = { @@ -156,7 +157,8 @@ case class CreateDataSourceTableAsSelectCommand( override def requiredOrdering: Seq[SortOrder] = { val options = table.storage.properties - V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options) + V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options, + numStaticPartitionCols) } override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { @@ -181,8 +183,8 @@ case class CreateDataSourceTableAsSelectCommand( return Seq.empty } - saveDataIntoTable( - sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true) + saveDataIntoTable(sparkSession, table, table.storage.locationUri, child, SaveMode.Append, + tableExists = true, staticPartitions) } else { table.storage.locationUri.foreach { p => DataWritingCommand.assertEmptyRootPath(p, mode, sparkSession.sessionState.newHadoopConf) @@ -194,8 +196,8 @@ case class CreateDataSourceTableAsSelectCommand( } else { table.storage.locationUri } - val result = saveDataIntoTable( - sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) + val result = saveDataIntoTable(sparkSession, table, tableLocation, child, SaveMode.Overwrite, + tableExists = false, staticPartitions) val tableSchema = CharVarcharUtils.getRawSchema(result.schema, sessionState.conf) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), @@ -229,7 +231,8 @@ case class CreateDataSourceTableAsSelectCommand( tableLocation: Option[URI], physicalPlan: SparkPlan, mode: SaveMode, - tableExists: Boolean): BaseRelation = { + tableExists: Boolean, + staticPartitionSpec: Map[String, String]): BaseRelation = { // Create the relation based on the input logical plan: `query`. val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( @@ -241,7 +244,8 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics) + dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics, + staticPartitionSpec) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) @@ -251,4 +255,7 @@ case class CreateDataSourceTableAsSelectCommand( override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(query = newChild) + + override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand = + copy(staticPartitions = partitionSpec) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d50fd88f65c2e..62a137fe06485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -450,7 +450,10 @@ case class DataSource( * The returned command is unresolved and need to be analyzed. */ private def planForWritingFileFormat( - format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = { + format: FileFormat, + mode: SaveMode, + data: LogicalPlan, + staticPartitions: Map[String, String] = Map.empty): InsertIntoHadoopFsRelationCommand = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -477,7 +480,7 @@ case class DataSource( // will be adjusted within InsertIntoHadoopFsRelation. InsertIntoHadoopFsRelationCommand( outputPath = outputPath, - staticPartitions = Map.empty, + staticPartitions = staticPartitions, ifPartitionNotExists = false, partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), bucketSpec = bucketSpec, @@ -504,13 +507,15 @@ case class DataSource( * command with this physical plan instead of creating a new physical plan, * so that the metrics can be correctly linked to the given physical plan and * shown in the web UI. + * @param staticPartitions The static partitions for this writing. */ def writeAndRead( mode: SaveMode, data: LogicalPlan, outputColumnNames: Seq[String], physicalPlan: SparkPlan, - metrics: Map[String, SQLMetric]): BaseRelation = { + metrics: Map[String, SQLMetric], + staticPartitions: Map[String, String] = Map.empty): BaseRelation = { val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) providingInstance() match { case dataSource: CreatableRelationProvider => @@ -519,7 +524,7 @@ case class DataSource( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false) - val cmd = planForWritingFileFormat(format, mode, data) + val cmd = planForWritingFileFormat(format, mode, data, staticPartitions) val resolvedPartCols = DataSource.resolvePartitionColumns(cmd.partitionColumns, outputColumns, data, equality) val resolved = cmd.copy( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 0216503fba0f4..337a05c1cf45d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -179,7 +179,7 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] { // Let's say that we have a table "t", which is created by // CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c) // The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3" - // will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3". + // will be converted to "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 2, 3". // // Basically, we will put those partition columns having a assigned value back // to the SELECT clause. The output of the SELECT clause is organized as diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 41b55a3b6e936..86bce5c469bea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -76,7 +76,7 @@ case class InsertIntoHadoopFsRelationCommand( override def requiredOrdering: Seq[SortOrder] = V1WritesUtils.getSortOrder(outputColumns, partitionColumns, bucketSpec, options, - staticPartitions.size) + numStaticPartitionCols) override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that @@ -188,7 +188,7 @@ case class InsertIntoHadoopFsRelationCommand( bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options, - numStaticPartitionCols = staticPartitions.size) + numStaticPartitionCols = numStaticPartitionCols) // update metastore partition metadata @@ -278,4 +278,7 @@ case class InsertIntoHadoopFsRelationCommand( override protected def withNewChildInternal( newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild) + + override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand = + copy(staticPartitions = partitionSpec) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index d3cac32ae6632..dada4d3de913a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution.datasources +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder, String2StringExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Cast, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder, String2StringExpression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, RebalancePartitions, RepartitionOperation, Sort} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DataWritingCommand @@ -30,6 +34,16 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String trait V1WriteCommand extends DataWritingCommand { + /** + * Specify the static partitions of the V1 write command. + */ + def staticPartitions: Map[String, String] + + /** + * The number of static partition columns in `partitionColumns`. + * Note that, the static partition must be ahead of partition columns. + */ + final def numStaticPartitionCols: Int = staticPartitions.size /** * Specify the partition columns of the V1 write command. @@ -41,6 +55,11 @@ trait V1WriteCommand extends DataWritingCommand { * add SortExec if necessary when the requiredOrdering is empty. */ def requiredOrdering: Seq[SortOrder] + + /** + * Replace the static partition spec for the V1 write command. + */ + def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand } /** @@ -102,6 +121,83 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper { } } +/** + * This rule is used to eliminate dynamic partition to static partition for v1 writes if the + * partition columns is foldable, so that we can avoid unnecessary sort for dynamic partition. + * + * For example, a pure SQL: + * {{{ + * INSERT INTO TABLE t1 PARTITION(p) SELECT c, 'a' as p FROM t2 + * => + * INSERT INTO TABLE t1 PARTITION(p='a') SELECT c FROM t2 + * }}} + */ +object EliminateV1DynamicPartitionWrites extends Rule[LogicalPlan] { + + @tailrec + private def queryOutput(p: LogicalPlan): Seq[NamedExpression] = p match { + case p: Project => p.projectList + case f: Filter => queryOutput(f.child) + case r: RepartitionOperation => queryOutput(r.child) + case r: RebalancePartitions => queryOutput(r.child) + case s: Sort => queryOutput(s.child) + case l: LocalLimit => queryOutput(l.child) + case l: GlobalLimit => queryOutput(l.child) + case _ => Seq.empty + } + + private def getPartitionSpecString(part: Any): String = { + if (part == null) { + null + } else { + assert(part.isInstanceOf[UTF8String]) + ExternalCatalogUtils.getPartitionSpecString(part.asInstanceOf[UTF8String].toString) + } + } + + private def tryEvalStaticPartition(named: NamedExpression): Option[(String, String)] = { + named match { + case Alias(l: Literal, name) => + Some((name, getPartitionSpecString( + Cast(l, StringType, Option(conf.sessionLocalTimeZone)).eval()))) + case _ => None + } + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.eliminateDynamicPartitionWrites) { + return plan + } + + val resolver = SparkSession.active.sessionState.analyzer.resolver + plan.transformDown { + case v1Writes: V1WriteCommand => + val output = queryOutput(v1Writes.query) + + // We can not infer a static partition which after a dynamic partition column, + // for example: + // INSERT INTO TABLE t PARTITION BY(p1, p2) + // SELECT c, p1, 'a' as p2 + var previousStaticPartition = true + val newStaticPartitionSpec = new mutable.HashMap[String, String]() + val it = v1Writes.partitionColumns.drop(v1Writes.numStaticPartitionCols) + .map(attr => output.find(o => resolver(attr.name, o.name))).iterator + while (previousStaticPartition && it.hasNext) { + it.next().flatMap(part => tryEvalStaticPartition(part)) match { + case Some((name, partitionValue)) => newStaticPartitionSpec.put(name, partitionValue) + case None => previousStaticPartition = false + } + } + + if (newStaticPartitionSpec.nonEmpty) { + v1Writes.withNewStaticPartitionSpec(v1Writes.staticPartitions ++ newStaticPartitionSpec) + } else { + v1Writes + } + } + } +} + object V1WritesUtils { /** A function that converts the empty string to null for partition values. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index c7fa365abbdeb..6ff550672a9dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1539,7 +1539,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase { assert(table.location == makeQualifiedPath(dir.getAbsolutePath)) assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) spark.sql( - s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, 'b' FROM t WHERE 1 = 0") + s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, a FROM t WHERE 1 = 0") assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0) assert(!new File(dir, "c=c/b=b").exists()) checkAnswer(spark.table(tableName), Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index c18396b554d74..ad4d4708db5ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{CommandResultExec, QueryExecution} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.util.QueryExecutionListener @@ -97,6 +99,33 @@ abstract class V1WriteCommandSuiteBase extends QueryTest with SQLTestUtils { spark.listenerManager.unregister(listener) } + + private def getV1WriteCommand(df: DataFrame): V1WriteCommand = { + val plan = df.queryExecution.sparkPlan + .asInstanceOf[CommandResultExec].commandPhysicalPlan + val dataWritingCommandExec = plan match { + case aqe: AdaptiveSparkPlanExec => aqe.inputPlan + case _ => plan + } + val v1WriteCommand = dataWritingCommandExec.asInstanceOf[DataWritingCommandExec].cmd + assert(v1WriteCommand.isInstanceOf[V1WriteCommand]) + v1WriteCommand.asInstanceOf[V1WriteCommand] + } + + protected def checkStaticPartitions( + expectedStaticPartitions: Map[String, String], + hasLogicalSort: Boolean, + hasEmpty2Null: Boolean = false)(query: => DataFrame): Unit = { + executeAndCheckOrdering(hasLogicalSort, true, hasEmpty2Null) { + val df = query + val v1writes = getV1WriteCommand(df) + val actualStaticPartitions = v1writes.staticPartitions + assert(actualStaticPartitions.size == expectedStaticPartitions.size) + actualStaticPartitions.foreach { case (k, v) => + assert(expectedStaticPartitions.contains(k) && expectedStaticPartitions(k) == v) + } + } + } } class V1WriteCommandSuite extends V1WriteCommandSuiteBase with SharedSparkSession { @@ -277,4 +306,112 @@ class V1WriteCommandSuite extends V1WriteCommandSuiteBase with SharedSparkSessio } } } + + test("SPARK-40354: Support eliminate dynamic partition for v1 writes - insert") { + withTable("t") { + sql( + """ + |CREATE TABLE t(key INT, value STRING) USING PARQUET + |PARTITIONED BY (p1 INT, p2 STRING) + |""".stripMargin) + + // optimize all dynamic partition to static with special partition value + checkStaticPartitions(Map("p1" -> null, "p2" -> null), hasLogicalSort = false) { + sql( + """ + |INSERT INTO t PARTITION(p1, p2) + |SELECT key, value, cast(null as int) as p1, '' as p2 FROM testData + |""".stripMargin) + } + + Seq("WHERE key = 1", "DISTRIBUTE BY key", "ORDER BY key", "LIMIT 10").foreach { extra => + val hasLogicalSort = extra.contains("ORDER") + // optimize all dynamic partition to static + checkStaticPartitions(Map("p1" -> "1", "p2" -> "b"), hasLogicalSort = hasLogicalSort) { + sql( + s""" + |INSERT INTO t PARTITION(p1, p2) + |SELECT key, value, 1 as p1, 'b' as p2 FROM testData + |$extra + |""".stripMargin) + } + } + + // static partition ahead of dynamic + checkStaticPartitions(Map("p1" -> "1"), hasLogicalSort = true, hasEmpty2Null = true) { + sql( + """ + |INSERT INTO t PARTITION(p1, p2) + |SELECT key, value, 1 as p1, value as p2 FROM testData + |""".stripMargin) + } + + // dynamic partition ahead of static + checkStaticPartitions(Map.empty, hasLogicalSort = true) { + sql( + """ + |INSERT INTO t PARTITION(p1, p2) + |SELECT key, value, key as p1, 'b' as p2 FROM testData + |""".stripMargin) + } + + // all partition columns are dynamic + checkStaticPartitions(Map.empty, hasLogicalSort = true, hasEmpty2Null = true) { + sql( + """ + |INSERT INTO t PARTITION(p1, p2) + |SELECT key, value, key as p1, value p2 FROM testData + |""".stripMargin) + } + } + } + + test("SPARK-40354: Support eliminate dynamic partition for v1 writes - ctas") { + withTable("t1", "t2", "t3", "t4", "t5") { + // optimize all dynamic partition to static with special partition value + checkStaticPartitions(Map("p1" -> null, "p2" -> null), hasLogicalSort = false) { + sql( + """ + |CREATE TABLE t1 USING PARQUET PARTITIONED BY(p1, p2) AS + |SELECT key, value, cast(null as int) as p1, '' as p2 FROM testData + |""".stripMargin) + } + + // optimize all dynamic partition to static + checkStaticPartitions(Map("p1" -> "1", "p2" -> "b"), hasLogicalSort = false) { + sql( + """ + |CREATE TABLE t2 USING PARQUET PARTITIONED BY(p1, p2) AS + |SELECT key, value, 1 as p1, 'b' as p2 FROM testData + |""".stripMargin) + } + + // static partition ahead of dynamic + checkStaticPartitions(Map("p1" -> "1"), hasLogicalSort = true, hasEmpty2Null = true) { + sql( + """ + |CREATE TABLE t3 USING PARQUET PARTITIONED BY(p1, p2) AS + |SELECT key, value, 1 as p1, value as p2 FROM testData + |""".stripMargin) + } + + // dynamic partition ahead of static + checkStaticPartitions(Map.empty, hasLogicalSort = true) { + sql( + """ + |CREATE TABLE t4 USING PARQUET PARTITIONED BY(p1, p2) AS + |SELECT key, value, key as p1, 'b' as p2 FROM testData + |""".stripMargin) + } + + // all partition columns are dynamic + checkStaticPartitions(Map.empty, hasLogicalSort = true, hasEmpty2Null = true) { + sql( + """ + |CREATE TABLE t5 USING PARQUET PARTITIONED BY(p1, p2) AS + |SELECT key, value, key as p1, value as p2 FROM testData + |""".stripMargin) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index ce32077502770..7718fee5d9e23 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -36,6 +36,7 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with V1WritesHiveUtils val query: LogicalPlan val outputColumnNames: Seq[String] val mode: SaveMode + override def staticPartitions: Map[String, String] = Map.empty protected val tableIdentifier = tableDesc.identifier @@ -157,6 +158,8 @@ case class CreateHiveTableAsSelectCommand( override protected def withNewChildInternal( newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query = newChild) + + override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand = this } /** @@ -207,4 +210,6 @@ case class OptimizedCreateHiveTableAsSelectCommand( override protected def withNewChildInternal( newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand = copy(query = newChild) + + override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand = this } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8c3aa0a80c1b7..16fa82f7f83bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -85,6 +85,9 @@ case class InsertIntoHiveTable( V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options) } + def staticPartitions: Map[String, String] = + partition.filter(_._2.isDefined).map(kv => kv._1 -> kv._2.get) + /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -293,4 +296,6 @@ case class InsertIntoHiveTable( override protected def withNewChildInternal(newChild: LogicalPlan): InsertIntoHiveTable = copy(query = newChild) + + override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand = this }