Skip to content

Commit b7012aa

Browse files
cfmcgradyulysses-you
authored andcommitted
[KYUUBI #4710][ARROW][FOLLOWUP] Post driver-side metrics for LocalTableScanExec/CommandResultExec
### _Why are the changes needed?_ to resolve #4710 (comment) ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4769 from cfmcgrady/arrow-send-driver-metrics. Closes #4710 a952d08 [Fu Chen] refactor a5645de [Fu Chen] address comment 6749630 [Fu Chen] update 2dff41e [Fu Chen] add SparkMetricsTestUtils 8c772bc [Fu Chen] ut 4e3cd7d [Fu Chen] metrics Authored-by: Fu Chen <[email protected]> Signed-off-by: ulyssesyou <[email protected]>
1 parent 0b72a61 commit b7012aa

File tree

3 files changed

+134
-16
lines changed

3 files changed

+134
-16
lines changed

externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.kyuubi
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22+
import org.apache.spark.SparkContext
2223
import org.apache.spark.internal.Logging
2324
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
2425
import org.apache.spark.rdd.RDD
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, Spa
2829
import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution}
2930
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
3031
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
32+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3133
import org.apache.spark.sql.functions._
3234
import org.apache.spark.sql.types._
3335

@@ -184,26 +186,36 @@ object SparkDatasetHelper extends Logging {
184186
result.toArray
185187
}
186188

187-
def doCommandResultExec(command: SparkPlan): Array[Array[Byte]] = {
189+
private lazy val commandResultExecRowsMethod = DynMethods.builder("rows")
190+
.impl("org.apache.spark.sql.execution.CommandResultExec")
191+
.build()
192+
193+
private def doCommandResultExec(command: SparkPlan): Array[Array[Byte]] = {
194+
val spark = SparkSession.active
195+
// TODO: replace with `command.rows` once we drop Spark 3.1 support.
196+
val rows = commandResultExecRowsMethod.invoke[Seq[InternalRow]](command)
197+
command.longMetric("numOutputRows").add(rows.size)
198+
sendDriverMetrics(spark.sparkContext, command.metrics)
188199
KyuubiArrowConverters.toBatchIterator(
189-
// TODO: replace with `command.rows.iterator` once we drop Spark 3.1 support.
190-
commandResultExecRowsMethod.invoke[Seq[InternalRow]](command).iterator,
200+
rows.iterator,
191201
command.schema,
192-
SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch,
202+
spark.sessionState.conf.arrowMaxRecordsPerBatch,
193203
maxBatchSize,
194204
-1,
195-
SparkSession.active.sessionState.conf.sessionLocalTimeZone).toArray
205+
spark.sessionState.conf.sessionLocalTimeZone).toArray
196206
}
197207

198-
def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = {
208+
private def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = {
209+
val spark = SparkSession.active
199210
localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size)
211+
sendDriverMetrics(spark.sparkContext, localTableScan.metrics)
200212
KyuubiArrowConverters.toBatchIterator(
201213
localTableScan.rows.iterator,
202214
localTableScan.schema,
203-
SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch,
215+
spark.sessionState.conf.arrowMaxRecordsPerBatch,
204216
maxBatchSize,
205217
-1,
206-
SparkSession.active.sessionState.conf.sessionLocalTimeZone).toArray
218+
spark.sessionState.conf.sessionLocalTimeZone).toArray
207219
}
208220

