@@ -22,14 +22,13 @@ import java.util.concurrent.RejectedExecutionException
2222import scala .collection .JavaConverters ._
2323import scala .collection .mutable .ArrayBuffer
2424
25- import org .apache .spark .rdd .RDD
2625import org .apache .spark .sql .DataFrame
27- import org .apache .spark .sql .execution .{CollectLimitExec , SQLExecution , TakeOrderedAndProjectExec }
26+ import org .apache .spark .sql .execution .{CollectLimitExec , SQLExecution }
27+ import org .apache .spark .sql .execution .arrow .{ArrowCollectLimitExec , KyuubiArrowUtils }
2828import org .apache .spark .sql .kyuubi .SparkDatasetHelper
2929import org .apache .spark .sql .types ._
30- import org .apache .kyuubi .{KyuubiSQLException , Logging }
31- import org .apache .spark .sql .execution .arrow .{ArrowCollectLimitExec , KyuubiArrowUtils }
3230
31+ import org .apache .kyuubi .{KyuubiSQLException , Logging }
3332import org .apache .kyuubi .config .KyuubiConf .OPERATION_RESULT_MAX_ROWS
3433import org .apache .kyuubi .engine .spark .KyuubiSparkUtil ._
3534import org .apache .kyuubi .operation .{ArrayFetchIterator , FetchIterator , IterableFetchIterator , OperationHandle , OperationState }
@@ -189,73 +188,70 @@ class ArrowBasedExecuteStatement(
189188 handle) {
190189
191190 override protected def incrementalCollectResult (resultDF : DataFrame ): Iterator [Any ] = {
192- collectAsArrow(convertComplexType(resultDF)) { rdd =>
193- rdd.toLocalIterator
191+ val df = convertComplexType(resultDF)
192+ withNewExecutionId(df) {
193+ SparkDatasetHelper .toArrowBatchRdd(df).toLocalIterator
194194 }
195195 }
196196
197197 override protected def fullCollectResult (resultDF : DataFrame ): Array [_] = {
198- collectAsArrow(convertComplexType(resultDF)) { rdd =>
199- rdd.collect()
200- }
198+ executeCollect(convertComplexType(resultDF))
201199 }
202200
203201 override protected def takeResult (resultDF : DataFrame , maxRows : Int ): Array [_] = {
204- // this will introduce shuffle and hurt performance
205- val limitedResult = resultDF.limit(maxRows)
206- // collectAsArrow(convertComplexType(limitedResult)) { rdd =>
207- // rdd.collect()
208- // }
209- val df = convertComplexType(limitedResult)
210- SQLExecution .withNewExecutionId(df.queryExecution, Some (" collectAsArrow" )) {
211- df.queryExecution.executedPlan.resetMetrics()
212- df.queryExecution.executedPlan match {
213- case collectLimit @ CollectLimitExec (limit, _) =>
214- val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
215- val batches = ArrowCollectLimitExec .takeAsArrowBatches(collectLimit, df.schema, 1000 , 1024 * 1024 , timeZoneId)
216- // .map(_._1)
217- val result = ArrayBuffer [Array [Byte ]]()
218- var i = 0
219- var rest = limit
220- println(s " batch....size... ${batches.length}" )
221- while (i < batches.length && rest > 0 ) {
222- val (batch, size) = batches(i)
223- if (size < rest) {
224- result += batch
225- // TODO: toInt
226- rest = rest - size.toInt
227- } else if (size == rest) {
228- result += batch
229- rest = 0
230- } else { // size > rest
231- println(s " size...... ${size}....rest...... ${rest}" )
232- // result += KyuubiArrowUtils.slice(batch, 0, rest)
233- result += KyuubiArrowUtils .sliceV2(df.schema, timeZoneId, batch, 0 , rest)
234- rest = 0
235- }
236- i += 1
237- }
238- result.toArray
239-
240- case takeOrderedAndProjectExec @ TakeOrderedAndProjectExec (limit, _, _, _) =>
241- val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
242- ArrowCollectLimitExec .taskOrdered(takeOrderedAndProjectExec, df.schema, 1000 , 1024 * 1024 , timeZoneId)
243- .map(_._1)
244- case _ =>
245- println(" yyyy" )
246- SparkDatasetHelper .toArrowBatchRdd(df).collect()
247- }
248- }
202+ executeCollect(convertComplexType(resultDF.limit(maxRows)))
249203 }
250204
251205 /**
252206 * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
253207 * operation, so that we can track the arrow-based queries on the UI tab.
254208 */
255- private def collectAsArrow [T ](df : DataFrame )(action : RDD [ Array [ Byte ]] => T ): T = {
209+ private def withNewExecutionId [T ](df : DataFrame )(body : => T ): T = {
256210 SQLExecution .withNewExecutionId(df.queryExecution, Some (" collectAsArrow" )) {
257211 df.queryExecution.executedPlan.resetMetrics()
258- action(SparkDatasetHelper .toArrowBatchRdd(df))
212+ body
213+ }
214+ }
215+
216+ def executeCollect (df : DataFrame ): Array [Array [Byte ]] = withNewExecutionId(df) {
217+ executeArrowBatchCollect(df).getOrElse {
218+ SparkDatasetHelper .toArrowBatchRdd(df).collect()
219+ }
220+ }
221+
222+ private def executeArrowBatchCollect (df : DataFrame ): Option [Array [Array [Byte ]]] = {
223+ df.queryExecution.executedPlan match {
224+ case collectLimit @ CollectLimitExec (limit, _) =>
225+ val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
226+ val maxRecordsPerBatch = spark.conf.getOption(
227+ " spark.sql.execution.arrow.maxRecordsPerBatch" ).map(_.toInt).getOrElse(10000 )
228+ // val maxBatchSize =
229+ // (spark.sessionState.conf.getConf(SPARK_CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
230+ val maxBatchSize = 1024 * 1024 * 4
231+ val batches = ArrowCollectLimitExec .takeAsArrowBatches(
232+ collectLimit,
233+ df.schema,
234+ maxRecordsPerBatch,
235+ maxBatchSize,
236+ timeZoneId)
237+ val result = ArrayBuffer [Array [Byte ]]()
238+ var i = 0
239+ var rest = limit
240+ while (i < batches.length && rest > 0 ) {
241+ val (batch, size) = batches(i)
242+ if (size <= rest) {
243+ result += batch
244+ // returned ArrowRecordBatch has less than `limit` row count, safety to do conversion
245+ rest -= size.toInt
246+ } else { // size > rest
247+ result += KyuubiArrowUtils .sliceV2(df.schema, timeZoneId, batch, 0 , rest)
248+ rest = 0
249+ }
250+ i += 1
251+ }
252+ Option (result.toArray)
253+ case _ =>
254+ None
259255 }
260256 }
261257
0 commit comments