Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType

/**
* Performs (external) sorting.
Expand Down Expand Up @@ -71,36 +72,8 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
val ordering = RowOrdering.create(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)

// The generator for prefix
val prefixExpr = SortPrefix(boundSortExpression)
val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
val prefix = prefixProjection.apply(row)
result.isNull = prefix.isNullAt(0)
result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0)
result
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)

if (testSpillFrequency > 0) {
rowSorter.setTestSpillFrequency(testSpillFrequency)
}
rowSorter = SortExec.createSorter(
sortOrder, output, schema, enableRadixSort, testSpillFrequency)
rowSorter
}

Expand Down Expand Up @@ -206,3 +179,43 @@ case class SortExec(
override protected def withNewChildInternal(newChild: SparkPlan): SortExec =
copy(child = newChild)
}
object SortExec {
def createSorter(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is because maxConcurrentOutputFileWriters need to create sorter at FileFormatWriter

sortOrder: Seq[SortOrder],
output: Seq[Attribute],
schema: StructType,
enableRadixSort: Boolean,
testSpillFrequency: Int = 0): UnsafeExternalRowSorter = {
val ordering = RowOrdering.create(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)

// The generator for prefix
val prefixExpr = SortPrefix(boundSortExpression)
val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
val prefix = prefixProjection.apply(row)
result.isNull = prefix.isNullAt(0)
result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0)
result
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)

if (testSpillFrequency > 0) {
rowSorter.setTestSpillFrequency(testSpillFrequency)
}
rowSorter
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{V2ScanPartitioning, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
Expand All @@ -38,6 +37,7 @@ class SparkOptimizer(
override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
Seq(SchemaPruning) :+
V1Writes :+
V2ScanRelationPushDown :+
V2ScanPartitioning :+
V2Writes :+
Expand Down Expand Up @@ -78,6 +78,7 @@ class SparkOptimizer(
ExtractPythonUDFFromJoinCondition.ruleName :+
ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+
ExtractPythonUDFs.ruleName :+
V1Writes.ruleName :+
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioning.ruleName :+
V2Writes.ruleName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.net.URI

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -141,7 +142,18 @@ case class CreateDataSourceTableAsSelectCommand(
mode: SaveMode,
query: LogicalPlan,
outputColumnNames: Seq[String])
extends DataWritingCommand {
extends V1Write {

override lazy val partitionColumns: Seq[Attribute] = {
table.partitionColumnNames.map { name =>
query.resolve(name :: Nil, SparkSession.active.sessionState.analyzer.resolver).getOrElse {
throw QueryCompilationErrors.cannotResolveAttributeError(
name, query.output.map(_.name).mkString(", "))
}.asInstanceOf[Attribute]
}
}
override lazy val bucketSpec: Option[BucketSpec] = table.bucketSpec
override lazy val options: Map[String, String] = table.storage.properties

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
Expand All @@ -47,7 +45,7 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}


/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
object FileFormatWriter extends Logging with V1WritesHelper {
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
outputPath: String,
Expand Down Expand Up @@ -78,6 +76,7 @@ object FileFormatWriter extends Logging {
maxWriters: Int,
createSorter: () => UnsafeExternalRowSorter)

// scalastyle:off argcount
/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
Expand All @@ -100,6 +99,7 @@ object FileFormatWriter extends Logging {
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
staticPartitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String])
Expand All @@ -126,39 +126,7 @@ object FileFormatWriter extends Logging {
}
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan

val writerBucketSpec = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)

if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
"true") {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets)
.partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

val writerBucketSpec = getWriterBucketSpec(bucketSpec, dataColumns, options)
val caseInsensitiveOptions = CaseInsensitiveMap(options)

val dataSchema = dataColumns.toStructType
Expand All @@ -184,20 +152,6 @@ object FileFormatWriter extends Logging {
statsTrackers = statsTrackers
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering =
partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
Copy link
Contributor Author

@ulysses-you ulysses-you Mar 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a issue here, since we have AQE. The plan is the AdaptiveSparkPlanExec who has no outputOrdering. For dynamic partition write, the code will always add an extra sort.

This pr can resolve this issue together. @cloud-fan @c21

val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}

SQLExecution.checkSQLExecutionId(sparkSession)

// propagate the description UUID into the jobs, so that committers
Expand All @@ -208,29 +162,26 @@ object FileFormatWriter extends Logging {
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

val sortColumns = getBucketSortColumns(bucketSpec, dataColumns)
try {
val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
(empty2NullPlan.execute(), None)
val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
val concurrentOutputWriterSpec = if (concurrentWritersEnabled) {
val output = empty2NullPlan.output
val enableRadixSort = sparkSession.sessionState.conf.enableRadixSort
val outputSchema = empty2NullPlan.schema
Some(ConcurrentOutputWriterSpec(maxWriters,
() => SortExec.createSorter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this refactoring (SortExec.createSorter) is not very necessary. Why can't we create a SortExec operator and call createSorter() as before? What's the advantage of current code compared to before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the previous code, we create and eval a SortExec is mainly for the ordering of dynamic partition. For the concurrent writers, we only need the sorter. After we pull out the sort, create a new SortExec seems overkill.

getSortOrder(output, partitionColumns, staticPartitionColumns.size,
bucketSpec, options),
output,
outputSchema,
enableRadixSort
)))
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns)
val sortPlan = SortExec(
orderingExpr,
global = false,
child = empty2NullPlan)

val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
if (concurrentWritersEnabled) {
(empty2NullPlan.execute(),
Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())))
} else {
(sortPlan.execute(), None)
}
None
}
val rdd = empty2NullPlan.execute()

// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
// partition rdd to make sure we at least set up one write task to write the metadata.
Expand Down Expand Up @@ -278,6 +229,7 @@ object FileFormatWriter extends Logging {
throw QueryExecutionErrors.jobAbortedError(cause)
}
}
// scalastyle:on argcount

/** Writes data out in a single Spark task. */
private def executeTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
ifPartitionNotExists: Boolean,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
override val partitionColumns: Seq[Attribute],
override val bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
override val options: Map[String, String],
query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumnNames: Seq[String])
extends DataWritingCommand {
extends V1Write {

private lazy val parameters = CaseInsensitiveMap(options)

Expand All @@ -74,6 +74,8 @@ case class InsertIntoHadoopFsRelationCommand(
staticPartitions.size < partitionColumns.length
}

override lazy val numStaticPartitions: Int = staticPartitions.size

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
// Most formats don't do well with duplicate columns, so lets not allow that
SchemaUtils.checkColumnNameDuplication(
Expand Down Expand Up @@ -181,6 +183,7 @@ case class InsertIntoHadoopFsRelationCommand(
committerOutputPath.toString, customPartitionLocations, outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
staticPartitionColumns = partitionColumns.take(staticPartitions.size),
bucketSpec = bucketSpec,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = options)
Expand Down
Loading