From a5849430a315bd8d13738dc6af7fdd7972aec3ca Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 23 Mar 2023 18:40:17 +0800 Subject: [PATCH 01/28] arrow take --- .../spark/operation/ExecuteStatement.scala | 24 ++- .../arrow/ArrowCollectLimitExec.scala | 127 ++++++++++++++++ .../arrow/ArrowConvertersHelper.scala | 143 ++++++++++++++++++ kyuubi-server/pom.xml | 4 + 4 files changed, 294 insertions(+), 4 deletions(-) create mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala create mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index b29d2ca9a7e..1108f46f7fd 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -23,11 +23,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution} import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.types._ - import org.apache.kyuubi.{KyuubiSQLException, Logging} +import org.apache.spark.sql.execution.arrow.ArrowCollectLimitExec + import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._ import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationHandle, OperationState} @@ -201,8 +202,23 @@ class ArrowBasedExecuteStatement( override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = { // this will introduce shuffle and hurt performance val limitedResult = resultDF.limit(maxRows) - collectAsArrow(convertComplexType(limitedResult)) { rdd => - rdd.collect() +// collectAsArrow(convertComplexType(limitedResult)) { rdd => +// rdd.collect() +// } + val df = convertComplexType(limitedResult) + SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { + df.queryExecution.executedPlan.resetMetrics() + df.queryExecution.executedPlan match { + case collectLimit @ CollectLimitExec(limit, _) => + // scalastyle:off + println("ddddd") + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId) + .map(_._1) + case _ => + println("yyyy") + SparkDatasetHelper.toArrowBatchRdd(df).collect() + } } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala new file mode 100644 index 00000000000..c250e520ca4 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.execution.CollectLimitExec +import org.apache.spark.sql.types.StructType + +object ArrowCollectLimitExec extends SQLConfHelper { + + type Batch = (Array[Byte], Long) + + def takeAsArrowBatches( + collectLimitExec: CollectLimitExec, + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + timeZoneId: String): Array[Batch] = { + val n = collectLimitExec.limit + // TODO + val takeFromEnd = false + if (n == 0) { + return new Array[Batch](0) + } else { + val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // // TODO: refactor and reuse the code from RDD's take() + val childRDD = collectLimitExec.child.execute() + val buf = if (takeFromEnd) new ListBuffer[Batch] else new ArrayBuffer[Batch] + var bufferedRowSize = 0L + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (bufferedRowSize < n && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. +// var numPartsToTry = conf.limitInitialNumPartitions + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. + if (buf.isEmpty) { + numPartsToTry = partsScanned * limitScaleUpFactor + } else { + val left = n - bufferedRowSize + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt + numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) + } + } + + val parts = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) + val partsToScan = if (takeFromEnd) { + // Reverse partitions to scan. So, if parts was [1, 2, 3] in 200 partitions (0 to 199), + // it becomes [198, 197, 196]. + parts.map(p => (totalParts - 1) - p) + } else { + parts + } + + val sc = collectLimitExec.session.sparkContext + val res = sc.runJob( + childRDD, + (it: Iterator[InternalRow]) => { + val batches = ArrowConvertersHelper.toBatchWithSchemaIterator( + it, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + collectLimitExec.limit, + timeZoneId) + batches.map(b => b -> batches.rowCountInLastBatch).toArray + }, + partsToScan) + + var i = 0 + if (takeFromEnd) { +// while (buf.length < n && i < res.length) { +// val rows = decodeUnsafeRows(res(i)._2) +// if (n - buf.length >= res(i)._1) { +// buf.prepend(rows.toArray[InternalRow]: _*) +// } else { +// val dropUntil = res(i)._1 - (n - buf.length) +// // Same as Iterator.drop but this only takes a long. +// var j: Long = 0L +// while (j < dropUntil) { rows.next(); j += 1L} +// buf.prepend(rows.toArray[InternalRow]: _*) +// } +// i += 1 +// } + } else { + while (bufferedRowSize < n && i < res.length) { + var j = 0 + val batches = res(i) + while (j < batches.length && n > bufferedRowSize) { + val batch = batches(j) + val (_, batchSize) = batch + buf += batch + bufferedRowSize += batchSize + j += 1 + } + i += 1 + } + } + partsScanned += partsToScan.size + } + + buf.toArray + } + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala new file mode 100644 index 00000000000..c7dcf6bfc7d --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.io.ByteArrayOutputStream +import java.nio.channels.Channels + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} +import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.{SizeEstimator, Utils} + +object ArrowConvertersHelper extends Logging { + + /** + * Convert the input rows into fully contained arrow batches. + * Different from [[toBatchIterator]], each output arrow batch starts with the schema. + */ + private[sql] def toBatchWithSchemaIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String): ArrowBatchWithSchemaIterator = { + new ArrowBatchWithSchemaIterator( + rowIter, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + limit, + timeZoneId, + TaskContext.get) + } + + private[sql] class ArrowBatchWithSchemaIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String, + context: TaskContext) + extends Iterator[Array[Byte]] { + + protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected val unloader = new VectorUnloader(root) + protected val arrowWriter = ArrowWriter.create(root) + + Option(context).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } + + override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { + root.close() + allocator.close() + false + } + + private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) + var rowCountInLastBatch: Long = 0 + var rowCount: Long = 0 + + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + rowCountInLastBatch = 0 +// var estimatedBatchSize = arrowSchemaSize + var estimatedBatchSize = 0 + Utils.tryWithSafeFinally { + // Always write the schema. +// MessageSerializer.serialize(writeChannel, arrowSchema) + + // Always write the first row. + while (rowIter.hasNext && ( + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch || + rowCount < limit)) { + val row = rowIter.next() + arrowWriter.write(row) + estimatedBatchSize += (row match { + case ur: UnsafeRow => ur.getSizeInBytes + // Trying to estimate the size of the current row, assuming 16 bytes per value. + case ir: InternalRow => ir.numFields * 16 + }) + rowCountInLastBatch += 1 + rowCount += 1 + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + out.toByteArray + } + } +} diff --git a/kyuubi-server/pom.xml b/kyuubi-server/pom.xml index 7408ac5dd00..905e0375fca 100644 --- a/kyuubi-server/pom.xml +++ b/kyuubi-server/pom.xml @@ -466,7 +466,11 @@ io.delta delta-core_${scala.binary.version} +<<<<<<< HEAD test +======= + +>>>>>>> 2a4e53609... arrow take From 8593d856a38485cfd7e66d7d4963bd0daa9c8efa Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 24 Mar 2023 18:19:54 +0800 Subject: [PATCH 02/28] driver slice last batch --- .../spark/operation/ExecuteStatement.scala | 36 ++++++- .../arrow/ArrowCollectLimitExec.scala | 19 +++- .../execution/arrow/KyuubiArrowUtils.scala | 96 +++++++++++++++++++ 3 files changed, 145 insertions(+), 6 deletions(-) create mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index 1108f46f7fd..a57a711ffd6 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -20,14 +20,15 @@ package org.apache.kyuubi.engine.spark.operation import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution} +import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution, TakeOrderedAndProjectExec} import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.types._ import org.apache.kyuubi.{KyuubiSQLException, Logging} -import org.apache.spark.sql.execution.arrow.ArrowCollectLimitExec +import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils} import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._ @@ -210,10 +211,35 @@ class ArrowBasedExecuteStatement( df.queryExecution.executedPlan.resetMetrics() df.queryExecution.executedPlan match { case collectLimit @ CollectLimitExec(limit, _) => - // scalastyle:off - println("ddddd") val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone - ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId) + val batches = ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId) +// .map(_._1) + val result = ArrayBuffer[Array[Byte]]() + var i = 0 + var rest = limit + println(s"batch....size... ${batches.length}") + while (i < batches.length && rest > 0) { + val (batch, size) = batches(i) + if (size < rest) { + result += batch + // TODO: toInt + rest = rest - size.toInt + } else if (size == rest) { + result += batch + rest = 0 + } else { // size > rest + println(s"size......${size}....rest......${rest}") +// result += KyuubiArrowUtils.slice(batch, 0, rest) + result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest) + rest = 0 + } + i += 1 + } + result.toArray + + case takeOrderedAndProjectExec @ TakeOrderedAndProjectExec(limit, _, _, _) => + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + ArrowCollectLimitExec.taskOrdered(takeOrderedAndProjectExec, df.schema, 1000, 1024 * 1024, timeZoneId) .map(_._1) case _ => println("yyyy") diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala index c250e520ca4..8a804b47191 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow import scala.collection.mutable.{ArrayBuffer, ListBuffer} import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} -import org.apache.spark.sql.execution.CollectLimitExec +import org.apache.spark.sql.execution.{CollectLimitExec, TakeOrderedAndProjectExec} import org.apache.spark.sql.types.StructType object ArrowCollectLimitExec extends SQLConfHelper { @@ -124,4 +124,21 @@ object ArrowCollectLimitExec extends SQLConfHelper { buf.toArray } } + + def taskOrdered( + takeOrdered: TakeOrderedAndProjectExec, + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + timeZoneId: String): Array[Batch] = { + val batches = ArrowConvertersHelper.toBatchWithSchemaIterator( + takeOrdered.executeCollect().iterator, + schema, + maxEstimatedBatchSize, + maxEstimatedBatchSize, + takeOrdered.limit, + timeZoneId) + batches.map(b => b -> batches.rowCountInLastBatch).toArray + } } + diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala new file mode 100644 index 00000000000..bc2d3d24b51 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.MessageSerializer +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils + +object KyuubiArrowUtils { + val rootAllocator = new RootAllocator(Long.MaxValue) + .newChildAllocator("ReadIntTest", 0, Long.MaxValue) + // BufferAllocator allocator = + // ArrowUtils.rootAllocator.newChildAllocator("ReadIntTest", 0, Long.MAX_VALUE); + def slice(bytes: Array[Byte], start: Int, length: Int): Array[Byte] = { + val in = new ByteArrayInputStream(bytes) + val out = new ByteArrayOutputStream() + + var reader: ArrowStreamReader = null + try { + reader = new ArrowStreamReader(in, rootAllocator) +// reader.getVectorSchemaRoot.getSchema + reader.loadNextBatch() + val root = reader.getVectorSchemaRoot.slice(start, length) +// val loader = new VectorLoader(root) + val writer = new ArrowStreamWriter(root, null, out) + writer.start() + writer.writeBatch() + writer.end() + writer.close() + out.toByteArray + } finally { + if (reader != null) { + reader.close() + } + in.close() + out.close() + } + } + + def sliceV2(schema: StructType, + timeZoneId: String, bytes: Array[Byte], start: Int, length: Int): Array[Byte] = { + val in = new ByteArrayInputStream(bytes) + val out = new ByteArrayOutputStream() + + try { +// reader = new ArrowStreamReader(in, rootAllocator) +// // reader.getVectorSchemaRoot.getSchema +// reader.loadNextBatch() +// println("bytes......" + bytes.length) +// println("rowCount......" + reader.getVectorSchemaRoot.getRowCount) +// val root = reader.getVectorSchemaRoot.slice(start, length) + + + val recordBatch = MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), rootAllocator) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + + val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(recordBatch) + recordBatch.close() + + + val unloader = new VectorUnloader(root.slice(start, length)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() + out.toByteArray() + } finally { + in.close() + out.close() + } + } +} From 008867122d99fe3b986a030b92fe31c1253381f0 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Wed, 29 Mar 2023 18:59:46 +0800 Subject: [PATCH 03/28] refine --- .../connector/tpcds/TPCDSQuerySuite.scala | 2 + externals/kyuubi-spark-sql-engine/pom.xml | 7 ++ .../spark/operation/ExecuteStatement.scala | 108 +++++++++--------- .../arrow/ArrowCollectLimitExec.scala | 17 --- .../execution/arrow/KyuubiArrowUtils.scala | 15 ++- .../SparkArrowbasedOperationSuite.scala | 34 ++++++ 6 files changed, 104 insertions(+), 79 deletions(-) diff --git a/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala b/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala index 83679989a79..3a1fd87a5ce 100644 --- a/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala +++ b/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala @@ -88,4 +88,6 @@ class TPCDSQuerySuite extends KyuubiFunSuite { } } } + + test("aa") {} } diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml index 5b227cb5e29..942a8ec6c27 100644 --- a/externals/kyuubi-spark-sql-engine/pom.xml +++ b/externals/kyuubi-spark-sql-engine/pom.xml @@ -71,6 +71,13 @@ provided + + com.google.guava + guava + 14.0.1 + provided + + org.scala-lang scala-compiler diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index a57a711ffd6..e323d2d7316 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -22,14 +22,13 @@ import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution, TakeOrderedAndProjectExec} +import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution} +import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils} import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.types._ -import org.apache.kyuubi.{KyuubiSQLException, Logging} -import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils} +import org.apache.kyuubi.{KyuubiSQLException, Logging} import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._ import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationHandle, OperationState} @@ -189,73 +188,70 @@ class ArrowBasedExecuteStatement( handle) { override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = { - collectAsArrow(convertComplexType(resultDF)) { rdd => - rdd.toLocalIterator + val df = convertComplexType(resultDF) + withNewExecutionId(df) { + SparkDatasetHelper.toArrowBatchRdd(df).toLocalIterator } } override protected def fullCollectResult(resultDF: DataFrame): Array[_] = { - collectAsArrow(convertComplexType(resultDF)) { rdd => - rdd.collect() - } + executeCollect(convertComplexType(resultDF)) } override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = { - // this will introduce shuffle and hurt performance - val limitedResult = resultDF.limit(maxRows) -// collectAsArrow(convertComplexType(limitedResult)) { rdd => -// rdd.collect() -// } - val df = convertComplexType(limitedResult) - SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { - df.queryExecution.executedPlan.resetMetrics() - df.queryExecution.executedPlan match { - case collectLimit @ CollectLimitExec(limit, _) => - val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone - val batches = ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId) -// .map(_._1) - val result = ArrayBuffer[Array[Byte]]() - var i = 0 - var rest = limit - println(s"batch....size... ${batches.length}") - while (i < batches.length && rest > 0) { - val (batch, size) = batches(i) - if (size < rest) { - result += batch - // TODO: toInt - rest = rest - size.toInt - } else if (size == rest) { - result += batch - rest = 0 - } else { // size > rest - println(s"size......${size}....rest......${rest}") -// result += KyuubiArrowUtils.slice(batch, 0, rest) - result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest) - rest = 0 - } - i += 1 - } - result.toArray - - case takeOrderedAndProjectExec @ TakeOrderedAndProjectExec(limit, _, _, _) => - val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone - ArrowCollectLimitExec.taskOrdered(takeOrderedAndProjectExec, df.schema, 1000, 1024 * 1024, timeZoneId) - .map(_._1) - case _ => - println("yyyy") - SparkDatasetHelper.toArrowBatchRdd(df).collect() - } - } + executeCollect(convertComplexType(resultDF.limit(maxRows))) } /** * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based * operation, so that we can track the arrow-based queries on the UI tab. */ - private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = { + private def withNewExecutionId[T](df: DataFrame)(body: => T): T = { SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { df.queryExecution.executedPlan.resetMetrics() - action(SparkDatasetHelper.toArrowBatchRdd(df)) + body + } + } + + def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { + executeArrowBatchCollect(df).getOrElse { + SparkDatasetHelper.toArrowBatchRdd(df).collect() + } + } + + private def executeArrowBatchCollect(df: DataFrame): Option[Array[Array[Byte]]] = { + df.queryExecution.executedPlan match { + case collectLimit @ CollectLimitExec(limit, _) => + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + val maxRecordsPerBatch = spark.conf.getOption( + "spark.sql.execution.arrow.maxRecordsPerBatch").map(_.toInt).getOrElse(10000) + // val maxBatchSize = + // (spark.sessionState.conf.getConf(SPARK_CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong + val maxBatchSize = 1024 * 1024 * 4 + val batches = ArrowCollectLimitExec.takeAsArrowBatches( + collectLimit, + df.schema, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + val result = ArrayBuffer[Array[Byte]]() + var i = 0 + var rest = limit + while (i < batches.length && rest > 0) { + val (batch, size) = batches(i) + if (size <= rest) { + result += batch + // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion + rest -= size.toInt + } else { // size > rest + result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest) + rest = 0 + } + i += 1 + } + Option(result.toArray) + case _ => + None } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala index 8a804b47191..86492877e97 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala @@ -124,21 +124,4 @@ object ArrowCollectLimitExec extends SQLConfHelper { buf.toArray } } - - def taskOrdered( - takeOrdered: TakeOrderedAndProjectExec, - schema: StructType, - maxRecordsPerBatch: Long, - maxEstimatedBatchSize: Long, - timeZoneId: String): Array[Batch] = { - val batches = ArrowConvertersHelper.toBatchWithSchemaIterator( - takeOrdered.executeCollect().iterator, - schema, - maxEstimatedBatchSize, - maxEstimatedBatchSize, - takeOrdered.limit, - timeZoneId) - batches.map(b => b -> batches.rowCountInLastBatch).toArray - } } - diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala index bc2d3d24b51..674a0c8ccee 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala @@ -21,9 +21,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.channels.Channels import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.MessageSerializer -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -58,8 +58,12 @@ object KyuubiArrowUtils { } } - def sliceV2(schema: StructType, - timeZoneId: String, bytes: Array[Byte], start: Int, length: Int): Array[Byte] = { + def sliceV2( + schema: StructType, + timeZoneId: String, + bytes: Array[Byte], + start: Int, + length: Int): Array[Byte] = { val in = new ByteArrayInputStream(bytes) val out = new ByteArrayOutputStream() @@ -71,9 +75,9 @@ object KyuubiArrowUtils { // println("rowCount......" + reader.getVectorSchemaRoot.getRowCount) // val root = reader.getVectorSchemaRoot.slice(start, length) - val recordBatch = MessageSerializer.deserializeRecordBatch( - new ReadChannel(Channels.newChannel(in)), rootAllocator) + new ReadChannel(Channels.newChannel(in)), + rootAllocator) val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) @@ -81,7 +85,6 @@ object KyuubiArrowUtils { vectorLoader.load(recordBatch) recordBatch.close() - val unloader = new VectorUnloader(root.slice(start, length)) val writeChannel = new WriteChannel(Channels.newChannel(out)) val batch = unloader.getRecordBatch() diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index ae6237bb59c..e48b1f46ffe 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -20,8 +20,10 @@ package org.apache.kyuubi.engine.spark.operation import java.sql.Statement import org.apache.spark.KyuubiSparkContextHelper +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.functions.col import org.apache.spark.sql.util.QueryExecutionListener import org.apache.kyuubi.config.KyuubiConf @@ -138,6 +140,20 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp assert(metrics("numOutputRows").value === 1) } + test("aa") { + + withJdbcStatement() { statement => + loadPartitionedTable() + val n = 17 + statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$n") + val result = statement.executeQuery("select * from t_1") + for (i <- 0 until n) { + assert(result.next()) + } + assert(!result.next()) + } + } + private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = { val query = s""" @@ -177,4 +193,22 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp .allSessions() .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener)) } + + private def loadPartitionedTable(): Unit = { + SparkSQLEngine.currentEngine.get + .backendService + .sessionManager + .allSessions() + .map(_.asInstanceOf[SparkSessionImpl].spark) + .foreach { spark => + spark.range(1000) + .repartitionByRange(100, col("id")) + .createOrReplaceTempView("t_1") + spark.sql("select * from t_1") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () + } + } + } } From ed8c6928baeda334773a3067ac08a84666f5a463 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Mon, 3 Apr 2023 12:21:09 +0800 Subject: [PATCH 04/28] refactor --- .../spark/operation/ExecuteStatement.scala | 28 +++-- ...imitExec.scala => ArrowCollectUtils.scala} | 100 ++++++++++-------- .../arrow/ArrowConvertersHelper.scala | 18 ++-- .../execution/arrow/KyuubiArrowUtils.scala | 46 ++------ .../SparkArrowbasedOperationSuite.scala | 31 ++++-- 5 files changed, 110 insertions(+), 113 deletions(-) rename externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/{ArrowCollectLimitExec.scala => ArrowCollectUtils.scala} (54%) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index e323d2d7316..9afdd92f2cf 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -22,9 +22,10 @@ import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution} -import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils} +import org.apache.spark.sql.execution.arrow.{ArrowCollectUtils, KyuubiArrowUtils} import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.types._ @@ -213,7 +214,7 @@ class ArrowBasedExecuteStatement( } } - def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { + private def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { executeArrowBatchCollect(df).getOrElse { SparkDatasetHelper.toArrowBatchRdd(df).collect() } @@ -223,17 +224,16 @@ class ArrowBasedExecuteStatement( df.queryExecution.executedPlan match { case collectLimit @ CollectLimitExec(limit, _) => val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone - val maxRecordsPerBatch = spark.conf.getOption( - "spark.sql.execution.arrow.maxRecordsPerBatch").map(_.toInt).getOrElse(10000) - // val maxBatchSize = - // (spark.sessionState.conf.getConf(SPARK_CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong - val maxBatchSize = 1024 * 1024 * 4 - val batches = ArrowCollectLimitExec.takeAsArrowBatches( + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + + val batches = ArrowCollectUtils.takeAsArrowBatches( collectLimit, - df.schema, maxRecordsPerBatch, maxBatchSize, timeZoneId) + + // note that the number of rows in the returned arrow batches may be >= `limit`, performing + // the slicing operation of result val result = ArrayBuffer[Array[Byte]]() var i = 0 var rest = limit @@ -244,7 +244,7 @@ class ArrowBasedExecuteStatement( // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion rest -= size.toInt } else { // size > rest - result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest) + result += KyuubiArrowUtils.slice(df.schema, timeZoneId, batch, 0, rest) rest = 0 } i += 1 @@ -263,4 +263,12 @@ class ArrowBasedExecuteStatement( SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString) } + private lazy val maxBatchSize: Long = { + // respect spark connect config + spark.sparkContext.getConf.getOption("spark.connect.grpc.arrow.maxBatchSize") + .orElse(Option("4m")) + .map(JavaUtils.byteStringAs(_, ByteUnit.MiB)) + .get + } + } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala similarity index 54% rename from externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala rename to externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala index 86492877e97..86fc28eec71 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala @@ -17,40 +17,60 @@ package org.apache.spark.sql.execution.arrow -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} -import org.apache.spark.sql.execution.{CollectLimitExec, TakeOrderedAndProjectExec} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.CollectLimitExec -object ArrowCollectLimitExec extends SQLConfHelper { +object ArrowCollectUtils extends SQLConfHelper { type Batch = (Array[Byte], Long) + /** + * Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be + * summarized in the following steps: + * 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty + * array of batches. + * 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of + * data to collect. + * 3. Use an iterative approach to collect data in batches until the specified limit is reached. + * In each iteration, it selects a subset of the partitions of the RDD to scan and tries to + * collect data from them. + * 4. For each partition subset, we use the runJob method of the Spark context to execute a + * closure that scans the partition data and converts it to Arrow batches. + * 5. Check if the collected data reaches the specified limit. If not, it selects another subset + * of partitions to scan and repeats the process until the limit is reached or all partitions + * have been scanned. + * 6. Return an array of all the collected Arrow batches. + * + * Note that: + * 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit` + * row count + * 2. We don't implement the `takeFromEnd` logical + * + * @return + */ def takeAsArrowBatches( collectLimitExec: CollectLimitExec, - schema: StructType, maxRecordsPerBatch: Long, maxEstimatedBatchSize: Long, timeZoneId: String): Array[Batch] = { val n = collectLimitExec.limit - // TODO - val takeFromEnd = false + val schema = collectLimitExec.schema if (n == 0) { return new Array[Batch](0) } else { val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) - // // TODO: refactor and reuse the code from RDD's take() + // TODO: refactor and reuse the code from RDD's take() val childRDD = collectLimitExec.child.execute() - val buf = if (takeFromEnd) new ListBuffer[Batch] else new ArrayBuffer[Batch] + val buf = new ArrayBuffer[Batch] var bufferedRowSize = 0L val totalParts = childRDD.partitions.length var partsScanned = 0 while (bufferedRowSize < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. -// var numPartsToTry = conf.limitInitialNumPartitions - var numPartsToTry = 1 + var numPartsToTry = limitInitialNumPartitions if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, multiply by // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need @@ -65,58 +85,36 @@ object ArrowCollectLimitExec extends SQLConfHelper { } } - val parts = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) - val partsToScan = if (takeFromEnd) { - // Reverse partitions to scan. So, if parts was [1, 2, 3] in 200 partitions (0 to 199), - // it becomes [198, 197, 196]. - parts.map(p => (totalParts - 1) - p) - } else { - parts - } + val partsToScan = + partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = collectLimitExec.session.sparkContext val res = sc.runJob( childRDD, (it: Iterator[InternalRow]) => { - val batches = ArrowConvertersHelper.toBatchWithSchemaIterator( + val batches = ArrowConvertersHelper.toBatchIterator( it, schema, maxRecordsPerBatch, maxEstimatedBatchSize, - collectLimitExec.limit, + n, timeZoneId) batches.map(b => b -> batches.rowCountInLastBatch).toArray }, partsToScan) var i = 0 - if (takeFromEnd) { -// while (buf.length < n && i < res.length) { -// val rows = decodeUnsafeRows(res(i)._2) -// if (n - buf.length >= res(i)._1) { -// buf.prepend(rows.toArray[InternalRow]: _*) -// } else { -// val dropUntil = res(i)._1 - (n - buf.length) -// // Same as Iterator.drop but this only takes a long. -// var j: Long = 0L -// while (j < dropUntil) { rows.next(); j += 1L} -// buf.prepend(rows.toArray[InternalRow]: _*) -// } -// i += 1 -// } - } else { - while (bufferedRowSize < n && i < res.length) { - var j = 0 - val batches = res(i) - while (j < batches.length && n > bufferedRowSize) { - val batch = batches(j) - val (_, batchSize) = batch - buf += batch - bufferedRowSize += batchSize - j += 1 - } - i += 1 + while (bufferedRowSize < n && i < res.length) { + var j = 0 + val batches = res(i) + while (j < batches.length && n > bufferedRowSize) { + val batch = batches(j) + val (_, batchSize) = batch + buf += batch + bufferedRowSize += batchSize + j += 1 } + i += 1 } partsScanned += partsToScan.size } @@ -124,4 +122,12 @@ object ArrowCollectLimitExec extends SQLConfHelper { buf.toArray } } + + /** + * Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211 + */ + def limitInitialNumPartitions: Int = { + conf.getConfString("spark.sql.limit.initialNumPartitions", "1") + .toInt + } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala index c7dcf6bfc7d..e7e77799658 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala @@ -29,22 +29,22 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.{SizeEstimator, Utils} +import org.apache.spark.util.Utils object ArrowConvertersHelper extends Logging { /** - * Convert the input rows into fully contained arrow batches. - * Different from [[toBatchIterator]], each output arrow batch starts with the schema. + * Different from [[org.apache.spark.sql.execution.arrow.ArrowConvertersHelper.toBatchIterator]], + * each output arrow batch contains this batch row count. */ - private[sql] def toBatchWithSchemaIterator( + def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, maxEstimatedBatchSize: Long, limit: Long, - timeZoneId: String): ArrowBatchWithSchemaIterator = { - new ArrowBatchWithSchemaIterator( + timeZoneId: String): ArrowBatchIterator = { + new ArrowBatchIterator( rowIter, schema, maxRecordsPerBatch, @@ -54,7 +54,7 @@ object ArrowConvertersHelper extends Logging { TaskContext.get) } - private[sql] class ArrowBatchWithSchemaIterator( + private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, @@ -88,7 +88,6 @@ object ArrowConvertersHelper extends Logging { false } - private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) var rowCountInLastBatch: Long = 0 var rowCount: Long = 0 @@ -97,11 +96,8 @@ object ArrowConvertersHelper extends Logging { val writeChannel = new WriteChannel(Channels.newChannel(out)) rowCountInLastBatch = 0 -// var estimatedBatchSize = arrowSchemaSize var estimatedBatchSize = 0 Utils.tryWithSafeFinally { - // Always write the schema. -// MessageSerializer.serialize(writeChannel, arrowSchema) // Always write the first row. while (rowIter.hasNext && ( diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala index 674a0c8ccee..ba94e255c9c 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala @@ -20,45 +20,20 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.channels.Channels -import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.MessageSerializer import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils object KyuubiArrowUtils { - val rootAllocator = new RootAllocator(Long.MaxValue) - .newChildAllocator("ReadIntTest", 0, Long.MaxValue) - // BufferAllocator allocator = - // ArrowUtils.rootAllocator.newChildAllocator("ReadIntTest", 0, Long.MAX_VALUE); - def slice(bytes: Array[Byte], start: Int, length: Int): Array[Byte] = { - val in = new ByteArrayInputStream(bytes) - val out = new ByteArrayOutputStream() - - var reader: ArrowStreamReader = null - try { - reader = new ArrowStreamReader(in, rootAllocator) -// reader.getVectorSchemaRoot.getSchema - reader.loadNextBatch() - val root = reader.getVectorSchemaRoot.slice(start, length) -// val loader = new VectorLoader(root) - val writer = new ArrowStreamWriter(root, null, out) - writer.start() - writer.writeBatch() - writer.end() - writer.close() - out.toByteArray - } finally { - if (reader != null) { - reader.close() - } - in.close() - out.close() - } - } - def sliceV2( + private val rootAllocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + def slice( schema: StructType, timeZoneId: String, bytes: Array[Byte], @@ -68,13 +43,6 @@ object KyuubiArrowUtils { val out = new ByteArrayOutputStream() try { -// reader = new ArrowStreamReader(in, rootAllocator) -// // reader.getVectorSchemaRoot.getSchema -// reader.loadNextBatch() -// println("bytes......" + bytes.length) -// println("rowCount......" + reader.getVectorSchemaRoot.getRowCount) -// val root = reader.getVectorSchemaRoot.slice(start, length) - val recordBatch = MessageSerializer.deserializeRecordBatch( new ReadChannel(Channels.newChannel(in)), rootAllocator) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index e48b1f46ffe..017fb3b312c 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -142,15 +142,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp test("aa") { + val returnSize = Seq( + 7, + 10, + 13, + 20, + 29) + + withJdbcStatement() { statement => + loadPartitionedTable() + returnSize.foreach { size => + statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$size") + val result = statement.executeQuery("select * from t_1") + for (i <- 0 until size) { + assert(result.next()) + } + assert(!result.next()) + } + } + withJdbcStatement() { statement => loadPartitionedTable() - val n = 17 - statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$n") - val result = statement.executeQuery("select * from t_1") - for (i <- 0 until n) { - assert(result.next()) + returnSize.foreach { size => + val result = statement.executeQuery(s"select * from t_1 limit $size") + for (i <- 0 until size) { + assert(result.next()) + } + assert(!result.next()) } - assert(!result.next()) } } From 4212a8967c8a10fd259901092414d79248e919ad Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 4 Apr 2023 14:05:29 +0800 Subject: [PATCH 05/28] refactor and add ut --- .../spark/operation/ExecuteStatement.scala | 74 +------------- .../arrow/ArrowConvertersHelper.scala | 9 ++ .../spark/sql/kyuubi/SparkDatasetHelper.scala | 98 +++++++++++++++++++ .../engine/spark/WithSparkSQLEngine.scala | 3 +- .../SparkArrowbasedOperationSuite.scala | 95 ++++++++++-------- .../spark/KyuubiSparkContextHelper.scala | 2 + 6 files changed, 167 insertions(+), 114 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index 9afdd92f2cf..ca30f53001f 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -20,13 +20,9 @@ package org.apache.kyuubi.engine.spark.operation import java.util.concurrent.RejectedExecutionException import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution} -import org.apache.spark.sql.execution.arrow.{ArrowCollectUtils, KyuubiArrowUtils} -import org.apache.spark.sql.kyuubi.SparkDatasetHelper +import org.apache.spark.sql.kyuubi.SparkDatasetHelper._ import org.apache.spark.sql.types._ import org.apache.kyuubi.{KyuubiSQLException, Logging} @@ -189,10 +185,7 @@ class ArrowBasedExecuteStatement( handle) { override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = { - val df = convertComplexType(resultDF) - withNewExecutionId(df) { - SparkDatasetHelper.toArrowBatchRdd(df).toLocalIterator - } + toArrowBatchLocalIterator(convertComplexType(resultDF)) } override protected def fullCollectResult(resultDF: DataFrame): Array[_] = { @@ -203,72 +196,11 @@ class ArrowBasedExecuteStatement( executeCollect(convertComplexType(resultDF.limit(maxRows))) } - /** - * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based - * operation, so that we can track the arrow-based queries on the UI tab. - */ - private def withNewExecutionId[T](df: DataFrame)(body: => T): T = { - SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { - df.queryExecution.executedPlan.resetMetrics() - body - } - } - - private def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { - executeArrowBatchCollect(df).getOrElse { - SparkDatasetHelper.toArrowBatchRdd(df).collect() - } - } - - private def executeArrowBatchCollect(df: DataFrame): Option[Array[Array[Byte]]] = { - df.queryExecution.executedPlan match { - case collectLimit @ CollectLimitExec(limit, _) => - val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone - val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch - - val batches = ArrowCollectUtils.takeAsArrowBatches( - collectLimit, - maxRecordsPerBatch, - maxBatchSize, - timeZoneId) - - // note that the number of rows in the returned arrow batches may be >= `limit`, performing - // the slicing operation of result - val result = ArrayBuffer[Array[Byte]]() - var i = 0 - var rest = limit - while (i < batches.length && rest > 0) { - val (batch, size) = batches(i) - if (size <= rest) { - result += batch - // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion - rest -= size.toInt - } else { // size > rest - result += KyuubiArrowUtils.slice(df.schema, timeZoneId, batch, 0, rest) - rest = 0 - } - i += 1 - } - Option(result.toArray) - case _ => - None - } - } - override protected def isArrowBasedOperation: Boolean = true override val resultFormat = "arrow" private def convertComplexType(df: DataFrame): DataFrame = { - SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString) - } - - private lazy val maxBatchSize: Long = { - // respect spark connect config - spark.sparkContext.getConf.getOption("spark.connect.grpc.arrow.maxBatchSize") - .orElse(Option("4m")) - .map(JavaUtils.byteStringAs(_, ByteUnit.MiB)) - .get + convertTopLevelComplexTypeToHiveString(df, timestampAsString) } - } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala index e7e77799658..de9d39e27c7 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala @@ -136,4 +136,13 @@ object ArrowConvertersHelper extends Logging { out.toByteArray } } + + // for testing + def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { + ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context) + } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 1a542937338..7df14e77795 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -17,18 +17,95 @@ package org.apache.spark.sql.kyuubi +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.arrow.{ArrowCollectUtils, ArrowConverters, KyuubiArrowUtils} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.kyuubi.engine.spark.KyuubiSparkUtil import org.apache.kyuubi.engine.spark.schema.RowSet object SparkDatasetHelper { + def toArrowBatchRdd[T](ds: Dataset[T]): RDD[Array[Byte]] = { ds.toArrowBatchRdd } + /** + * Forked from [[Dataset.toArrowBatchRdd(plan: SparkPlan)]]. + * Convert to an RDD of serialized ArrowRecordBatches. + */ + def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { + val schemaCaptured = plan.schema + val maxRecordsPerBatch = plan.session.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = plan.session.sessionState.conf.sessionLocalTimeZone + plan.execute().mapPartitionsInternal { iter => + val context = TaskContext.get() + ArrowConverters.toBatchIterator( + iter, + schemaCaptured, + maxRecordsPerBatch, + timeZoneId, + context) + } + } + + def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = { + val timeZoneId = collectLimit.session.sessionState.conf.sessionLocalTimeZone + val maxRecordsPerBatch = collectLimit.session.sessionState.conf.arrowMaxRecordsPerBatch + + val batches = ArrowCollectUtils.takeAsArrowBatches( + collectLimit, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + + // note that the number of rows in the returned arrow batches may be >= `limit`, preform + // the slicing operation of result + val result = ArrayBuffer[Array[Byte]]() + var i = 0 + var rest = collectLimit.limit + while (i < batches.length && rest > 0) { + val (batch, size) = batches(i) + if (size <= rest) { + result += batch + // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion + rest -= size.toInt + } else { // size > rest + result += KyuubiArrowUtils.slice(collectLimit.schema, timeZoneId, batch, 0, rest) + rest = 0 + } + i += 1 + } + result.toArray + } + + def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { + executeArrowBatchCollect(df.queryExecution.executedPlan) + } + + def toArrowBatchLocalIterator(df: DataFrame): Iterator[Array[Byte]] = { + withNewExecutionId(df) { + toArrowBatchRdd(df).toLocalIterator + } + } + + def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = { + case adaptiveSparkPlan: AdaptiveSparkPlanExec => + executeArrowBatchCollect(adaptiveSparkPlan.finalPhysicalPlan) + case collectLimit: CollectLimitExec => + doCollectLimit(collectLimit) + case plan: SparkPlan => + toArrowBatchRdd(plan).collect() + } + def convertTopLevelComplexTypeToHiveString( df: DataFrame, timestampAsString: Boolean): DataFrame = { @@ -75,4 +152,25 @@ object SparkDatasetHelper { s"`${part.replace("`", "``")}`" } } + + private lazy val maxBatchSize: Long = { + // respect spark connect config + KyuubiSparkUtil.globalSparkContext + .getConf + .getOption("spark.connect.grpc.arrow.maxBatchSize") + .orElse(Option("4m")) + .map(JavaUtils.byteStringAs(_, ByteUnit.MiB)) + .get + } + + /** + * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based + * operation, so that we can track the arrow-based queries on the UI tab. + */ + private def withNewExecutionId[T](df: DataFrame)(body: => T): T = { + SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) { + df.queryExecution.executedPlan.resetMetrics() + body + } + } } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala index 629a8374b12..03d1102b322 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala @@ -58,7 +58,8 @@ trait WithSparkSQLEngine extends KyuubiFunSuite { s"jdbc:derby:;databaseName=$metastorePath;create=true") System.setProperty("spark.sql.warehouse.dir", warehousePath.toString) System.setProperty("spark.sql.hive.metastore.sharedPrefixes", "org.apache.hive.jdbc") - System.setProperty("spark.ui.enabled", "false") + System.setProperty("spark.ui.enabled", "true") + System.setProperty("spark.ui.port", "4040") withKyuubiConf.foreach { case (k, v) => System.setProperty(k, v) } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 017fb3b312c..22f5e4c2489 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -22,8 +22,11 @@ import java.sql.Statement import org.apache.spark.KyuubiSparkContextHelper import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.arrow.ArrowConvertersHelper import org.apache.spark.sql.functions.col +import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.util.QueryExecutionListener import org.apache.kyuubi.config.KyuubiConf @@ -143,33 +146,59 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp test("aa") { val returnSize = Seq( - 7, - 10, - 13, - 20, - 29) - - withJdbcStatement() { statement => - loadPartitionedTable() - returnSize.foreach { size => - statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$size") - val result = statement.executeQuery("select * from t_1") - for (i <- 0 until size) { - assert(result.next()) - } - assert(!result.next()) + 7, // less than one partition + 10, // equal to one partition + 13, // between one and two partitions, run two jobs + 20, // equal to two partitions + 29) // between two and three partitions + + // aqe + // outermost AdaptiveSparkPlanExec + spark.range(1000) + .repartitionByRange(100, col("id")) + .createOrReplaceTempView("t_1") + spark.sql("select * from t_1") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () } + returnSize.foreach { size => + val df = spark.sql(s"select * from t_1 limit $size") + val headPlan = df.queryExecution.executedPlan.collectLeaves().head + assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) + assert( + headPlan.asInstanceOf[AdaptiveSparkPlanExec].finalPhysicalPlan.isInstanceOf[ + CollectLimitExec]) + + val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) + + val rows = ArrowConvertersHelper.fromBatchIterator( + arrowBinary.iterator, + df.schema, + "", + KyuubiSparkContextHelper.dummyTaskContext()) + assert(rows.size == size) } - withJdbcStatement() { statement => - loadPartitionedTable() - returnSize.foreach { size => - val result = statement.executeQuery(s"select * from t_1 limit $size") - for (i <- 0 until size) { - assert(result.next()) - } - assert(!result.next()) + // outermost CollectLimitExec + spark.range(0, 1000, 1, numPartitions = 100) + .createOrReplaceTempView("t_2") + spark.sql("select * from t_2") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () } + returnSize.foreach { size => + val df = spark.sql(s"select * from t_2 limit $size") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[CollectLimitExec]) + val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) + val rows = ArrowConvertersHelper.fromBatchIterator( + arrowBinary.iterator, + df.schema, + "", + KyuubiSparkContextHelper.dummyTaskContext()) + assert(rows.size == size) } } @@ -212,22 +241,4 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp .allSessions() .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener)) } - - private def loadPartitionedTable(): Unit = { - SparkSQLEngine.currentEngine.get - .backendService - .sessionManager - .allSessions() - .map(_.asInstanceOf[SparkSessionImpl].spark) - .foreach { spark => - spark.range(1000) - .repartitionByRange(100, col("id")) - .createOrReplaceTempView("t_1") - spark.sql("select * from t_1") - .foreachPartition { p: Iterator[Row] => - assert(p.length == 10) - () - } - } - } } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala index 8293123ead7..1b662eadf96 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala @@ -27,4 +27,6 @@ object KyuubiSparkContextHelper { def waitListenerBus(spark: SparkSession): Unit = { spark.sparkContext.listenerBus.waitUntilEmpty() } + + def dummyTaskContext(): TaskContextImpl = TaskContext.empty() } From 6c5b1eb615b8d5e3232e4815c845a717df6926a5 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 4 Apr 2023 14:34:02 +0800 Subject: [PATCH 06/28] add ut --- .../SparkArrowbasedOperationSuite.scala | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 22f5e4c2489..75f1cc8e22f 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -20,7 +20,8 @@ package org.apache.kyuubi.engine.spark.operation import java.sql.Statement import org.apache.spark.KyuubiSparkContextHelper -import org.apache.spark.sql.Row +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec @@ -143,8 +144,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp assert(metrics("numOutputRows").value === 1) } - test("aa") { - + test("SparkDatasetHelper.executeArrowBatchCollect should return expect row count") { val returnSize = Seq( 7, // less than one partition 10, // equal to one partition @@ -202,6 +202,25 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp } } + test("arrow serialization should not introduce extra shuffle for outermost limit") { + var numStages = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numStages = jobStart.stageInfos.length + } + } + withJdbcStatement() { statement => + withSparkListener(listener) { + withPartitionedTable("t_3") { + statement.executeQuery("select * from t_3 limit 1000") + } + KyuubiSparkContextHelper.waitListenerBus(spark) + } + } + // Should be only one stage since there is no shuffle. + assert(numStages == 1) + } + private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = { val query = s""" @@ -241,4 +260,37 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp .allSessions() .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener)) } + + private def withSparkListener[T](listener: SparkListener)(body: => T): T = { + withAllSessions(s => s.sparkContext.addSparkListener(listener)) + try { + body + } finally { + withAllSessions(s => s.sparkContext.removeSparkListener(listener)) + } + + } + + private def withPartitionedTable[T](viewName: String)(body: => T): T = { + withAllSessions { spark => + spark.range(0, 1000, 1, numPartitions = 100) + .createOrReplaceTempView(viewName) + } + try { + body + } finally { + withAllSessions { spark => + spark.sql(s"DROP VIEW IF EXISTS $viewName") + } + } + } + + private def withAllSessions(op: SparkSession => Unit): Unit = { + SparkSQLEngine.currentEngine.get + .backendService + .sessionManager + .allSessions() + .map(_.asInstanceOf[SparkSessionImpl].spark) + .foreach(op(_)) + } } From ee5a7567a212ee3baf964d2d364e0b39060e1b78 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 4 Apr 2023 15:06:47 +0800 Subject: [PATCH 07/28] revert unnecessarily changes --- .../kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala | 2 -- externals/kyuubi-spark-sql-engine/pom.xml | 7 ------- .../apache/kyuubi/engine/spark/WithSparkSQLEngine.scala | 3 +-- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala b/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala index 3a1fd87a5ce..83679989a79 100644 --- a/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala +++ b/extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala @@ -88,6 +88,4 @@ class TPCDSQuerySuite extends KyuubiFunSuite { } } } - - test("aa") {} } diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml index 942a8ec6c27..5b227cb5e29 100644 --- a/externals/kyuubi-spark-sql-engine/pom.xml +++ b/externals/kyuubi-spark-sql-engine/pom.xml @@ -71,13 +71,6 @@ provided - - com.google.guava - guava - 14.0.1 - provided - - org.scala-lang scala-compiler diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala index 03d1102b322..629a8374b12 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/WithSparkSQLEngine.scala @@ -58,8 +58,7 @@ trait WithSparkSQLEngine extends KyuubiFunSuite { s"jdbc:derby:;databaseName=$metastorePath;create=true") System.setProperty("spark.sql.warehouse.dir", warehousePath.toString) System.setProperty("spark.sql.hive.metastore.sharedPrefixes", "org.apache.hive.jdbc") - System.setProperty("spark.ui.enabled", "true") - System.setProperty("spark.ui.port", "4040") + System.setProperty("spark.ui.enabled", "false") withKyuubiConf.foreach { case (k, v) => System.setProperty(k, v) } From 4e7ca54df30a2b05a05a6786137edc59debdf92c Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 4 Apr 2023 15:08:03 +0800 Subject: [PATCH 08/28] unnecessarily changes --- kyuubi-server/pom.xml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/kyuubi-server/pom.xml b/kyuubi-server/pom.xml index 905e0375fca..7408ac5dd00 100644 --- a/kyuubi-server/pom.xml +++ b/kyuubi-server/pom.xml @@ -466,11 +466,7 @@ io.delta delta-core_${scala.binary.version} -<<<<<<< HEAD test -======= - ->>>>>>> 2a4e53609... arrow take From 885cf2c71f1e06302b68747f3b01d78430da0361 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 4 Apr 2023 19:43:50 +0800 Subject: [PATCH 09/28] infer row size by schema.defaultSize --- .../spark/sql/execution/arrow/ArrowConvertersHelper.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala index de9d39e27c7..52c96ad40c5 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala @@ -115,8 +115,8 @@ object ArrowConvertersHelper extends Logging { arrowWriter.write(row) estimatedBatchSize += (row match { case ur: UnsafeRow => ur.getSizeInBytes - // Trying to estimate the size of the current row, assuming 16 bytes per value. - case ir: InternalRow => ir.numFields * 16 + // Trying to estimate the size of the current row + case _: InternalRow => schema.defaultSize }) rowCountInLastBatch += 1 rowCount += 1 From 25e4f056c0983c09235d513b94063071066bfebb Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Tue, 4 Apr 2023 20:22:35 +0800 Subject: [PATCH 10/28] add docs --- .../spark/sql/execution/arrow/ArrowConvertersHelper.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala index 52c96ad40c5..fc819410dca 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala @@ -54,6 +54,13 @@ object ArrowConvertersHelper extends Logging { TaskContext.get) } + /** + * This class ArrowBatchIterator is derived from + * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]], + * with two key differences: + * 1. there is no requirement to write the schema at the batch header + * 2. iteration halts when `rowCount` equals `limit` + */ private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, From 03d074732d48726519d56eccd9882a467e1324fe Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 10:07:11 +0800 Subject: [PATCH 11/28] address comment --- .../scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 7df14e77795..adfc0dc1f9d 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -67,7 +67,7 @@ object SparkDatasetHelper { maxBatchSize, timeZoneId) - // note that the number of rows in the returned arrow batches may be >= `limit`, preform + // note that the number of rows in the returned arrow batches may be >= `limit`, perform // the slicing operation of result val result = ArrayBuffer[Array[Byte]]() var i = 0 From 2286afc6b947eba480f5fc8188bf36298bad5d36 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 11:25:00 +0800 Subject: [PATCH 12/28] reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan --- .../spark/sql/kyuubi/SparkDatasetHelper.scala | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index adfc0dc1f9d..ba55812e161 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.kyuubi.engine.spark.KyuubiSparkUtil import org.apache.kyuubi.engine.spark.schema.RowSet +import org.apache.kyuubi.reflection.DynMethods object SparkDatasetHelper { @@ -99,7 +100,7 @@ object SparkDatasetHelper { def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = { case adaptiveSparkPlan: AdaptiveSparkPlanExec => - executeArrowBatchCollect(adaptiveSparkPlan.finalPhysicalPlan) + executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan)) case collectLimit: CollectLimitExec => doCollectLimit(collectLimit) case plan: SparkPlan => @@ -163,6 +164,36 @@ object SparkDatasetHelper { .get } + /** + * This method provides a reflection-based implementation of + * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime + * without patching SPARK-41914. + * + * TODO: Once we drop support for Spark 3.1.x, we can directly call + * [[AdaptiveSparkPlanExec.finalPhysicalPlan]]. + */ + private def finalPhysicalPlan(adaptiveSparkPlanExec: AdaptiveSparkPlanExec): SparkPlan = { + withFinalPlanUpdate(adaptiveSparkPlanExec, identity) + } + + /** + * A reflection-based implementation of [[AdaptiveSparkPlanExec.withFinalPlanUpdate]]. + */ + private def withFinalPlanUpdate[T]( + adaptiveSparkPlanExec: AdaptiveSparkPlanExec, + fun: SparkPlan => T): T = { + val getFinalPhysicalPlan = DynMethods.builder("getFinalPhysicalPlan") + .hiddenImpl(adaptiveSparkPlanExec.getClass) + .build() + val plan = getFinalPhysicalPlan.invoke[SparkPlan](adaptiveSparkPlanExec) + val result = fun(plan) + val finalPlanUpdate = DynMethods.builder("finalPlanUpdate") + .hiddenImpl(adaptiveSparkPlanExec.getClass) + .build(adaptiveSparkPlanExec) + finalPlanUpdate.invoke[Unit](adaptiveSparkPlanExec) + result + } + /** * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based * operation, so that we can track the arrow-based queries on the UI tab. From 81886f01c0bf19863509ae0c97f6d344752a4115 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 11:26:06 +0800 Subject: [PATCH 13/28] address comment --- .../spark/sql/execution/arrow/ArrowConvertersHelper.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala index fc819410dca..0a20189c778 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala @@ -103,7 +103,7 @@ object ArrowConvertersHelper extends Logging { val writeChannel = new WriteChannel(Channels.newChannel(out)) rowCountInLastBatch = 0 - var estimatedBatchSize = 0 + var estimatedBatchSize = 0L Utils.tryWithSafeFinally { // Always write the first row. From e3bf84c0385a354c6c5f8d197b52c3a594700edb Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 12:24:20 +0800 Subject: [PATCH 14/28] refactor --- .../execution/arrow/ArrowCollectUtils.scala | 133 -------- .../arrow/ArrowConvertersHelper.scala | 155 --------- .../arrow/KyuubiArrowConverters.scala | 306 ++++++++++++++++++ .../execution/arrow/KyuubiArrowUtils.scala | 67 ---- .../spark/sql/kyuubi/SparkDatasetHelper.scala | 90 +++--- .../SparkArrowbasedOperationSuite.scala | 6 +- 6 files changed, 354 insertions(+), 403 deletions(-) delete mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala delete mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala create mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala delete mode 100644 externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala deleted file mode 100644 index 86fc28eec71..00000000000 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectUtils.scala +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.arrow - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} -import org.apache.spark.sql.execution.CollectLimitExec - -object ArrowCollectUtils extends SQLConfHelper { - - type Batch = (Array[Byte], Long) - - /** - * Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be - * summarized in the following steps: - * 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty - * array of batches. - * 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of - * data to collect. - * 3. Use an iterative approach to collect data in batches until the specified limit is reached. - * In each iteration, it selects a subset of the partitions of the RDD to scan and tries to - * collect data from them. - * 4. For each partition subset, we use the runJob method of the Spark context to execute a - * closure that scans the partition data and converts it to Arrow batches. - * 5. Check if the collected data reaches the specified limit. If not, it selects another subset - * of partitions to scan and repeats the process until the limit is reached or all partitions - * have been scanned. - * 6. Return an array of all the collected Arrow batches. - * - * Note that: - * 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit` - * row count - * 2. We don't implement the `takeFromEnd` logical - * - * @return - */ - def takeAsArrowBatches( - collectLimitExec: CollectLimitExec, - maxRecordsPerBatch: Long, - maxEstimatedBatchSize: Long, - timeZoneId: String): Array[Batch] = { - val n = collectLimitExec.limit - val schema = collectLimitExec.schema - if (n == 0) { - return new Array[Batch](0) - } else { - val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) - // TODO: refactor and reuse the code from RDD's take() - val childRDD = collectLimitExec.child.execute() - val buf = new ArrayBuffer[Batch] - var bufferedRowSize = 0L - val totalParts = childRDD.partitions.length - var partsScanned = 0 - while (bufferedRowSize < n && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = limitInitialNumPartitions - if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, multiply by - // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need - // to try, but overestimate it by 50%. We also cap the estimation in the end. - if (buf.isEmpty) { - numPartsToTry = partsScanned * limitScaleUpFactor - } else { - val left = n - bufferedRowSize - // As left > 0, numPartsToTry is always >= 1 - numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt - numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) - } - } - - val partsToScan = - partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) - - val sc = collectLimitExec.session.sparkContext - val res = sc.runJob( - childRDD, - (it: Iterator[InternalRow]) => { - val batches = ArrowConvertersHelper.toBatchIterator( - it, - schema, - maxRecordsPerBatch, - maxEstimatedBatchSize, - n, - timeZoneId) - batches.map(b => b -> batches.rowCountInLastBatch).toArray - }, - partsToScan) - - var i = 0 - while (bufferedRowSize < n && i < res.length) { - var j = 0 - val batches = res(i) - while (j < batches.length && n > bufferedRowSize) { - val batch = batches(j) - val (_, batchSize) = batch - buf += batch - bufferedRowSize += batchSize - j += 1 - } - i += 1 - } - partsScanned += partsToScan.size - } - - buf.toArray - } - } - - /** - * Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211 - */ - def limitInitialNumPartitions: Int = { - conf.getConfString("spark.sql.limit.initialNumPartitions", "1") - .toInt - } -} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala deleted file mode 100644 index 0a20189c778..00000000000 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersHelper.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.arrow - -import java.io.ByteArrayOutputStream -import java.nio.channels.Channels - -import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} -import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} -import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.Utils - -object ArrowConvertersHelper extends Logging { - - /** - * Different from [[org.apache.spark.sql.execution.arrow.ArrowConvertersHelper.toBatchIterator]], - * each output arrow batch contains this batch row count. - */ - def toBatchIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - maxEstimatedBatchSize: Long, - limit: Long, - timeZoneId: String): ArrowBatchIterator = { - new ArrowBatchIterator( - rowIter, - schema, - maxRecordsPerBatch, - maxEstimatedBatchSize, - limit, - timeZoneId, - TaskContext.get) - } - - /** - * This class ArrowBatchIterator is derived from - * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]], - * with two key differences: - * 1. there is no requirement to write the schema at the batch header - * 2. iteration halts when `rowCount` equals `limit` - */ - private[sql] class ArrowBatchIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - maxEstimatedBatchSize: Long, - limit: Long, - timeZoneId: String, - context: TaskContext) - extends Iterator[Array[Byte]] { - - protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - private val allocator = - ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", - 0, - Long.MaxValue) - - private val root = VectorSchemaRoot.create(arrowSchema, allocator) - protected val unloader = new VectorUnloader(root) - protected val arrowWriter = ArrowWriter.create(root) - - Option(context).foreach { - _.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - } - } - - override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { - root.close() - allocator.close() - false - } - - var rowCountInLastBatch: Long = 0 - var rowCount: Long = 0 - - override def next(): Array[Byte] = { - val out = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(out)) - - rowCountInLastBatch = 0 - var estimatedBatchSize = 0L - Utils.tryWithSafeFinally { - - // Always write the first row. - while (rowIter.hasNext && ( - // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. - // If the size in bytes is positive (set properly), always write the first row. - rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || - // If the size in bytes of rows are 0 or negative, unlimit it. - estimatedBatchSize <= 0 || - estimatedBatchSize < maxEstimatedBatchSize || - // If the size of rows are 0 or negative, unlimit it. - maxRecordsPerBatch <= 0 || - rowCountInLastBatch < maxRecordsPerBatch || - rowCount < limit)) { - val row = rowIter.next() - arrowWriter.write(row) - estimatedBatchSize += (row match { - case ur: UnsafeRow => ur.getSizeInBytes - // Trying to estimate the size of the current row - case _: InternalRow => schema.defaultSize - }) - rowCountInLastBatch += 1 - rowCount += 1 - } - arrowWriter.finish() - val batch = unloader.getRecordBatch() - MessageSerializer.serialize(writeChannel, batch) - - // Always write the Ipc options at the end. - ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) - - batch.close() - } { - arrowWriter.reset() - } - - out.toByteArray - } - } - - // for testing - def fromBatchIterator( - arrowBatchIter: Iterator[Array[Byte]], - schema: StructType, - timeZoneId: String, - context: TaskContext): Iterator[InternalRow] = { - ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context) - } -} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala new file mode 100644 index 00000000000..e3cd5770127 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.channels.Channels + +import scala.collection.mutable.ArrayBuffer + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.CollectLimitExec +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils + +object KyuubiArrowConverters extends SQLConfHelper with Logging { + + type Batch = (Array[Byte], Long) + + private val rootAllocator = ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + + /** + * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]], + * each output arrow batch contains this batch row count. + */ + def toBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String): ArrowBatchIterator = { + new ArrowBatchIterator( + rowIter, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + limit, + timeZoneId, + TaskContext.get) + } + + /** + * This class ArrowBatchIterator is derived from + * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]], + * with two key differences: + * 1. there is no requirement to write the schema at the batch header + * 2. iteration halts when `rowCount` equals `limit` + */ + private[sql] class ArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String, + context: TaskContext) + extends Iterator[Array[Byte]] { + + protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected val unloader = new VectorUnloader(root) + protected val arrowWriter = ArrowWriter.create(root) + + Option(context).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } + + override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { + root.close() + allocator.close() + false + } + + var rowCountInLastBatch: Long = 0 + var rowCount: Long = 0 + + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + rowCountInLastBatch = 0 + var estimatedBatchSize = 0L + Utils.tryWithSafeFinally { + + // Always write the first row. + while (rowIter.hasNext && ( + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch || + rowCount < limit)) { + val row = rowIter.next() + arrowWriter.write(row) + estimatedBatchSize += (row match { + case ur: UnsafeRow => ur.getSizeInBytes + // Trying to estimate the size of the current row + case _: InternalRow => schema.defaultSize + }) + rowCountInLastBatch += 1 + rowCount += 1 + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + out.toByteArray + } + } + + /** + * this method is to slice the input Arrow record batch byte array `bytes`, starting from `start` + * and taking `length` number of elements. + */ + def slice( + schema: StructType, + timeZoneId: String, + bytes: Array[Byte], + start: Int, + length: Int): Array[Byte] = { + val in = new ByteArrayInputStream(bytes) + val out = new ByteArrayOutputStream() + + try { + val recordBatch = MessageSerializer.deserializeRecordBatch( + new ReadChannel(Channels.newChannel(in)), + rootAllocator) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + + val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(recordBatch) + recordBatch.close() + + val unloader = new VectorUnloader(root.slice(start, length)) + val writeChannel = new WriteChannel(Channels.newChannel(out)) + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + batch.close() + out.toByteArray() + } finally { + in.close() + out.close() + } + } + + /** + * Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be + * summarized in the following steps: + * 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty + * array of batches. + * 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of + * data to collect. + * 3. Use an iterative approach to collect data in batches until the specified limit is reached. + * In each iteration, it selects a subset of the partitions of the RDD to scan and tries to + * collect data from them. + * 4. For each partition subset, we use the runJob method of the Spark context to execute a + * closure that scans the partition data and converts it to Arrow batches. + * 5. Check if the collected data reaches the specified limit. If not, it selects another subset + * of partitions to scan and repeats the process until the limit is reached or all partitions + * have been scanned. + * 6. Return an array of all the collected Arrow batches. + * + * Note that: + * 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit` + * row count + * 2. We don't implement the `takeFromEnd` logical + * + * @return + */ + def takeAsArrowBatches( + collectLimitExec: CollectLimitExec, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + timeZoneId: String): Array[Batch] = { + val n = collectLimitExec.limit + val schema = collectLimitExec.schema + if (n == 0) { + return new Array[Batch](0) + } else { + val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2) + // TODO: refactor and reuse the code from RDD's take() + val childRDD = collectLimitExec.child.execute() + val buf = new ArrayBuffer[Batch] + var bufferedRowSize = 0L + val totalParts = childRDD.partitions.length + var partsScanned = 0 + while (bufferedRowSize < n && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = limitInitialNumPartitions + if (partsScanned > 0) { + // If we didn't find any rows after the previous iteration, multiply by + // limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need + // to try, but overestimate it by 50%. We also cap the estimation in the end. + if (buf.isEmpty) { + numPartsToTry = partsScanned * limitScaleUpFactor + } else { + val left = n - bufferedRowSize + // As left > 0, numPartsToTry is always >= 1 + numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt + numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) + } + } + + val partsToScan = + partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) + + val sc = collectLimitExec.session.sparkContext + val res = sc.runJob( + childRDD, + (it: Iterator[InternalRow]) => { + val batches = KyuubiArrowConverters.toBatchIterator( + it, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + n, + timeZoneId) + batches.map(b => b -> batches.rowCountInLastBatch).toArray + }, + partsToScan) + + var i = 0 + while (bufferedRowSize < n && i < res.length) { + var j = 0 + val batches = res(i) + while (j < batches.length && n > bufferedRowSize) { + val batch = batches(j) + val (_, batchSize) = batch + buf += batch + bufferedRowSize += batchSize + j += 1 + } + i += 1 + } + partsScanned += partsToScan.size + } + + buf.toArray + } + } + + /** + * Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211 + */ + private def limitInitialNumPartitions: Int = { + conf.getConfString("spark.sql.limit.initialNumPartitions", "1") + .toInt + } + + // for testing + def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = { + ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context) + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala deleted file mode 100644 index ba94e255c9c..00000000000 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.arrow - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import java.nio.channels.Channels - -import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader} -import org.apache.arrow.vector.ipc.{ReadChannel, WriteChannel} -import org.apache.arrow.vector.ipc.message.MessageSerializer -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils - -object KyuubiArrowUtils { - - private val rootAllocator = - ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", - 0, - Long.MaxValue) - def slice( - schema: StructType, - timeZoneId: String, - bytes: Array[Byte], - start: Int, - length: Int): Array[Byte] = { - val in = new ByteArrayInputStream(bytes) - val out = new ByteArrayOutputStream() - - try { - val recordBatch = MessageSerializer.deserializeRecordBatch( - new ReadChannel(Channels.newChannel(in)), - rootAllocator) - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - - val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) - val vectorLoader = new VectorLoader(root) - vectorLoader.load(recordBatch) - recordBatch.close() - - val unloader = new VectorUnloader(root.slice(start, length)) - val writeChannel = new WriteChannel(Channels.newChannel(out)) - val batch = unloader.getRecordBatch() - MessageSerializer.serialize(writeChannel, batch) - batch.close() - out.toByteArray() - } finally { - in.close() - out.close() - } - } -} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index ba55812e161..038444a87ee 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.arrow.{ArrowCollectUtils, ArrowConverters, KyuubiArrowUtils} +import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -35,6 +35,19 @@ import org.apache.kyuubi.reflection.DynMethods object SparkDatasetHelper { + def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { + executeArrowBatchCollect(df.queryExecution.executedPlan) + } + + def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = { + case adaptiveSparkPlan: AdaptiveSparkPlanExec => + executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan)) + case collectLimit: CollectLimitExec => + doCollectLimit(collectLimit) + case plan: SparkPlan => + toArrowBatchRdd(plan).collect() + } + def toArrowBatchRdd[T](ds: Dataset[T]): RDD[Array[Byte]] = { ds.toArrowBatchRdd } @@ -58,55 +71,12 @@ object SparkDatasetHelper { } } - def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = { - val timeZoneId = collectLimit.session.sessionState.conf.sessionLocalTimeZone - val maxRecordsPerBatch = collectLimit.session.sessionState.conf.arrowMaxRecordsPerBatch - - val batches = ArrowCollectUtils.takeAsArrowBatches( - collectLimit, - maxRecordsPerBatch, - maxBatchSize, - timeZoneId) - - // note that the number of rows in the returned arrow batches may be >= `limit`, perform - // the slicing operation of result - val result = ArrayBuffer[Array[Byte]]() - var i = 0 - var rest = collectLimit.limit - while (i < batches.length && rest > 0) { - val (batch, size) = batches(i) - if (size <= rest) { - result += batch - // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion - rest -= size.toInt - } else { // size > rest - result += KyuubiArrowUtils.slice(collectLimit.schema, timeZoneId, batch, 0, rest) - rest = 0 - } - i += 1 - } - result.toArray - } - - def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { - executeArrowBatchCollect(df.queryExecution.executedPlan) - } - def toArrowBatchLocalIterator(df: DataFrame): Iterator[Array[Byte]] = { withNewExecutionId(df) { toArrowBatchRdd(df).toLocalIterator } } - def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = { - case adaptiveSparkPlan: AdaptiveSparkPlanExec => - executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan)) - case collectLimit: CollectLimitExec => - doCollectLimit(collectLimit) - case plan: SparkPlan => - toArrowBatchRdd(plan).collect() - } - def convertTopLevelComplexTypeToHiveString( df: DataFrame, timestampAsString: Boolean): DataFrame = { @@ -146,7 +116,7 @@ object SparkDatasetHelper { * Fork from Apache Spark-3.3.1 org.apache.spark.sql.catalyst.util.quoteIfNeeded to adapt to * Spark-3.1.x */ - def quoteIfNeeded(part: String): String = { + private def quoteIfNeeded(part: String): String = { if (part.matches("[a-zA-Z0-9_]+") && !part.matches("\\d+")) { part } else { @@ -164,6 +134,36 @@ object SparkDatasetHelper { .get } + private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = { + val timeZoneId = collectLimit.session.sessionState.conf.sessionLocalTimeZone + val maxRecordsPerBatch = collectLimit.session.sessionState.conf.arrowMaxRecordsPerBatch + + val batches = KyuubiArrowConverters.takeAsArrowBatches( + collectLimit, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + + // note that the number of rows in the returned arrow batches may be >= `limit`, perform + // the slicing operation of result + val result = ArrayBuffer[Array[Byte]]() + var i = 0 + var rest = collectLimit.limit + while (i < batches.length && rest > 0) { + val (batch, size) = batches(i) + if (size <= rest) { + result += batch + // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion + rest -= size.toInt + } else { // size > rest + result += KyuubiArrowConverters.slice(collectLimit.schema, timeZoneId, batch, 0, rest) + rest = 0 + } + i += 1 + } + result.toArray + } + /** * This method provides a reflection-based implementation of * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 75f1cc8e22f..3f8d7d7e663 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec -import org.apache.spark.sql.execution.arrow.ArrowConvertersHelper +import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.functions.col import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.util.QueryExecutionListener @@ -172,7 +172,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) - val rows = ArrowConvertersHelper.fromBatchIterator( + val rows = KyuubiArrowConverters.fromBatchIterator( arrowBinary.iterator, df.schema, "", @@ -193,7 +193,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp val plan = df.queryExecution.executedPlan assert(plan.isInstanceOf[CollectLimitExec]) val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) - val rows = ArrowConvertersHelper.fromBatchIterator( + val rows = KyuubiArrowConverters.fromBatchIterator( arrowBinary.iterator, df.schema, "", From d70aee36b921e3278fbab08a57ac700c2c101129 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 14:23:14 +0800 Subject: [PATCH 15/28] SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x --- .../execution/arrow/KyuubiArrowConverters.scala | 5 ++++- .../spark/sql/kyuubi/SparkDatasetHelper.scala | 14 +++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index e3cd5770127..9b9054cd974 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -27,6 +27,7 @@ import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.CollectLimitExec @@ -252,7 +253,9 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { val partsToScan = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) - val sc = collectLimitExec.session.sparkContext + // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we + // drop Spark-3.1.x support. + val sc = SparkSession.active.sparkContext val res = sc.runJob( childRDD, (it: Iterator[InternalRow]) => { diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 038444a87ee..b27a76783cb 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters} @@ -58,8 +58,10 @@ object SparkDatasetHelper { */ def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = plan.schema - val maxRecordsPerBatch = plan.session.sessionState.conf.arrowMaxRecordsPerBatch - val timeZoneId = plan.session.sessionState.conf.sessionLocalTimeZone + // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we + // drop Spark-3.1.x support. + val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toBatchIterator( @@ -135,8 +137,10 @@ object SparkDatasetHelper { } private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = { - val timeZoneId = collectLimit.session.sessionState.conf.sessionLocalTimeZone - val maxRecordsPerBatch = collectLimit.session.sessionState.conf.arrowMaxRecordsPerBatch + // TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we + // drop Spark-3.1.x support. + val timeZoneId = SparkSession.active.sessionState.conf.sessionLocalTimeZone + val maxRecordsPerBatch = SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch val batches = KyuubiArrowConverters.takeAsArrowBatches( collectLimit, From 4cef20481386fe24781c0c80cd85863a4fcfb978 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 15:10:55 +0800 Subject: [PATCH 16/28] SparkArrowbasedOperationSuite adapt Spark-3.1.x --- .../org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala | 2 +- .../spark/operation/SparkArrowbasedOperationSuite.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index b27a76783cb..5e84c6ae23d 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -176,7 +176,7 @@ object SparkDatasetHelper { * TODO: Once we drop support for Spark 3.1.x, we can directly call * [[AdaptiveSparkPlanExec.finalPhysicalPlan]]. */ - private def finalPhysicalPlan(adaptiveSparkPlanExec: AdaptiveSparkPlanExec): SparkPlan = { + def finalPhysicalPlan(adaptiveSparkPlanExec: AdaptiveSparkPlanExec): SparkPlan = { withFinalPlanUpdate(adaptiveSparkPlanExec, identity) } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 3f8d7d7e663..f0a030e3807 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -166,9 +166,8 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp val df = spark.sql(s"select * from t_1 limit $size") val headPlan = df.queryExecution.executedPlan.collectLeaves().head assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) - assert( - headPlan.asInstanceOf[AdaptiveSparkPlanExec].finalPhysicalPlan.isInstanceOf[ - CollectLimitExec]) + val finalPhysicalPlan = SparkDatasetHelper.finalPhysicalPlan(AdaptiveSparkPlanExec) + assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) From 573a262ed95e7fd029e6186bd6fff15470672933 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 15:13:31 +0800 Subject: [PATCH 17/28] fix --- .../engine/spark/operation/SparkArrowbasedOperationSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index f0a030e3807..e1daea5ddd2 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -166,7 +166,8 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp val df = spark.sql(s"select * from t_1 limit $size") val headPlan = df.queryExecution.executedPlan.collectLeaves().head assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) - val finalPhysicalPlan = SparkDatasetHelper.finalPhysicalPlan(AdaptiveSparkPlanExec) + val finalPhysicalPlan = + SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) From c83cf3f5e486c64f3a145d71fbea527fb3dc72e9 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 16:08:35 +0800 Subject: [PATCH 18/28] SparkArrowbasedOperationSuite adapt Spark-3.1.x --- .../operation/SparkArrowbasedOperationSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index e1daea5ddd2..08ba7c8de19 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -165,10 +165,12 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp returnSize.foreach { size => val df = spark.sql(s"select * from t_1 limit $size") val headPlan = df.queryExecution.executedPlan.collectLeaves().head - assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) - val finalPhysicalPlan = - SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) - assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") { + assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) + val finalPhysicalPlan = + SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) + assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) + } val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) From 9ffb44fb20c8e8a2bce8656f9c17ad41d1214d8c Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 18:25:00 +0800 Subject: [PATCH 19/28] make toBatchIterator private --- .../arrow/KyuubiArrowConverters.scala | 224 +++++++++--------- 1 file changed, 112 insertions(+), 112 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index 9b9054cd974..6e1ac5c2aaf 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -44,117 +44,6 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { 0, Long.MaxValue) - /** - * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]], - * each output arrow batch contains this batch row count. - */ - def toBatchIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - maxEstimatedBatchSize: Long, - limit: Long, - timeZoneId: String): ArrowBatchIterator = { - new ArrowBatchIterator( - rowIter, - schema, - maxRecordsPerBatch, - maxEstimatedBatchSize, - limit, - timeZoneId, - TaskContext.get) - } - - /** - * This class ArrowBatchIterator is derived from - * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]], - * with two key differences: - * 1. there is no requirement to write the schema at the batch header - * 2. iteration halts when `rowCount` equals `limit` - */ - private[sql] class ArrowBatchIterator( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - maxEstimatedBatchSize: Long, - limit: Long, - timeZoneId: String, - context: TaskContext) - extends Iterator[Array[Byte]] { - - protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - private val allocator = - ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", - 0, - Long.MaxValue) - - private val root = VectorSchemaRoot.create(arrowSchema, allocator) - protected val unloader = new VectorUnloader(root) - protected val arrowWriter = ArrowWriter.create(root) - - Option(context).foreach { - _.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - } - } - - override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { - root.close() - allocator.close() - false - } - - var rowCountInLastBatch: Long = 0 - var rowCount: Long = 0 - - override def next(): Array[Byte] = { - val out = new ByteArrayOutputStream() - val writeChannel = new WriteChannel(Channels.newChannel(out)) - - rowCountInLastBatch = 0 - var estimatedBatchSize = 0L - Utils.tryWithSafeFinally { - - // Always write the first row. - while (rowIter.hasNext && ( - // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. - // If the size in bytes is positive (set properly), always write the first row. - rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || - // If the size in bytes of rows are 0 or negative, unlimit it. - estimatedBatchSize <= 0 || - estimatedBatchSize < maxEstimatedBatchSize || - // If the size of rows are 0 or negative, unlimit it. - maxRecordsPerBatch <= 0 || - rowCountInLastBatch < maxRecordsPerBatch || - rowCount < limit)) { - val row = rowIter.next() - arrowWriter.write(row) - estimatedBatchSize += (row match { - case ur: UnsafeRow => ur.getSizeInBytes - // Trying to estimate the size of the current row - case _: InternalRow => schema.defaultSize - }) - rowCountInLastBatch += 1 - rowCount += 1 - } - arrowWriter.finish() - val batch = unloader.getRecordBatch() - MessageSerializer.serialize(writeChannel, batch) - - // Always write the Ipc options at the end. - ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) - - batch.close() - } { - arrowWriter.reset() - } - - out.toByteArray - } - } - /** * this method is to slice the input Arrow record batch byte array `bytes`, starting from `start` * and taking `length` number of elements. @@ -259,7 +148,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { val res = sc.runJob( childRDD, (it: Iterator[InternalRow]) => { - val batches = KyuubiArrowConverters.toBatchIterator( + val batches = toBatchIterator( it, schema, maxRecordsPerBatch, @@ -298,6 +187,117 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { .toInt } + /** + * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]], + * each output arrow batch contains this batch row count. + */ + private def toBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String): ArrowBatchIterator = { + new ArrowBatchIterator( + rowIter, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + limit, + timeZoneId, + TaskContext.get) + } + + /** + * This class ArrowBatchIterator is derived from + * [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]], + * with two key differences: + * 1. there is no requirement to write the schema at the batch header + * 2. iteration halts when `rowCount` equals `limit` + */ + private[sql] class ArrowBatchIterator( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + maxEstimatedBatchSize: Long, + limit: Long, + timeZoneId: String, + context: TaskContext) + extends Iterator[Array[Byte]] { + + protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + + private val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected val unloader = new VectorUnloader(root) + protected val arrowWriter = ArrowWriter.create(root) + + Option(context).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } + + override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { + root.close() + allocator.close() + false + } + + var rowCountInLastBatch: Long = 0 + var rowCount: Long = 0 + + override def next(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val writeChannel = new WriteChannel(Channels.newChannel(out)) + + rowCountInLastBatch = 0 + var estimatedBatchSize = 0L + Utils.tryWithSafeFinally { + + // Always write the first row. + while (rowIter.hasNext && ( + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch || + rowCount < limit)) { + val row = rowIter.next() + arrowWriter.write(row) + estimatedBatchSize += (row match { + case ur: UnsafeRow => ur.getSizeInBytes + // Trying to estimate the size of the current row + case _: InternalRow => schema.defaultSize + }) + rowCountInLastBatch += 1 + rowCount += 1 + } + arrowWriter.finish() + val batch = unloader.getRecordBatch() + MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. + ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) + + batch.close() + } { + arrowWriter.reset() + } + + out.toByteArray + } + } + // for testing def fromBatchIterator( arrowBatchIter: Iterator[Array[Byte]], From b72bc6fb2da63bd91b12ab7f237a848b37b5b1ce Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 6 Apr 2023 18:58:09 +0800 Subject: [PATCH 20/28] add offset support to adapt Spark-3.4.x --- .../spark/sql/kyuubi/SparkDatasetHelper.scala | 18 +++++++++++++++-- .../SparkArrowbasedOperationSuite.scala | 20 +++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 5e84c6ae23d..34be1d36174 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -42,7 +42,8 @@ object SparkDatasetHelper { def executeArrowBatchCollect: SparkPlan => Array[Array[Byte]] = { case adaptiveSparkPlan: AdaptiveSparkPlanExec => executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan)) - case collectLimit: CollectLimitExec => + // TODO: avoid extra shuffle if `offset` > 0 + case collectLimit: CollectLimitExec if offset(collectLimit) <= 0 => doCollectLimit(collectLimit) case plan: SparkPlan => toArrowBatchRdd(plan).collect() @@ -193,11 +194,24 @@ object SparkDatasetHelper { val result = fun(plan) val finalPlanUpdate = DynMethods.builder("finalPlanUpdate") .hiddenImpl(adaptiveSparkPlanExec.getClass) - .build(adaptiveSparkPlanExec) + .build() finalPlanUpdate.invoke[Unit](adaptiveSparkPlanExec) result } + /** + * offset support was add since Spark-3.4(set SPARK-28330), to ensure backward compatibility with + * earlier versions of Spark, this function uses reflective calls to the "offset". + */ + private def offset(collectLimitExec: CollectLimitExec): Int = { + val offset = DynMethods.builder("offset") + .impl(collectLimitExec.getClass) + .orNoop() + .build() + Option(offset.invoke[Int](collectLimitExec)) + .getOrElse(0) + } + /** * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based * operation, so that we can track the arrow-based queries on the UI tab. diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 08ba7c8de19..ef1a26e8a0e 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -204,6 +204,26 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp } } + test("result offset support") { + assume(SPARK_ENGINE_RUNTIME_VERSION > "3.3") + var numStages = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numStages = jobStart.stageInfos.length + } + } + withJdbcStatement() { statement => + withSparkListener(listener) { + withPartitionedTable("t_3") { + statement.executeQuery("select * from t_3 limit 10 offset 10") + } + KyuubiSparkContextHelper.waitListenerBus(spark) + } + } + // the extra shuffle be introduced if the `offset` > 0 + assert(numStages == 2) + } + test("arrow serialization should not introduce extra shuffle for outermost limit") { var numStages = 0 val listener = new SparkListener { From 22cc70fbae969d8a1d1d5afaf7d3a4d7301cc9e0 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 01:57:40 +0800 Subject: [PATCH 21/28] add ut --- externals/kyuubi-spark-sql-engine/pom.xml | 7 + .../spark/sql/kyuubi/SparkDatasetHelper.scala | 10 +- .../SparkArrowbasedOperationSuite.scala | 205 +++++++++++++----- 3 files changed, 171 insertions(+), 51 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml index 5b227cb5e29..8c984e4cab4 100644 --- a/externals/kyuubi-spark-sql-engine/pom.xml +++ b/externals/kyuubi-spark-sql-engine/pom.xml @@ -65,6 +65,13 @@ provided + + org.apache.spark + spark-sql_${scala.binary.version} + test-jar + test + + org.apache.spark spark-repl_${scala.binary.version} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 34be1d36174..7409e781d82 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kyuubi import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} @@ -33,7 +34,7 @@ import org.apache.kyuubi.engine.spark.KyuubiSparkUtil import org.apache.kyuubi.engine.spark.schema.RowSet import org.apache.kyuubi.reflection.DynMethods -object SparkDatasetHelper { +object SparkDatasetHelper extends Logging { def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) { executeArrowBatchCollect(df.queryExecution.executedPlan) @@ -43,8 +44,13 @@ object SparkDatasetHelper { case adaptiveSparkPlan: AdaptiveSparkPlanExec => executeArrowBatchCollect(finalPhysicalPlan(adaptiveSparkPlan)) // TODO: avoid extra shuffle if `offset` > 0 - case collectLimit: CollectLimitExec if offset(collectLimit) <= 0 => + case collectLimit: CollectLimitExec if offset(collectLimit) > 0 => + logWarning("unsupported offset > 0, an extra shuffle will be introduced.") + toArrowBatchRdd(collectLimit).collect() + case collectLimit: CollectLimitExec if collectLimit.limit >= 0 => doCollectLimit(collectLimit) + case collectLimit: CollectLimitExec if collectLimit.limit < 0 => + executeArrowBatchCollect(collectLimit.child) case plan: SparkPlan => toArrowBatchRdd(plan).collect() } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index ef1a26e8a0e..fcc432c3e4d 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -21,15 +21,19 @@ import java.sql.Statement import org.apache.spark.KyuubiSparkContextHelper import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution} +import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kyuubi.SparkDatasetHelper import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.kyuubi.KyuubiException import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine} import org.apache.kyuubi.engine.spark.session.SparkSessionImpl @@ -150,57 +154,111 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp 10, // equal to one partition 13, // between one and two partitions, run two jobs 20, // equal to two partitions - 29) // between two and three partitions - - // aqe - // outermost AdaptiveSparkPlanExec - spark.range(1000) - .repartitionByRange(100, col("id")) - .createOrReplaceTempView("t_1") - spark.sql("select * from t_1") - .foreachPartition { p: Iterator[Row] => - assert(p.length == 10) - () - } - returnSize.foreach { size => - val df = spark.sql(s"select * from t_1 limit $size") - val headPlan = df.queryExecution.executedPlan.collectLeaves().head - if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") { - assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) - val finalPhysicalPlan = - SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) - assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) - } + 29, // between two and three partitions + 1000, // all partitions + 1001) // more than total row count +// -1) // all + + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.EliminateLimits", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.EliminateLimits") { + // aqe + // outermost AdaptiveSparkPlanExec + spark.range(1000) + .repartitionByRange(100, col("id")) + .createOrReplaceTempView("t_1") + spark.sql("select * from t_1") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () + } + returnSize.foreach { size => + val df = spark.sql(s"select * from t_1 limit $size") + val headPlan = df.queryExecution.executedPlan.collectLeaves().head + if (SPARK_ENGINE_RUNTIME_VERSION >= "3.2") { + assert(headPlan.isInstanceOf[AdaptiveSparkPlanExec]) + val finalPhysicalPlan = + SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) + assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) + } - val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) + val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution + .executedPlan) + + val rows = KyuubiArrowConverters.fromBatchIterator( + arrowBinary.iterator, + df.schema, + "", + KyuubiSparkContextHelper.dummyTaskContext()) + if (size > 1000) { + assert(rows.size == 1000) + } else { + assert(rows.size == size) + } + } - val rows = KyuubiArrowConverters.fromBatchIterator( - arrowBinary.iterator, - df.schema, - "", - KyuubiSparkContextHelper.dummyTaskContext()) - assert(rows.size == size) + // outermost CollectLimitExec + spark.range(0, 1000, 1, numPartitions = 100) + .createOrReplaceTempView("t_2") + spark.sql("select * from t_2") + .foreachPartition { p: Iterator[Row] => + assert(p.length == 10) + () + } + returnSize.foreach { size => + val df = spark.sql(s"select * from t_2 limit $size") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[CollectLimitExec]) + val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution + .executedPlan) + val rows = KyuubiArrowConverters.fromBatchIterator( + arrowBinary.iterator, + df.schema, + "", + KyuubiSparkContextHelper.dummyTaskContext()) + if (size > 1000) { + assert(rows.size == 1000) + } else { + assert(rows.size == size) + } + } } + } - // outermost CollectLimitExec - spark.range(0, 1000, 1, numPartitions = 100) - .createOrReplaceTempView("t_2") - spark.sql("select * from t_2") - .foreachPartition { p: Iterator[Row] => - assert(p.length == 10) - () - } - returnSize.foreach { size => - val df = spark.sql(s"select * from t_2 limit $size") - val plan = df.queryExecution.executedPlan - assert(plan.isInstanceOf[CollectLimitExec]) - val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution.executedPlan) - val rows = KyuubiArrowConverters.fromBatchIterator( - arrowBinary.iterator, - df.schema, - "", - KyuubiSparkContextHelper.dummyTaskContext()) - assert(rows.size == size) + test("aqe should work properly") { + + val s = spark + import s.implicits._ + + spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))).toDF() + .createOrReplaceTempView("testData") + spark.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil, + 2).toDF() + .createOrReplaceTempView("testData2") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM( + | SELECT * FROM testData join testData2 ON key = a where value = '1' + |) LIMIT 1 + |""".stripMargin) + val smj = plan.collect { case smj: SortMergeJoinExec => smj } + val bhj = adaptivePlan.collect { case bhj: BroadcastHashJoinExec => bhj } + assert(smj.size == 1) + assert(bhj.size == 1) } } @@ -315,4 +373,53 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp .map(_.asInstanceOf[SparkSessionImpl].spark) .foreach(op(_)) } + + private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + val dfAdaptive = spark.sql(query) + val planBefore = dfAdaptive.queryExecution.executedPlan + val result = dfAdaptive.collect() + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.sql(query) + QueryTest.checkAnswer(df, df.collect().toSeq) + } + val planAfter = dfAdaptive.queryExecution.executedPlan + val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val exchanges = adaptivePlan.collect { + case e: Exchange => e + } + assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") + (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) + } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.isStaticConfigKey(k)) { + throw new KyuubiException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f + finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } } + +case class TestData(key: Int, value: String) +case class TestData2(a: Int, b: Int) From 8280783c31ef36107ab44475724152dc82a7d91e Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 10:09:14 +0800 Subject: [PATCH 22/28] add `isStaticConfigKey` to adapt Spark-3.1.x --- .../SparkArrowbasedOperationSuite.scala | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index fcc432c3e4d..2257d6877b1 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -38,6 +38,7 @@ import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine} import org.apache.kyuubi.engine.spark.session.SparkSessionImpl import org.apache.kyuubi.operation.SparkDataTypeTests +import org.apache.kyuubi.reflection.DynFields class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests { @@ -406,7 +407,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp } } (keys, values).zipped.foreach { (k, v) => - if (SQLConf.isStaticConfigKey(k)) { + if (isStaticConfigKey(k)) { throw new KyuubiException(s"Cannot modify the value of a static config: $k") } conf.setConfString(k, v) @@ -419,6 +420,21 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp } } } + + /** + * This method provides a reflection-based implementation of [[SQLConf.isStaticConfigKey]] to + * adapt Spark-3.1.x + * + * TODO: Once we drop support for Spark 3.1.x, we can directly call + * [[SQLConf.isStaticConfigKey()]]. + */ + private def isStaticConfigKey(key: String): Boolean = { + val staticConfKeys = DynFields.builder() + .hiddenImpl(SQLConf.getClass, "staticConfKeys") + .build[java.util.Set[String]](SQLConf) + .get() + staticConfKeys.contains(key) + } } case class TestData(key: Int, value: String) From 6d596fcce46ff4db0bbb4599acb7b940a62daaf9 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 10:46:51 +0800 Subject: [PATCH 23/28] address comment --- .../sql/execution/arrow/KyuubiArrowConverters.scala | 2 +- .../apache/spark/sql/kyuubi/SparkDatasetHelper.scala | 11 ++++++----- .../operation/SparkArrowbasedOperationSuite.scala | 3 ++- pom.xml | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index 6e1ac5c2aaf..75e6431b1bc 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -55,7 +55,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { start: Int, length: Int): Array[Byte] = { val in = new ByteArrayInputStream(bytes) - val out = new ByteArrayOutputStream() + val out = new ByteArrayOutputStream(bytes.length) try { val recordBatch = MessageSerializer.deserializeRecordBatch( diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 7409e781d82..1c8d32c4850 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -210,11 +210,12 @@ object SparkDatasetHelper extends Logging { * earlier versions of Spark, this function uses reflective calls to the "offset". */ private def offset(collectLimitExec: CollectLimitExec): Int = { - val offset = DynMethods.builder("offset") - .impl(collectLimitExec.getClass) - .orNoop() - .build() - Option(offset.invoke[Int](collectLimitExec)) + Option( + DynMethods.builder("offset") + .impl(collectLimitExec.getClass) + .orNoop() + .build() + .invoke[Int](collectLimitExec)) .getOrElse(0) } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 2257d6877b1..a67aeadc808 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -18,6 +18,7 @@ package org.apache.kyuubi.engine.spark.operation import java.sql.Statement +import java.util.{Set => JSet} import org.apache.spark.KyuubiSparkContextHelper import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} @@ -431,7 +432,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp private def isStaticConfigKey(key: String): Boolean = { val staticConfKeys = DynFields.builder() .hiddenImpl(SQLConf.getClass, "staticConfKeys") - .build[java.util.Set[String]](SQLConf) + .build[JSet[String]](SQLConf) .get() staticConfKeys.contains(key) } diff --git a/pom.xml b/pom.xml index b2b0341e2e9..e77e6d55d24 100644 --- a/pom.xml +++ b/pom.xml @@ -538,8 +538,8 @@ hadoop-client From 6064ab961a6e7da4576b1cf206a9e27b6e0840e7 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 15:26:01 +0800 Subject: [PATCH 24/28] limit = 0 test case --- .../SparkArrowbasedOperationSuite.scala | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index a67aeadc808..4ea3429b240 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -152,6 +152,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp test("SparkDatasetHelper.executeArrowBatchCollect should return expect row count") { val returnSize = Seq( + 0, // spark optimizer guaranty the `limit != 0`, it's just for the sanity check 7, // less than one partition 10, // equal to one partition 13, // between one and two partitions, run two jobs @@ -159,13 +160,23 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp 29, // between two and three partitions 1000, // all partitions 1001) // more than total row count -// -1) // all + def runAndCheck(sparkPlan: SparkPlan, expectSize: Int): Unit = { + val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(sparkPlan) + val rows = KyuubiArrowConverters.fromBatchIterator( + arrowBinary.iterator, + sparkPlan.schema, + "", + KyuubiSparkContextHelper.dummyTaskContext()) + assert(rows.size == expectSize) + } + + val excludedRules = Seq( + "org.apache.spark.sql.catalyst.optimizer.EliminateLimits", + "org.apache.spark.sql.execution.adaptive.AQEPropagateEmptyRelation").mkString(",") withSQLConf( - SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> - "org.apache.spark.sql.catalyst.optimizer.EliminateLimits", - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> - "org.apache.spark.sql.catalyst.optimizer.EliminateLimits") { + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules, + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) { // aqe // outermost AdaptiveSparkPlanExec spark.range(1000) @@ -185,19 +196,10 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp SparkDatasetHelper.finalPhysicalPlan(headPlan.asInstanceOf[AdaptiveSparkPlanExec]) assert(finalPhysicalPlan.isInstanceOf[CollectLimitExec]) } - - val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution - .executedPlan) - - val rows = KyuubiArrowConverters.fromBatchIterator( - arrowBinary.iterator, - df.schema, - "", - KyuubiSparkContextHelper.dummyTaskContext()) if (size > 1000) { - assert(rows.size == 1000) + runAndCheck(df.queryExecution.executedPlan, 1000) } else { - assert(rows.size == size) + runAndCheck(df.queryExecution.executedPlan, size) } } @@ -221,9 +223,9 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp "", KyuubiSparkContextHelper.dummyTaskContext()) if (size > 1000) { - assert(rows.size == 1000) + runAndCheck(df.queryExecution.executedPlan, 1000) } else { - assert(rows.size == size) + runAndCheck(df.queryExecution.executedPlan, size) } } } From 3700839109aa3be791fdea947ce43463dedad8ab Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 15:48:08 +0800 Subject: [PATCH 25/28] SparkArrowbasedOperationSuite adapt Spark-3.1.x --- .../operation/SparkArrowbasedOperationSuite.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 4ea3429b240..df3d7518bbd 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -214,14 +214,10 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp returnSize.foreach { size => val df = spark.sql(s"select * from t_2 limit $size") val plan = df.queryExecution.executedPlan - assert(plan.isInstanceOf[CollectLimitExec]) - val arrowBinary = SparkDatasetHelper.executeArrowBatchCollect(df.queryExecution - .executedPlan) - val rows = KyuubiArrowConverters.fromBatchIterator( - arrowBinary.iterator, - df.schema, - "", - KyuubiSparkContextHelper.dummyTaskContext()) + // rule PropagateEmptyRelation can't be excluded in the Spark-3.1.x, skipped. + if (!(SPARK_ENGINE_RUNTIME_VERSION < "3.2" && size == 0)) { + assert(plan.isInstanceOf[CollectLimitExec]) + } if (size > 1000) { runAndCheck(df.queryExecution.executedPlan, 1000) } else { From facc13f788bd959005628474ae206be5781f6765 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 16:59:12 +0800 Subject: [PATCH 26/28] exclude rule OptimizeLimitZero --- .../spark/operation/SparkArrowbasedOperationSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index df3d7518bbd..2ef29b398a3 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -173,6 +173,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp val excludedRules = Seq( "org.apache.spark.sql.catalyst.optimizer.EliminateLimits", + "org.apache.spark.sql.catalyst.optimizer.OptimizeLimitZero", "org.apache.spark.sql.execution.adaptive.AQEPropagateEmptyRelation").mkString(",") withSQLConf( SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules, @@ -214,10 +215,7 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp returnSize.foreach { size => val df = spark.sql(s"select * from t_2 limit $size") val plan = df.queryExecution.executedPlan - // rule PropagateEmptyRelation can't be excluded in the Spark-3.1.x, skipped. - if (!(SPARK_ENGINE_RUNTIME_VERSION < "3.2" && size == 0)) { - assert(plan.isInstanceOf[CollectLimitExec]) - } + assert(plan.isInstanceOf[CollectLimitExec]) if (size > 1000) { runAndCheck(df.queryExecution.executedPlan, 1000) } else { From 130bcb141e58bc7618786c39b9a383ba70a410db Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Fri, 7 Apr 2023 21:02:32 +0800 Subject: [PATCH 27/28] finally close --- .../execution/arrow/KyuubiArrowConverters.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index 75e6431b1bc..9209a31562b 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -39,11 +39,6 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { type Batch = (Array[Byte], Long) - private val rootAllocator = ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", - 0, - Long.MaxValue) - /** * this method is to slice the input Arrow record batch byte array `bytes`, starting from `start` * and taking `length` number of elements. @@ -57,13 +52,16 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { val in = new ByteArrayInputStream(bytes) val out = new ByteArrayOutputStream(bytes.length) + val rootAllocator = ArrowUtils.rootAllocator.newChildAllocator( + s"slice", + 0, + Long.MaxValue) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) try { val recordBatch = MessageSerializer.deserializeRecordBatch( new ReadChannel(Channels.newChannel(in)), rootAllocator) - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - - val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) val vectorLoader = new VectorLoader(root) vectorLoader.load(recordBatch) recordBatch.close() @@ -77,6 +75,8 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { } finally { in.close() out.close() + root.close() + rootAllocator.close() } } From 82c912ed6d67fd98f70683f6f33363caa0fb4d8c Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Sat, 8 Apr 2023 13:01:19 +0800 Subject: [PATCH 28/28] close vector --- .../arrow/KyuubiArrowConverters.scala | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index 9209a31562b..dd6163ec97c 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.channels.Channels +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.arrow.vector._ @@ -52,21 +53,25 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { val in = new ByteArrayInputStream(bytes) val out = new ByteArrayOutputStream(bytes.length) - val rootAllocator = ArrowUtils.rootAllocator.newChildAllocator( - s"slice", + var vectorSchemaRoot: VectorSchemaRoot = null + var slicedVectorSchemaRoot: VectorSchemaRoot = null + + val sliceAllocator = ArrowUtils.rootAllocator.newChildAllocator( + "slice", 0, Long.MaxValue) val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val root = VectorSchemaRoot.create(arrowSchema, rootAllocator) + vectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, sliceAllocator) try { val recordBatch = MessageSerializer.deserializeRecordBatch( new ReadChannel(Channels.newChannel(in)), - rootAllocator) - val vectorLoader = new VectorLoader(root) + sliceAllocator) + val vectorLoader = new VectorLoader(vectorSchemaRoot) vectorLoader.load(recordBatch) recordBatch.close() + slicedVectorSchemaRoot = vectorSchemaRoot.slice(start, length) - val unloader = new VectorUnloader(root.slice(start, length)) + val unloader = new VectorUnloader(slicedVectorSchemaRoot) val writeChannel = new WriteChannel(Channels.newChannel(out)) val batch = unloader.getRecordBatch() MessageSerializer.serialize(writeChannel, batch) @@ -75,8 +80,15 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { } finally { in.close() out.close() - root.close() - rootAllocator.close() + if (vectorSchemaRoot != null) { + vectorSchemaRoot.getFieldVectors.asScala.foreach(_.close()) + vectorSchemaRoot.close() + } + if (slicedVectorSchemaRoot != null) { + slicedVectorSchemaRoot.getFieldVectors.asScala.foreach(_.close()) + slicedVectorSchemaRoot.close() + } + sliceAllocator.close() } }