@@ -190,15 +190,34 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
190190 df : DataFrame ,
191191 expectedNumOfJobs : Int ,
192192 expectedMetrics : Map [Long , (String , Map [String , Any ])]): Unit = {
193- val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet)
193+ val expectedMetricsPredicates = expectedMetrics.mapValues { case (nodeName, nodeMetrics) =>
194+ (nodeName, nodeMetrics.mapValues(expectedMetricValue =>
195+ (actualMetricValue : Any ) => expectedMetricValue.toString === actualMetricValue)
196+ )}
197+ testSparkPlanMetricsWithPredicates(df, expectedNumOfJobs, expectedMetricsPredicates)
198+ }
199+
200+ /**
201+ * Call `df.collect()` and verify if the collected metrics satisfy the specified predicates.
202+ * @param df `DataFrame` to run
203+ * @param expectedNumOfJobs number of jobs that will run
204+ * @param expectedMetricsPredicates the expected metrics predicates. The format is
205+ * `nodeId -> (operatorName, metric name -> metric value predicate)`.
206+ */
207+ protected def testSparkPlanMetricsWithPredicates (
208+ df : DataFrame ,
209+ expectedNumOfJobs : Int ,
210+ expectedMetricsPredicates : Map [Long , (String , Map [String , Any => Boolean ])]): Unit = {
211+ val optActualMetrics =
212+ getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetricsPredicates.keySet)
194213 optActualMetrics.foreach { actualMetrics =>
195- assert(expectedMetrics .keySet === actualMetrics.keySet)
196- for (nodeId <- expectedMetrics .keySet) {
197- val (expectedNodeName, expectedMetricsMap ) = expectedMetrics (nodeId)
214+ assert(expectedMetricsPredicates .keySet === actualMetrics.keySet)
215+ for (nodeId <- expectedMetricsPredicates .keySet) {
216+ val (expectedNodeName, expectedMetricsPredicatesMap ) = expectedMetricsPredicates (nodeId)
198217 val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
199218 assert(expectedNodeName === actualNodeName)
200- for (metricName <- expectedMetricsMap .keySet) {
201- assert(expectedMetricsMap (metricName).toString === actualMetricsMap(metricName))
219+ for (metricName <- expectedMetricsPredicatesMap .keySet) {
220+ assert(expectedMetricsPredicatesMap (metricName)( actualMetricsMap(metricName) ))
202221 }
203222 }
204223 }
@@ -248,6 +267,28 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
248267 protected def timingMetricStats (metricStr : String ): Seq [(Float , String )] = {
249268 metricStats(metricStr).map(stringToDuration)
250269 }
270+
271+ /**
272+ * Returns a function to check whether all stats (sum, min, med and max) of a timing metric
273+ * satisfy the specified predicate.
274+ * @param predicate predicate to check stats
275+ * @return function to check all stats of a timing metric
276+ */
277+ protected def timingMetricAllStatsShould (predicate : Float => Boolean ): Any => Boolean = {
278+ (timingMetric : Any ) =>
279+ timingMetricStats(timingMetric.toString).forall { case (duration, _) => predicate(duration) }
280+ }
281+
282+ /**
283+ * Returns a function to check whether all stats (sum, min, med and max) of a size metric satisfy
284+ * the specified predicate.
285+ * @param predicate predicate to check stats
286+ * @return function to check all stats of a size metric
287+ */
288+ protected def sizeMetricAllStatsShould (predicate : Float => Boolean ): Any => Boolean = {
289+ (sizeMetric : Any ) =>
290+ sizeMetricStats(sizeMetric.toString).forall { case (bytes, _) => predicate(bytes)}
291+ }
251292}
252293
253294
0 commit comments