Skip to content

Commit 1a65125

Browse files
cfmcgradyulysses-you
authored andcommitted
[KYUUBI #4662] [ARROW] Arrow serialization should not introduce extra shuffle for outermost limit
### _Why are the changes needed?_ The fundamental concept is to execute a job similar to the way in which `CollectLimitExec.executeCollect()` operates. ```sql select * from parquet.`parquet/tpcds/sf1000/catalog_sales` limit 20; ``` Before this PR: ![截屏2023-04-04 下午3 20 34](https://user-images.githubusercontent.com/8537877/229717946-87c480c6-9550-4d00-bc96-14d59d7ce9f7.png) ![截屏2023-04-04 下午3 20 54](https://user-images.githubusercontent.com/8537877/229717973-bf6da5af-74e7-422a-b9fa-8b7bebd43320.png) After this PR: ![截屏2023-04-04 下午3 17 05](https://user-images.githubusercontent.com/8537877/229718016-6218d019-b223-4deb-b596-6f0431d33d2a.png) ![截屏2023-04-04 下午3 17 16](https://user-images.githubusercontent.com/8537877/229718046-ea07cd1f-5ffc-42ba-87d5-08085feb4b16.png) ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4662 from cfmcgrady/arrow-collect-limit-exec-2. Closes #4662 82c912e [Fu Chen] close vector 130bcb1 [Fu Chen] finally close facc13f [Fu Chen] exclude rule OptimizeLimitZero 3700839 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x 6064ab9 [Fu Chen] limit = 0 test case 6d596fc [Fu Chen] address comment 8280783 [Fu Chen] add `isStaticConfigKey` to adapt Spark-3.1.x 22cc70f [Fu Chen] add ut b72bc6f [Fu Chen] add offset support to adapt Spark-3.4.x 9ffb44f [Fu Chen] make toBatchIterator private c83cf3f [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x 573a262 [Fu Chen] fix 4cef204 [Fu Chen] SparkArrowbasedOperationSuite adapt Spark-3.1.x d70aee3 [Fu Chen] SparkPlan.session -> SparkSession.active to adapt Spark-3.1.x e3bf84c [Fu Chen] refactor 81886f0 [Fu Chen] address comment 2286afc [Fu Chen] reflective calla AdaptiveSparkPlanExec.finalPhysicalPlan 03d0747 [Fu Chen] address comment 25e4f05 [Fu Chen] add docs 885cf2c [Fu Chen] infer row size by schema.defaultSize 4e7ca54 [Fu Chen] unnecessarily changes ee5a756 [Fu Chen] revert unnecessarily changes 6c5b1eb [Fu Chen] add ut 4212a89 [Fu Chen] refactor and add ut ed8c692 [Fu Chen] refactor 0088671 [Fu Chen] refine 8593d85 [Fu Chen] driver slice last batch a584943 [Fu Chen] arrow take Authored-by: Fu Chen <[email protected]> Signed-off-by: ulyssesyou <[email protected]>
1 parent 5faebb1 commit 1a65125

File tree

7 files changed

+753
-33
lines changed

7 files changed

+753
-33
lines changed

externals/kyuubi-spark-sql-engine/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@
6565
<scope>provided</scope>
6666
</dependency>
6767

68+
<dependency>
69+
<groupId>org.apache.spark</groupId>
70+
<artifactId>spark-sql_${scala.binary.version}</artifactId>
71+
<type>test-jar</type>
72+
<scope>test</scope>
73+
</dependency>
74+
6875
<dependency>
6976
<groupId>org.apache.spark</groupId>
7077
<artifactId>spark-repl_${scala.binary.version}</artifactId>

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@ import java.util.concurrent.RejectedExecutionException
2121

2222
import scala.collection.JavaConverters._
2323

24-
import org.apache.spark.rdd.RDD
2524
import org.apache.spark.sql.DataFrame
26-
import org.apache.spark.sql.execution.SQLExecution
27-
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
25+
import org.apache.spark.sql.kyuubi.SparkDatasetHelper._
2826
import org.apache.spark.sql.types._
2927

3028
import org.apache.kyuubi.{KyuubiSQLException, Logging}
@@ -187,42 +185,22 @@ class ArrowBasedExecuteStatement(
187185
handle) {
188186

189187
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
190-
collectAsArrow(convertComplexType(resultDF)) { rdd =>
191-
rdd.toLocalIterator
192-
}
188+
toArrowBatchLocalIterator(convertComplexType(resultDF))
193189
}
194190

195191
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
196-
collectAsArrow(convertComplexType(resultDF)) { rdd =>
197-
rdd.collect()
198-
}
192+
executeCollect(convertComplexType(resultDF))
199193
}
200194

