diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml index 5b227cb5e29..8c984e4cab4 100644 --- a/externals/kyuubi-spark-sql-engine/pom.xml +++ b/externals/kyuubi-spark-sql-engine/pom.xml @@ -65,6 +65,13 @@ provided + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + test + + org.apache.spark spark-repl_${scala.binary.version} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index b29d2ca9a7e..ca30f53001f 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -21,10 +21,8 @@ import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ -import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.kyuubi.SparkDatasetHelper +import org.apache.spark.sql.kyuubi.SparkDatasetHelper._ import org.apache.spark.sql.types._ import org.apache.kyuubi.{KyuubiSQLException, Logging} @@ -187,34 +185,15 @@ class ArrowBasedExecuteStatement( handle) { override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = { - collectAsArrow(convertComplexType(resultDF)) { rdd => - rdd.toLocalIterator - } + toArrowBatchLocalIterator(convertComplexType(resultDF)) } override protected def fullCollectResult(resultDF: DataFrame): Array[_] = { - collectAsArrow(convertComplexType(resultDF)) { rdd => - rdd.collect() - } + executeCollect(convertComplexType(resultDF)) } override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = { - // this will introduce shuffle and hurt performance - val limitedResult = resultDF.limit(maxRows) - collectAsArrow(convertComplexType(limitedResult)) { rdd => - rdd.collect() - } - } - - /** - * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based - * operation, so that we can track the arrow-based queries on the UI tab. - */ - private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = { - SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { - df.queryExecution.executedPlan.resetMetrics() - action(SparkDatasetHelper.toArrowBatchRdd(df)) - } + executeCollect(convertComplexType(resultDF.limit(maxRows))) } override protected def isArrowBasedOperation: Boolean = true @@ -222,7 +201,6 @@ class ArrowBasedExecuteStatement( override val resultFormat = "arrow" private def convertComplexType(df: DataFrame): DataFrame = { - SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString) + convertTopLevelComplexTypeToHiveString(df, timestampAsString) } - } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala new file mode 100644 index 00000000000..dd6163ec97c --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.CollectLimitExec +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils + +object KyuubiArrowConverters extends SQLConfHelper with Logging { + + type Batch = (Array[Byte], Long) + + /** + * this method is to slice the input Arrow record batch byte array `bytes`, starting from `start` + * and taking `length` number of elements. + */ + def slice( + schema: StructType, + timeZoneId: String, + bytes: Array[Byte], + start: Int, + length: Int): Array[Byte] = { + val in = new ByteArrayInputStream(bytes) + val out = new ByteArrayOutputStream(bytes.length) + + var vectorSchemaRoot: VectorSchemaRoot = null + var slicedVectorSchemaRoot: VectorSchemaRoot = null + + val sliceAllocator = ArrowUtils.rootAllocator.newChildAllocator( + "slice", + 0, + Long.MaxValue) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + vectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, sliceAllocator) + try { + val recordBatch = MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), + sliceAllocator) + val vectorLoader = new VectorLoader(vectorSchemaRoot) + vectorLoader.load(recordBatch) + recordBatch.close() + slicedVectorSchemaRoot = vectorSchemaRoot.slice(start, length) + + val unloader = new VectorUnloader(slicedVectorSchemaRoot) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() + out.toByteArray() + } finally { + in.close() + out.close() + if (vectorSchemaRoot != null) { + vectorSchemaRoot.getFieldVectors.asScala.foreach(_.close()) + vectorSchemaRoot.close() + } + if (slicedVectorSchemaRoot != null) { + slicedVectorSchemaRoot.getFieldVectors.asScala.foreach(_.close()) + slicedVectorSchemaRoot.close() + } + sliceAllocator.close() + } + } + + /** + * Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be + * summarized in the following steps: + * 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty + * array of batches. + * 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of + * data to collect. + * 3. Use an iterative approach to collect data in batches until the specified limit is reached. + * In each iteration, it selects a subset of the partitions of the RDD to scan and tries to + * collect data from them. + * 4. For each partition subset, we use the runJob method of the Spark context to execute a + * closure that scans the partition data and converts it to Arrow batches. + * 5. Check if the collected data reaches the specified limit. If not, it selects another subset + * of partitions to scan and repeats the process until the limit is reached or all partitions + * have been scanned. + * 6. Return an array of all the collected Arrow batches. + * + * Note that: + * 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit` + * row count + * 2. We don't implement the `takeFromEnd` logical + * + * @return + */ + def takeAsArrowBatches( + collectLimitExec: CollectLimitExec, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + timeZoneId: String): Array[Batch] = { + val n = collectLimitExec.limit + val schema = collectLimitExec.schema + if (n == 0) { + return new Array[Batch](0) + } else { + val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // TODO: refactor and reuse the code from RDD's take() + val childRDD = collectLimitExec.child.execute() + val buf = new ArrayBuffer[Batch] + var bufferedRowSize = 0L + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (bufferedRowSize < n && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = limitInitialNumPartitions + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. + if (buf.isEmpty) { + numPartsToTry = partsScanned * limitScaleUpFactor + } else { + val left = n - bufferedRowSize + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt + numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) + } + } + + val partsToScan = + partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) + + // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we + // drop Spark-3.1.x support. + val sc = SparkSession.active.sparkContext + val res = sc.runJob( + childRDD, + (it: Iterator[InternalRow]) => { + val batches = toBatchIterator( + it, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + n, + timeZoneId) + batches.map(b => b -> batches.rowCountInLastBatch).toArray + }, + partsToScan) + + var i = 0 + while (bufferedRowSize < n && i < res.length) { + var j = 0 + val batches = res(i) + while (j < batches.length && n > bufferedRowSize) { + val batch = batches(j) + val (_, batchSize) = batch + buf += batch + bufferedRowSize += batchSize + j += 1 + } + i += 1 + } + partsScanned += partsToScan.size + } + + buf.toArray + } + } + + /** + * Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211 + */ + private def limitInitialNumPartitions: Int = { + conf.getConfString("spark.sql.limit.initialNumPartitions", "1") + .toInt + } + + /** + * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]], + * each output arrow batch contains this batch row count. + */ + private def toBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String): ArrowBatchIterator = { + new ArrowBatchIterator( + rowIter, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + limit, + timeZoneId, + TaskContext.get) + } + + /** + * This class ArrowBatchIterator is derived from + * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]], + * with two key differences: + * 1. there is no requirement to write the schema at the batch header + * 2. iteration halts when `rowCount` equals `limit` + */ + private[sql] class ArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String, + context: TaskContext) + extends Iterator[Array[Byte]] { + + protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected val unloader = new VectorUnloader(root) + protected val arrowWriter = ArrowWriter.create(root) + + Option(context).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } + + override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { + root.close() + allocator.close() + false + } + + var rowCountInLastBatch: Long = 0 + var rowCount: Long = 0 + + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + rowCountInLastBatch = 0 + var estimatedBatchSize = 0L + Utils.tryWithSafeFinally { + + // Always write the first row. + while (rowIter.hasNext && ( + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch || + rowCount < limit)) { + val row = rowIter.next() + arrowWriter.write(row) + estimatedBatchSize += (row match { + case ur: UnsafeRow => ur.getSizeInBytes + // Trying to estimate the size of the current row + case _: InternalRow => schema.defaultSize + }) + rowCountInLastBatch += 1 + rowCount += 1 + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + out.toByteArray + } + } + + // for testing + def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { + ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context) + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 1a542937338..1c8d32c4850 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -17,18 +17,75 @@ package org.apache.spark.sql.kyuubi +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.kyuubi.engine.spark.KyuubiSparkUtil import org.apache.kyuubi.engine.spark.schema.RowSet +import org.apache.kyuubi.reflection.DynMethods + +object SparkDatasetHelper extends Logging { + + def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { + executeArrowBatchCollect(df.queryExecution.executedPlan) + } + + def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = { + case adaptiveSparkPlan: AdaptiveSparkPlanExec => + executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan)) + // TODO: avoid extra shuffle if `offset` > 0 + case collectLimit: CollectLimitExec if offset(collectLimit) > 0 => + logWarning("unsupported offset > 0, an extra shuffle will be introduced.") + toArrowBatchRdd(collectLimit).collect() + case collectLimit: CollectLimitExec if collectLimit.limit >= 0 => + doCollectLimit(collectLimit) + case collectLimit: CollectLimitExec if collectLimit.limit < 0 => + executeArrowBatchCollect(collectLimit.child) + case plan: SparkPlan => + toArrowBatchRdd(plan).collect() + } -object SparkDatasetHelper { def toArrowBatchRdd[T](ds: Dataset[T]): RDD[Array[Byte]] = { ds.toArrowBatchRdd } + /** + * Forked from [[Dataset.toArrowBatchRdd(plan: SparkPlan)]]. + * Convert to an RDD of serialized ArrowRecordBatches. + */ + def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + val schemaCaptured = plan.schema + // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we + // drop Spark-3.1.x support. + val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone + plan.execute().mapPartitionsInternal { iter => + val context = TaskContext.get() + ArrowConverters.toBatchIterator( + iter, + schemaCaptured, + maxRecordsPerBatch, + timeZoneId, + context) + } + } + + def toArrowBatchLocalIterator(df: DataFrame): Iterator[Array[Byte]] = { + withNewExecutionId(df) { + toArrowBatchRdd(df).toLocalIterator + } + } + def convertTopLevelComplexTypeToHiveString( df: DataFrame, timestampAsString: Boolean): DataFrame = { @@ -68,11 +125,108 @@ object SparkDatasetHelper { * Fork from Apache Spark-3.3.1 org.apache.spark.sql.catalyst.util.quoteIfNeeded to adapt to * Spark-3.1.x */ - def quoteIfNeeded(part: String): String = { + private def quoteIfNeeded(part: String): String = { if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) { part } else { s"`${part.replace("`", "``")}`" } } + + private lazy val maxBatchSize: Long = { + // respect spark connect config + KyuubiSparkUtil.globalSparkContext + .getConf + .getOption("spark.connect.grpc.arrow.maxBatchSize") + .orElse(Option("4m")) + .map(JavaUtils.byteStringAs(_, ByteUnit.MiB)) + .get + } + + private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = { + // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we + // drop Spark-3.1.x support. + val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone + val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch + + val batches = KyuubiArrowConverters.takeAsArrowBatches( + collectLimit, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + + // note that the number of rows in the returned arrow batches may be >= `limit`, perform + // the slicing operation of result + val result = ArrayBuffer[Array[Byte]]() + var i = 0 + var rest = collectLimit.limit + while (i < batches.length && rest > 0) { + val (batch, size) = batches(i) + if (size <= rest) { + result += batch + // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion + rest -= size.toInt + } else { // size > rest + result += KyuubiArrowConverters.slice(collectLimit.schema, timeZoneId, batch, 0, rest) + rest = 0 + } + i += 1 + } + result.toArray + } + + /** + * This method provides a reflection-based implementation of + * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime + * without patching SPARK-41914. + * + * TODO: Once we drop support for Spark 3.1.x, we can directly call + * [[AdaptiveSparkPlanExec.finalPhysicalPlan]]. + */ + def finalPhysicalPlan(adaptiveSparkPlanExec: AdaptiveSparkPlanExec): SparkPlan = { + withFinalPlanUpdate(adaptiveSparkPlanExec, identity) + } + + /** + * A reflection-based implementation of [[AdaptiveSparkPlanExec.withFinalPlanUpdate]]. + */ + private def withFinalPlanUpdate[T]( + adaptiveSparkPlanExec: AdaptiveSparkPlanExec, + fun: SparkPlan => T): T = { + val getFinalPhysicalPlan = DynMethods.builder("getFinalPhysicalPlan") + .hiddenImpl(adaptiveSparkPlanExec.getClass) + .build() + val plan = getFinalPhysicalPlan.invoke[SparkPlan](adaptiveSparkPlanExec) + val result = fun(plan) + val finalPlanUpdate = DynMethods.builder("finalPlanUpdate") + .hiddenImpl(adaptiveSparkPlanExec.getClass) + .build() + finalPlanUpdate.invoke[Unit](adaptiveSparkPlanExec) + result + } + + /** + * offset support was add since Spark-3.4(set SPARK-28330), to ensure backward compatibility with + * earlier versions of Spark, this function uses reflective calls to the "offset". + */ + private def offset(collectLimitExec: CollectLimitExec): Int = { + Option( + DynMethods.builder("offset") + .impl(collectLimitExec.getClass) + .orNoop() + .build() + .invoke[Int](collectLimitExec)) + .getOrElse(0) + } + + /** + * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based + * operation, so that we can track the arrow-based queries on the UI tab. + */ + private def withNewExecutionId[T](df: DataFrame)(body: => T): T = { + SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { + df.queryExecution.executedPlan.resetMetrics() + body + } + } } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index ae6237bb59c..2ef29b398a3 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -18,16 +18,28 @@ package org.apache.kyuubi.engine.spark.operation import java.sql.Statement +import java.util.{Set => JSet} import org.apache.spark.KyuubiSparkContextHelper +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.kyuubi.KyuubiException import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine} import org.apache.kyuubi.engine.spark.session.SparkSessionImpl import org.apache.kyuubi.operation.SparkDataTypeTests +import org.apache.kyuubi.reflection.DynFields class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests { @@ -138,6 +150,155 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp assert(metrics("numOutputRows").value === 1) } + test("SparkDatasetHelper.executeArrowBatchCollect should return expect row count") { + val returnSize = Seq( + 0, // spark optimizer guaranty the `limit != 0`, it's just for the sanity check + 7, // less than one partition + 10, // equal to one partition + 13, // between one and two partitions, run two jobs + 20, // equal to two partitions + 29, // between two and three partitions + 1000, // all partitions + 1001) // more than total row count + + def runAndCheck(sparkPlan: SparkPlan, expectSize: Int): Unit = { + val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(sparkPlan) + val rows = KyuubiArrowConverters.fromBatchIterator( + arrowBinary.iterator, + sparkPlan.schema, + "", + KyuubiSparkContextHelper.dummyTaskContext()) + assert(rows.size == expectSize) + } + + val excludedRules = Seq( + "org.apache.spark.sql.catalyst.optimizer.EliminateLimits", + "org.apache.spark.sql.catalyst.optimizer.OptimizeLimitZero", + "org.apache.spark.sql.execution.adaptive.AQEPropagateEmptyRelation").mkString(",") + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules, + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) { + // aqe + // outermost AdaptiveSparkPlanExec + spark.range(1000) + .repartitionByRange(100, col("id")) + .createOrReplaceTempView("t_1") + spark.sql("select * from t_1") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () + } + returnSize.foreach { size => + val df = spark.sql(s"select * from t_1 limit $size") + val headPlan = df.queryExecution.executedPlan.collectLeaves().head + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") { + assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) + val finalPhysicalPlan = + SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) + assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) + } + if (size > 1000) { + runAndCheck(df.queryExecution.executedPlan, 1000) + } else { + runAndCheck(df.queryExecution.executedPlan, size) + } + } + + // outermost CollectLimitExec + spark.range(0, 1000, 1, numPartitions = 100) + .createOrReplaceTempView("t_2") + spark.sql("select * from t_2") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () + } + returnSize.foreach { size => + val df = spark.sql(s"select * from t_2 limit $size") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[CollectLimitExec]) + if (size > 1000) { + runAndCheck(df.queryExecution.executedPlan, 1000) + } else { + runAndCheck(df.queryExecution.executedPlan, size) + } + } + } + } + + test("aqe should work properly") { + + val s = spark + import s.implicits._ + + spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + .createOrReplaceTempView("testData") + spark.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, + 2).toDF() + .createOrReplaceTempView("testData2") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM( + | SELECT * FROM testData join testData2 ON key = a where value = '1' + |) LIMIT 1 + |""".stripMargin) + val smj = plan.collect { case smj: SortMergeJoinExec => smj } + val bhj = adaptivePlan.collect { case bhj: BroadcastHashJoinExec => bhj } + assert(smj.size == 1) + assert(bhj.size == 1) + } + } + + test("result offset support") { + assume(SPARK_ENGINE_RUNTIME_VERSION > "3.3") + var numStages = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numStages = jobStart.stageInfos.length + } + } + withJdbcStatement() { statement => + withSparkListener(listener) { + withPartitionedTable("t_3") { + statement.executeQuery("select * from t_3 limit 10 offset 10") + } + KyuubiSparkContextHelper.waitListenerBus(spark) + } + } + // the extra shuffle be introduced if the `offset` > 0 + assert(numStages == 2) + } + + test("arrow serialization should not introduce extra shuffle for outermost limit") { + var numStages = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numStages = jobStart.stageInfos.length + } + } + withJdbcStatement() { statement => + withSparkListener(listener) { + withPartitionedTable("t_3") { + statement.executeQuery("select * from t_3 limit 1000") + } + KyuubiSparkContextHelper.waitListenerBus(spark) + } + } + // Should be only one stage since there is no shuffle. + assert(numStages == 1) + } + private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = { val query = s""" @@ -177,4 +338,101 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp .allSessions() .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener)) } + + private def withSparkListener[T](listener: SparkListener)(body: => T): T = { + withAllSessions(s => s.sparkContext.addSparkListener(listener)) + try { + body + } finally { + withAllSessions(s => s.sparkContext.removeSparkListener(listener)) + } + + } + + private def withPartitionedTable[T](viewName: String)(body: => T): T = { + withAllSessions { spark => + spark.range(0, 1000, 1, numPartitions = 100) + .createOrReplaceTempView(viewName) + } + try { + body + } finally { + withAllSessions { spark => + spark.sql(s"DROP VIEW IF EXISTS $viewName") + } + } + } + + private def withAllSessions(op: SparkSession => Unit): Unit = { + SparkSQLEngine.currentEngine.get + .backendService + .sessionManager + .allSessions() + .map(_.asInstanceOf[SparkSessionImpl].spark) + .foreach(op(_)) + } + + private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + val dfAdaptive = spark.sql(query) + val planBefore = dfAdaptive.queryExecution.executedPlan + val result = dfAdaptive.collect() + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.sql(query) + QueryTest.checkAnswer(df, df.collect().toSeq) + } + val planAfter = dfAdaptive.queryExecution.executedPlan + val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val exchanges = adaptivePlan.collect { + case e: Exchange => e + } + assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") + (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (isStaticConfigKey(k)) { + throw new KyuubiException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * This method provides a reflection-based implementation of [[SQLConf.isStaticConfigKey]] to + * adapt Spark-3.1.x + * + * TODO: Once we drop support for Spark 3.1.x, we can directly call + * [[SQLConf.isStaticConfigKey()]]. + */ + private def isStaticConfigKey(key: String): Boolean = { + val staticConfKeys = DynFields.builder() + .hiddenImpl(SQLConf.getClass, "staticConfKeys") + .build[JSet[String]](SQLConf) + .get() + staticConfKeys.contains(key) + } } + +case class TestData(key: Int, value: String) +case class TestData2(a: Int, b: Int) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala index 8293123ead7..1b662eadf96 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala @@ -27,4 +27,6 @@ object KyuubiSparkContextHelper { def waitListenerBus(spark: SparkSession): Unit = { spark.sparkContext.listenerBus.waitUntilEmpty() } + + def dummyTaskContext(): TaskContextImpl = TaskContext.empty() } diff --git a/pom.xml b/pom.xml index b2b0341e2e9..e77e6d55d24 100644 --- a/pom.xml +++ b/pom.xml @@ -538,8 +538,8 @@ hadoop-client