Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ object ExternalCatalogUtils {
}
}

def getPartitionSpecString(value: String): String = {
if (value == null || value.isEmpty) {
null
} else {
value
}
}

def getPartitionValueString(value: String): String = {
if (value == null || value.isEmpty) {
DEFAULT_PARTITION_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ELIMINATE_DYNAMIC_PARTITION_WRITES =
buildConf("spark.sql.optimizer.eliminateDynamicPartitionWrites")
.internal()
.doc("When set to true, Spark optimizer will infer if the partition column is static and " +
"convert it to static partition.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val COMPRESS_CACHED = buildConf("spark.sql.inMemoryColumnarStorage.compressed")
.doc("When set to true Spark SQL will automatically select a compression codec for each " +
"column based on statistics of the data.")
Expand Down Expand Up @@ -4684,6 +4693,9 @@ class SQLConf extends Serializable with Logging {

def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED)

def eliminateDynamicPartitionWrites: Boolean =
getConf(SQLConf.ELIMINATE_DYNAMIC_PARTITION_WRITES)

def inferDictAsStruct: Boolean = getConf(SQLConf.INFER_NESTED_DICT_AS_STRUCT)

def legacyInferArrayTypeFromFirstElement: Boolean = getConf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.{EliminateV1DynamicPartitionWrites, PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
Expand All @@ -38,6 +38,7 @@ class SparkOptimizer(
// TODO: move SchemaPruning into catalyst
Seq(SchemaPruning) :+
GroupBasedRowLevelOperationScanPlanning :+
EliminateV1DynamicPartitionWrites :+
V1Writes :+
V2ScanRelationPushDown :+
V2ScanPartitioningAndOrdering :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ case class CreateDataSourceTableAsSelectCommand(
table: CatalogTable,
mode: SaveMode,
query: LogicalPlan,
outputColumnNames: Seq[String])
outputColumnNames: Seq[String],
staticPartitions: Map[String, String] = Map.empty)
extends V1WriteCommand {

override lazy val partitionColumns: Seq[Attribute] = {
Expand All @@ -156,7 +157,8 @@ case class CreateDataSourceTableAsSelectCommand(

override def requiredOrdering: Seq[SortOrder] = {
val options = table.storage.properties
V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options)
V1WritesUtils.getSortOrder(outputColumns, partitionColumns, table.bucketSpec, options,
numStaticPartitionCols)
}

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
Expand All @@ -181,8 +183,8 @@ case class CreateDataSourceTableAsSelectCommand(
return Seq.empty
}

saveDataIntoTable(
sparkSession, table, table.storage.locationUri, child, SaveMode.Append, tableExists = true)
saveDataIntoTable(sparkSession, table, table.storage.locationUri, child, SaveMode.Append,
tableExists = true, staticPartitions)
} else {
table.storage.locationUri.foreach { p =>
DataWritingCommand.assertEmptyRootPath(p, mode, sparkSession.sessionState.newHadoopConf)
Expand All @@ -194,8 +196,8 @@ case class CreateDataSourceTableAsSelectCommand(
} else {
table.storage.locationUri
}
val result = saveDataIntoTable(
sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false)
val result = saveDataIntoTable(sparkSession, table, tableLocation, child, SaveMode.Overwrite,
tableExists = false, staticPartitions)
val tableSchema = CharVarcharUtils.getRawSchema(result.schema, sessionState.conf)
val newTable = table.copy(
storage = table.storage.copy(locationUri = tableLocation),
Expand Down Expand Up @@ -229,7 +231,8 @@ case class CreateDataSourceTableAsSelectCommand(
tableLocation: Option[URI],
physicalPlan: SparkPlan,
mode: SaveMode,
tableExists: Boolean): BaseRelation = {
tableExists: Boolean,
staticPartitionSpec: Map[String, String]): BaseRelation = {
// Create the relation based on the input logical plan: `query`.
val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_))
val dataSource = DataSource(
Expand All @@ -241,7 +244,8 @@ case class CreateDataSourceTableAsSelectCommand(
catalogTable = if (tableExists) Some(table) else None)

try {
dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics)
dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan, metrics,
staticPartitionSpec)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)
Expand All @@ -251,4 +255,7 @@ case class CreateDataSourceTableAsSelectCommand(

override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(query = newChild)

override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand =
copy(staticPartitions = partitionSpec)
}
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,10 @@ case class DataSource(
* The returned command is unresolved and need to be analyzed.
*/
private def planForWritingFileFormat(
format: FileFormat, mode: SaveMode, data: LogicalPlan): InsertIntoHadoopFsRelationCommand = {
format: FileFormat,
mode: SaveMode,
data: LogicalPlan,
staticPartitions: Map[String, String] = Map.empty): InsertIntoHadoopFsRelationCommand = {
// Don't glob path for the write path. The contracts here are:
// 1. Only one output path can be specified on the write path;
// 2. Output path must be a legal HDFS style file system path;
Expand All @@ -477,7 +480,7 @@ case class DataSource(
// will be adjusted within InsertIntoHadoopFsRelation.
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
staticPartitions = staticPartitions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain more about the difference between write with dynamic partition columns and without? The code can be simplified quite a lot if we just need to remove sort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main benefit is we can save a local sort. yes, I agree, the current implementation is over complex. The reason is I guess some downstream projects or extensions may depend on the static partitions(e.g. add repartition for dynamic partition writes), so I merge the infered static partitions into the original. I'm fine to simplify the code to just remove a sort.

Copy link
Contributor

@cloud-fan cloud-fan Sep 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the downstream projects can detect real dynamic partition columns by themselves. I'd prefer a simple solution here to just remove local sort.

ifPartitionNotExists = false,
partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted),
bucketSpec = bucketSpec,
Expand All @@ -504,13 +507,15 @@ case class DataSource(
* command with this physical plan instead of creating a new physical plan,
* so that the metrics can be correctly linked to the given physical plan and
* shown in the web UI.
* @param staticPartitions The static partitions for this writing.
*/
def writeAndRead(
mode: SaveMode,
data: LogicalPlan,
outputColumnNames: Seq[String],
physicalPlan: SparkPlan,
metrics: Map[String, SQLMetric]): BaseRelation = {
metrics: Map[String, SQLMetric],
staticPartitions: Map[String, String] = Map.empty): BaseRelation = {
val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames)
providingInstance() match {
case dataSource: CreatableRelationProvider =>
Expand All @@ -519,7 +524,7 @@ case class DataSource(
sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data))
case format: FileFormat =>
disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false)
val cmd = planForWritingFileFormat(format, mode, data)
val cmd = planForWritingFileFormat(format, mode, data, staticPartitions)
val resolvedPartCols =
DataSource.resolvePartitionColumns(cmd.partitionColumns, outputColumns, data, equality)
val resolved = cmd.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ case class DataSourceAnalysis(analyzer: Analyzer) extends Rule[LogicalPlan] {
// Let's say that we have a table "t", which is created by
// CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)
// The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3"
// will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3".
// will be converted to "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 2, 3".
//
// Basically, we will put those partition columns having a assigned value back
// to the SELECT clause. The output of the SELECT clause is organized as
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ case class InsertIntoHadoopFsRelationCommand(

override def requiredOrdering: Seq[SortOrder] =
V1WritesUtils.getSortOrder(outputColumns, partitionColumns, bucketSpec, options,
staticPartitions.size)
numStaticPartitionCols)

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
Expand Down Expand Up @@ -188,7 +188,7 @@ case class InsertIntoHadoopFsRelationCommand(
bucketSpec = bucketSpec,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = options,
numStaticPartitionCols = staticPartitions.size)
numStaticPartitionCols = numStaticPartitionCols)


// update metastore partition metadata
Expand Down Expand Up @@ -278,4 +278,7 @@ case class InsertIntoHadoopFsRelationCommand(

override protected def withNewChildInternal(
newChild: LogicalPlan): InsertIntoHadoopFsRelationCommand = copy(query = newChild)

override def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand =
copy(staticPartitions = partitionSpec)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.sql.execution.datasources

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder, String2StringExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Cast, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder, String2StringExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, RebalancePartitions, RepartitionOperation, Sort}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DataWritingCommand
Expand All @@ -30,6 +34,16 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String

trait V1WriteCommand extends DataWritingCommand {
/**
* Specify the static partitions of the V1 write command.
*/
def staticPartitions: Map[String, String]

/**
* The number of static partition columns in `partitionColumns`.
* Note that, the static partition must be ahead of partition columns.
*/
final def numStaticPartitionCols: Int = staticPartitions.size

/**
* Specify the partition columns of the V1 write command.
Expand All @@ -41,6 +55,11 @@ trait V1WriteCommand extends DataWritingCommand {
* add SortExec if necessary when the requiredOrdering is empty.
*/
def requiredOrdering: Seq[SortOrder]

/**
* Replace the static partition spec for the V1 write command.
*/
def withNewStaticPartitionSpec(partitionSpec: Map[String, String]): V1WriteCommand
}

/**
Expand Down Expand Up @@ -102,6 +121,83 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
}
}

/**
* This rule is used to eliminate dynamic partition to static partition for v1 writes if the
* partition columns is foldable, so that we can avoid unnecessary sort for dynamic partition.
*
* For example, a pure SQL:
* {{{
* INSERT INTO TABLE t1 PARTITION(p) SELECT c, 'a' as p FROM t2
* =>
* INSERT INTO TABLE t1 PARTITION(p='a') SELECT c FROM t2
* }}}
*/
object EliminateV1DynamicPartitionWrites extends Rule[LogicalPlan] {

@tailrec
private def queryOutput(p: LogicalPlan): Seq[NamedExpression] = p match {
case p: Project => p.projectList
case f: Filter => queryOutput(f.child)
case r: RepartitionOperation => queryOutput(r.child)
case r: RebalancePartitions => queryOutput(r.child)
case s: Sort => queryOutput(s.child)
case l: LocalLimit => queryOutput(l.child)
case l: GlobalLimit => queryOutput(l.child)
case _ => Seq.empty
}

private def getPartitionSpecString(part: Any): String = {
if (part == null) {
null
} else {
assert(part.isInstanceOf[UTF8String])
ExternalCatalogUtils.getPartitionSpecString(part.asInstanceOf[UTF8String].toString)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

null and empty partition are treated as null, and the final partition path will be __HIVE_DEFAULT_PARTITION__

}
}

private def tryEvalStaticPartition(named: NamedExpression): Option[(String, String)] = {
named match {
case Alias(l: Literal, name) =>
Some((name, getPartitionSpecString(
Cast(l, StringType, Option(conf.sessionLocalTimeZone)).eval())))
case _ => None
}
}

override def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.eliminateDynamicPartitionWrites) {
return plan
}

val resolver = SparkSession.active.sessionState.analyzer.resolver
plan.transformDown {
case v1Writes: V1WriteCommand =>
val output = queryOutput(v1Writes.query)

// We can not infer a static partition which after a dynamic partition column,
// for example:
// INSERT INTO TABLE t PARTITION BY(p1, p2)
// SELECT c, p1, 'a' as p2
var previousStaticPartition = true
val newStaticPartitionSpec = new mutable.HashMap[String, String]()
val it = v1Writes.partitionColumns.drop(v1Writes.numStaticPartitionCols)
.map(attr => output.find(o => resolver(attr.name, o.name))).iterator
while (previousStaticPartition && it.hasNext) {
it.next().flatMap(part => tryEvalStaticPartition(part)) match {
case Some((name, partitionValue)) => newStaticPartitionSpec.put(name, partitionValue)
case None => previousStaticPartition = false
}
}

if (newStaticPartitionSpec.nonEmpty) {
v1Writes.withNewStaticPartitionSpec(v1Writes.staticPartitions ++ newStaticPartitionSpec)
} else {
v1Writes
}
}
}
}

object V1WritesUtils {

/** A function that converts the empty string to null for partition values. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,7 @@ abstract class DDLSuite extends QueryTest with DDLSuiteBase {
assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0)
spark.sql(
s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, 'b' FROM t WHERE 1 = 0")
s"INSERT INTO TABLE $tableName PARTITION(c='c', b) SELECT *, a FROM t WHERE 1 = 0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is used to validate partial static partition, so use a column instead of literal to recover it.

assert(spark.sql(s"SHOW PARTITIONS $tableName").count() == 0)
assert(!new File(dir, "c=c/b=b").exists())
checkAnswer(spark.table(tableName), Nil)
Expand Down
Loading