Skip to content

Commit 68dbdd7

Browse files
committed
check Sort metrics values
1 parent 42c77f2 commit 68dbdd7

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,23 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
197197
// Assume the execution plan with node id is
198198
// Sort(nodeId = 0)
199199
// Exchange(nodeId = 1)
200-
// LocalTableScan(nodeId = 2)
200+
// Project(nodeId = 2)
201+
// LocalTableScan(nodeId = 3)
202+
// Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core,
203+
// so Project here is not collapsed into LocalTableScan.
201204
val df = Seq(1, 3, 2).toDF("id").sort('id)
202-
testSparkPlanMetrics(df, 2, Map.empty)
203-
assert(df.queryExecution.executedPlan.find(_.isInstanceOf[SortExec]).isDefined)
205+
val metrics = getSparkPlanMetrics(df, 2, Set(0))
206+
val sortMetrics = metrics.get.get(0).get
207+
// Check node 0 is Sort node
208+
val operatorName = sortMetrics._1
209+
assert(operatorName == "Sort")
210+
// Check metrics values
211+
val sortTimeStr = sortMetrics._2.get("sort time total (min, med, max)").get.toString
212+
timingMetricStats(sortTimeStr).foreach { case (sortTime, _) => assert(sortTime >= 0) }
213+
val peakMemoryStr = sortMetrics._2.get("peak memory total (min, med, max)").get.toString
214+
sizeMetricStats(peakMemoryStr).foreach { case (peakMemory, _) => assert(peakMemory > 0) }
215+
val spillSizeStr = sortMetrics._2.get("spill size total (min, med, max)").get.toString
216+
sizeMetricStats(spillSizeStr).foreach { case (spillSize, _) => assert(spillSize >= 0) }
204217
}
205218

206219
test("SortMergeJoin metrics") {

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution.metric
1919

2020
import java.io.File
21+
import java.util.regex.Pattern
2122

2223
import scala.collection.mutable.HashMap
2324

@@ -198,6 +199,51 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
198199
}
199200
}
200201
}
202+
203+
private def metricStats(metricStr: String): Seq[String] = {
204+
val sum = metricStr.substring(0, metricStr.indexOf("(")).stripPrefix("\n").stripSuffix(" ")
205+
val minMedMax = metricStr.substring(metricStr.indexOf("(") + 1, metricStr.indexOf(")"))
206+
.split(", ").toSeq
207+
(sum +: minMedMax)
208+
}
209+
210+
private def stringToBytes(str: String): (Float, String) = {
211+
val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (EB|PB|TB|GB|MB|KB|B)").matcher(str)
212+
if (matcher.matches()) {
213+
(matcher.group(1).toFloat, matcher.group(3))
214+
} else {
215+
throw new NumberFormatException("Failed to parse byte string: " + str)
216+
}
217+
}
218+
219+
private def stringToDuration(str: String): (Float, String) = {
220+
val matcher = Pattern.compile("([0-9]+(\\.[0-9]+)?) (ms|s|m|h)").matcher(str)
221+
if (matcher.matches()) {
222+
(matcher.group(1).toFloat, matcher.group(3))
223+
} else {
224+
throw new NumberFormatException("Failed to parse time string: " + str)
225+
}
226+
}
227+
228+
/**
229+
* Convert a size metric string to a sequence of stats, including sum, min, med and max in order,
230+
* each a tuple of (value, unit).
231+
* @param metricStr size metric string, e.g. "\n96.2 MB (32.1 MB, 32.1 MB, 32.1 MB)"
232+
* @return A sequence of stats, e.g. ((96.2,MB), (32.1,MB), (32.1,MB), (32.1,MB))
233+
*/
234+
protected def sizeMetricStats(metricStr: String): Seq[(Float, String)] = {
235+
metricStats(metricStr).map(stringToBytes)
236+
}
237+
238+
/**
239+
* Convert a timing metric string to a sequence of stats, including sum, min, med and max in
240+
* order, each a tuple of (value, unit).
241+
* @param metricStr timing metric string, e.g. "\n2.0 ms (1.0 ms, 1.0 ms, 1.0 ms)"
242+
* @return A sequence of stats, e.g. ((2.0,ms), (1.0,ms), (1.0,ms), (1.0,ms))
243+
*/
244+
protected def timingMetricStats(metricStr: String): Seq[(Float, String)] = {
245+
metricStats(metricStr).map(stringToDuration)
246+
}
201247
}
202248

203249

0 commit comments

Comments
 (0)