1818package org .apache .kyuubi .engine .spark .operation
1919
2020import java .sql .Statement
21- import java .util .{Set => JSet }
21+ import java .util .{Locale , Set => JSet }
2222
2323import org .apache .spark .KyuubiSparkContextHelper
2424import org .apache .spark .scheduler .{SparkListener , SparkListenerJobStart }
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
2929import org .apache .spark .sql .execution .arrow .KyuubiArrowConverters
3030import org .apache .spark .sql .execution .exchange .Exchange
3131import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , SortMergeJoinExec }
32+ import org .apache .spark .sql .execution .metric .SparkMetricsTestUtils
3233import org .apache .spark .sql .functions .col
3334import org .apache .spark .sql .internal .SQLConf
3435import org .apache .spark .sql .kyuubi .SparkDatasetHelper
@@ -41,7 +42,8 @@ import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
4142import org .apache .kyuubi .operation .SparkDataTypeTests
4243import org .apache .kyuubi .reflection .DynFields
4344
44- class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
45+ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests
46+ with SparkMetricsTestUtils {
4547
4648 override protected def jdbcUrl : String = getJdbcUrl
4749
@@ -58,6 +60,16 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
5860 withJdbcStatement() { statement =>
5961 checkResultSetFormat(statement, " arrow" )
6062 }
63+ spark.catalog.listTables()
64+ .collect()
65+ .foreach { table =>
66+ if (table.isTemporary) {
67+ spark.catalog.dropTempView(table.name)
68+ } else {
69+ spark.sql(s " DROP TABLE IF EXISTS ${table.name}" )
70+ }
71+ ()
72+ }
6173 }
6274
6375 test(" detect resultSet format" ) {
@@ -288,13 +300,12 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
288300 assert(nodeName == " org.apache.spark.sql.execution.CommandResultExec" )
289301 }
290302 withJdbcStatement(" table_1" ) { statement =>
291- statement.executeQuery(s " CREATE TABLE table_1 (id bigint) USING parquet " )
303+ statement.executeQuery(" CREATE TABLE table_1 (id bigint) USING parquet" )
292304 withSparkListener(listener) {
293305 withSparkListener(l2) {
294306 val resultSet = statement.executeQuery(" SHOW TABLES" )
295307 assert(resultSet.next())
296308 assert(resultSet.getString(" tableName" ) == " table_1" )
297- KyuubiSparkContextHelper .waitListenerBus(spark)
298309 }
299310 }
300311 }
@@ -348,6 +359,33 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
348359 assert(metrics(" numOutputRows" ).value === 1 )
349360 }
350361
362+ test(" post LocalTableScanExec driver-side metrics" ) {
363+ val expectedMetrics = Map (
364+ 0L -> ((" LocalTableScan" , Map (" number of output rows" -> " 2" ))))
365+ withTables(" view_1" ) {
366+ val s = spark
367+ import s .implicits ._
368+ Seq ((1 , " a" ), (2 , " b" )).toDF(" c1" , " c2" ).createOrReplaceTempView(" view_1" )
369+ val df = spark.sql(" SELECT * FROM view_1" )
370+ val metrics = getSparkPlanMetrics(df)
371+ assert(metrics == expectedMetrics)
372+ }
373+ }
374+
375+ test(" post CommandResultExec driver-side metrics" ) {
376+ spark.sql(" show tables" ).show(truncate = false )
377+ assume(SPARK_ENGINE_RUNTIME_VERSION >= " 3.2" )
378+ val expectedMetrics = Map (
379+ 0L -> ((" CommandResult" , Map (" number of output rows" -> " 2" ))))
380+ withTables(" table_1" , " table_2" ) {
381+ spark.sql(" CREATE TABLE table_1 (id bigint) USING parquet" )
382+ spark.sql(" CREATE TABLE table_2 (id bigint) USING parquet" )
383+ val df = spark.sql(" SHOW TABLES" )
384+ val metrics = getSparkPlanMetrics(df)
385+ assert(metrics == expectedMetrics)
386+ }
387+ }
388+
351389 private def checkResultSetFormat (statement : Statement , expectFormat : String ): Unit = {
352390 val query =
353391 s """
@@ -465,6 +503,20 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
465503 }
466504 }
467505
506+ private def withTables [T ](tableNames : String * )(f : => T ): T = {
507+ try {
508+ f
509+ } finally {
510+ tableNames.foreach { name =>
511+ if (name.toUpperCase(Locale .ROOT ).startsWith(" VIEW" )) {
512+ spark.sql(s " DROP VIEW IF EXISTS $name" )
513+ } else {
514+ spark.sql(s " DROP TABLE IF EXISTS $name" )
515+ }
516+ }
517+ }
518+ }
519+
468520 /**
469521 * This method provides a reflection-based implementation of [[SQLConf.isStaticConfigKey ]] to
470522 * adapt Spark-3.1.x
0 commit comments