From db8a62d98c5354aa6c47b0ddbdcdbe3c8b2510dc Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 11 Jun 2020 17:17:43 -0700 Subject: [PATCH 01/33] Fix: Init commit --- .../apache/spark/sql/internal/SQLConf.scala | 9 ++++ .../aggregate/HashAggregateExec.scala | 48 +++++++++++++------ .../execution/WholeStageCodegenSuite.scala | 21 ++++++++ 3 files changed, 64 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6bbeb2de7538c..332e0773b2ad6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2196,6 +2196,13 @@ object SQLConf { .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") .createWithDefault(16) + val SPILL_PARTIAL_AGGREGATE_DISABLED = + buildConf("spark.sql.aggregate.spill.partialaggregate.disabled") + .internal() + .doc("Avoid sort/spill to disk during partial aggregation") + .booleanConf + .createWithDefault(false) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") @@ -2922,6 +2929,8 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def spillInPartialAggregationDisabled: Boolean = getConf(SPILL_PARTIAL_AGGREGATE_DISABLED) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9c07ea10a87e7..ad750e7d53b22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -409,6 +409,9 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false + private var avoidSpillInPartialAggregateTerm: String = _ + private var outputFunc: String = _ + // whether a vectorized hashmap is used instead // we have decided to always use the row-based hashmap, // but the vectorized hashmap can still be switched on for testing and benchmarking purposes. @@ -680,6 +683,8 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") + avoidSpillInPartialAggregateTerm = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else if (sqlContext.conf.enableVectorizedHashMap) { @@ -750,6 +755,7 @@ case class HashAggregateExec( finishRegularHashMap } + outputFunc = generateResultFunction(ctx) val doAggFuncName = ctx.addNewFunction(doAgg, s""" |private void $doAgg() throws java.io.IOException { @@ -761,7 +767,6 @@ case class HashAggregateExec( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputFunc = generateResultFunction(ctx) def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { @@ -879,6 +884,9 @@ case class HashAggregateExec( val oomeClassName = classOf[SparkOutOfMemoryError].getName + val thisPlan = ctx.addReferenceObj("plan", this) + val spillInPartialAggregateDisabled = sqlContext.conf.spillInPartialAggregationDisabled + val findOrInsertRegularHashMap: String = s""" |// generate grouping key @@ -892,19 +900,25 @@ case class HashAggregateExec( |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. |if ($unsafeRowBuffer == null) { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | // If sort/spill to disk is disabled, do not create the sorter + | if (!$avoidSpillInPartialAggregateTerm && $spillInPartialAggregateDisabled) { + | $avoidSpillInPartialAggregateTerm = true; + | $unsafeRowBuffer = (UnsafeRow) $thisPlan.getEmptyAggregationBuffer(); | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, $unsafeRowKeyHash); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, $unsafeRowKeyHash); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new $oomeClassName("No enough memory for aggregation"); + | } | } |} """.stripMargin @@ -1005,7 +1019,7 @@ case class HashAggregateExec( } val updateRowInHashMap: String = { - if (isFastHashMapEnabled) { + val updateRowinMap = if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => @@ -1080,6 +1094,12 @@ case class HashAggregateExec( } else { updateRowInRegularHashMap } + s""" + |$updateRowinMap + |if ($avoidSpillInPartialAggregateTerm) { + | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); + |} + |""".stripMargin } val declareRowBuffer: String = if (isFastHashMapEnabled) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f7396ee2a89c8..074dec8475ba8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAnd import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ @@ -165,6 +166,26 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } + test("SPARK-: Avoid spill in partial aggregation " + + "when spark.sql.aggregate.spill.partialaggregate.disabled is set") { + withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true"), + (SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key, "false")) { + // Create Dataframes + val arrayData = Seq(("James", 1), ("James", 1), ("Phil", 1)) + val srcDF = arrayData.toDF("name", "values") + val aggDF = srcDF.groupBy("name").sum("values") + + val results = aggDF.collect() + val hashAggNode = aggDF.queryExecution.executedPlan.find { + case h: HashAggregateExec => !h.child.isInstanceOf[ShuffleExchangeExec] + case _ => false + } + assert(hashAggNode.isDefined) + assert(hashAggNode.get.metrics("spillSize").value == 0) + assert(results === Seq(Row("James", 2), Row("Phil", 1))) + } + } + def genGroupByCode(caseNum: Int): CodeAndComment = { val caseExp = (1 to caseNum).map { i => s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" From 9a59925961ad7423c515f0bb1e561ded3ff894b7 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 11 Jun 2020 17:22:55 -0700 Subject: [PATCH 02/33] Fix: Fix UT name --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 074dec8475ba8..933d8972bb9d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -166,7 +166,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } - test("SPARK-: Avoid spill in partial aggregation " + + test("SPARK-31973: Avoid spill in partial aggregation " + "when spark.sql.aggregate.spill.partialaggregate.disabled is set") { withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true"), (SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key, "false")) { From feacdcf7819b19f0e5d855314798c4244421f9f3 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 11 Jun 2020 18:07:38 -0700 Subject: [PATCH 03/33] Fix: Fix codegen --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ad750e7d53b22..e1a032f81bf43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -835,9 +835,11 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") + val spillInPartialAggregateDisabled = sqlContext.conf.spillInPartialAggregationDisabled s""" |if (!$initAgg) { | $initAgg = true; + | $avoidSpillInPartialAggregateTerm = $spillInPartialAggregateDisabled; | $createFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); @@ -885,7 +887,6 @@ case class HashAggregateExec( val oomeClassName = classOf[SparkOutOfMemoryError].getName val thisPlan = ctx.addReferenceObj("plan", this) - val spillInPartialAggregateDisabled = sqlContext.conf.spillInPartialAggregationDisabled val findOrInsertRegularHashMap: String = s""" @@ -900,9 +901,8 @@ case class HashAggregateExec( |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. |if ($unsafeRowBuffer == null) { - | // If sort/spill to disk is disabled, do not create the sorter - | if (!$avoidSpillInPartialAggregateTerm && $spillInPartialAggregateDisabled) { - | $avoidSpillInPartialAggregateTerm = true; + | if ($avoidSpillInPartialAggregateTerm) { + | // If sort/spill to disk is disabled, do not sort/spil to disk | $unsafeRowBuffer = (UnsafeRow) $thisPlan.getEmptyAggregationBuffer(); | } else { | if ($sorterTerm == null) { From ab98ea4f934b0d457915e8af60c4eedbb4a7cbe6 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 11 Jun 2020 18:14:14 -0700 Subject: [PATCH 04/33] Revert "Fix: Fix codegen" This reverts commit 086ba42847f164a2200be13af48f3cce9dda794f. --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index e1a032f81bf43..ad750e7d53b22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -835,11 +835,9 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") - val spillInPartialAggregateDisabled = sqlContext.conf.spillInPartialAggregationDisabled s""" |if (!$initAgg) { | $initAgg = true; - | $avoidSpillInPartialAggregateTerm = $spillInPartialAggregateDisabled; | $createFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); @@ -887,6 +885,7 @@ case class HashAggregateExec( val oomeClassName = classOf[SparkOutOfMemoryError].getName val thisPlan = ctx.addReferenceObj("plan", this) + val spillInPartialAggregateDisabled = sqlContext.conf.spillInPartialAggregationDisabled val findOrInsertRegularHashMap: String = s""" @@ -901,8 +900,9 @@ case class HashAggregateExec( |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. |if ($unsafeRowBuffer == null) { - | if ($avoidSpillInPartialAggregateTerm) { - | // If sort/spill to disk is disabled, do not sort/spil to disk + | // If sort/spill to disk is disabled, do not create the sorter + | if (!$avoidSpillInPartialAggregateTerm && $spillInPartialAggregateDisabled) { + | $avoidSpillInPartialAggregateTerm = true; | $unsafeRowBuffer = (UnsafeRow) $thisPlan.getEmptyAggregationBuffer(); | } else { | if ($sorterTerm == null) { From 2e102d1defe9d257b8843747ed853b852d8f465a Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 11 Jun 2020 18:22:50 -0700 Subject: [PATCH 05/33] Fix: Fix codegen logic --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ad750e7d53b22..2e4658d53f2d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -900,10 +900,10 @@ case class HashAggregateExec( |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. |if ($unsafeRowBuffer == null) { - | // If sort/spill to disk is disabled, do not create the sorter + | // If sort/spill to disk is disabled, nothing is done. + | // Aggregation buffer is created later | if (!$avoidSpillInPartialAggregateTerm && $spillInPartialAggregateDisabled) { | $avoidSpillInPartialAggregateTerm = true; - | $unsafeRowBuffer = (UnsafeRow) $thisPlan.getEmptyAggregationBuffer(); | } else { | if ($sorterTerm == null) { | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); @@ -921,6 +921,10 @@ case class HashAggregateExec( | } | } |} + |// Create an empty aggregation buffer + |if ($avoidSpillInPartialAggregateTerm) { + | $unsafeRowBuffer = (UnsafeRow) $thisPlan.getEmptyAggregationBuffer(); + |} """.stripMargin val findOrInsertHashMap: String = { From 5fa601b44865690e9f34c7df034318bf544cff34 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 11 Jun 2020 18:26:07 -0700 Subject: [PATCH 06/33] Fix: Fix codegen logic --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2e4658d53f2d9..729a6c1616d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -899,10 +899,10 @@ case class HashAggregateExec( |} |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. - |if ($unsafeRowBuffer == null) { + |if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | // If sort/spill to disk is disabled, nothing is done. | // Aggregation buffer is created later - | if (!$avoidSpillInPartialAggregateTerm && $spillInPartialAggregateDisabled) { + | if ($spillInPartialAggregateDisabled) { | $avoidSpillInPartialAggregateTerm = true; | } else { | if ($sorterTerm == null) { From 220eaed99368ec61df5ec9fa0e673184fbb5e9dd Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 17 Jun 2020 12:44:13 -0700 Subject: [PATCH 07/33] Fix: Fix codegen logic --- .../sql/execution/aggregate/AggUtils.scala | 4 + .../aggregate/HashAggregateExec.scala | 96 ++++++++++++------- .../execution/WholeStageCodegenSuite.scala | 44 ++++----- 3 files changed, 88 insertions(+), 56 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 56a287d4d0279..0bf4f6c3b06e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -353,4 +353,8 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + def areAggExpressionsPartial(exprs: Seq[AggregateExpression]): Boolean = { + exprs.forall(e => e.mode == Partial || e.mode == PartialMerge) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 729a6c1616d42..4ba8a8d202d56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -63,6 +63,8 @@ case class HashAggregateExec( require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) + override def needStopCheck: Boolean = sqlContext.conf.spillInPartialAggregationDisabled + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) @@ -72,6 +74,8 @@ case class HashAggregateExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), + "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, "Num records" + + " skipped partial aggregation skipped"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) @@ -252,6 +256,7 @@ case class HashAggregateExec( s""" |while (!$initAgg) { | $initAgg = true; + | $avoidSpillInPartialAggregateTerm = ${Utils.isTesting} && $isPartial; | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); @@ -335,6 +340,7 @@ case class HashAggregateExec( // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes + // To individually generate code for each aggregate function, an element in `updateExprs` holds // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => @@ -409,7 +415,9 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false + private val isPartial = AggUtils.areAggExpressionsPartial(aggregateExpressions) private var avoidSpillInPartialAggregateTerm: String = _ + private var childrenConsumed: String = _ private var outputFunc: String = _ // whether a vectorized hashmap is used instead @@ -685,6 +693,8 @@ case class HashAggregateExec( val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") avoidSpillInPartialAggregateTerm = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") + childrenConsumed = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else if (sqlContext.conf.enableVectorizedHashMap) { @@ -760,6 +770,7 @@ case class HashAggregateExec( s""" |private void $doAgg() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $childrenConsumed = true; | $finishHashMap |} """.stripMargin) @@ -838,11 +849,17 @@ case class HashAggregateExec( s""" |if (!$initAgg) { | $initAgg = true; + | $avoidSpillInPartialAggregateTerm = ${Utils.isTesting} && $isPartial; | $createFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); + | $shouldStopCheckCode; + |} + |if (!$childrenConsumed) { + | $doAggFuncName(); + | $shouldStopCheckCode; |} |// output the result |$outputFromFastHashMap @@ -889,50 +906,49 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" - |// generate grouping key - |${unsafeRowKeyCode.code} - |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); - |if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); - |} - |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - |// aggregation after processing all input rows. - |if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { - | // If sort/spill to disk is disabled, nothing is done. - | // Aggregation buffer is created later - | if ($spillInPartialAggregateDisabled) { - | $avoidSpillInPartialAggregateTerm = true; - | } else { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + |if (!$avoidSpillInPartialAggregateTerm) { + | // generate grouping key + | ${unsafeRowKeyCode.code} + | int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); + | if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); + | } + | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + | // aggregation after processing all input rows. + | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { + | // If sort/spill to disk is disabled, nothing is done. + | // Aggregation buffer is created later + | if ($spillInPartialAggregateDisabled && $isPartial) { + | $avoidSpillInPartialAggregateTerm = true; | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, $unsafeRowKeyHash); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, $unsafeRowKeyHash); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new $oomeClassName("No enough memory for aggregation"); + | } | } | } |} - |// Create an empty aggregation buffer - |if ($avoidSpillInPartialAggregateTerm) { - | $unsafeRowBuffer = (UnsafeRow) $thisPlan.getEmptyAggregationBuffer(); - |} """.stripMargin + val partTerm = metricTerm(ctx, "partialAggSkipped") val findOrInsertHashMap: String = { - if (isFastHashMapEnabled) { + val insertCode = if (isFastHashMapEnabled) { // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |if ($checkFallbackForGeneratedHashMap) { + |if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { | ${fastRowKeys.map(_.code).mkString("\n")} | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | $fastRowBuffer = $fastHashMapTerm.findOrInsert( @@ -947,6 +963,18 @@ case class HashAggregateExec( } else { findOrInsertRegularHashMap } + val initExpr = declFunctions.flatMap(f => f.initialValues) + val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) + s""" + |$insertCode + |// Create an empty aggregation buffer + |if ($avoidSpillInPartialAggregateTerm) { + | ${unsafeRowKeyCode.code} + | ${emptyBufferKeyCode.code} + | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; + | $partTerm.add(1); + |} + |""".stripMargin } val inputAttr = aggregateBufferAttributes ++ inputAttributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 933d8972bb9d3..07a6dbcd74e0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{AggUtils, HashAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ @@ -52,6 +51,27 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(df.collect() === Array(Row(9, 4.5))) } + test(s"Avoid spill in partial aggregation" ) { + withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true")) { + // Create Dataframes + val arrayData = Seq(("James", 1), ("James", 1), ("Phil", 1)) + val srcDF = arrayData.toDF("name", "values") + val aggDF = srcDF.groupBy("name").sum("values") + val a = aggDF.queryExecution.debug.codegenToSeq() + val partAggNode = aggDF.queryExecution.executedPlan.find { + case h: HashAggregateExec + if AggUtils.areAggExpressionsPartial(h.aggregateExpressions) => true + case _ => false + } + checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) + assert(partAggNode.isDefined, + "No HashAggregate node with partial aggregate expression found") + assert(partAggNode.get.metrics("partialAggSkipped").value == 3, + "Partial aggregation got triggrered in partial hash aggregate node") + } + } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan @@ -166,26 +186,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } - test("SPARK-31973: Avoid spill in partial aggregation " + - "when spark.sql.aggregate.spill.partialaggregate.disabled is set") { - withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true"), - (SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key, "false")) { - // Create Dataframes - val arrayData = Seq(("James", 1), ("James", 1), ("Phil", 1)) - val srcDF = arrayData.toDF("name", "values") - val aggDF = srcDF.groupBy("name").sum("values") - - val results = aggDF.collect() - val hashAggNode = aggDF.queryExecution.executedPlan.find { - case h: HashAggregateExec => !h.child.isInstanceOf[ShuffleExchangeExec] - case _ => false - } - assert(hashAggNode.isDefined) - assert(hashAggNode.get.metrics("spillSize").value == 0) - assert(results === Seq(Row("James", 2), Row("Phil", 1))) - } - } - def genGroupByCode(caseNum: Int): CodeAndComment = { val caseExp = (1 to caseNum).map { i => s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" From 452b6328eba3bea4a23b7386e5cce6ae9ccf099b Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 17 Jun 2020 13:01:58 -0700 Subject: [PATCH 08/33] Fix: clean up --- .../spark/sql/execution/WholeStageCodegenSuite.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 07a6dbcd74e0e..a30be3d1eb889 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -54,10 +54,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test(s"Avoid spill in partial aggregation" ) { withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true")) { // Create Dataframes - val arrayData = Seq(("James", 1), ("James", 1), ("Phil", 1)) - val srcDF = arrayData.toDF("name", "values") - val aggDF = srcDF.groupBy("name").sum("values") - val a = aggDF.queryExecution.debug.codegenToSeq() + val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) + val aggDF = data.toDF("name", "values").groupBy("name").sum("values") val partAggNode = aggDF.queryExecution.executedPlan.find { case h: HashAggregateExec if AggUtils.areAggExpressionsPartial(h.aggregateExpressions) => true @@ -66,7 +64,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) assert(partAggNode.isDefined, "No HashAggregate node with partial aggregate expression found") - assert(partAggNode.get.metrics("partialAggSkipped").value == 3, + assert(partAggNode.get.metrics("partialAggSkipped").value == data.size, "Partial aggregation got triggrered in partial hash aggregate node") } } From 68dd5a38f1c9138dd811bea9bce138c2f47b7f3c Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 17 Jun 2020 13:05:52 -0700 Subject: [PATCH 09/33] Fix: remove partialmerge --- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 0bf4f6c3b06e8..90ee0602eb3a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -355,6 +355,6 @@ object AggUtils { } def areAggExpressionsPartial(exprs: Seq[AggregateExpression]): Boolean = { - exprs.forall(e => e.mode == Partial || e.mode == PartialMerge) + exprs.forall(e => e.mode == Partial) } } From 692fd1be475188c030a1eb7eae66232b0b7f38d9 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 17 Jun 2020 13:08:05 -0700 Subject: [PATCH 10/33] Fix: fix typo, remove whitelines --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 1 - .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4ba8a8d202d56..c5dcb4d1f6f18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -340,7 +340,6 @@ case class HashAggregateExec( // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ inputAttributes - // To individually generate code for each aggregate function, an element in `updateExprs` holds // all the expressions for the buffer of an aggregation function. val updateExprs = aggregateExpressions.map { e => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a30be3d1eb889..f378f8f9993ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -65,7 +65,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(partAggNode.isDefined, "No HashAggregate node with partial aggregate expression found") assert(partAggNode.get.metrics("partialAggSkipped").value == data.size, - "Partial aggregation got triggrered in partial hash aggregate node") + "Partial aggregation got triggered in partial hash aggregate node") } } From f1b6ac13508ccd283cc8f1e72368ec802ffec305 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 17 Jun 2020 22:15:38 -0700 Subject: [PATCH 11/33] Fix: Fix UT attempt --- .../apache/spark/sql/execution/aggregate/HashAggregateExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c5dcb4d1f6f18..f5b5f2b1cf442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -256,7 +256,6 @@ case class HashAggregateExec( s""" |while (!$initAgg) { | $initAgg = true; - | $avoidSpillInPartialAggregateTerm = ${Utils.isTesting} && $isPartial; | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); From 05c891f4961b60260eff4b8a69680a6cbdc0f163 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 17 Jun 2020 22:21:36 -0700 Subject: [PATCH 12/33] Fix: Address review comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 4 ++-- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 332e0773b2ad6..4f01ae8b1c96f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2196,8 +2196,8 @@ object SQLConf { .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") .createWithDefault(16) - val SPILL_PARTIAL_AGGREGATE_DISABLED = - buildConf("spark.sql.aggregate.spill.partialaggregate.disabled") + val SKIP_PARTIAL_AGGREGATE_ENABLED = + buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") .internal() .doc("Avoid sort/spill to disk during partial aggregation") .booleanConf @@ -2929,7 +2929,7 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) - def spillInPartialAggregationDisabled: Boolean = getConf(SPILL_PARTIAL_AGGREGATE_DISABLED) + def spillInPartialAggregationDisabled: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f5b5f2b1cf442..2130b9724c50a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -74,8 +74,8 @@ case class HashAggregateExec( "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), - "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, "Num records" + - " skipped partial aggregation skipped"), + "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, + "number of skipped records for partial aggregates"), "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f378f8f9993ce..105ca4a74bcbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -52,7 +52,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } test(s"Avoid spill in partial aggregation" ) { - withSQLConf((SQLConf.SPILL_PARTIAL_AGGREGATE_DISABLED.key, "true")) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { // Create Dataframes val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) val aggDF = data.toDF("name", "values").groupBy("name").sum("values") From dd3c56a77ba1c3a72dd53ea513b448432d635505 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 19 Jun 2020 11:08:28 -0700 Subject: [PATCH 13/33] Fix: UT fixes, refactoring --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- .../org/apache/spark/sql/execution/aggregate/AggUtils.scala | 4 ++-- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f01ae8b1c96f..6e1bae72835b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2197,7 +2197,7 @@ object SQLConf { .createWithDefault(16) val SKIP_PARTIAL_AGGREGATE_ENABLED = - buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") + buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") .internal() .doc("Avoid sort/spill to disk during partial aggregation") .booleanConf @@ -2929,7 +2929,7 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) - def spillInPartialAggregationDisabled: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) + def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 90ee0602eb3a9..90403cee4887f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -354,7 +354,7 @@ object AggUtils { finalAndCompleteAggregate :: Nil } - def areAggExpressionsPartial(exprs: Seq[AggregateExpression]): Boolean = { - exprs.forall(e => e.mode == Partial) + def areAggExpressionsPartial(modes: Seq[AggregateMode]): Boolean = { + modes.nonEmpty && modes.forall(_ == Partial) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 2130b9724c50a..8c07fe75217c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -63,7 +63,7 @@ case class HashAggregateExec( require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) - override def needStopCheck: Boolean = sqlContext.conf.spillInPartialAggregationDisabled + override def needStopCheck: Boolean = sqlContext.conf.skipPartialAggregate override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ @@ -413,7 +413,7 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false - private val isPartial = AggUtils.areAggExpressionsPartial(aggregateExpressions) + private val isPartial = AggUtils.areAggExpressionsPartial(modes) private var avoidSpillInPartialAggregateTerm: String = _ private var childrenConsumed: String = _ private var outputFunc: String = _ @@ -900,7 +900,7 @@ case class HashAggregateExec( val oomeClassName = classOf[SparkOutOfMemoryError].getName val thisPlan = ctx.addReferenceObj("plan", this) - val spillInPartialAggregateDisabled = sqlContext.conf.spillInPartialAggregationDisabled + val spillInPartialAggregateDisabled = sqlContext.conf.skipPartialAggregate val findOrInsertRegularHashMap: String = s""" From cb8b922e64549240ccc5f9929d2afe37cadd39b5 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 19 Jun 2020 11:52:30 -0700 Subject: [PATCH 14/33] Fix: fix indent --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6e1bae72835b8..a8765a406ec02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2197,7 +2197,7 @@ object SQLConf { .createWithDefault(16) val SKIP_PARTIAL_AGGREGATE_ENABLED = - buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") + buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") .internal() .doc("Avoid sort/spill to disk during partial aggregation") .booleanConf From 7952aa707a91bac1df9e84ad17acacd006c7e850 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 19 Jun 2020 12:01:56 -0700 Subject: [PATCH 15/33] UT: Add more test --- .../execution/WholeStageCodegenSuite.scala | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 105ca4a74bcbc..087bb0a1517d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -57,8 +57,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) val aggDF = data.toDF("name", "values").groupBy("name").sum("values") val partAggNode = aggDF.queryExecution.executedPlan.find { - case h: HashAggregateExec - if AggUtils.areAggExpressionsPartial(h.aggregateExpressions) => true + case h: HashAggregateExec => + AggUtils.areAggExpressionsPartial(h.aggregateExpressions.map(_.mode)) case _ => false } checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) @@ -69,6 +69,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } + test(s"Partial aggregation should not happen when no Aggregate expr" ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { + val aggDF = testData2.select(sumDistinct($"a")) + val aggNodes = aggDF.queryExecution.executedPlan.collect { + case h: HashAggregateExec => h + } + checkAnswer(aggDF, Row(6)) + assert(aggNodes.nonEmpty) + Thread.sleep(1000000) + assert(aggNodes.forall(_.metrics("partialAggSkipped").value == 0)) + } + } + + test(s"Distinct: Partial aggregation should happen for" + + s" HashAggregate nodes performing partial Aggregate operations " ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { + val aggDF = testData2.select(sumDistinct($"a"), sum($"b")) + val aggNodes = aggDF.queryExecution.executedPlan.collect { + case h: HashAggregateExec => h + } + val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf[SerializeFromObjectExec]) + checkAnswer(aggDF, Row(6, 9)) + assert(baseNodes.size == 1 ) + assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count()) + assert(other.forall(_.metrics("partialAggSkipped").value == 0)) + } + } test("Aggregate with grouping keys should be included in WholeStageCodegen") { val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) From 56c95e242126d7aacdb4862adc5e094b4e29561b Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 19 Jun 2020 16:16:00 -0700 Subject: [PATCH 16/33] Fix UT attempt --- .../sql/execution/aggregate/HashAggregateExec.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8c07fe75217c5..ce08855f53963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -63,7 +63,7 @@ case class HashAggregateExec( require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) - override def needStopCheck: Boolean = sqlContext.conf.skipPartialAggregate + override def needStopCheck: Boolean = skipPartialAggregate override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ @@ -415,6 +415,7 @@ case class HashAggregateExec( private val isPartial = AggUtils.areAggExpressionsPartial(modes) private var avoidSpillInPartialAggregateTerm: String = _ + private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate private var childrenConsumed: String = _ private var outputFunc: String = _ @@ -847,7 +848,8 @@ case class HashAggregateExec( s""" |if (!$initAgg) { | $initAgg = true; - | $avoidSpillInPartialAggregateTerm = ${Utils.isTesting} && $isPartial; + | $avoidSpillInPartialAggregateTerm = ${Utils.isTesting} + | && $isPartial && $skipPartialAggregate; | $createFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); @@ -899,8 +901,7 @@ case class HashAggregateExec( val oomeClassName = classOf[SparkOutOfMemoryError].getName - val thisPlan = ctx.addReferenceObj("plan", this) - val spillInPartialAggregateDisabled = sqlContext.conf.skipPartialAggregate + val findOrInsertRegularHashMap: String = s""" @@ -918,7 +919,7 @@ case class HashAggregateExec( | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | // If sort/spill to disk is disabled, nothing is done. | // Aggregation buffer is created later - | if ($spillInPartialAggregateDisabled && $isPartial) { + | if ($skipPartialAggregate && $isPartial) { | $avoidSpillInPartialAggregateTerm = true; | } else { | if ($sorterTerm == null) { From 43237ba31354f1816df7b98a881e52863d1c65a9 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Sat, 20 Jun 2020 11:37:46 -0700 Subject: [PATCH 17/33] Enabling the conf to runn all tests with the feature --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a8765a406ec02..ab9fb6654aec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2201,7 +2201,7 @@ object SQLConf { .internal() .doc("Avoid sort/spill to disk during partial aggregation") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + From 99c1d2226d170f789dfa534ffc658f7fc430c38d Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Tue, 23 Jun 2020 19:54:23 -0700 Subject: [PATCH 18/33] Unit test fix attempt --- .../execution/aggregate/HashAggregateExec.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index ce08855f53963..6d4f31601a7e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -413,9 +413,9 @@ case class HashAggregateExec( private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false - private val isPartial = AggUtils.areAggExpressionsPartial(modes) private var avoidSpillInPartialAggregateTerm: String = _ - private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate + private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate && + AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty private var childrenConsumed: String = _ private var outputFunc: String = _ @@ -638,6 +638,8 @@ case class HashAggregateExec( |${consume(ctx, resultVars)} """.stripMargin } + + ctx.addNewFunction(funcName, s""" |private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) @@ -848,8 +850,8 @@ case class HashAggregateExec( s""" |if (!$initAgg) { | $initAgg = true; - | $avoidSpillInPartialAggregateTerm = ${Utils.isTesting} - | && $isPartial && $skipPartialAggregate; + | $avoidSpillInPartialAggregateTerm = + | ${Utils.isTesting} && $skipPartialAggregate; | $createFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); @@ -900,9 +902,6 @@ case class HashAggregateExec( } val oomeClassName = classOf[SparkOutOfMemoryError].getName - - - val findOrInsertRegularHashMap: String = s""" |if (!$avoidSpillInPartialAggregateTerm) { @@ -919,7 +918,7 @@ case class HashAggregateExec( | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | // If sort/spill to disk is disabled, nothing is done. | // Aggregation buffer is created later - | if ($skipPartialAggregate && $isPartial) { + | if ($skipPartialAggregate) { | $avoidSpillInPartialAggregateTerm = true; | } else { | if ($sorterTerm == null) { From d2873a3fb26280f2a81bd4180debf538c707484d Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Tue, 23 Jun 2020 23:26:53 -0700 Subject: [PATCH 19/33] UT fix attmpt --- .../execution/metric/SQLMetricsSuite.scala | 87 ++++++++++--------- 1 file changed, 46 insertions(+), 41 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 50652690339a8..916772ee5991d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -142,48 +142,53 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("Aggregate metrics: track avg probe") { - // The executed plan looks like: - // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) - // +- Exchange hashpartitioning(a#61, 5) - // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) - // +- Exchange RoundRobinPartitioning(1) - // +- LocalTableScan [a#61] - // - // Assume the execution plan with node id is: - // Wholestage disabled: - // HashAggregate(nodeId = 0) - // Exchange(nodeId = 1) - // HashAggregate(nodeId = 2) - // Exchange (nodeId = 3) - // LocalTableScan(nodeId = 4) - // - // Wholestage enabled: - // WholeStageCodegen(nodeId = 0) - // HashAggregate(nodeId = 1) - // Exchange(nodeId = 2) - // WholeStageCodegen(nodeId = 3) - // HashAggregate(nodeId = 4) - // Exchange(nodeId = 5) - // LocalTableScan(nodeId = 6) - Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(1).groupBy('a).count() - val nodeIds = if (enableWholeStage) { - Set(4L, 1L) - } else { - Set(2L, 0L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString - if (!probes.contains("\n")) { - // It's a single metrics value - assert(probes.toDouble > 1.0) + val skipPartialAgg = spark.sessionState.conf.getConf(SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED) + if (skipPartialAgg) { + logInfo("Skipping, since partial Aggregation is disabled") + } else { + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) } else { - val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") - // Extract min, med, max from the string and strip off everthing else. - val index = mainValue.indexOf(" (", 0) - mainValue.slice(0, index).split(", ").foreach { - probe => assert(probe.toDouble > 1.0) + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + if (!probes.contains("\n")) { + // It's a single metrics value + assert(probes.toDouble > 1.0) + } else { + val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") + // Extract min, med, max from the string and strip off everthing else. + val index = mainValue.indexOf(" (", 0) + mainValue.slice(0, index).split(", ").foreach { + probe => assert(probe.toDouble > 1.0) + } } } } From afc2903e4a327d6caef518e6d3f0dc431424ac7c Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 24 Jun 2020 17:05:18 -0700 Subject: [PATCH 20/33] Ut fix attempt --- .../aggregate/HashAggregateExec.scala | 3 +- .../execution/WholeStageCodegenSuite.scala | 1 - .../execution/AggregationQuerySuite.scala | 59 ++++++++++--------- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6d4f31601a7e2..42cbac782129b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -416,7 +416,6 @@ case class HashAggregateExec( private var avoidSpillInPartialAggregateTerm: String = _ private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate && AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty - private var childrenConsumed: String = _ private var outputFunc: String = _ // whether a vectorized hashmap is used instead @@ -694,7 +693,7 @@ case class HashAggregateExec( val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") avoidSpillInPartialAggregateTerm = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") - childrenConsumed = ctx. + val childrenConsumed = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 087bb0a1517d5..387cb9455a6f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -77,7 +77,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } checkAnswer(aggDF, Row(6)) assert(aggNodes.nonEmpty) - Thread.sleep(1000000) assert(aggNodes.forall(_.metrics("partialAggSkipped").value == 0)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index fac981267f4d7..19160c09ea012 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1057,38 +1057,43 @@ class HashAggregationQuerySuite extends AggregationQuerySuite class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> - enableTwoLevelMaps) { - Seq(4, 8).foreach { uaoSize => - UnsafeAlignedOffset.setUaoSize(uaoSize) - (1 to 3).foreach { fallbackStartsAt => - withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> - s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { - // Create a new df to make sure its physical operator picks up - // spark.sql.TungstenAggregate.testFallbackStartsAt. - // todo: remove it? - val newActual = Dataset.ofRows(spark, actual.logicalPlan) - - QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { - case Some(errorMessage) => - val newErrorMessage = - s""" - |The following aggregation query failed when using HashAggregate with - |controlled fallback (it falls back to bytes to bytes map once it has - |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation - |once it has processed $fallbackStartsAt input rows). - |The query is ${actual.queryExecution} - |$errorMessage + // The HashAggregationQueryWithControlledFallbackSuite is dependent on ordering and also + // assumes partial aggregation to have happened. + // disabling the flag that skips partial aggregation + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "false")) { + Seq("true", "false").foreach { enableTwoLevelMaps => + withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key -> + enableTwoLevelMaps) { + Seq(4, 8).foreach { uaoSize => + UnsafeAlignedOffset.setUaoSize(uaoSize) + (1 to 3).foreach { fallbackStartsAt => + withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> + s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") { + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? + val newActual = Dataset.ofRows(spark, actual.logicalPlan) + + QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using HashAggregate with + |controlled fallback (it falls back to bytes to bytes map once it has + |processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation + |once it has processed $fallbackStartsAt input rows). + |The query is ${actual.queryExecution} + |$errorMessage """.stripMargin - fail(newErrorMessage) - case None => // Success + fail(newErrorMessage) + case None => // Success + } } } + // reset static uaoSize to avoid affect other tests + UnsafeAlignedOffset.setUaoSize(0) } - // reset static uaoSize to avoid affect other tests - UnsafeAlignedOffset.setUaoSize(0) } } } From 7766401f81dbde6d2941cacd69b51ad9acc5d855 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 3 Jul 2020 08:23:27 -0700 Subject: [PATCH 21/33] Add heuristic --- .../org/apache/spark/sql/internal/SQLConf.scala | 16 ++++++++++++++++ .../execution/aggregate/HashAggregateExec.scala | 13 ++++++++++++- .../execution/aggregate/HashMapGenerator.scala | 8 ++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ab9fb6654aec0..2e1dd89905a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2203,6 +2203,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SKIP_PARTIAL_AGGREGATE_THRESHOLD = + buildConf("spark.sql.aggregate.partialaggregate.skip.threshold") + .internal() + .longConf + .createWithDefault(100000) + + val SKIP_PARTIAL_AGGREGATE_RATIO = + buildConf("spark.sql.aggregate.partialaggregate.skip.ratio") + .internal() + .doubleConf + .createWithDefault(0.5) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") @@ -2931,6 +2943,10 @@ class SQLConf extends Serializable with Logging { def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) + def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_THRESHOLD) + + def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_RATIO) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 42cbac782129b..a3bfa2f653866 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -416,6 +416,7 @@ case class HashAggregateExec( private var avoidSpillInPartialAggregateTerm: String = _ private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate && AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty + private var rowCountTerm: String = _ private var outputFunc: String = _ // whether a vectorized hashmap is used instead @@ -900,7 +901,11 @@ case class HashAggregateExec( ("true", "true", "", "") } + val skipPartialAggregateThreshold = sqlContext.conf.skipPartialAggregateThreshold + val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio + val oomeClassName = classOf[SparkOutOfMemoryError].getName + val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val findOrInsertRegularHashMap: String = s""" |if (!$avoidSpillInPartialAggregateTerm) { @@ -917,7 +922,11 @@ case class HashAggregateExec( | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | // If sort/spill to disk is disabled, nothing is done. | // Aggregation buffer is created later - | if ($skipPartialAggregate) { + | $countTerm = $countTerm + $hashMapTerm.getNumRows(); + | boolean skipPartAgg = + | !($rowCountTerm < $skipPartialAggregateThreshold) && + | ($countTerm/$rowCountTerm) > $skipPartialAggRatio; + | if ($skipPartialAggregate && skipPartAgg) { | $avoidSpillInPartialAggregateTerm = true; | } else { | if ($sorterTerm == null) { @@ -940,6 +949,7 @@ case class HashAggregateExec( """.stripMargin val partTerm = metricTerm(ctx, "partialAggSkipped") + val findOrInsertHashMap: String = { val insertCode = if (isFastHashMapEnabled) { // If fast hash map is on, we first generate code to probe and update the fast hash map. @@ -954,6 +964,7 @@ case class HashAggregateExec( |} |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) { + | $countTerm = $countTerm + $fastHashMapTerm.getNumRows(); | $findOrInsertRegularHashMap |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index e1c85823259b1..94ea91d23e33f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -136,6 +136,14 @@ abstract class HashMapGenerator( """.stripMargin } + protected final def generateNumRows(): String = { + s""" + |public int getNumRows() { + | return batch.numRows(); + |} + """.stripMargin + } + protected final def genComputeHash( ctx: CodegenContext, input: String, From 75125d95ee1e011dd2348273dd84de453a93cd92 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Mon, 6 Jul 2020 11:57:59 -0700 Subject: [PATCH 22/33] Fix: Include missing change, remove unnecessary changes, handle comments --- .../UnsafeFixedWidthAggregationMap.java | 14 +++ .../aggregate/HashAggregateExec.scala | 10 +-- .../aggregate/HashMapGenerator.scala | 2 + .../execution/WholeStageCodegenSuite.scala | 3 +- .../execution/metric/SQLMetricsSuite.scala | 87 +++++++++---------- 5 files changed, 64 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 117e98f33a0ec..061e09543ac83 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -63,6 +63,14 @@ public final class UnsafeFixedWidthAggregationMap { */ private final UnsafeRow currentAggregationBuffer; + /** + * Number of rows that were added to the map + * This includes the elements that were passed on sorter + * using {@link #destructAndCreateExternalSorter()} + * + */ + private long numRowsAdded = 0L; + /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. @@ -147,6 +155,8 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow key, int hash) { ); if (!putSucceeded) { return null; + } else { + numRowsAdded = numRowsAdded + 1; } } @@ -249,4 +259,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), map); } + + public long getNumRows() { + return numRowsAdded; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a3bfa2f653866..b9c001b22bca6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -638,8 +638,6 @@ case class HashAggregateExec( |${consume(ctx, resultVars)} """.stripMargin } - - ctx.addNewFunction(funcName, s""" |private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) @@ -696,6 +694,8 @@ case class HashAggregateExec( addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") val childrenConsumed = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") + rowCountTerm = ctx. + addMutableState(CodeGenerator.JAVA_LONG, "rowCount") if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else if (sqlContext.conf.enableVectorizedHashMap) { @@ -857,11 +857,11 @@ case class HashAggregateExec( | long $beforeAgg = System.nanoTime(); | $doAggFuncName(); | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); - | $shouldStopCheckCode; + | if (shouldStop()) return; |} |if (!$childrenConsumed) { | $doAggFuncName(); - | $shouldStopCheckCode; + | if (shouldStop()) return; |} |// output the result |$outputFromFastHashMap @@ -904,8 +904,8 @@ case class HashAggregateExec( val skipPartialAggregateThreshold = sqlContext.conf.skipPartialAggregateThreshold val skipPartialAggRatio = sqlContext.conf.skipPartialAggregateRatio - val oomeClassName = classOf[SparkOutOfMemoryError].getName val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") + val oomeClassName = classOf[SparkOutOfMemoryError].getName val findOrInsertRegularHashMap: String = s""" |if (!$avoidSpillInPartialAggregateTerm) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 94ea91d23e33f..d7b01ef39a654 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -75,6 +75,8 @@ abstract class HashMapGenerator( | |${generateRowIterator()} | + |${generateNumRows()} + | |${generateClose()} |} """.stripMargin diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 387cb9455a6f8..ee4a589ad4c02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -52,7 +52,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } test(s"Avoid spill in partial aggregation" ) { - withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true"), + ("spark.sql.aggregate.partialaggregate.skip.threshold", "2")) { // Create Dataframes val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) val aggDF = data.toDF("name", "values").groupBy("name").sum("values") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 916772ee5991d..50652690339a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -142,53 +142,48 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("Aggregate metrics: track avg probe") { - val skipPartialAgg = spark.sessionState.conf.getConf(SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED) - if (skipPartialAgg) { - logInfo("Skipping, since partial Aggregation is disabled") - } else { - // The executed plan looks like: - // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) - // +- Exchange hashpartitioning(a#61, 5) - // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) - // +- Exchange RoundRobinPartitioning(1) - // +- LocalTableScan [a#61] - // - // Assume the execution plan with node id is: - // Wholestage disabled: - // HashAggregate(nodeId = 0) - // Exchange(nodeId = 1) - // HashAggregate(nodeId = 2) - // Exchange (nodeId = 3) - // LocalTableScan(nodeId = 4) - // - // Wholestage enabled: - // WholeStageCodegen(nodeId = 0) - // HashAggregate(nodeId = 1) - // Exchange(nodeId = 2) - // WholeStageCodegen(nodeId = 3) - // HashAggregate(nodeId = 4) - // Exchange(nodeId = 5) - // LocalTableScan(nodeId = 6) - Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(1).groupBy('a).count() - val nodeIds = if (enableWholeStage) { - Set(4L, 1L) + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) + } else { + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + if (!probes.contains("\n")) { + // It's a single metrics value + assert(probes.toDouble > 1.0) } else { - Set(2L, 0L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString - if (!probes.contains("\n")) { - // It's a single metrics value - assert(probes.toDouble > 1.0) - } else { - val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") - // Extract min, med, max from the string and strip off everthing else. - val index = mainValue.indexOf(" (", 0) - mainValue.slice(0, index).split(", ").foreach { - probe => assert(probe.toDouble > 1.0) - } + val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") + // Extract min, med, max from the string and strip off everthing else. + val index = mainValue.indexOf(" (", 0) + mainValue.slice(0, index).split(", ").foreach { + probe => assert(probe.toDouble > 1.0) } } } From 3ca81ae8d381509f33c49dd3a81f57856e5bd264 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 8 Jul 2020 12:16:57 -0700 Subject: [PATCH 23/33] Refactor: avoid additional code on reducer, fix tests, --- .../aggregate/HashAggregateExec.scala | 205 ++++++++++++------ .../execution/metric/SQLMetricsSuite.scala | 87 ++++---- 2 files changed, 180 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b9c001b22bca6..7b01be472155e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -63,8 +63,6 @@ case class HashAggregateExec( require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) - override def needStopCheck: Boolean = skipPartialAggregate - override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) @@ -690,8 +688,12 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") - avoidSpillInPartialAggregateTerm = ctx. - addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate") + if (skipPartialAggregate) { + avoidSpillInPartialAggregateTerm = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, + "avoidPartialAggregate", + term => s"$term = ${Utils.isTesting};") + } val childrenConsumed = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") rowCountTerm = ctx. @@ -850,8 +852,6 @@ case class HashAggregateExec( s""" |if (!$initAgg) { | $initAgg = true; - | $avoidSpillInPartialAggregateTerm = - | ${Utils.isTesting} && $skipPartialAggregate; | $createFastHashMap | $hashMapTerm = $thisPlan.createHashMap(); | long $beforeAgg = System.nanoTime(); @@ -869,6 +869,8 @@ case class HashAggregateExec( """.stripMargin } + override def needStopCheck: Boolean = skipPartialAggregate + private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( @@ -906,82 +908,133 @@ case class HashAggregateExec( val countTerm = ctx.addMutableState(CodeGenerator.JAVA_LONG, "count") val oomeClassName = classOf[SparkOutOfMemoryError].getName - val findOrInsertRegularHashMap: String = - s""" - |if (!$avoidSpillInPartialAggregateTerm) { - | // generate grouping key - | ${unsafeRowKeyCode.code} - | int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); - | if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); - | } - | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - | // aggregation after processing all input rows. - | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { - | // If sort/spill to disk is disabled, nothing is done. - | // Aggregation buffer is created later - | $countTerm = $countTerm + $hashMapTerm.getNumRows(); - | boolean skipPartAgg = - | !($rowCountTerm < $skipPartialAggregateThreshold) && - | ($countTerm/$rowCountTerm) > $skipPartialAggRatio; - | if ($skipPartialAggregate && skipPartAgg) { - | $avoidSpillInPartialAggregateTerm = true; - | } else { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, $unsafeRowKeyHash); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); - | } - | } - | } - |} - """.stripMargin + val findOrInsertRegularHashMap: String = { + def getAggBufferFromMap = { + s""" + |// generate grouping key + |${unsafeRowKeyCode.code} + |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); + |if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); + |} + """.stripMargin + } - val partTerm = metricTerm(ctx, "partialAggSkipped") + def addToSorter: String = { + s""" + |if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + |} else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + |} + |$resetCounter + |// the hash map had be spilled, it should have enough memory now, + |// try to allocate buffer again. + |$unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, $unsafeRowKeyHash); + |if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new $oomeClassName("No enough memory for aggregation"); + |}""".stripMargin + } + + def getHeuristicToAvoidAgg: String = { + s""" + |!($rowCountTerm < $skipPartialAggregateThreshold) && + | ($countTerm/$rowCountTerm) > $skipPartialAggRatio; + |""".stripMargin + } + + if (skipPartialAggregate) { + s""" + |if (!$avoidSpillInPartialAggregateTerm) { + | $getAggBufferFromMap + | // Can't allocate buffer from the hash map. + | // Check if we can avoid partial aggregation. Otherwise, Spill the map and fallback to sort-based + | // aggregation after processing all input rows. + | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { + | $countTerm = $countTerm + $hashMapTerm.getNumRows(); + | boolean skipPartAgg = $getHeuristicToAvoidAgg + | if (skipPartAgg) { + | // Aggregation buffer is created later + | $avoidSpillInPartialAggregateTerm = true; + | } else { + | $addToSorter + | } + | } + |} + """.stripMargin + } else { + s""" + | $getAggBufferFromMap + | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + | // aggregation after processing all input rows. + | if ($unsafeRowBuffer == null) { + | $addToSorter + | } + """.stripMargin + } + } val findOrInsertHashMap: String = { val insertCode = if (isFastHashMapEnabled) { + def findOrInsertIntoFastHashMap = { + s""" + |${fastRowKeys.map(_.code).mkString("\n")} + |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); + |} + |""".stripMargin + } + val insertFastMap = if (skipPartialAggregate) { + s""" + |if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { + | $findOrInsertIntoFastHashMap + |} + |$countTerm = $fastHashMapTerm.getNumRows(); + |""".stripMargin + } else { + s""" + |if ($checkFallbackForGeneratedHashMap) { + | $findOrInsertIntoFastHashMap + |} + |""".stripMargin + } // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" - |if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } - |} + |$insertFastMap |// Cannot find the key in fast hash map, try regular hash map. |if ($fastRowBuffer == null) { - | $countTerm = $countTerm + $fastHashMapTerm.getNumRows(); | $findOrInsertRegularHashMap |} """.stripMargin } else { findOrInsertRegularHashMap } - val initExpr = declFunctions.flatMap(f => f.initialValues) - val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) + def createEmptyAggBufferAndUpdateMetrics: String = { + if (skipPartialAggregate) { + val partTerm = metricTerm(ctx, "partialAggSkipped") + val initExpr = declFunctions.flatMap(f => f.initialValues) + val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) + s""" + |// Create an empty aggregation buffer + |if ($avoidSpillInPartialAggregateTerm) { + | ${unsafeRowKeyCode.code} + | ${emptyBufferKeyCode.code} + | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; + | $partTerm.add(1); + |} + |""".stripMargin + } else "" + } + s""" |$insertCode - |// Create an empty aggregation buffer - |if ($avoidSpillInPartialAggregateTerm) { - | ${unsafeRowKeyCode.code} - | ${emptyBufferKeyCode.code} - | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; - | $partTerm.add(1); - |} + |$createEmptyAggBufferAndUpdateMetrics |""".stripMargin } @@ -1059,7 +1112,7 @@ case class HashAggregateExec( } val updateRowInHashMap: String = { - val updateRowinMap = if (isFastHashMapEnabled) { + val updateRowInMap = if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc => @@ -1134,12 +1187,22 @@ case class HashAggregateExec( } else { updateRowInRegularHashMap } + + def outputRow: String = { + if (skipPartialAggregate) { + s""" + |if ($avoidSpillInPartialAggregateTerm) { + | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); + |} + |""".stripMargin + } else "" + } + s""" - |$updateRowinMap - |if ($avoidSpillInPartialAggregateTerm) { - | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); - |} - |""".stripMargin + |$updateRowInMap + |$outputRow + |$rowCountTerm = $rowCountTerm + 1; + |""".stripMargin } val declareRowBuffer: String = if (isFastHashMapEnabled) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 50652690339a8..d1220e6e3ca1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -142,52 +142,57 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("Aggregate metrics: track avg probe") { - // The executed plan looks like: - // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) - // +- Exchange hashpartitioning(a#61, 5) - // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) - // +- Exchange RoundRobinPartitioning(1) - // +- LocalTableScan [a#61] - // - // Assume the execution plan with node id is: - // Wholestage disabled: - // HashAggregate(nodeId = 0) - // Exchange(nodeId = 1) - // HashAggregate(nodeId = 2) - // Exchange (nodeId = 3) - // LocalTableScan(nodeId = 4) - // - // Wholestage enabled: - // WholeStageCodegen(nodeId = 0) - // HashAggregate(nodeId = 1) - // Exchange(nodeId = 2) - // WholeStageCodegen(nodeId = 3) - // HashAggregate(nodeId = 4) - // Exchange(nodeId = 5) - // LocalTableScan(nodeId = 6) - Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(1).groupBy('a).count() - val nodeIds = if (enableWholeStage) { - Set(4L, 1L) - } else { - Set(2L, 0L) - } - val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get - nodeIds.foreach { nodeId => - val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString - if (!probes.contains("\n")) { - // It's a single metrics value - assert(probes.toDouble > 1.0) + if (spark.sessionState.conf.getConf(SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED)) { + logInfo("Skipping, since partial Aggregation is disabled") + } else { + // The executed plan looks like: + // HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L]) + // +- Exchange hashpartitioning(a#61, 5) + // +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L]) + // +- Exchange RoundRobinPartitioning(1) + // +- LocalTableScan [a#61] + // + // Assume the execution plan with node id is: + // Wholestage disabled: + // HashAggregate(nodeId = 0) + // Exchange(nodeId = 1) + // HashAggregate(nodeId = 2) + // Exchange (nodeId = 3) + // LocalTableScan(nodeId = 4) + // + // Wholestage enabled: + // WholeStageCodegen(nodeId = 0) + // HashAggregate(nodeId = 1) + // Exchange(nodeId = 2) + // WholeStageCodegen(nodeId = 3) + // HashAggregate(nodeId = 4) + // Exchange(nodeId = 5) + // LocalTableScan(nodeId = 6) + Seq(true, false).foreach { enableWholeStage => + val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val nodeIds = if (enableWholeStage) { + Set(4L, 1L) } else { - val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") - // Extract min, med, max from the string and strip off everthing else. - val index = mainValue.indexOf(" (", 0) - mainValue.slice(0, index).split(", ").foreach { - probe => assert(probe.toDouble > 1.0) + Set(2L, 0L) + } + val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get + nodeIds.foreach { nodeId => + val probes = metrics(nodeId)._2("avg hash probe bucket list iters").toString + if (!probes.contains("\n")) { + // It's a single metrics value + assert(probes.toDouble > 1.0) + } else { + val mainValue = probes.split("\n").apply(1).stripPrefix("(").stripSuffix(")") + // Extract min, med, max from the string and strip off everthing else. + val index = mainValue.indexOf(" (", 0) + mainValue.slice(0, index).split(", ").foreach { + probe => assert(probe.toDouble > 1.0) + } } } } } + } test("ObjectHashAggregate metrics") { From 8850777f41dbf6d0a526a8409c426f156e7f7fa7 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 8 Jul 2020 12:56:31 -0700 Subject: [PATCH 24/33] gst --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 7b01be472155e..df543603ad7e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -952,7 +952,8 @@ case class HashAggregateExec( |if (!$avoidSpillInPartialAggregateTerm) { | $getAggBufferFromMap | // Can't allocate buffer from the hash map. - | // Check if we can avoid partial aggregation. Otherwise, Spill the map and fallback to sort-based + | // Check if we can avoid partial aggregation. + | // Otherwise, Spill the map and fallback to sort-based | // aggregation after processing all input rows. | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { | $countTerm = $countTerm + $hashMapTerm.getNumRows(); From 26a2fd63c9410d274a8ed26aeccff0e0a6a4f79f Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Wed, 15 Jul 2020 15:15:17 -0700 Subject: [PATCH 25/33] Address review comments --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++--- .../sql/execution/aggregate/HashAggregateExec.scala | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1dd89905a65..16081c2bf785e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2199,19 +2199,23 @@ object SQLConf { val SKIP_PARTIAL_AGGREGATE_ENABLED = buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") .internal() - .doc("Avoid sort/spill to disk during partial aggregation") + .doc("Avoid sorter(sort/spill) during partial aggregation") .booleanConf .createWithDefault(true) val SKIP_PARTIAL_AGGREGATE_THRESHOLD = buildConf("spark.sql.aggregate.partialaggregate.skip.threshold") .internal() + .doc("Number of records after which aggregate operator checks if " + + "partial aggregation phase can be avoided") .longConf .createWithDefault(100000) - val SKIP_PARTIAL_AGGREGATE_RATIO = + val SKIP_PARTIAL_AGGREGATE_REDUCTION_RATIO = buildConf("spark.sql.aggregate.partialaggregate.skip.ratio") .internal() + .doc("Ratio of number of records present in map of Aggregate operator" + + "to the total number of records processed by the Aggregate operator") .doubleConf .createWithDefault(0.5) @@ -2945,7 +2949,7 @@ class SQLConf extends Serializable with Logging { def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_THRESHOLD) - def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_RATIO) + def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_REDUCTION_RATIO) def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index df543603ad7e3..81b901d49d723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -943,7 +943,7 @@ case class HashAggregateExec( def getHeuristicToAvoidAgg: String = { s""" |!($rowCountTerm < $skipPartialAggregateThreshold) && - | ($countTerm/$rowCountTerm) > $skipPartialAggRatio; + | ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio; |""".stripMargin } From c08881600cb837f2d3d2e5ffa16cc367b14c55de Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 6 Aug 2020 16:54:42 -0700 Subject: [PATCH 26/33] Address review commenst --- .../apache/spark/sql/internal/SQLConf.scala | 20 +++-- .../aggregate/HashAggregateExec.scala | 74 ++++++++++--------- .../execution/WholeStageCodegenSuite.scala | 4 +- .../execution/metric/SQLMetricsSuite.scala | 1 - 4 files changed, 53 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 16081c2bf785e..841ffff8358ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2197,22 +2197,26 @@ object SQLConf { .createWithDefault(16) val SKIP_PARTIAL_AGGREGATE_ENABLED = - buildConf("spark.sql.aggregate.partialaggregate.skip.enabled") + buildConf("spark.sql.aggregate.skipPartialAggregate") .internal() - .doc("Avoid sorter(sort/spill) during partial aggregation") + .doc("When enabled, the partial aggregation is skipped when the following" + + "two conditions are met. 1. When the total number of records processed is greater" + + s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + + "of recornd count in map to the total records is less that value defined by " + + s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") .booleanConf .createWithDefault(true) - val SKIP_PARTIAL_AGGREGATE_THRESHOLD = - buildConf("spark.sql.aggregate.partialaggregate.skip.threshold") + val SKIP_PARTIAL_AGGREGATE_MINROWS = + buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") .internal() .doc("Number of records after which aggregate operator checks if " + "partial aggregation phase can be avoided") .longConf .createWithDefault(100000) - val SKIP_PARTIAL_AGGREGATE_REDUCTION_RATIO = - buildConf("spark.sql.aggregate.partialaggregate.skip.ratio") + val SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO = + buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") .internal() .doc("Ratio of number of records present in map of Aggregate operator" + "to the total number of records processed by the Aggregate operator") @@ -2947,9 +2951,9 @@ class SQLConf extends Serializable with Logging { def skipPartialAggregate: Boolean = getConf(SKIP_PARTIAL_AGGREGATE_ENABLED) - def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_THRESHOLD) + def skipPartialAggregateThreshold: Long = getConf(SKIP_PARTIAL_AGGREGATE_MINROWS) - def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_REDUCTION_RATIO) + def skipPartialAggregateRatio: Double = getConf(SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO) def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 81b901d49d723..c5607eb949ea6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -67,15 +67,22 @@ case class HashAggregateExec( child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), - "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), - "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), - "partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, - "number of skipped records for partial aggregates"), - "avgHashProbe" -> - SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) + override lazy val metrics = { + val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), + "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in aggregation build"), + "avgHashProbe" -> + SQLMetrics.createAverageMetric(sparkContext, "avg hash probe bucket list iters")) + if (skipPartialAggregateEnabled) { + metrics ++ Map("partialAggSkipped" -> SQLMetrics.createMetric(sparkContext, + "number of skipped records for partial aggregates")) + } else { + metrics + } + } + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -412,7 +419,7 @@ case class HashAggregateExec( private var isFastHashMapEnabled: Boolean = false private var avoidSpillInPartialAggregateTerm: String = _ - private val skipPartialAggregate = sqlContext.conf.skipPartialAggregate && + private val skipPartialAggregateEnabled = sqlContext.conf.skipPartialAggregate && AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty private var rowCountTerm: String = _ private var outputFunc: String = _ @@ -688,7 +695,7 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") - if (skipPartialAggregate) { + if (skipPartialAggregateEnabled) { avoidSpillInPartialAggregateTerm = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate", @@ -781,7 +788,6 @@ case class HashAggregateExec( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { @@ -869,7 +875,7 @@ case class HashAggregateExec( """.stripMargin } - override def needStopCheck: Boolean = skipPartialAggregate + override def needStopCheck: Boolean = skipPartialAggregateEnabled private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key @@ -940,24 +946,22 @@ case class HashAggregateExec( |}""".stripMargin } - def getHeuristicToAvoidAgg: String = { - s""" - |!($rowCountTerm < $skipPartialAggregateThreshold) && - | ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio; - |""".stripMargin - } - - if (skipPartialAggregate) { - s""" + if (skipPartialAggregateEnabled) { + val checkIfPartialAggSkipped = + s""" + |!($rowCountTerm < $skipPartialAggregateThreshold) && + | ((float)$countTerm/$rowCountTerm) > $skipPartialAggRatio; + |""".stripMargin + s""" |if (!$avoidSpillInPartialAggregateTerm) { | $getAggBufferFromMap | // Can't allocate buffer from the hash map. | // Check if we can avoid partial aggregation. | // Otherwise, Spill the map and fallback to sort-based | // aggregation after processing all input rows. - | if ($unsafeRowBuffer == null && !$avoidSpillInPartialAggregateTerm) { + | if ($unsafeRowBuffer == null) { | $countTerm = $countTerm + $hashMapTerm.getNumRows(); - | boolean skipPartAgg = $getHeuristicToAvoidAgg + | boolean skipPartAgg = $checkIfPartialAggSkipped | if (skipPartAgg) { | // Aggregation buffer is created later | $avoidSpillInPartialAggregateTerm = true; @@ -969,12 +973,12 @@ case class HashAggregateExec( """.stripMargin } else { s""" - | $getAggBufferFromMap - | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - | // aggregation after processing all input rows. - | if ($unsafeRowBuffer == null) { - | $addToSorter - | } + |$getAggBufferFromMap + |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + |// aggregation after processing all input rows. + |if ($unsafeRowBuffer == null) { + | $addToSorter + |} """.stripMargin } } @@ -990,7 +994,7 @@ case class HashAggregateExec( |} |""".stripMargin } - val insertFastMap = if (skipPartialAggregate) { + val insertFastMap = if (skipPartialAggregateEnabled) { s""" |if ($checkFallbackForGeneratedHashMap && !$avoidSpillInPartialAggregateTerm) { | $findOrInsertIntoFastHashMap @@ -1017,8 +1021,8 @@ case class HashAggregateExec( findOrInsertRegularHashMap } def createEmptyAggBufferAndUpdateMetrics: String = { - if (skipPartialAggregate) { - val partTerm = metricTerm(ctx, "partialAggSkipped") + if (skipPartialAggregateEnabled) { + val numAggSkippedRows = metricTerm(ctx, "partialAggSkipped") val initExpr = declFunctions.flatMap(f => f.initialValues) val emptyBufferKeyCode = GenerateUnsafeProjection.createCode(ctx, initExpr) s""" @@ -1027,7 +1031,7 @@ case class HashAggregateExec( | ${unsafeRowKeyCode.code} | ${emptyBufferKeyCode.code} | $unsafeRowBuffer = ${emptyBufferKeyCode.value}; - | $partTerm.add(1); + | $numAggSkippedRows.add(1); |} |""".stripMargin } else "" @@ -1190,7 +1194,7 @@ case class HashAggregateExec( } def outputRow: String = { - if (skipPartialAggregate) { + if (skipPartialAggregateEnabled) { s""" |if ($avoidSpillInPartialAggregateTerm) { | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index ee4a589ad4c02..8829d13ebebfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,9 +51,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession assert(df.collect() === Array(Row(9, 4.5))) } - test(s"Avoid spill in partial aggregation" ) { + test("Avoid spill in partial aggregation" ) { withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true"), - ("spark.sql.aggregate.partialaggregate.skip.threshold", "2")) { + (SQLConf.SKIP_PARTIAL_AGGREGATE_MINROWS.key, "2")) { // Create Dataframes val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) val aggDF = data.toDF("name", "values").groupBy("name").sum("values") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d1220e6e3ca1f..8f5646283a9d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -192,7 +192,6 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } } - } test("ObjectHashAggregate metrics") { From c49f106b205f215812af79f82d2a89e03015143b Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Thu, 6 Aug 2020 16:57:50 -0700 Subject: [PATCH 27/33] Fix forward reference --- .../apache/spark/sql/internal/SQLConf.scala | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 841ffff8358ab..685efa2e01b9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2196,19 +2196,9 @@ object SQLConf { .checkValue(bit => bit >= 10 && bit <= 30, "The bit value must be in [10, 30].") .createWithDefault(16) - val SKIP_PARTIAL_AGGREGATE_ENABLED = - buildConf("spark.sql.aggregate.skipPartialAggregate") - .internal() - .doc("When enabled, the partial aggregation is skipped when the following" + - "two conditions are met. 1. When the total number of records processed is greater" + - s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + - "of recornd count in map to the total records is less that value defined by " + - s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") - .booleanConf - .createWithDefault(true) val SKIP_PARTIAL_AGGREGATE_MINROWS = - buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") + buildConf("spark.sql.aggregate.skipPartialAggregate.minNumRows") .internal() .doc("Number of records after which aggregate operator checks if " + "partial aggregation phase can be avoided") @@ -2216,13 +2206,24 @@ object SQLConf { .createWithDefault(100000) val SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO = - buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") + buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") .internal() .doc("Ratio of number of records present in map of Aggregate operator" + "to the total number of records processed by the Aggregate operator") .doubleConf .createWithDefault(0.5) + val SKIP_PARTIAL_AGGREGATE_ENABLED = + buildConf("spark.sql.aggregate.skipPartialAggregate") + .internal() + .doc("When enabled, the partial aggregation is skipped when the following" + + "two conditions are met. 1. When the total number of records processed is greater" + + s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + + "of recornd count in map to the total records is less that value defined by " + + s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") + .booleanConf + .createWithDefault(true) + val AVRO_COMPRESSION_CODEC = buildConf("spark.sql.avro.compression.codec") .doc("Compression codec used in writing of AVRO files. Supported codecs: " + "uncompressed, deflate, snappy, bzip2 and xz. Default codec is snappy.") From 69f1d71aa3de5cc5359bc533d63d826335d7c6d6 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 7 Aug 2020 10:22:25 -0700 Subject: [PATCH 28/33] UT fixes, address review comments --- .../sql/execution/aggregate/AggUtils.scala | 4 --- .../aggregate/HashAggregateExec.scala | 34 +++++++++++++------ .../execution/WholeStageCodegenSuite.scala | 17 ++-------- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 90403cee4887f..56a287d4d0279 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -353,8 +353,4 @@ object AggUtils { finalAndCompleteAggregate :: Nil } - - def areAggExpressionsPartial(modes: Seq[AggregateMode]): Boolean = { - modes.nonEmpty && modes.forall(_ == Partial) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c5607eb949ea6..cd20f2331b84f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -420,7 +420,7 @@ case class HashAggregateExec( private var avoidSpillInPartialAggregateTerm: String = _ private val skipPartialAggregateEnabled = sqlContext.conf.skipPartialAggregate && - AggUtils.areAggExpressionsPartial(modes) && find(_.isInstanceOf[ExpandExec]).isEmpty + modes.forall(_ == Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty private var rowCountTerm: String = _ private var outputFunc: String = _ @@ -695,16 +695,19 @@ case class HashAggregateExec( private def doProduceWithKeys(ctx: CodegenContext): String = { val initAgg = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initAgg") + + var childrenConsumed: String = null if (skipPartialAggregateEnabled) { avoidSpillInPartialAggregateTerm = ctx. addMutableState(CodeGenerator.JAVA_BOOLEAN, "avoidPartialAggregate", term => s"$term = ${Utils.isTesting};") + rowCountTerm = ctx. + addMutableState(CodeGenerator.JAVA_LONG, "rowCount") + childrenConsumed = ctx. + addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") } - val childrenConsumed = ctx. - addMutableState(CodeGenerator.JAVA_BOOLEAN, "childrenConsumed") - rowCountTerm = ctx. - addMutableState(CodeGenerator.JAVA_LONG, "rowCount") + if (sqlContext.conf.enableTwoLevelAggMap) { enableTwoLevelHashMap(ctx) } else if (sqlContext.conf.enableVectorizedHashMap) { @@ -776,11 +779,14 @@ case class HashAggregateExec( } outputFunc = generateResultFunction(ctx) + val genChildrenConsumedCode = if (skipPartialAggregateEnabled) { + s"${childrenConsumed} = true;" + } else "" val doAggFuncName = ctx.addNewFunction(doAgg, s""" |private void $doAgg() throws java.io.IOException { | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | $childrenConsumed = true; + | $genChildrenConsumedCode | $finishHashMap |} """.stripMargin) @@ -855,6 +861,15 @@ case class HashAggregateExec( val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") + val genCodePostInitCode = + if (skipPartialAggregateEnabled) { + s""" + |if (!$childrenConsumed) { + | $doAggFuncName(); + | if (shouldStop()) return; + |} + """.stripMargin + } else "" s""" |if (!$initAgg) { | $initAgg = true; @@ -865,10 +880,7 @@ case class HashAggregateExec( | $aggTime.add((System.nanoTime() - $beforeAgg) / $NANOS_PER_MILLIS); | if (shouldStop()) return; |} - |if (!$childrenConsumed) { - | $doAggFuncName(); - | if (shouldStop()) return; - |} + |$genCodePostInitCode |// output the result |$outputFromFastHashMap |$outputFromRegularHashMap @@ -1199,6 +1211,7 @@ case class HashAggregateExec( |if ($avoidSpillInPartialAggregateTerm) { | $outputFunc(${unsafeRowKeyCode.value}, $unsafeRowBuffer); |} + |$rowCountTerm = $rowCountTerm + 1; |""".stripMargin } else "" } @@ -1206,7 +1219,6 @@ case class HashAggregateExec( s""" |$updateRowInMap |$outputRow - |$rowCountTerm = $rowCountTerm + 1; |""".stripMargin } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 8829d13ebebfc..589854cad5773 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite @@ -59,7 +60,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val aggDF = data.toDF("name", "values").groupBy("name").sum("values") val partAggNode = aggDF.queryExecution.executedPlan.find { case h: HashAggregateExec => - AggUtils.areAggExpressionsPartial(h.aggregateExpressions.map(_.mode)) + h.aggregateExpressions.map(_.mode).forall(_ == Partial) case _ => false } checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) @@ -70,18 +71,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } - test(s"Partial aggregation should not happen when no Aggregate expr" ) { - withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { - val aggDF = testData2.select(sumDistinct($"a")) - val aggNodes = aggDF.queryExecution.executedPlan.collect { - case h: HashAggregateExec => h - } - checkAnswer(aggDF, Row(6)) - assert(aggNodes.nonEmpty) - assert(aggNodes.forall(_.metrics("partialAggSkipped").value == 0)) - } - } - test(s"Distinct: Partial aggregation should happen for" + s" HashAggregate nodes performing partial Aggregate operations " ) { withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { @@ -93,7 +82,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession checkAnswer(aggDF, Row(6, 9)) assert(baseNodes.size == 1 ) assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count()) - assert(other.forall(_.metrics("partialAggSkipped").value == 0)) + assert(other.forall(!_.metrics.contains("partialAggSkipped"))) } } From c9a415de201a784756a3e3ef1a39ef069b7f0b4e Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 7 Aug 2020 10:31:17 -0700 Subject: [PATCH 29/33] Address review copmments --- .../spark/sql/execution/UnsafeFixedWidthAggregationMap.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 061e09543ac83..3046423ad7658 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -67,7 +67,6 @@ public final class UnsafeFixedWidthAggregationMap { * Number of rows that were added to the map * This includes the elements that were passed on sorter * using {@link #destructAndCreateExternalSorter()} - * */ private long numRowsAdded = 0L; From ceaa4e52d558d21964e5ea84a236519680202115 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 7 Aug 2020 11:12:31 -0700 Subject: [PATCH 30/33] Fix style check --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 589854cad5773..d4d8881143b24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.{AggUtils, HashAggregateExec} From 2ae5525294718ea43cee419bc095a3d382b6c085 Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 7 Aug 2020 13:10:26 -0700 Subject: [PATCH 31/33] Fix UT --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 2 +- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index cd20f2331b84f..b59f46dc3fe66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -420,7 +420,7 @@ case class HashAggregateExec( private var avoidSpillInPartialAggregateTerm: String = _ private val skipPartialAggregateEnabled = sqlContext.conf.skipPartialAggregate && - modes.forall(_ == Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty + !modes.exists(_ != Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty private var rowCountTerm: String = _ private var outputFunc: String = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index d4d8881143b24..1687f462d56f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -60,9 +60,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val aggDF = data.toDF("name", "values").groupBy("name").sum("values") val partAggNode = aggDF.queryExecution.executedPlan.find { case h: HashAggregateExec => - h.aggregateExpressions.map(_.mode).forall(_ == Partial) + !h.aggregateExpressions.map(_.mode).exists(_ != Partial) case _ => false } + checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1))) assert(partAggNode.isDefined, "No HashAggregate node with partial aggregate expression found") From 0a186f0eb5d71732ec3abc6b42a12dae6594277f Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Fri, 7 Aug 2020 14:39:57 -0700 Subject: [PATCH 32/33] UT fix --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 6 ++++-- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index b59f46dc3fe66..70e768502e98d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -419,8 +419,10 @@ case class HashAggregateExec( private var isFastHashMapEnabled: Boolean = false private var avoidSpillInPartialAggregateTerm: String = _ - private val skipPartialAggregateEnabled = sqlContext.conf.skipPartialAggregate && - !modes.exists(_ != Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty + private val skipPartialAggregateEnabled = { + sqlContext.conf.skipPartialAggregate && + modes.nonEmpty && modes.forall(_ == Partial) && find(_.isInstanceOf[ExpandExec]).isEmpty + } private var rowCountTerm: String = _ private var outputFunc: String = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 1687f462d56f5..6ceb24073141d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -60,7 +60,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession val aggDF = data.toDF("name", "values").groupBy("name").sum("values") val partAggNode = aggDF.queryExecution.executedPlan.find { case h: HashAggregateExec => - !h.aggregateExpressions.map(_.mode).exists(_ != Partial) + val modes = h.aggregateExpressions.map(_.mode) + modes.nonEmpty && modes.forall(_ == Partial) case _ => false } From 11572a105ef870c9e95b4302e6613b7f0e73d0de Mon Sep 17 00:00:00 2001 From: Karuppayya Rajendran Date: Tue, 11 Aug 2020 09:53:42 -0700 Subject: [PATCH 33/33] Address review comments --- .../org/apache/spark/sql/internal/SQLConf.scala | 14 +++++++++++--- .../sql/execution/WholeStageCodegenSuite.scala | 10 +++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 685efa2e01b9c..f2ae777c5b9c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2202,15 +2202,22 @@ object SQLConf { .internal() .doc("Number of records after which aggregate operator checks if " + "partial aggregation phase can be avoided") + .version("3.1.0") .longConf .createWithDefault(100000) val SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO = buildConf("spark.sql.aggregate.skipPartialAggregate.aggregateRatio") .internal() - .doc("Ratio of number of records present in map of Aggregate operator" + - "to the total number of records processed by the Aggregate operator") + .doc("Ratio beyond which the partial aggregation is skipped." + + "This is computed by taking the ratio of number of records present" + + " in map of Aggregate operator to the total number of records processed" + + " by the Aggregate operator.") + .version("3.1.0") .doubleConf + .checkValue(ratio => ratio > 0 && ratio < 1, "Invalid value for " + + "spark.sql.aggregate.skipPartialAggregate.aggregateRatio. Valid value needs" + + " to be between 0 and 1" ) .createWithDefault(0.5) val SKIP_PARTIAL_AGGREGATE_ENABLED = @@ -2219,8 +2226,9 @@ object SQLConf { .doc("When enabled, the partial aggregation is skipped when the following" + "two conditions are met. 1. When the total number of records processed is greater" + s"than threshold defined by ${SKIP_PARTIAL_AGGREGATE_MINROWS.key} 2. When the ratio" + - "of recornd count in map to the total records is less that value defined by " + + "of record count in map to the total records is less that value defined by " + s"${SKIP_PARTIAL_AGGREGATE_AGGREGATE_RATIO.key}") + .version("3.1.0") .booleanConf .createWithDefault(true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6ceb24073141d..54f1746c1ad90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -53,8 +53,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } test("Avoid spill in partial aggregation" ) { - withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true"), - (SQLConf.SKIP_PARTIAL_AGGREGATE_MINROWS.key, "2")) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true"), + (SQLConf.SKIP_PARTIAL_AGGREGATE_MINROWS.key -> "2")) { // Create Dataframes val data = Seq(("James", 1), ("James", 1), ("Phil", 1)) val aggDF = data.toDF("name", "values").groupBy("name").sum("values") @@ -73,9 +73,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } } - test(s"Distinct: Partial aggregation should happen for" + - s" HashAggregate nodes performing partial Aggregate operations " ) { - withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) { + test(s"Distinct: Partial aggregation should happen for " + + "HashAggregate nodes performing partial Aggregate operations " ) { + withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key -> "true")) { val aggDF = testData2.select(sumDistinct($"a"), sum($"b")) val aggNodes = aggDF.queryExecution.executedPlan.collect { case h: HashAggregateExec => h