201195
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
202-
// this will introduce shuffle and hurt performance
203-
val limitedResult = resultDF.limit(maxRows)
204-
collectAsArrow(convertComplexType(limitedResult)) { rdd =>
205-
rdd.collect()
206-
}
207-
}
208-
209-
/**
210-
* refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
211-
* operation, so that we can track the arrow-based queries on the UI tab.
212-
*/
213-
private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
214-
SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
215-
df.queryExecution.executedPlan.resetMetrics()
216-
action(SparkDatasetHelper.toArrowBatchRdd(df))
217-
}
196+
executeCollect(convertComplexType(resultDF.limit(maxRows)))
218197
}
219198

220199
override protected def isArrowBasedOperation: Boolean = true
221200

222201
override val resultFormat = "arrow"
223202

224203
private def convertComplexType(df: DataFrame): DataFrame = {
225-
SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
204+
convertTopLevelComplexTypeToHiveString(df, timestampAsString)
226205
}
227-
228206
}
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.arrow
19+
20+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
21+
import java.nio.channels.Channels
22+
23+
import scala.collection.JavaConverters._
24+
import scala.collection.mutable.ArrayBuffer
25+
26+
import org.apache.arrow.vector._
27+
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
28+
import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer}
29+
import org.apache.spark.TaskContext
30+
import org.apache.spark.internal.Logging
31+
import org.apache.spark.sql.SparkSession
32+
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
33+
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
34+
import org.apache.spark.sql.execution.CollectLimitExec
35+
import org.apache.spark.sql.types._
36+
import org.apache.spark.sql.util.ArrowUtils
37+
import org.apache.spark.util.Utils
38+
39+
object KyuubiArrowConverters extends SQLConfHelper with Logging {
40+
41+
type Batch = (Array[Byte], Long)
42+
43+
/**
44+
* this method is to slice the input Arrow record batch byte array `bytes`, starting from `start`
45+
* and taking `length` number of elements.
46+
*/
47+
def slice(
48+
schema: StructType,
49+
timeZoneId: String,
50+
bytes: Array[Byte],
51+
start: Int,
52+
length: Int): Array[Byte] = {
53+
val in = new ByteArrayInputStream(bytes)
54+
val out = new ByteArrayOutputStream(bytes.length)
55+
56+
var vectorSchemaRoot: VectorSchemaRoot = null
57+
var slicedVectorSchemaRoot: VectorSchemaRoot = null
58+
59+
val sliceAllocator = ArrowUtils.rootAllocator.newChildAllocator(
60+
"slice",
61+
0,
62+
Long.MaxValue)
63+
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
64+
vectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, sliceAllocator)
65+
try {
66+
val recordBatch = MessageSerializer.deserializeRecordBatch(
67+
new ReadChannel(Channels.newChannel(in)),
68+
sliceAllocator)
69+
val vectorLoader = new VectorLoader(vectorSchemaRoot)
70+
vectorLoader.load(recordBatch)
71+
recordBatch.close()
72+
slicedVectorSchemaRoot = vectorSchemaRoot.slice(start, length)
73+
74+
val unloader = new VectorUnloader(slicedVectorSchemaRoot)
75+
val writeChannel = new WriteChannel(Channels.newChannel(out))
76+
val batch = unloader.getRecordBatch()
77+
MessageSerializer.serialize(writeChannel, batch)
78+
batch.close()
79+
out.toByteArray()
80+
} finally {
81+
in.close()
82+
out.close()
83+
if (vectorSchemaRoot != null) {
84+
vectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
85+
vectorSchemaRoot.close()
86+
}
87+
if (slicedVectorSchemaRoot != null) {
88+
slicedVectorSchemaRoot.getFieldVectors.asScala.foreach(_.close())
89+
slicedVectorSchemaRoot.close()
90+
}
91+
sliceAllocator.close()
92+
}
93+
}
94+
95+
/**
96+
* Forked from `org.apache.spark.sql.execution.SparkPlan#executeTake()`, the algorithm can be
97+
* summarized in the following steps:
98+
* 1. If the limit specified in the CollectLimitExec object is 0, the function returns an empty
99+
* array of batches.
100+
* 2. Otherwise, execute the child query plan of the CollectLimitExec object to obtain an RDD of
101+
* data to collect.
102+
* 3. Use an iterative approach to collect data in batches until the specified limit is reached.
103+
* In each iteration, it selects a subset of the partitions of the RDD to scan and tries to
104+
* collect data from them.
105+
* 4. For each partition subset, we use the runJob method of the Spark context to execute a
106+
* closure that scans the partition data and converts it to Arrow batches.
107+
* 5. Check if the collected data reaches the specified limit. If not, it selects another subset
108+
* of partitions to scan and repeats the process until the limit is reached or all partitions
109+
* have been scanned.
110+
* 6. Return an array of all the collected Arrow batches.
111+
*
112+
* Note that:
113+
* 1. The returned Arrow batches row count >= limit, if the input df has more than the `limit`
114+
* row count
115+
* 2. We don't implement the `takeFromEnd` logical
116+
*
117+
* @return
118+
*/
119+
def takeAsArrowBatches(
120+
collectLimitExec: CollectLimitExec,
121+
maxRecordsPerBatch: Long,
122+
maxEstimatedBatchSize: Long,
123+
timeZoneId: String): Array[Batch] = {
124+
val n = collectLimitExec.limit
125+
val schema = collectLimitExec.schema
126+
if (n == 0) {
127+
return new Array[Batch](0)
128+
} else {
129+
val limitScaleUpFactor = Math.max(conf.limitScaleUpFactor, 2)
130+
// TODO: refactor and reuse the code from RDD's take()
131+
val childRDD = collectLimitExec.child.execute()
132+
val buf = new ArrayBuffer[Batch]
133+
var bufferedRowSize = 0L
134+
val totalParts = childRDD.partitions.length
135+
var partsScanned = 0
136+
while (bufferedRowSize < n && partsScanned < totalParts) {
137+
// The number of partitions to try in this iteration. It is ok for this number to be
138+
// greater than totalParts because we actually cap it at totalParts in runJob.
139+
var numPartsToTry = limitInitialNumPartitions
140+
if (partsScanned > 0) {
141+
// If we didn't find any rows after the previous iteration, multiply by
142+
// limitScaleUpFactor and retry. Otherwise, interpolate the number of partitions we need
143+
// to try, but overestimate it by 50%. We also cap the estimation in the end.
144+
if (buf.isEmpty) {
145+
numPartsToTry = partsScanned * limitScaleUpFactor
146+
} else {
147+
val left = n - bufferedRowSize
148+
// As left > 0, numPartsToTry is always >= 1
149+
numPartsToTry = Math.ceil(1.5 * left * partsScanned / bufferedRowSize).toInt
150+
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
151+
}
152+
}
153+
154+
val partsToScan =
155+
partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
156+
157+
// TODO: SparkPlan.session introduced in SPARK-35798, replace with SparkPlan.session once we
158+
// drop Spark-3.1.x support.
159+
val sc = SparkSession.active.sparkContext
160+
val res = sc.runJob(
161+
childRDD,
162+
(it: Iterator[InternalRow]) => {
163+
val batches = toBatchIterator(
164+
it,
165+
schema,
166+
maxRecordsPerBatch,
167+
maxEstimatedBatchSize,
168+
n,
169+
timeZoneId)
170+
batches.map(b => b -> batches.rowCountInLastBatch).toArray
171+
},
172+
partsToScan)
173+
174+
var i = 0
175+
while (bufferedRowSize < n && i < res.length) {
176+
var j = 0
177+
val batches = res(i)
178+
while (j < batches.length && n > bufferedRowSize) {
179+
val batch = batches(j)
180+
val (_, batchSize) = batch
181+
buf += batch
182+
bufferedRowSize += batchSize
183+
j += 1
184+
}
185+
i += 1
186+
}
187+
partsScanned += partsToScan.size
188+
}
189+
190+
buf.toArray
191+
}
192+
}
193+
194+
/**
195+
* Spark introduced the config `spark.sql.limit.initialNumPartitions` since 3.4.0. see SPARK-40211
196+
*/
197+
private def limitInitialNumPartitions: Int = {
198+
conf.getConfString("spark.sql.limit.initialNumPartitions", "1")
199+
.toInt
200+
}
201+
202+
/**
203+
* Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
204+
* each output arrow batch contains this batch row count.
205+
*/
206+
private def toBatchIterator(
207+
rowIter: Iterator[InternalRow],
208+
schema: StructType,
209+
maxRecordsPerBatch: Long,
210+
maxEstimatedBatchSize: Long,
211+
limit: Long,
212+
timeZoneId: String): ArrowBatchIterator = {
213+
new ArrowBatchIterator(
214+
rowIter,
215+
schema,
216+
maxRecordsPerBatch,
217+
maxEstimatedBatchSize,
218+
limit,
219+
timeZoneId,
220+
TaskContext.get)
221+
}
222+
223+
/**
224+
* This class ArrowBatchIterator is derived from
225+
* [[org.apache.spark.sql.execution.arrow.ArrowConverters.ArrowBatchWithSchemaIterator]],
226+
* with two key differences:
227+
* 1. there is no requirement to write the schema at the batch header
228+
* 2. iteration halts when `rowCount` equals `limit`
229+
*/
230+
private[sql] class ArrowBatchIterator(
231+
rowIter: Iterator[InternalRow],
232+
schema: StructType,
233+
maxRecordsPerBatch: Long,
234+
maxEstimatedBatchSize: Long,
235+
limit: Long,
236+
timeZoneId: String,
237+
context: TaskContext)
238+
extends Iterator[Array[Byte]] {
239+
240+
protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
241+
private val allocator =
242+
ArrowUtils.rootAllocator.newChildAllocator(
243+
s"to${this.getClass.getSimpleName}",
244+
0,
245+
Long.MaxValue)
246+
247+
private val root = VectorSchemaRoot.create(arrowSchema, allocator)
248+
protected val unloader = new VectorUnloader(root)
249+
protected val arrowWriter = ArrowWriter.create(root)
250+
251+
Option(context).foreach {
252+
_.addTaskCompletionListener[Unit] { _ =>
253+
root.close()
254+
allocator.close()
255+
}
256+
}
257+
258+
override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
259+
root.close()
260+
allocator.close()
261+
false
262+
}
263+
264+
var rowCountInLastBatch: Long = 0
265+
var rowCount: Long = 0
266+
267+
override def next(): Array[Byte] = {
268+
val out = new ByteArrayOutputStream()
269+
val writeChannel = new WriteChannel(Channels.newChannel(out))
270+
271+
rowCountInLastBatch = 0
272+
var estimatedBatchSize = 0L
273+
Utils.tryWithSafeFinally {
274+
275+
// Always write the first row.
276+
while (rowIter.hasNext && (
277+
// For maxBatchSize and maxRecordsPerBatch, respect whatever smaller.
278+
// If the size in bytes is positive (set properly), always write the first row.
279+
rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 ||
280+
// If the size in bytes of rows are 0 or negative, unlimit it.
281+
estimatedBatchSize <= 0 ||
282+
estimatedBatchSize < maxEstimatedBatchSize ||
283+
// If the size of rows are 0 or negative, unlimit it.
284+
maxRecordsPerBatch <= 0 ||
285+
rowCountInLastBatch < maxRecordsPerBatch ||
286+
rowCount < limit)) {
287+
val row = rowIter.next()
288+
arrowWriter.write(row)
289+
estimatedBatchSize += (row match {
290+
case ur: UnsafeRow => ur.getSizeInBytes
291+
// Trying to estimate the size of the current row
292+
case _: InternalRow => schema.defaultSize
293+
})
294+
rowCountInLastBatch += 1
295+
rowCount += 1
296+
}
297+
arrowWriter.finish()
298+
val batch = unloader.getRecordBatch()
299+
MessageSerializer.serialize(writeChannel, batch)
300+
301+
// Always write the Ipc options at the end.
302+
ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
303+
304+
batch.close()
305+
} {
306+
arrowWriter.reset()
307+
}
308+
309+
out.toByteArray
310+
}
311+
}
312+
313+
// for testing
314+
def fromBatchIterator(
315+
arrowBatchIter: Iterator[Array[Byte]],
316+
schema: StructType,
317+
timeZoneId: String,
318+
context: TaskContext): Iterator[InternalRow] = {
319+
ArrowConverters.fromBatchIterator(arrowBatchIter, schema, timeZoneId, context)
320+
}
321+
}

0 commit comments

Comments
 (0)