Skip to content

Commit c3336d8

Browse files
committed
add testSparkPlanMetricsWithPredicates and comments for sort time
1 parent 1e55f31 commit c3336d8

File tree

2 files changed

+55
-19
lines changed

2 files changed

+55
-19
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,19 +202,14 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
202202
// Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core,
203203
// so Project here is not collapsed into LocalTableScan.
204204
val df = Seq(1, 3, 2).toDF("id").sort('id)
205-
val metrics = getSparkPlanMetrics(df, 2, Set(0))
206-
assert(metrics.isDefined)
207-
val sortMetrics = metrics.get.get(0).get
208-
// Check node 0 is Sort node
209-
val operatorName = sortMetrics._1
210-
assert(operatorName == "Sort")
211-
// Check metrics values
212-
val sortTimeStr = sortMetrics._2.get("sort time total (min, med, max)").get.toString
213-
assert(timingMetricStats(sortTimeStr).forall { case (sortTime, _) => sortTime >= 0 })
214-
val peakMemoryStr = sortMetrics._2.get("peak memory total (min, med, max)").get.toString
215-
assert(sizeMetricStats(peakMemoryStr).forall { case (peakMemory, _) => peakMemory > 0 })
216-
val spillSizeStr = sortMetrics._2.get("spill size total (min, med, max)").get.toString
217-
assert(sizeMetricStats(spillSizeStr).forall { case (spillSize, _) => spillSize >= 0 })
205+
testSparkPlanMetricsWithPredicates(df, 2, Map(
206+
0L -> (("Sort", Map(
207+
// In SortExec, sort time is collected as nanoseconds, but it is converted and stored as
208+
// milliseconds. So sort time may be 0 if sort is executed very fast.
209+
"sort time total (min, med, max)" -> timingMetricAllStatsShould(_ >= 0),
210+
"peak memory total (min, med, max)" -> sizeMetricAllStatsShould(_ > 0),
211+
"spill size total (min, med, max)" -> sizeMetricAllStatsShould(_ >= 0))))
212+
))
218213
}
219214

220215
test("SortMergeJoin metrics") {

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,34 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
190190
df: DataFrame,
191191
expectedNumOfJobs: Int,
192192
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
193-
val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet)
193+
val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) =>
194+
(nodeName, nodeMetrics.mapValues(expectedMetricValue =>
195+
(actualMetricValue: Any) => expectedMetricValue.toString === actualMetricValue)
196+
)}
197+
testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates)
198+
}
199+
200+
/**
201+
* Call `df.collect()` and verify if the collected metrics satisfy the specified predicates.
202+
* @param df `DataFrame` to run
203+
* @param expectedNumOfJobs number of jobs that will run
204+
* @param expectedMetricsPredicates the expected metrics predicates. The format is
205+
* `nodeId -> (operatorName, metric name -> metric value predicate)`.
206+
*/
207+
protected def testSparkPlanMetricsWithPredicates(
208+
df: DataFrame,
209+
expectedNumOfJobs: Int,
210+
expectedMetricsPredicates: Map[Long, (String, Map[String, Any => Boolean])]): Unit = {
211+
val optActualMetrics =
212+
getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet)
194213
optActualMetrics.foreach { actualMetrics =>
195-
assert(expectedMetrics.keySet === actualMetrics.keySet)
196-
for (nodeId <- expectedMetrics.keySet) {
197-
val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId)
214+
assert(expectedMetricsPredicates.keySet === actualMetrics.keySet)
215+
for (nodeId <- expectedMetricsPredicates.keySet) {
216+
val (expectedNodeName, expectedMetricsPredicatesMap) = expectedMetricsPredicates(nodeId)
198217
val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
199218
assert(expectedNodeName === actualNodeName)
200-
for (metricName <- expectedMetricsMap.keySet) {
201-
assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName))
219+
for (metricName <- expectedMetricsPredicatesMap.keySet) {
220+
assert(expectedMetricsPredicatesMap(metricName)(actualMetricsMap(metricName)))
202221
}
203222
}
204223
}
@@ -248,6 +267,28 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
248267
protected def timingMetricStats(metricStr: String): Seq[(Float, String)] = {
249268
metricStats(metricStr).map(stringToDuration)
250269
}
270+
271+
/**
272+
* Returns a function to check whether all stats (sum, min, med and max) of a timing metric
273+
* satisfy the specified predicate.
274+
* @param predicate predicate to check stats
275+
* @return function to check all stats of a timing metric
276+
*/
277+
protected def timingMetricAllStatsShould(predicate: Float => Boolean): Any => Boolean = {
278+
(timingMetric: Any) =>
279+
timingMetricStats(timingMetric.toString).forall { case (duration, _) => predicate(duration) }
280+
}
281+
282+
/**
283+
* Returns a function to check whether all stats (sum, min, med and max) of a size metric satisfy
284+
* the specified predicate.
285+
* @param predicate predicate to check stats
286+
* @return function to check all stats of a size metric
287+
*/
288+
protected def sizeMetricAllStatsShould(predicate: Float => Boolean): Any => Boolean = {
289+
(sizeMetric: Any) =>
290+
sizeMetricStats(sizeMetric.toString).forall { case (bytes, _) => predicate(bytes)}
291+
}
251292
}
252293

253294

0 commit comments

Comments
 (0)