Skip to content

Commit 0088671

Browse files
committed
refine
1 parent 8593d85 commit 0088671

File tree

6 files changed

+104
-79
lines changed

6 files changed

+104
-79
lines changed

extensions/spark/kyuubi-spark-connector-tpcds/src/test/scala/org/apache/kyuubi/spark/connector/tpcds/TPCDSQuerySuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,6 @@ class TPCDSQuerySuite extends KyuubiFunSuite {
8888
}
8989
}
9090
}
91+
92+
test("aa") {}
9193
}

externals/kyuubi-spark-sql-engine/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@
7171
<scope>provided</scope>
7272
</dependency>
7373

74+
<dependency>
75+
<groupId>com.google.guava</groupId>
76+
<artifactId>guava</artifactId>
77+
<version>14.0.1</version>
78+
<scope>provided</scope>
79+
</dependency>
80+
7481
<dependency>
7582
<groupId>org.scala-lang</groupId>
7683
<artifactId>scala-compiler</artifactId>

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala

Lines changed: 52 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ import java.util.concurrent.RejectedExecutionException
2222
import scala.collection.JavaConverters._
2323
import scala.collection.mutable.ArrayBuffer
2424

25-
import org.apache.spark.rdd.RDD
2625
import org.apache.spark.sql.DataFrame
27-
import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution, TakeOrderedAndProjectExec}
26+
import org.apache.spark.sql.execution.{CollectLimitExec, SQLExecution}
27+
import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils}
2828
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
2929
import org.apache.spark.sql.types._
30-
import org.apache.kyuubi.{KyuubiSQLException, Logging}
31-
import org.apache.spark.sql.execution.arrow.{ArrowCollectLimitExec, KyuubiArrowUtils}
3230

31+
import org.apache.kyuubi.{KyuubiSQLException, Logging}
3332
import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
3433
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
3534
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationHandle, OperationState}
@@ -189,73 +188,70 @@ class ArrowBasedExecuteStatement(
189188
handle) {
190189

191190
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
192-
collectAsArrow(convertComplexType(resultDF)) { rdd =>
193-
rdd.toLocalIterator
191+
val df = convertComplexType(resultDF)
192+
withNewExecutionId(df) {
193+
SparkDatasetHelper.toArrowBatchRdd(df).toLocalIterator
194194
}
195195
}
196196

197197
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
198-
collectAsArrow(convertComplexType(resultDF)) { rdd =>
199-
rdd.collect()
200-
}
198+
executeCollect(convertComplexType(resultDF))
201199
}
202200

203201
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
204-
// this will introduce shuffle and hurt performance
205-
val limitedResult = resultDF.limit(maxRows)
206-
// collectAsArrow(convertComplexType(limitedResult)) { rdd =>
207-
// rdd.collect()
208-
// }
209-
val df = convertComplexType(limitedResult)
210-
SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
211-
df.queryExecution.executedPlan.resetMetrics()
212-
df.queryExecution.executedPlan match {
213-
case collectLimit @ CollectLimitExec(limit, _) =>
214-
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
215-
val batches = ArrowCollectLimitExec.takeAsArrowBatches(collectLimit, df.schema, 1000, 1024 * 1024, timeZoneId)
216-
// .map(_._1)
217-
val result = ArrayBuffer[Array[Byte]]()
218-
var i = 0
219-
var rest = limit
220-
println(s"batch....size... ${batches.length}")
221-
while (i < batches.length && rest > 0) {
222-
val (batch, size) = batches(i)
223-
if (size < rest) {
224-
result += batch
225-
// TODO: toInt
226-
rest = rest - size.toInt
227-
} else if (size == rest) {
228-
result += batch
229-
rest = 0
230-
} else { // size > rest
231-
println(s"size......${size}....rest......${rest}")
232-
// result += KyuubiArrowUtils.slice(batch, 0, rest)
233-
result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest)
234-
rest = 0
235-
}
236-
i += 1
237-
}
238-
result.toArray
239-
240-
case takeOrderedAndProjectExec @ TakeOrderedAndProjectExec(limit, _, _, _) =>
241-
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
242-
ArrowCollectLimitExec.taskOrdered(takeOrderedAndProjectExec, df.schema, 1000, 1024 * 1024, timeZoneId)
243-
.map(_._1)
244-
case _ =>
245-
println("yyyy")
246-
SparkDatasetHelper.toArrowBatchRdd(df).collect()
247-
}
248-
}
202+
executeCollect(convertComplexType(resultDF.limit(maxRows)))
249203
}
250204