209221
/**
@@ -268,10 +280,6 @@ object SparkDatasetHelper extends Logging {
268280
sparkPlan.getClass.getName == "org.apache.spark.sql.execution.CommandResultExec"
269281
}
270282

271-
private lazy val commandResultExecRowsMethod = DynMethods.builder("rows")
272-
.impl("org.apache.spark.sql.execution.CommandResultExec")
273-
.build()
274-
275283
/**
276284
* refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
277285
* operation, so that we can track the arrow-based queries on the UI tab.
@@ -282,4 +290,9 @@ object SparkDatasetHelper extends Logging {
282290
body
283291
}
284292
}
293+
294+
private def sendDriverMetrics(sc: SparkContext, metrics: Map[String, SQLMetric]): Unit = {
295+
val executionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
296+
SQLMetrics.postDriverMetricUpdates(sc, executionId, metrics.values.toSeq)
297+
}
285298
}

externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.kyuubi.engine.spark.operation
1919

2020
import java.sql.Statement
21-
import java.util.{Set => JSet}
21+
import java.util.{Locale, Set => JSet}
2222

2323
import org.apache.spark.KyuubiSparkContextHelper
2424
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
2929
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
3030
import org.apache.spark.sql.execution.exchange.Exchange
3131
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
32+
import org.apache.spark.sql.execution.metric.SparkMetricsTestUtils
3233
import org.apache.spark.sql.functions.col
3334
import org.apache.spark.sql.internal.SQLConf
3435
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
@@ -41,7 +42,8 @@ import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
4142
import org.apache.kyuubi.operation.SparkDataTypeTests
4243
import org.apache.kyuubi.reflection.DynFields
4344

44-
class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
45+
class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests
46+
with SparkMetricsTestUtils {
4547

4648
override protected def jdbcUrl: String = getJdbcUrl
4749

@@ -58,6 +60,16 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
5860
withJdbcStatement() { statement =>
5961
checkResultSetFormat(statement, "arrow")
6062
}
63+
spark.catalog.listTables()
64+
.collect()
65+
.foreach { table =>
66+
if (table.isTemporary) {
67+
spark.catalog.dropTempView(table.name)
68+
} else {
69+
spark.sql(s"DROP TABLE IF EXISTS ${table.name}")
70+
}
71+
()
72+
}
6173
}
6274

6375
test("detect resultSet format") {
@@ -288,13 +300,12 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
288300
assert(nodeName == "org.apache.spark.sql.execution.CommandResultExec")
289301
}
290302
withJdbcStatement("table_1") { statement =>
291-
statement.executeQuery(s"CREATE TABLE table_1 (id bigint) USING parquet")
303+
statement.executeQuery("CREATE TABLE table_1 (id bigint) USING parquet")
292304
withSparkListener(listener) {
293305
withSparkListener(l2) {
294306
val resultSet = statement.executeQuery("SHOW TABLES")
295307
assert(resultSet.next())
296308
assert(resultSet.getString("tableName") == "table_1")
297-
KyuubiSparkContextHelper.waitListenerBus(spark)
298309
}
299310
}
300311
}
@@ -348,6 +359,33 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
348359
assert(metrics("numOutputRows").value === 1)
349360
}
350361

362+
test("post LocalTableScanExec driver-side metrics") {
363+
val expectedMetrics = Map(
364+
0L -> (("LocalTableScan", Map("number of output rows" -> "2"))))
365+
withTables("view_1") {
366+
val s = spark
367+
import s.implicits._
368+
Seq((1, "a"), (2, "b")).toDF("c1", "c2").createOrReplaceTempView("view_1")
369+
val df = spark.sql("SELECT * FROM view_1")
370+
val metrics = getSparkPlanMetrics(df)
371+
assert(metrics == expectedMetrics)
372+
}
373+
}
374+
375+
test("post CommandResultExec driver-side metrics") {
376+
spark.sql("show tables").show(truncate = false)
377+
assume(SPARK_ENGINE_RUNTIME_VERSION >= "3.2")
378+
val expectedMetrics = Map(
379+
0L -> (("CommandResult", Map("number of output rows" -> "2"))))
380+
withTables("table_1", "table_2") {
381+
spark.sql("CREATE TABLE table_1 (id bigint) USING parquet")
382+
spark.sql("CREATE TABLE table_2 (id bigint) USING parquet")
383+
val df = spark.sql("SHOW TABLES")
384+
val metrics = getSparkPlanMetrics(df)
385+
assert(metrics == expectedMetrics)
386+
}
387+
}
388+
351389
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
352390
val query =
353391
s"""
@@ -465,6 +503,20 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
465503
}
466504
}
467505

506+
private def withTables[T](tableNames: String*)(f: => T): T = {
507+
try {
508+
f
509+
} finally {
510+
tableNames.foreach { name =>
511+
if (name.toUpperCase(Locale.ROOT).startsWith("VIEW")) {
512+
spark.sql(s"DROP VIEW IF EXISTS $name")
513+
} else {
514+
spark.sql(s"DROP TABLE IF EXISTS $name")
515+
}
516+
}
517+
}
518+
}
519+
468520
/**
469521
* This method provides a reflection-based implementation of [[SQLConf.isStaticConfigKey]] to
470522
* adapt Spark-3.1.x
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.metric
19+
20+
import org.apache.spark.sql.DataFrame
21+
import org.apache.spark.sql.execution.SparkPlanInfo
22+
import org.apache.spark.sql.execution.ui.SparkPlanGraph
23+
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
24+
25+
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
26+
27+
trait SparkMetricsTestUtils {
28+
this: WithSparkSQLEngine =>
29+
30+
private lazy val statusStore = spark.sharedState.statusStore
31+
private def currentExecutionIds(): Set[Long] = {
32+
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
33+
statusStore.executionsList.map(_.executionId).toSet
34+
}
35+
36+
protected def getSparkPlanMetrics(df: DataFrame): Map[Long, (String, Map[String, Any])] = {
37+
val previousExecutionIds = currentExecutionIds()
38+
SparkDatasetHelper.executeCollect(df)
39+
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
40+
val executionIds = currentExecutionIds().diff(previousExecutionIds)
41+
assert(executionIds.size === 1)
42+
val executionId = executionIds.head
43+
val metricValues = statusStore.executionMetrics(executionId)
44+
SparkPlanGraph(SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan)).allNodes
45+
.map { node =>
46+
val nodeMetrics = node.metrics.map { metric =>
47+
val metricValue = metricValues(metric.accumulatorId)
48+
(metric.name, metricValue)
49+
}.toMap
50+
(node.id, node.name -> nodeMetrics)
51+
}.toMap
52+
}
53+
}

0 commit comments

Comments
 (0)