diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6bbeb2de7538c..f2ae777c5b9c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2196,6 +2196,42 @@ object SQLConf { .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") .createWithDefault(16) + + val SKIP_PARTIAL_AGGREGATE_MINROWS = + buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") + .internal() + .doc("Number of records after which aggregate operator checks if " + + "partial aggregation phase can be avoided") + .version("3.1.0") + .longConf + .createWithDefault(100000) + + val SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO = + buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") + .internal() + .doc("Ratio beyond which the partial aggregation is skipped." + + "This is computed by taking the ratio of number of records present" + + " in map of Aggregate operator to the total number of records processed" + + " by the Aggregate operator.") + .version("3.1.0") + .doubleConf + .checkValue(ratio => ratio > 0 && ratio < 1, "Invalid value for " + + "spark.sql.aggregate.skipPartialAggregate.aggregateRatio. Valid value needs" + + " to be between 0 and 1" ) + .createWithDefault(0.5) + + val SKIP_PARTIAL_AGGREGATE_ENABLED = + buildConf("spark.sql.aggregate.skipPartialAggregate") + .internal() + .doc("When enabled, the partial aggregation is skipped when the following" + + "two conditions are met. 1. When the total number of records processed is greater" + + s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + + "of record count in map to the total records is less that value defined by " + + s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") @@ -2922,6 +2958,12 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) + + def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_MINROWS) + + def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 117e98f33a0ec..3046423ad7658 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -63,6 +63,13 @@ public final class UnsafeFixedWidthAggregationMap { */ private final UnsafeRow currentAggregationBuffer; + /** + * Number of rows that were added to the map + * This includes the elements that were passed on sorter + * using {@link #destructAndCreateExternalSorter()} + */ + private long numRowsAdded = 0L; + /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. @@ -147,6 +154,8 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { ); if (!putSucceeded) { return null; + } else { + numRowsAdded = numRowsAdded + 1; } } @@ -249,4 +258,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), map); } + + public long getNumRows() { + return numRowsAdded; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9c07ea10a87e7..70e768502e98d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -67,13 +67,22 @@ case class HashAggregateExec( child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), - "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), - "avgHashProbe" -> - SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) + override lazy val metrics = { + val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), + "avgHashProbe" -> + SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) + if (skipPartialAggregateEnabled) { + metrics ++ Map("partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, + "number of skipped records for partial aggregates")) + } else { + metrics + } + } + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -409,6 +418,14 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false + private var avoidSpillInPartialAggregateTerm: String = _ + private val skipPartialAggregateEnabled = { + sqlContext.conf.skipPartialAggregate && + modes.nonEmpty && modes.forall(_ == Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty + } + private var rowCountTerm: String = _ + private var outputFunc: String = _ + // whether a vectorized hashmap is used instead // we have decided to always use the row-based hashmap, // but the vectorized hashmap can still be switched on for testing and benchmarking purposes. @@ -680,6 +697,19 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") + + var childrenConsumed: String = null + if (skipPartialAggregateEnabled) { + avoidSpillInPartialAggregateTerm = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, + "avoidPartialAggregate", + term => s"$term = ${Utils.isTesting};") + rowCountTerm = ctx. + addMutableState(CodeGenerator.JAVA_LONG, "rowCount") + childrenConsumed = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") + } + if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else if (sqlContext.conf.enableVectorizedHashMap) { @@ -750,10 +780,15 @@ case class HashAggregateExec( finishRegularHashMap } + outputFunc = generateResultFunction(ctx) + val genChildrenConsumedCode = if (skipPartialAggregateEnabled) { + s"${childrenConsumed} = true;" + } else "" val doAggFuncName = ctx.addNewFunction(doAgg, s""" |private void $doAgg() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $genChildrenConsumedCode | $finishHashMap |} """.stripMargin) @@ -761,8 +796,6 @@ case class HashAggregateExec( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputFunc = generateResultFunction(ctx) - def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { @@ -830,6 +863,15 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") + val genCodePostInitCode = + if (skipPartialAggregateEnabled) { + s""" + |if (!$childrenConsumed) { + | $doAggFuncName(); + | if (shouldStop()) return; + |} + """.stripMargin + } else "" s""" |if (!$initAgg) { | $initAgg = true; @@ -838,13 +880,17 @@ case class HashAggregateExec( | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + | if (shouldStop()) return; |} + |$genCodePostInitCode |// output the result |$outputFromFastHashMap |$outputFromRegularHashMap """.stripMargin } + override def needStopCheck: Boolean = skipPartialAggregateEnabled + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( @@ -877,50 +923,109 @@ case class HashAggregateExec( ("true", "true", "", "") } + val skipPartialAggregateThreshold = sqlContext.conf.skipPartialAggregateThreshold + val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio + + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val oomeClassName = classOf[SparkOutOfMemoryError].getName + val findOrInsertRegularHashMap: String = { + def getAggBufferFromMap = { + s""" + |// generate grouping key + |${unsafeRowKeyCode.code} + |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); + |if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); + |} + """.stripMargin + } - val findOrInsertRegularHashMap: String = - s""" - |// generate grouping key - |${unsafeRowKeyCode.code} - |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); - |if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); - |} - |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - |// aggregation after processing all input rows. - |if ($unsafeRowBuffer == null) { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, $unsafeRowKeyHash); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); - | } - |} + def addToSorter: String = { + s""" + |if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + |} else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + |} + |$resetCounter + |// the hash map had be spilled, it should have enough memory now, + |// try to allocate buffer again. + |$unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, $unsafeRowKeyHash); + |if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new $oomeClassName("No enough memory for aggregation"); + |}""".stripMargin + } + + if (skipPartialAggregateEnabled) { + val checkIfPartialAggSkipped = + s""" + |!($rowCountTerm < $skipPartialAggregateThreshold) && + | ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio; + |""".stripMargin + s""" + |if (!$avoidSpillInPartialAggregateTerm) { + | $getAggBufferFromMap + | // Can't allocate buffer from the hash map. + | // Check if we can avoid partial aggregation. + | // Otherwise, Spill the map and fallback to sort-based + | // aggregation after processing all input rows. + | if ($unsafeRowBuffer == null) { + | $countTerm = $countTerm + $hashMapTerm.getNumRows(); + | boolean skipPartAgg = $checkIfPartialAggSkipped + | if (skipPartAgg) { + | // Aggregation buffer is created later + | $avoidSpillInPartialAggregateTerm = true; + | } else { + | $addToSorter + | } + | } + |} """.stripMargin + } else { + s""" + |$getAggBufferFromMap + |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + |// aggregation after processing all input rows. + |if ($unsafeRowBuffer == null) { + | $addToSorter + |} + """.stripMargin + } + } val findOrInsertHashMap: String = { - if (isFastHashMapEnabled) { + val insertCode = if (isFastHashMapEnabled) { + def findOrInsertIntoFastHashMap = { + s""" + |${fastRowKeys.map(_.code).mkString("\n")} + |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); + |} + |""".stripMargin + } + val insertFastMap = if (skipPartialAggregateEnabled) { + s""" + |if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { + | $findOrInsertIntoFastHashMap + |} + |$countTerm = $fastHashMapTerm.getNumRows(); + |""".stripMargin + } else { + s""" + |if ($checkFallbackForGeneratedHashMap) { + | $findOrInsertIntoFastHashMap + |} + |""".stripMargin + } // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |if ($checkFallbackForGeneratedHashMap) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } - |} + |$insertFastMap |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) { | $findOrInsertRegularHashMap @@ -929,6 +1034,27 @@ case class HashAggregateExec( } else { findOrInsertRegularHashMap } + def createEmptyAggBufferAndUpdateMetrics: String = { + if (skipPartialAggregateEnabled) { + val numAggSkippedRows = metricTerm(ctx, "partialAggSkipped") + val initExpr = declFunctions.flatMap(f => f.initialValues) + val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) + s""" + |// Create an empty aggregation buffer + |if ($avoidSpillInPartialAggregateTerm) { + | ${unsafeRowKeyCode.code} + | ${emptyBufferKeyCode.code} + | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; + | $numAggSkippedRows.add(1); + |} + |""".stripMargin + } else "" + } + + s""" + |$insertCode + |$createEmptyAggBufferAndUpdateMetrics + |""".stripMargin } val inputAttr = aggregateBufferAttributes ++ inputAttributes @@ -1005,7 +1131,7 @@ case class HashAggregateExec( } val updateRowInHashMap: String = { - if (isFastHashMapEnabled) { + val updateRowInMap = if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => @@ -1080,6 +1206,22 @@ case class HashAggregateExec( } else { updateRowInRegularHashMap } + + def outputRow: String = { + if (skipPartialAggregateEnabled) { + s""" + |if ($avoidSpillInPartialAggregateTerm) { + | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); + |} + |$rowCountTerm = $rowCountTerm + 1; + |""".stripMargin + } else "" + } + + s""" + |$updateRowInMap + |$outputRow + |""".stripMargin } val declareRowBuffer: String = if (isFastHashMapEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index e1c85823259b1..d7b01ef39a654 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -75,6 +75,8 @@ abstract class HashMapGenerator( | |${generateRowIterator()} | + |${generateNumRows()} + | |${generateClose()} |} """.stripMargin @@ -136,6 +138,14 @@ abstract class HashMapGenerator( """.stripMargin } + protected final def generateNumRows(): String = { + s""" + |public int getNumRows() { + | return batch.numRows(); + |} + """.stripMargin + } + protected final def genComputeHash( ctx: CodegenContext, input: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f7396ee2a89c8..54f1746c1ad90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{AggUtils, HashAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec @@ -51,6 +52,42 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(df.collect() === Array(Row(9, 4.5))) } + test("Avoid spill in partial aggregation" ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true"), + (SQLConf.SKIP_PARTIAL_AGGREGATE_MINROWS.key -> "2")) { + // Create Dataframes + val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) + val aggDF = data.toDF("name", "values").groupBy("name").sum("values") + val partAggNode = aggDF.queryExecution.executedPlan.find { + case h: HashAggregateExec => + val modes = h.aggregateExpressions.map(_.mode) + modes.nonEmpty && modes.forall(_ == Partial) + case _ => false + } + + checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) + assert(partAggNode.isDefined, + "No HashAggregate node with partial aggregate expression found") + assert(partAggNode.get.metrics("partialAggSkipped").value == data.size, + "Partial aggregation got triggered in partial hash aggregate node") + } + } + + test(s"Distinct: Partial aggregation should happen for " + + "HashAggregate nodes performing partial Aggregate operations " ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true")) { + val aggDF = testData2.select(sumDistinct($"a"), sum($"b")) + val aggNodes = aggDF.queryExecution.executedPlan.collect { + case h: HashAggregateExec => h + } + val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf[SerializeFromObjectExec]) + checkAnswer(aggDF, Row(6, 9)) + assert(baseNodes.size == 1 ) + assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count()) + assert(other.forall(!_.metrics.contains("partialAggSkipped"))) + } + } + test("Aggregate with grouping keys should be included in WholeStageCodegen") { val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 50652690339a8..8f5646283a9d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -142,48 +142,52 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("Aggregate metrics: track avg probe") { - // The executed plan looks like: - // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) - // +- Exchange hashpartitioning(a#61, 5) - // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) - // +- Exchange RoundRobinPartitioning(1) - // +- LocalTableScan [a#61] - // - // Assume the execution plan with node id is: - // Wholestage disabled: - // HashAggregate(nodeId = 0) - // Exchange(nodeId = 1) - // HashAggregate(nodeId = 2) - // Exchange (nodeId = 3) - // LocalTableScan(nodeId = 4) - // - // Wholestage enabled: - // WholeStageCodegen(nodeId = 0) - // HashAggregate(nodeId = 1) - // Exchange(nodeId = 2) - // WholeStageCodegen(nodeId = 3) - // HashAggregate(nodeId = 4) - // Exchange(nodeId = 5) - // LocalTableScan(nodeId = 6) - Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(1).groupBy('a).count() - val nodeIds = if (enableWholeStage) { - Set(4L, 1L) - } else { - Set(2L, 0L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString - if (!probes.contains("\n")) { - // It's a single metrics value - assert(probes.toDouble > 1.0) + if (spark.sessionState.conf.getConf(SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED)) { + logInfo("Skipping, since partial Aggregation is disabled") + } else { + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) } else { - val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") - // Extract min, med, max from the string and strip off everthing else. - val index = mainValue.indexOf(" (", 0) - mainValue.slice(0, index).split(", ").foreach { - probe => assert(probe.toDouble > 1.0) + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + if (!probes.contains("\n")) { + // It's a single metrics value + assert(probes.toDouble > 1.0) + } else { + val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") + // Extract min, med, max from the string and strip off everthing else. + val index = mainValue.indexOf(" (", 0) + mainValue.slice(0, index).split(", ").foreach { + probe => assert(probe.toDouble > 1.0) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index fac981267f4d7..19160c09ea012 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1057,38 +1057,43 @@ class HashAggregationQuerySuite extends AggregationQuerySuite class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> - enableTwoLevelMaps) { - Seq(4, 8).foreach { uaoSize => - UnsafeAlignedOffset.setUaoSize(uaoSize) - (1 to 3).foreach { fallbackStartsAt => - withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> - s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = Dataset.ofRows(spark, actual.logicalPlan) - - QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using HashAggregate with - |controlled fallback (it falls back to bytes to bytes map once it has - |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation - |once it has processed $fallbackStartsAt input rows). - |The query is ${actual.queryExecution} - |$errorMessage + // The HashAggregationQueryWithControlledFallbackSuite is dependent on ordering and also + // assumes partial aggregation to have happened. + // disabling the flag that skips partial aggregation + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "false")) { + Seq("true", "false").foreach { enableTwoLevelMaps => + withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> + enableTwoLevelMaps) { + Seq(4, 8).foreach { uaoSize => + UnsafeAlignedOffset.setUaoSize(uaoSize) + (1 to 3).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> + s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = Dataset.ofRows(spark, actual.logicalPlan) + + QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using HashAggregate with + |controlled fallback (it falls back to bytes to bytes map once it has + |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation + |once it has processed $fallbackStartsAt input rows). + |The query is ${actual.queryExecution} + |$errorMessage """.stripMargin - fail(newErrorMessage) - case None => // Success + fail(newErrorMessage) + case None => // Success + } } } + // reset static uaoSize to avoid affect other tests + UnsafeAlignedOffset.setUaoSize(0) } - // reset static uaoSize to avoid affect other tests - UnsafeAlignedOffset.setUaoSize(0) } } }