251205
/**
252206
* refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
253207
* operation, so that we can track the arrow-based queries on the UI tab.
254208
*/
255-
private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
209+
private def withNewExecutionId[T](df: DataFrame)(body: => T): T = {
256210
SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
257211
df.queryExecution.executedPlan.resetMetrics()
258-
action(SparkDatasetHelper.toArrowBatchRdd(df))
212+
body
213+
}
214+
}
215+
216+
def executeCollect(df: DataFrame): Array[Array[Byte]] = withNewExecutionId(df) {
217+
executeArrowBatchCollect(df).getOrElse {
218+
SparkDatasetHelper.toArrowBatchRdd(df).collect()
219+
}
220+
}
221+
222+
private def executeArrowBatchCollect(df: DataFrame): Option[Array[Array[Byte]]] = {
223+
df.queryExecution.executedPlan match {
224+
case collectLimit @ CollectLimitExec(limit, _) =>
225+
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
226+
val maxRecordsPerBatch = spark.conf.getOption(
227+
"spark.sql.execution.arrow.maxRecordsPerBatch").map(_.toInt).getOrElse(10000)
228+
// val maxBatchSize =
229+
// (spark.sessionState.conf.getConf(SPARK_CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
230+
val maxBatchSize = 1024 * 1024 * 4
231+
val batches = ArrowCollectLimitExec.takeAsArrowBatches(
232+
collectLimit,
233+
df.schema,
234+
maxRecordsPerBatch,
235+
maxBatchSize,
236+
timeZoneId)
237+
val result = ArrayBuffer[Array[Byte]]()
238+
var i = 0
239+
var rest = limit
240+
while (i < batches.length && rest > 0) {
241+
val (batch, size) = batches(i)
242+
if (size <= rest) {
243+
result += batch
244+
// returned ArrowRecordBatch has less than `limit` row count, safety to do conversion
245+
rest -= size.toInt
246+
} else { // size > rest
247+
result += KyuubiArrowUtils.sliceV2(df.schema, timeZoneId, batch, 0, rest)
248+
rest = 0
249+
}
250+
i += 1
251+
}
252+
Option(result.toArray)
253+
case _ =>
254+
None
259255
}
260256
}
261257

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowCollectLimitExec.scala

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -124,21 +124,4 @@ object ArrowCollectLimitExec extends SQLConfHelper {
124124
buf.toArray
125125
}
126126
}
127-
128-
def taskOrdered(
129-
takeOrdered: TakeOrderedAndProjectExec,
130-
schema: StructType,
131-
maxRecordsPerBatch: Long,
132-
maxEstimatedBatchSize: Long,
133-
timeZoneId: String): Array[Batch] = {
134-
val batches = ArrowConvertersHelper.toBatchWithSchemaIterator(
135-
takeOrdered.executeCollect().iterator,
136-
schema,
137-
maxEstimatedBatchSize,
138-
maxEstimatedBatchSize,
139-
takeOrdered.limit,
140-
timeZoneId)
141-
batches.map(b => b -> batches.rowCountInLastBatch).toArray
142-
}
143127
}
144-

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowUtils.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
2121
import java.nio.channels.Channels
2222

2323
import org.apache.arrow.memory.RootAllocator
24+
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
2425
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel}
2526
import org.apache.arrow.vector.ipc.message.MessageSerializer
26-
import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot, VectorUnloader}
2727
import org.apache.spark.sql.types.StructType
2828
import org.apache.spark.sql.util.ArrowUtils
2929

@@ -58,8 +58,12 @@ object KyuubiArrowUtils {
5858
}
5959
}
6060

61-
def sliceV2(schema: StructType,
62-
timeZoneId: String, bytes: Array[Byte], start: Int, length: Int): Array[Byte] = {
61+
def sliceV2(
62+
schema: StructType,
63+
timeZoneId: String,
64+
bytes: Array[Byte],
65+
start: Int,
66+
length: Int): Array[Byte] = {
6367
val in = new ByteArrayInputStream(bytes)
6468
val out = new ByteArrayOutputStream()
6569

@@ -71,17 +75,16 @@ object KyuubiArrowUtils {
7175
// println("rowCount......" + reader.getVectorSchemaRoot.getRowCount)
7276
// val root = reader.getVectorSchemaRoot.slice(start, length)
7377

74-
7578
val recordBatch = MessageSerializer.deserializeRecordBatch(
76-
new ReadChannel(Channels.newChannel(in)), rootAllocator)
79+
new ReadChannel(Channels.newChannel(in)),
80+
rootAllocator)
7781
val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
7882

7983
val root = VectorSchemaRoot.create(arrowSchema, rootAllocator)
8084
val vectorLoader = new VectorLoader(root)
8185
vectorLoader.load(recordBatch)
8286
recordBatch.close()
8387

84-
8588
val unloader = new VectorUnloader(root.slice(start, length))
8689
val writeChannel = new WriteChannel(Channels.newChannel(out))
8790
val batch = unloader.getRecordBatch()

externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ package org.apache.kyuubi.engine.spark.operation
2020
import java.sql.Statement
2121

2222
import org.apache.spark.KyuubiSparkContextHelper
23+
import org.apache.spark.sql.Row
2324
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
2425
import org.apache.spark.sql.execution.QueryExecution
26+
import org.apache.spark.sql.functions.col
2527
import org.apache.spark.sql.util.QueryExecutionListener
2628

2729
import org.apache.kyuubi.config.KyuubiConf
@@ -138,6 +140,20 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
138140
assert(metrics("numOutputRows").value === 1)
139141
}
140142

143+
test("aa") {
144+
145+
withJdbcStatement() { statement =>
146+
loadPartitionedTable()
147+
val n = 17
148+
statement.executeQuery(s"SET kyuubi.operation.result.max.rows=$n")
149+
val result = statement.executeQuery("select * from t_1")
150+
for (i <- 0 until n) {
151+
assert(result.next())
152+
}
153+
assert(!result.next())
154+
}
155+
}
156+
141157
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
142158
val query =
143159
s"""
@@ -177,4 +193,22 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
177193
.allSessions()
178194
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
179195
}
196+
197+
private def loadPartitionedTable(): Unit = {
198+
SparkSQLEngine.currentEngine.get
199+
.backendService
200+
.sessionManager
201+
.allSessions()
202+
.map(_.asInstanceOf[SparkSessionImpl].spark)
203+
.foreach { spark =>
204+
spark.range(1000)
205+
.repartitionByRange(100, col("id"))
206+
.createOrReplaceTempView("t_1")
207+
spark.sql("select * from t_1")
208+
.foreachPartition { p: Iterator[Row] =>
209+
assert(p.length == 10)
210+
()
211+
}
212+
}
213+
}
180214
}

0 commit comments

Comments
 (0)