From 400db3d1dad4b92e0a70b89c2c687a42017d589b Mon Sep 17 00:00:00 2001 From: pgandhi Date: Tue, 19 Mar 2019 17:26:25 -0400 Subject: [PATCH 01/13] [SPARK-27207] : Ensure aggregate buffers are initialized again for SortBasedAggregate Normally, the aggregate operations that are invoked for an aggregation buffer for User Defined Aggregate Functions(UDAF) follow the order like initialize(), update(), eval() OR initialize(), merge(), eval(). However, after a certain threshold configurable by spark.sql.objectHashAggregate.sortBased.fallbackThreshold is reached, ObjectHashAggregate falls back to SortBasedAggregator which invokes the merge or update operation without calling initialize() on the aggregate buffer. The fix here is to initialize aggregate buffers again when fallback to SortBasedAggregate operator happens. --- .../aggregate/ObjectAggregationIterator.scala | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 43514f5271ac8..9f91fc0b689ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -59,7 +59,8 @@ class ObjectAggregationIterator( private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers - private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { + var (sortBasedAggExpressions, sortBasedAggFunctions): ( + Seq[AggregateExpression], Array[AggregateFunction]) = { val newExpressions = aggregateExpressions.map { case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) @@ -67,9 +68,12 @@ class ObjectAggregationIterator( agg.copy(mode = Final) case other => other } - val newFunctions = initializeAggregateFunctions(newExpressions, 0) - val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) - generateProcessRow(newExpressions, newFunctions, newInputAttributes) + (newExpressions, initializeAggregateFunctions(newExpressions, 0)) + } + + private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { + val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes) + generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes) } /** @@ -93,7 +97,7 @@ class ObjectAggregationIterator( */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { if (groupingExpressions.isEmpty) { - val defaultAggregationBuffer = createNewAggregationBuffer() + val defaultAggregationBuffer = createNewAggregationBuffer(aggregateFunctions) generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer) } else { throw new IllegalStateException( @@ -106,18 +110,20 @@ class ObjectAggregationIterator( // // - when creating aggregation buffer for a new group in the hash map, and // - when creating the re-used buffer for sort-based aggregation - private def createNewAggregationBuffer(): SpecificInternalRow = { - val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType)) + private def createNewAggregationBuffer( + functions: Array[AggregateFunction]): SpecificInternalRow = { + val bufferFieldTypes = functions.flatMap(_.aggBufferAttributes.map(_.dataType)) val buffer = new SpecificInternalRow(bufferFieldTypes) - initAggregationBuffer(buffer) + initAggregationBuffer(buffer, functions) buffer } - private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = { + private def initAggregationBuffer( + buffer: SpecificInternalRow, functions: Array[AggregateFunction]): Unit = { // Initializes declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initializes imperative aggregates' buffer values - aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + functions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) } private def getAggregationBufferByKey( @@ -125,7 +131,7 @@ class ObjectAggregationIterator( var aggBuffer = hashMap.getAggregationBuffer(groupingKey) if (aggBuffer == null) { - aggBuffer = createNewAggregationBuffer() + aggBuffer = createNewAggregationBuffer(aggregateFunctions) hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer) } @@ -183,7 +189,7 @@ class ObjectAggregationIterator( StructType.fromAttributes(groupingAttributes), processRow, mergeAggregationBuffers, - createNewAggregationBuffer()) + createNewAggregationBuffer(sortBasedAggFunctions)) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow From ea050f7547597f82e591cc98085feefe58c4ed79 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Thu, 21 Mar 2019 11:25:43 -0500 Subject: [PATCH 02/13] [SPARK-27207] : Adding Unit Test and addressing reviews Adding unit test and refactoring code --- .../aggregate/ObjectAggregationIterator.scala | 13 +-- .../sql/TypedImperativeAggregateSuite.scala | 101 ++++++++++++++++++ 2 files changed, 105 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 9f91fc0b689ec..4373726c21580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -58,8 +58,7 @@ class ObjectAggregationIterator( private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ - // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers - var (sortBasedAggExpressions, sortBasedAggFunctions): ( + val (sortBasedAggExpressions, sortBasedAggFunctions): ( Seq[AggregateExpression], Array[AggregateFunction]) = { val newExpressions = aggregateExpressions.map { case agg @ AggregateExpression(_, Partial, _, _) => @@ -71,6 +70,7 @@ class ObjectAggregationIterator( (newExpressions, initializeAggregateFunctions(newExpressions, 0)) } + // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes) generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes) @@ -111,19 +111,14 @@ class ObjectAggregationIterator( // - when creating aggregation buffer for a new group in the hash map, and // - when creating the re-used buffer for sort-based aggregation private def createNewAggregationBuffer( - functions: Array[AggregateFunction]): SpecificInternalRow = { + functions: Array[AggregateFunction]): SpecificInternalRow = { val bufferFieldTypes = functions.flatMap(_.aggBufferAttributes.map(_.dataType)) val buffer = new SpecificInternalRow(bufferFieldTypes) - initAggregationBuffer(buffer, functions) - buffer - } - - private def initAggregationBuffer( - buffer: SpecificInternalRow, functions: Array[AggregateFunction]): Unit = { // Initializes declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initializes imperative aggregates' buffer values functions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + buffer } private def getAggregationBufferByKey( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index c5fb17345222a..7f445befe0fc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax +import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate @@ -210,6 +211,20 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } + test("SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate") { + withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") { + val df = data.toDF("value", "key").coalesce(2) + val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value")) + val expected = data.groupBy(_._2).toSeq.map { group => + val (key, values) = group + val valueMax = values.map(_._1).max + val countValue = values.size + Row(key, valueMax, countValue, valueMax) + } + checkAnswer(query, expected) + } + } + private def typedMax(column: Column): Column = { val max = TypedMax(column.expr, nullable = false) Column(max.toAggregateExpression()) @@ -219,6 +234,11 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { val max = TypedMax(column.expr, nullable = true) Column(max.toAggregateExpression()) } + + private def typedMax2(column: Column): Column = { + val max = TypedMax2(column.expr, nullable = false) + Column(max.toAggregateExpression()) + } } object TypedImperativeAggregateSuite { @@ -299,5 +319,86 @@ object TypedImperativeAggregateSuite { } } + /** + * Calculate the max value with object aggregation buffer. This stores class MaxValue + * in aggregation buffer. + */ + private case class TypedMax2( + child: Expression, + nullable: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { + + + var maxValueBuffer: MaxValue = null + override def createAggregationBuffer(): MaxValue = { + // Returns Int.MinValue if all inputs are null + maxValueBuffer = new MaxValue(Int.MinValue) + maxValueBuffer + } + + override def update(buffer: MaxValue, input: InternalRow): MaxValue = { + child.eval(input) match { + case inputValue: Int => + if (inputValue > buffer.value) { + buffer.value = inputValue + buffer.isValueSet = true + } + case null => // skip + } + buffer + } + + override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = { + // The below if condition will throw a Null Pointer Exception if initialize() is not called + if (maxValueBuffer.isValueSet) { + // do nothing + } + if (inputMax.value > bufferMax.value) { + bufferMax.value = inputMax.value + bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet + } + bufferMax + } + + override def eval(bufferMax: MaxValue): Any = { + if (nullable && bufferMax.isValueSet == false) { + null + } else { + bufferMax.value + } + } + + override lazy val deterministic: Boolean = true + + override def children: Seq[Expression] = Seq(child) + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) + + override def dataType: DataType = IntegerType + + override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = + copy(inputAggBufferOffset = newOffset) + + override def serialize(buffer: MaxValue): Array[Byte] = { + val out = new ByteArrayOutputStream() + val stream = new DataOutputStream(out) + stream.writeBoolean(buffer.isValueSet) + stream.writeInt(buffer.value) + out.toByteArray + } + + override def deserialize(storageFormat: Array[Byte]): MaxValue = { + val in = new ByteArrayInputStream(storageFormat) + val stream = new DataInputStream(in) + val isValueSet = stream.readBoolean() + val value = stream.readInt() + new MaxValue(value, isValueSet) + } + } private class MaxValue(var value: Int, var isValueSet: Boolean = false) } From 07148763e5f8aa83f846ccdc165fd1fb4907f2da Mon Sep 17 00:00:00 2001 From: pgandhi Date: Thu, 21 Mar 2019 11:32:25 -0500 Subject: [PATCH 03/13] [SPARK-27207] : Fixing Scalastyle tests --- .../apache/spark/sql/TypedImperativeAggregateSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 7f445befe0fc2..103aba0b9a0cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -320,9 +320,9 @@ object TypedImperativeAggregateSuite { } /** - * Calculate the max value with object aggregation buffer. This stores class MaxValue - * in aggregation buffer. - */ + * Calculate the max value with object aggregation buffer. This stores class MaxValue + * in aggregation buffer. + */ private case class TypedMax2( child: Expression, nullable: Boolean = false, From 4dc1007df9f5fe31a56163c4190eec37b11ea1ce Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 29 Mar 2019 13:28:54 -0500 Subject: [PATCH 04/13] [SPARK-27207] : Fix SortBasedAggregator to run with different aggregate functions and write unit test Fix SortBasedAggregator to ensure that update and merge are performed with two different sets of aggregate functions, one for update and one for merge respectively. --- .../aggregate/ObjectAggregationIterator.scala | 52 ++++++--- .../sql/TypedImperativeAggregateSuite.scala | 100 ------------------ .../SortBasedAggregationStoreSuite.scala | 5 +- .../sql/hive/execution/HiveUDAFSuite.scala | 30 ++++++ 4 files changed, 73 insertions(+), 114 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 4373726c21580..4f331a65d93f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -58,7 +58,7 @@ class ObjectAggregationIterator( private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ - val (sortBasedAggExpressions, sortBasedAggFunctions): ( + val (sortBasedMergeAggExpressions, sortBasedMergeAggFunctions): ( Seq[AggregateExpression], Array[AggregateFunction]) = { val newExpressions = aggregateExpressions.map { case agg @ AggregateExpression(_, Partial, _, _) => @@ -72,8 +72,9 @@ class ObjectAggregationIterator( // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { - val newInputAttributes = sortBasedAggFunctions.flatMap(_.inputAggBufferAttributes) - generateProcessRow(sortBasedAggExpressions, sortBasedAggFunctions, newInputAttributes) + val newInputAttributes = sortBasedMergeAggFunctions.flatMap(_.inputAggBufferAttributes) + generateProcessRow( + sortBasedMergeAggExpressions, sortBasedMergeAggFunctions, newInputAttributes) } /** @@ -184,7 +185,9 @@ class ObjectAggregationIterator( StructType.fromAttributes(groupingAttributes), processRow, mergeAggregationBuffers, - createNewAggregationBuffer(sortBasedAggFunctions)) + createNewAggregationBuffer(aggregateFunctions), + createNewAggregationBuffer(sortBasedMergeAggFunctions), + aggregateFunctions) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow @@ -212,7 +215,12 @@ class ObjectAggregationIterator( * @param processRow Function to update the aggregation buffer with input rows * @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing * aggregation buffers - * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer + * @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation + * buffer for update operation + * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation + * buffer for merge operation + * @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the + * aggregation buffer * * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]]. */ @@ -222,7 +230,9 @@ class SortBasedAggregator( groupingSchema: StructType, processRow: (InternalRow, InternalRow) => Unit, mergeAggregationBuffers: (InternalRow, InternalRow) => Unit, - makeEmptyAggregationBuffer: => InternalRow) { + makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow, + makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow, + sortBasedUpdateAggFunctions: Array[AggregateFunction]) { // external sorter to sort the input (grouping key + input row) with grouping key. private val inputSorter = createExternalSorterForInput() @@ -231,6 +241,10 @@ class SortBasedAggregator( def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = { inputSorter.insertKV(groupingKey, inputRow) } + private def serializeBuffer(buffer: InternalRow): Unit = { + sortBasedUpdateAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach( + _.serializeAggregateBufferInPlace(buffer)) + } /** * Returns a destructive iterator of AggregationBufferEntry. @@ -241,16 +255,18 @@ class SortBasedAggregator( val inputIterator = inputSorter.sortedIterator() var hasNextInput: Boolean = inputIterator.next() var hasNextAggBuffer: Boolean = initialAggBufferIterator.next() - private var result: AggregationBufferEntry = _ + private var updateResult: AggregationBufferEntry = _ + private var finalResult: AggregationBufferEntry = _ private var groupingKey: UnsafeRow = _ override def hasNext(): Boolean = { - result != null || findNextSortedGroup() + updateResult != null || finalResult != null || findNextSortedGroup() } override def next(): AggregationBufferEntry = { - val returnResult = result - result = null + val returnResult = finalResult + updateResult = null + finalResult = null returnResult } @@ -259,21 +275,31 @@ class SortBasedAggregator( if (hasNextInput || hasNextAggBuffer) { // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator groupingKey = findGroupingKey() - result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer) + updateResult = new AggregationBufferEntry( + groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions) + finalResult = new AggregationBufferEntry( + groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions) // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - processRow(result.aggregationBuffer, inputIterator.getValue) + processRow(updateResult.aggregationBuffer, inputIterator.getValue) hasNextInput = inputIterator.next() } + // This step ensures that the contents of the updateResult aggregation buffer are + // merged with the finalResult aggregation buffer to maintain consistency + if (hasNextAggBuffer) { + serializeBuffer(updateResult.aggregationBuffer) + mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer) + } // Secondly, merge the aggregation buffer with existing aggregation buffers. // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should // be called after calling processRow. while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { - mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue) + mergeAggregationBuffers( + finalResult.aggregationBuffer, initialAggBufferIterator.getValue) hasNextAggBuffer = initialAggBufferIterator.next() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 103aba0b9a0cd..2ceaa4b57c7df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax -import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate @@ -211,20 +210,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } - test("SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate") { - withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") { - val df = data.toDF("value", "key").coalesce(2) - val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value")) - val expected = data.groupBy(_._2).toSeq.map { group => - val (key, values) = group - val valueMax = values.map(_._1).max - val countValue = values.size - Row(key, valueMax, countValue, valueMax) - } - checkAnswer(query, expected) - } - } - private def typedMax(column: Column): Column = { val max = TypedMax(column.expr, nullable = false) Column(max.toAggregateExpression()) @@ -235,10 +220,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { Column(max.toAggregateExpression()) } - private def typedMax2(column: Column): Column = { - val max = TypedMax2(column.expr, nullable = false) - Column(max.toAggregateExpression()) - } } object TypedImperativeAggregateSuite { @@ -319,86 +300,5 @@ object TypedImperativeAggregateSuite { } } - /** - * Calculate the max value with object aggregation buffer. This stores class MaxValue - * in aggregation buffer. - */ - private case class TypedMax2( - child: Expression, - nullable: Boolean = false, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { - - - var maxValueBuffer: MaxValue = null - override def createAggregationBuffer(): MaxValue = { - // Returns Int.MinValue if all inputs are null - maxValueBuffer = new MaxValue(Int.MinValue) - maxValueBuffer - } - - override def update(buffer: MaxValue, input: InternalRow): MaxValue = { - child.eval(input) match { - case inputValue: Int => - if (inputValue > buffer.value) { - buffer.value = inputValue - buffer.isValueSet = true - } - case null => // skip - } - buffer - } - - override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = { - // The below if condition will throw a Null Pointer Exception if initialize() is not called - if (maxValueBuffer.isValueSet) { - // do nothing - } - if (inputMax.value > bufferMax.value) { - bufferMax.value = inputMax.value - bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet - } - bufferMax - } - - override def eval(bufferMax: MaxValue): Any = { - if (nullable && bufferMax.isValueSet == false) { - null - } else { - bufferMax.value - } - } - - override lazy val deterministic: Boolean = true - - override def children: Seq[Expression] = Seq(child) - - override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) - - override def dataType: DataType = IntegerType - - override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = - copy(mutableAggBufferOffset = newOffset) - - override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = - copy(inputAggBufferOffset = newOffset) - - override def serialize(buffer: MaxValue): Array[Byte] = { - val out = new ByteArrayOutputStream() - val stream = new DataOutputStream(out) - stream.writeBoolean(buffer.isValueSet) - stream.writeInt(buffer.value) - out.toByteArray - } - - override def deserialize(storageFormat: Array[Byte]): MaxValue = { - val in = new ByteArrayInputStream(storageFormat) - val stream = new DataInputStream(in) - val isValueSet = stream.readBoolean() - val value = stream.readInt() - new MaxValue(value, isValueSet) - } - } private class MaxValue(var value: Int, var isValueSet: Boolean = false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index dc67446460877..9b39cf5195227 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.unsafe.KVIterator @@ -78,7 +79,9 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte groupingSchema, updateInputRow, mergeAggBuffer, - createNewAggregationBuffer) + createNewAggregationBuffer, + createNewAggregationBuffer, + sortBasedUpdateAggFunctions = new Array[AggregateFunction](5)) (5000 to 100000).foreach { _ => randomKV(inputRow, group) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index ef40323c61315..905d2a4cc43ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -49,6 +49,16 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { (2: Integer) -> null, (3: Integer) -> null ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") + Seq( + (0: Integer) -> "val_0", + (1: Integer) -> "val_1", + (2: Integer) -> "val_2", + (3: Integer) -> "val_3", + (4: Integer) -> "val_4", + (5: Integer) -> "val_5", + (6: Integer) -> null, + (7: Integer) -> null + ).toDF("key", "value").repartition(2).createOrReplaceTempView("t2") } protected override def afterAll(): Unit = { @@ -111,6 +121,26 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { )) } + test("SPARK-27207: customized Hive UDAF with two aggregation buffers for Sort" + + " Based Aggregation") { + withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "2") { + val df = sql("SELECT key % 2, mock2(value) FROM t2 GROUP BY key % 2") + + val aggs = df.queryExecution.executedPlan.collect { + case agg: ObjectHashAggregateExec => agg + } + + // There should be two aggregate operators, one for partial aggregation, and the other for + // global aggregation. + assert(aggs.length == 2) + + checkAnswer(df, Seq( + Row(0, Row(3, 1)), + Row(1, Row(3, 1)) + )) + } + } + test("call JAVA UDAF") { withTempView("temp") { withUserDefinedFunction("myDoubleAvg" -> false) { From 088cbc692822a545a5063c482defed0f5491fcfc Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 29 Mar 2019 13:31:18 -0500 Subject: [PATCH 05/13] [SPARK-27207] : Fix new line for TypedImperativeAggregateSuite --- .../org/apache/spark/sql/TypedImperativeAggregateSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 2ceaa4b57c7df..c5fb17345222a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -219,7 +219,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { val max = TypedMax(column.expr, nullable = true) Column(max.toAggregateExpression()) } - } object TypedImperativeAggregateSuite { From 6a5ed71081060688fabef19ceae39b597980dfc2 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 29 Mar 2019 18:41:46 -0500 Subject: [PATCH 06/13] [SPARK-27207] : Changing design to use one buffer but initializing for different aggregate functions --- .../aggregate/ObjectAggregationIterator.scala | 42 ++++++------------- .../SortBasedAggregationStoreSuite.scala | 3 +- .../sql/hive/execution/HiveUDAFSuite.scala | 4 +- 3 files changed, 16 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 4f331a65d93f2..105e3816b629c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -186,8 +186,7 @@ class ObjectAggregationIterator( processRow, mergeAggregationBuffers, createNewAggregationBuffer(aggregateFunctions), - createNewAggregationBuffer(sortBasedMergeAggFunctions), - aggregateFunctions) + sortBasedMergeAggFunctions) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow @@ -217,10 +216,8 @@ class ObjectAggregationIterator( * aggregation buffers * @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation * buffer for update operation - * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation - * buffer for merge operation - * @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the - * aggregation buffer + * @param sortBasedMergeAggFunctions aggregate functions needed to serialize the + * aggregation buffer * * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]]. */ @@ -231,8 +228,7 @@ class SortBasedAggregator( processRow: (InternalRow, InternalRow) => Unit, mergeAggregationBuffers: (InternalRow, InternalRow) => Unit, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow, - makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow, - sortBasedUpdateAggFunctions: Array[AggregateFunction]) { + sortBasedMergeAggFunctions: Array[AggregateFunction]) { // external sorter to sort the input (grouping key + input row) with grouping key. private val inputSorter = createExternalSorterForInput() @@ -241,10 +237,6 @@ class SortBasedAggregator( def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = { inputSorter.insertKV(groupingKey, inputRow) } - private def serializeBuffer(buffer: InternalRow): Unit = { - sortBasedUpdateAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach( - _.serializeAggregateBufferInPlace(buffer)) - } /** * Returns a destructive iterator of AggregationBufferEntry. @@ -255,18 +247,16 @@ class SortBasedAggregator( val inputIterator = inputSorter.sortedIterator() var hasNextInput: Boolean = inputIterator.next() var hasNextAggBuffer: Boolean = initialAggBufferIterator.next() - private var updateResult: AggregationBufferEntry = _ - private var finalResult: AggregationBufferEntry = _ + private var result: AggregationBufferEntry = _ private var groupingKey: UnsafeRow = _ override def hasNext(): Boolean = { - updateResult != null || finalResult != null || findNextSortedGroup() + result != null || findNextSortedGroup() } override def next(): AggregationBufferEntry = { - val returnResult = finalResult - updateResult = null - finalResult = null + val returnResult = result + result = null returnResult } @@ -275,31 +265,25 @@ class SortBasedAggregator( if (hasNextInput || hasNextAggBuffer) { // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator groupingKey = findGroupingKey() - updateResult = new AggregationBufferEntry( + result = new AggregationBufferEntry( groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions) - finalResult = new AggregationBufferEntry( - groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions) // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - processRow(updateResult.aggregationBuffer, inputIterator.getValue) + processRow(result.aggregationBuffer, inputIterator.getValue) hasNextInput = inputIterator.next() } - // This step ensures that the contents of the updateResult aggregation buffer are - // merged with the finalResult aggregation buffer to maintain consistency - if (hasNextAggBuffer) { - serializeBuffer(updateResult.aggregationBuffer) - mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer) - } // Secondly, merge the aggregation buffer with existing aggregation buffers. // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should // be called after calling processRow. + sortBasedMergeAggFunctions.collect { case f: ImperativeAggregate => f }.foreach( + _.initialize(result.aggregationBuffer)) while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { mergeAggregationBuffers( - finalResult.aggregationBuffer, initialAggBufferIterator.getValue) + result.aggregationBuffer, initialAggBufferIterator.getValue) hasNextAggBuffer = initialAggBufferIterator.next() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 9b39cf5195227..3d4e2407506c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -80,8 +80,7 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte updateInputRow, mergeAggBuffer, createNewAggregationBuffer, - createNewAggregationBuffer, - sortBasedUpdateAggFunctions = new Array[AggregateFunction](5)) + sortBasedMergeAggFunctions = new Array[AggregateFunction](5)) (5000 to 100000).foreach { _ => randomKV(inputRow, group) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 905d2a4cc43ea..07182c3e87a01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -135,8 +135,8 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(aggs.length == 2) checkAnswer(df, Seq( - Row(0, Row(3, 1)), - Row(1, Row(3, 1)) + Row(0, Row(2, 1)), + Row(1, Row(2, 0)) )) } } From fb9fea82e36bcc98e864c6699d4629fd540a72a2 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Mon, 1 Apr 2019 09:29:42 -0500 Subject: [PATCH 07/13] Revert "[SPARK-27207] : Changing design to use one buffer but initializing for different aggregate functions" This reverts commit 6a5ed71081060688fabef19ceae39b597980dfc2. Reverting to previous commit --- .../aggregate/ObjectAggregationIterator.scala | 42 +++++++++++++------ .../SortBasedAggregationStoreSuite.scala | 3 +- .../sql/hive/execution/HiveUDAFSuite.scala | 4 +- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 105e3816b629c..4f331a65d93f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -186,7 +186,8 @@ class ObjectAggregationIterator( processRow, mergeAggregationBuffers, createNewAggregationBuffer(aggregateFunctions), - sortBasedMergeAggFunctions) + createNewAggregationBuffer(sortBasedMergeAggFunctions), + aggregateFunctions) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow @@ -216,8 +217,10 @@ class ObjectAggregationIterator( * aggregation buffers * @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation * buffer for update operation - * @param sortBasedMergeAggFunctions aggregate functions needed to serialize the - * aggregation buffer + * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation + * buffer for merge operation + * @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the + * aggregation buffer * * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]]. */ @@ -228,7 +231,8 @@ class SortBasedAggregator( processRow: (InternalRow, InternalRow) => Unit, mergeAggregationBuffers: (InternalRow, InternalRow) => Unit, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow, - sortBasedMergeAggFunctions: Array[AggregateFunction]) { + makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow, + sortBasedUpdateAggFunctions: Array[AggregateFunction]) { // external sorter to sort the input (grouping key + input row) with grouping key. private val inputSorter = createExternalSorterForInput() @@ -237,6 +241,10 @@ class SortBasedAggregator( def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = { inputSorter.insertKV(groupingKey, inputRow) } + private def serializeBuffer(buffer: InternalRow): Unit = { + sortBasedUpdateAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach( + _.serializeAggregateBufferInPlace(buffer)) + } /** * Returns a destructive iterator of AggregationBufferEntry. @@ -247,16 +255,18 @@ class SortBasedAggregator( val inputIterator = inputSorter.sortedIterator() var hasNextInput: Boolean = inputIterator.next() var hasNextAggBuffer: Boolean = initialAggBufferIterator.next() - private var result: AggregationBufferEntry = _ + private var updateResult: AggregationBufferEntry = _ + private var finalResult: AggregationBufferEntry = _ private var groupingKey: UnsafeRow = _ override def hasNext(): Boolean = { - result != null || findNextSortedGroup() + updateResult != null || finalResult != null || findNextSortedGroup() } override def next(): AggregationBufferEntry = { - val returnResult = result - result = null + val returnResult = finalResult + updateResult = null + finalResult = null returnResult } @@ -265,25 +275,31 @@ class SortBasedAggregator( if (hasNextInput || hasNextAggBuffer) { // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator groupingKey = findGroupingKey() - result = new AggregationBufferEntry( + updateResult = new AggregationBufferEntry( groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions) + finalResult = new AggregationBufferEntry( + groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions) // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - processRow(result.aggregationBuffer, inputIterator.getValue) + processRow(updateResult.aggregationBuffer, inputIterator.getValue) hasNextInput = inputIterator.next() } + // This step ensures that the contents of the updateResult aggregation buffer are + // merged with the finalResult aggregation buffer to maintain consistency + if (hasNextAggBuffer) { + serializeBuffer(updateResult.aggregationBuffer) + mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer) + } // Secondly, merge the aggregation buffer with existing aggregation buffers. // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should // be called after calling processRow. - sortBasedMergeAggFunctions.collect { case f: ImperativeAggregate => f }.foreach( - _.initialize(result.aggregationBuffer)) while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { mergeAggregationBuffers( - result.aggregationBuffer, initialAggBufferIterator.getValue) + finalResult.aggregationBuffer, initialAggBufferIterator.getValue) hasNextAggBuffer = initialAggBufferIterator.next() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 3d4e2407506c6..9b39cf5195227 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -80,7 +80,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte updateInputRow, mergeAggBuffer, createNewAggregationBuffer, - sortBasedMergeAggFunctions = new Array[AggregateFunction](5)) + createNewAggregationBuffer, + sortBasedUpdateAggFunctions = new Array[AggregateFunction](5)) (5000 to 100000).foreach { _ => randomKV(inputRow, group) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 07182c3e87a01..905d2a4cc43ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -135,8 +135,8 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(aggs.length == 2) checkAnswer(df, Seq( - Row(0, Row(2, 1)), - Row(1, Row(2, 0)) + Row(0, Row(3, 1)), + Row(1, Row(3, 1)) )) } } From 433b1bb5e1a4a53f604bd93868d341e8d14ceb62 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Fri, 5 Apr 2019 14:17:46 -0500 Subject: [PATCH 08/13] [SPARK-27207] : Revert to previous commit and serialize the buffer with sortBasedMergeAggregateFunctions Use two aggregate buffers as in previous commit and serialize updateResult on sortBasedMergeAggFunctions before merging it with finalResult. --- .../aggregate/ObjectAggregationIterator.scala | 14 ++++++-------- .../aggregate/SortBasedAggregationStoreSuite.scala | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 4f331a65d93f2..5e5705214e475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -187,7 +187,7 @@ class ObjectAggregationIterator( mergeAggregationBuffers, createNewAggregationBuffer(aggregateFunctions), createNewAggregationBuffer(sortBasedMergeAggFunctions), - aggregateFunctions) + sortBasedMergeAggFunctions) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow @@ -219,7 +219,7 @@ class ObjectAggregationIterator( * buffer for update operation * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation * buffer for merge operation - * @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the + * @param sortBasedMergeAggFunctions aggregate functions needed to serialize the * aggregation buffer * * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]]. @@ -232,7 +232,7 @@ class SortBasedAggregator( mergeAggregationBuffers: (InternalRow, InternalRow) => Unit, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow, makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow, - sortBasedUpdateAggFunctions: Array[AggregateFunction]) { + sortBasedMergeAggFunctions: Array[AggregateFunction]) { // external sorter to sort the input (grouping key + input row) with grouping key. private val inputSorter = createExternalSorterForInput() @@ -242,7 +242,7 @@ class SortBasedAggregator( inputSorter.insertKV(groupingKey, inputRow) } private def serializeBuffer(buffer: InternalRow): Unit = { - sortBasedUpdateAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach( + sortBasedMergeAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach( _.serializeAggregateBufferInPlace(buffer)) } @@ -289,10 +289,8 @@ class SortBasedAggregator( // This step ensures that the contents of the updateResult aggregation buffer are // merged with the finalResult aggregation buffer to maintain consistency - if (hasNextAggBuffer) { - serializeBuffer(updateResult.aggregationBuffer) - mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer) - } + serializeBuffer(updateResult.aggregationBuffer) + mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer) // Secondly, merge the aggregation buffer with existing aggregation buffers. // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should // be called after calling processRow. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 9b39cf5195227..ac8f5b8077168 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -81,7 +81,7 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte mergeAggBuffer, createNewAggregationBuffer, createNewAggregationBuffer, - sortBasedUpdateAggFunctions = new Array[AggregateFunction](5)) + sortBasedMergeAggFunctions = new Array[AggregateFunction](5)) (5000 to 100000).foreach { _ => randomKV(inputRow, group) From 8f5c6b02e6ce789b61395b40d78962f57b1d05b3 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Tue, 30 Apr 2019 09:41:19 -0500 Subject: [PATCH 09/13] [SPARK-27207] : Reverting the two buffer logic and simplifying the code Since, https://github.com/apache/spark/pull/24459 fixes the init-update-merge issue, the fix here is reverted. --- .../aggregate/ObjectAggregationIterator.scala | 43 +++++-------------- .../SortBasedAggregationStoreSuite.scala | 5 +-- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 5e5705214e475..90689f0e8fa5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -185,9 +185,7 @@ class ObjectAggregationIterator( StructType.fromAttributes(groupingAttributes), processRow, mergeAggregationBuffers, - createNewAggregationBuffer(aggregateFunctions), - createNewAggregationBuffer(sortBasedMergeAggFunctions), - sortBasedMergeAggFunctions) + createNewAggregationBuffer(sortBasedMergeAggFunctions)) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow @@ -215,12 +213,7 @@ class ObjectAggregationIterator( * @param processRow Function to update the aggregation buffer with input rows * @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing * aggregation buffers - * @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation - * buffer for update operation - * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation - * buffer for merge operation - * @param sortBasedMergeAggFunctions aggregate functions needed to serialize the - * aggregation buffer + * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer * * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]]. */ @@ -230,9 +223,7 @@ class SortBasedAggregator( groupingSchema: StructType, processRow: (InternalRow, InternalRow) => Unit, mergeAggregationBuffers: (InternalRow, InternalRow) => Unit, - makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow, - makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow, - sortBasedMergeAggFunctions: Array[AggregateFunction]) { + makeEmptyAggregationBuffer: => InternalRow) { // external sorter to sort the input (grouping key + input row) with grouping key. private val inputSorter = createExternalSorterForInput() @@ -241,10 +232,6 @@ class SortBasedAggregator( def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = { inputSorter.insertKV(groupingKey, inputRow) } - private def serializeBuffer(buffer: InternalRow): Unit = { - sortBasedMergeAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach( - _.serializeAggregateBufferInPlace(buffer)) - } /** * Returns a destructive iterator of AggregationBufferEntry. @@ -255,18 +242,16 @@ class SortBasedAggregator( val inputIterator = inputSorter.sortedIterator() var hasNextInput: Boolean = inputIterator.next() var hasNextAggBuffer: Boolean = initialAggBufferIterator.next() - private var updateResult: AggregationBufferEntry = _ - private var finalResult: AggregationBufferEntry = _ + private var result: AggregationBufferEntry = _ private var groupingKey: UnsafeRow = _ override def hasNext(): Boolean = { - updateResult != null || finalResult != null || findNextSortedGroup() + result != null || findNextSortedGroup() } override def next(): AggregationBufferEntry = { - val returnResult = finalResult - updateResult = null - finalResult = null + val returnResult = result + result = null returnResult } @@ -275,29 +260,21 @@ class SortBasedAggregator( if (hasNextInput || hasNextAggBuffer) { // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator groupingKey = findGroupingKey() - updateResult = new AggregationBufferEntry( - groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions) - finalResult = new AggregationBufferEntry( - groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions) + result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer) // Firstly, update the aggregation buffer with input rows. while (hasNextInput && groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { - processRow(updateResult.aggregationBuffer, inputIterator.getValue) + processRow(result.aggregationBuffer, inputIterator.getValue) hasNextInput = inputIterator.next() } - // This step ensures that the contents of the updateResult aggregation buffer are - // merged with the finalResult aggregation buffer to maintain consistency - serializeBuffer(updateResult.aggregationBuffer) - mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer) // Secondly, merge the aggregation buffer with existing aggregation buffers. // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should // be called after calling processRow. while (hasNextAggBuffer && groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { - mergeAggregationBuffers( - finalResult.aggregationBuffer, initialAggBufferIterator.getValue) + mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue) hasNextAggBuffer = initialAggBufferIterator.next() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index ac8f5b8077168..dc67446460877 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.unsafe.KVIterator @@ -79,9 +78,7 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte groupingSchema, updateInputRow, mergeAggBuffer, - createNewAggregationBuffer, - createNewAggregationBuffer, - sortBasedMergeAggFunctions = new Array[AggregateFunction](5)) + createNewAggregationBuffer) (5000 to 100000).foreach { _ => randomKV(inputRow, group) From df330fa1cf45e4533d93a8aebade12a504b6726b Mon Sep 17 00:00:00 2001 From: pgandhi Date: Mon, 6 May 2019 16:03:49 -0500 Subject: [PATCH 10/13] [SPARK-27207] : Coming up with a unit test for custom UDAF --- .../sql/TypedImperativeAggregateSuite.scala | 102 ++++++++++++++++++ .../sql/hive/execution/HiveUDAFSuite.scala | 30 ------ 2 files changed, 102 insertions(+), 30 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index c5fb17345222a..f50ef8d3be806 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax +import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate @@ -210,6 +211,20 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } + test("SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate") { + withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") { + val df = data.toDF("value", "key").coalesce(2) + val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value")) + val expected = data.groupBy(_._2).toSeq.map { group => + val (key, values) = group + val valueMax = values.map(_._1).max + val countValue = values.size + Row(key, valueMax, countValue, valueMax) + } + checkAnswer(query, expected) + } + } + private def typedMax(column: Column): Column = { val max = TypedMax(column.expr, nullable = false) Column(max.toAggregateExpression()) @@ -219,6 +234,11 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { val max = TypedMax(column.expr, nullable = true) Column(max.toAggregateExpression()) } + + private def typedMax2(column: Column): Column = { + val max = TypedMax2(column.expr, nullable = false) + Column(max.toAggregateExpression()) + } } object TypedImperativeAggregateSuite { @@ -299,5 +319,87 @@ object TypedImperativeAggregateSuite { } } + /** + * Calculate the max value with object aggregation buffer. This stores class MaxValue + * in aggregation buffer. + */ + private case class TypedMax2( + child: Expression, + nullable: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { + + + var maxValueBuffer: MaxValue = null + override def createAggregationBuffer(): MaxValue = { + // Returns Int.MinValue if all inputs are null + maxValueBuffer = new MaxValue(Int.MinValue) + maxValueBuffer + } + + override def update(buffer: MaxValue, input: InternalRow): MaxValue = { + child.eval(input) match { + case inputValue: Int => + if (inputValue > buffer.value) { + buffer.value = inputValue + buffer.isValueSet = true + } + case null => // skip + } + buffer + } + + override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = { + // The below if condition will throw a Null Pointer Exception if initialize() is not called + if (maxValueBuffer.isValueSet) { + // do nothing + } + if (inputMax.value > bufferMax.value) { + bufferMax.value = inputMax.value + bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet + } + bufferMax + } + + override def eval(bufferMax: MaxValue): Any = { + if (nullable && bufferMax.isValueSet == false) { + null + } else { + bufferMax.value + } + } + + override lazy val deterministic: Boolean = true + + override def children: Seq[Expression] = Seq(child) + + override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) + + override def dataType: DataType = IntegerType + + override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = + copy(inputAggBufferOffset = newOffset) + + override def serialize(buffer: MaxValue): Array[Byte] = { + val out = new ByteArrayOutputStream() + val stream = new DataOutputStream(out) + stream.writeBoolean(buffer.isValueSet) + stream.writeInt(buffer.value) + out.toByteArray + } + + override def deserialize(storageFormat: Array[Byte]): MaxValue = { + val in = new ByteArrayInputStream(storageFormat) + val stream = new DataInputStream(in) + val isValueSet = stream.readBoolean() + val value = stream.readInt() + new MaxValue(value, isValueSet) + } + } + private class MaxValue(var value: Int, var isValueSet: Boolean = false) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 01e4a0106b07c..3252cdafa1be1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -49,16 +49,6 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { (2: Integer) -> null, (3: Integer) -> null ).toDF("key", "value").repartition(2).createOrReplaceTempView("t") - Seq( - (0: Integer) -> "val_0", - (1: Integer) -> "val_1", - (2: Integer) -> "val_2", - (3: Integer) -> "val_3", - (4: Integer) -> "val_4", - (5: Integer) -> "val_5", - (6: Integer) -> null, - (7: Integer) -> null - ).toDF("key", "value").repartition(2).createOrReplaceTempView("t2") } protected override def afterAll(): Unit = { @@ -133,26 +123,6 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } - test("SPARK-27207: customized Hive UDAF with two aggregation buffers for Sort" + - " Based Aggregation") { - withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "2") { - val df = sql("SELECT key % 2, mock2(value) FROM t2 GROUP BY key % 2") - - val aggs = df.queryExecution.executedPlan.collect { - case agg: ObjectHashAggregateExec => agg - } - - // There should be two aggregate operators, one for partial aggregation, and the other for - // global aggregation. - assert(aggs.length == 2) - - checkAnswer(df, Seq( - Row(0, Row(3, 1)), - Row(1, Row(3, 1)) - )) - } - } - test("call JAVA UDAF") { withTempView("temp") { withUserDefinedFunction("myDoubleAvg" -> false) { From 5bd474cd55f6f77fbd003c532bd77343e42750bf Mon Sep 17 00:00:00 2001 From: pgandhi Date: Mon, 6 May 2019 16:14:45 -0500 Subject: [PATCH 11/13] [SPARK-27207] : Fixing Scalastyle Tests --- .../apache/spark/sql/TypedImperativeAggregateSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index f50ef8d3be806..cb732f71c7baf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -320,9 +320,9 @@ object TypedImperativeAggregateSuite { } /** - * Calculate the max value with object aggregation buffer. This stores class MaxValue - * in aggregation buffer. - */ + * Calculate the max value with object aggregation buffer. This stores class MaxValue + * in aggregation buffer. + */ private case class TypedMax2( child: Expression, nullable: Boolean = false, From 006616ed3c82ef2f7cac1631d9306e34fd069cff Mon Sep 17 00:00:00 2001 From: pgandhi Date: Tue, 7 May 2019 11:17:07 -0500 Subject: [PATCH 12/13] [SPARK-27207] : Simplifying unit test and indentation --- .../sql/TypedImperativeAggregateSuite.scala | 69 +++++-------------- 1 file changed, 19 insertions(+), 50 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index cb732f71c7baf..3d6ced7cd14f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -215,13 +215,7 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") { val df = data.toDF("value", "key").coalesce(2) val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value")) - val expected = data.groupBy(_._2).toSeq.map { group => - val (key, values) = group - val valueMax = values.map(_._1).max - val countValue = values.size - Row(key, valueMax, countValue, valueMax) - } - checkAnswer(query, expected) + query.show(10, false) } } @@ -320,54 +314,37 @@ object TypedImperativeAggregateSuite { } /** - * Calculate the max value with object aggregation buffer. This stores class MaxValue - * in aggregation buffer. + * SPARK-27207: Dummy UDAF to test whether aggregate buffers are reinitialized for + * SortBasedAggregate in aggregation buffer. */ private case class TypedMax2( - child: Expression, - nullable: Boolean = false, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) + child: Expression, + nullable: Boolean = false, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { - var maxValueBuffer: MaxValue = null + var initialized = false override def createAggregationBuffer(): MaxValue = { - // Returns Int.MinValue if all inputs are null - maxValueBuffer = new MaxValue(Int.MinValue) - maxValueBuffer + initialized = true + null } override def update(buffer: MaxValue, input: InternalRow): MaxValue = { - child.eval(input) match { - case inputValue: Int => - if (inputValue > buffer.value) { - buffer.value = inputValue - buffer.isValueSet = true - } - case null => // skip - } - buffer + // The below if condition will fail if initialize() is not called + assert(initialized) + null } override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = { - // The below if condition will throw a Null Pointer Exception if initialize() is not called - if (maxValueBuffer.isValueSet) { - // do nothing - } - if (inputMax.value > bufferMax.value) { - bufferMax.value = inputMax.value - bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet - } - bufferMax + // The below if condition will fail if initialize() is not called + assert(initialized) + null } override def eval(bufferMax: MaxValue): Any = { - if (nullable && bufferMax.isValueSet == false) { - null - } else { - bufferMax.value - } + null } override lazy val deterministic: Boolean = true @@ -385,19 +362,11 @@ object TypedImperativeAggregateSuite { copy(inputAggBufferOffset = newOffset) override def serialize(buffer: MaxValue): Array[Byte] = { - val out = new ByteArrayOutputStream() - val stream = new DataOutputStream(out) - stream.writeBoolean(buffer.isValueSet) - stream.writeInt(buffer.value) - out.toByteArray + null } override def deserialize(storageFormat: Array[Byte]): MaxValue = { - val in = new ByteArrayInputStream(storageFormat) - val stream = new DataInputStream(in) - val isValueSet = stream.readBoolean() - val value = stream.readInt() - new MaxValue(value, isValueSet) + null } } From c8959f4624760d5bca6ea449f0eeccad2e168ea5 Mon Sep 17 00:00:00 2001 From: pgandhi Date: Wed, 8 May 2019 09:35:04 -0500 Subject: [PATCH 13/13] [SPARK-27207] : Reverting changes and updating doc --- .../expressions/aggregate/interfaces.scala | 2 +- .../aggregate/ObjectAggregationIterator.scala | 34 +++++---- .../sql/TypedImperativeAggregateSuite.scala | 71 ------------------- 3 files changed, 17 insertions(+), 90 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 56c2ee6b53fe5..6fc20530f084e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -158,7 +158,7 @@ case class AggregateExpression( * ([[aggBufferAttributes]]) of an aggregation buffer which is used to hold partial aggregate * results. At runtime, multiple aggregate functions are evaluated by the same operator using a * combined aggregation buffer which concatenates the aggregation buffers of the individual - * aggregate functions. + * aggregate functions. Please note that aggregate functions should be stateless. * * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 90689f0e8fa5d..43514f5271ac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -58,8 +58,8 @@ class ObjectAggregationIterator( private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ - val (sortBasedMergeAggExpressions, sortBasedMergeAggFunctions): ( - Seq[AggregateExpression], Array[AggregateFunction]) = { + // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers + private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { val newExpressions = aggregateExpressions.map { case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) @@ -67,14 +67,9 @@ class ObjectAggregationIterator( agg.copy(mode = Final) case other => other } - (newExpressions, initializeAggregateFunctions(newExpressions, 0)) - } - - // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers - private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { - val newInputAttributes = sortBasedMergeAggFunctions.flatMap(_.inputAggBufferAttributes) - generateProcessRow( - sortBasedMergeAggExpressions, sortBasedMergeAggFunctions, newInputAttributes) + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + generateProcessRow(newExpressions, newFunctions, newInputAttributes) } /** @@ -98,7 +93,7 @@ class ObjectAggregationIterator( */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { if (groupingExpressions.isEmpty) { - val defaultAggregationBuffer = createNewAggregationBuffer(aggregateFunctions) + val defaultAggregationBuffer = createNewAggregationBuffer() generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer) } else { throw new IllegalStateException( @@ -111,15 +106,18 @@ class ObjectAggregationIterator( // // - when creating aggregation buffer for a new group in the hash map, and // - when creating the re-used buffer for sort-based aggregation - private def createNewAggregationBuffer( - functions: Array[AggregateFunction]): SpecificInternalRow = { - val bufferFieldTypes = functions.flatMap(_.aggBufferAttributes.map(_.dataType)) + private def createNewAggregationBuffer(): SpecificInternalRow = { + val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType)) val buffer = new SpecificInternalRow(bufferFieldTypes) + initAggregationBuffer(buffer) + buffer + } + + private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = { // Initializes declarative aggregates' buffer values expressionAggInitialProjection.target(buffer)(EmptyRow) // Initializes imperative aggregates' buffer values - functions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) - buffer + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) } private def getAggregationBufferByKey( @@ -127,7 +125,7 @@ class ObjectAggregationIterator( var aggBuffer = hashMap.getAggregationBuffer(groupingKey) if (aggBuffer == null) { - aggBuffer = createNewAggregationBuffer(aggregateFunctions) + aggBuffer = createNewAggregationBuffer() hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer) } @@ -185,7 +183,7 @@ class ObjectAggregationIterator( StructType.fromAttributes(groupingAttributes), processRow, mergeAggregationBuffers, - createNewAggregationBuffer(sortBasedMergeAggFunctions)) + createNewAggregationBuffer()) while (inputRows.hasNext) { // NOTE: The input row is always UnsafeRow diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 3d6ced7cd14f9..c5fb17345222a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax -import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate @@ -211,14 +210,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { checkAnswer(query, expected) } - test("SPARK-27207: Ensure aggregate buffers are initialized again for SortBasedAggregate") { - withSQLConf("spark.sql.objectHashAggregate.sortBased.fallbackThreshold" -> "5") { - val df = data.toDF("value", "key").coalesce(2) - val query = df.groupBy($"key").agg(typedMax2($"value"), count($"value"), typedMax2($"value")) - query.show(10, false) - } - } - private def typedMax(column: Column): Column = { val max = TypedMax(column.expr, nullable = false) Column(max.toAggregateExpression()) @@ -228,11 +219,6 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { val max = TypedMax(column.expr, nullable = true) Column(max.toAggregateExpression()) } - - private def typedMax2(column: Column): Column = { - val max = TypedMax2(column.expr, nullable = false) - Column(max.toAggregateExpression()) - } } object TypedImperativeAggregateSuite { @@ -313,62 +299,5 @@ object TypedImperativeAggregateSuite { } } - /** - * SPARK-27207: Dummy UDAF to test whether aggregate buffers are reinitialized for - * SortBasedAggregate in aggregation buffer. - */ - private case class TypedMax2( - child: Expression, - nullable: Boolean = false, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes { - - - var initialized = false - override def createAggregationBuffer(): MaxValue = { - initialized = true - null - } - - override def update(buffer: MaxValue, input: InternalRow): MaxValue = { - // The below if condition will fail if initialize() is not called - assert(initialized) - null - } - - override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = { - // The below if condition will fail if initialize() is not called - assert(initialized) - null - } - - override def eval(bufferMax: MaxValue): Any = { - null - } - - override lazy val deterministic: Boolean = true - - override def children: Seq[Expression] = Seq(child) - - override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType) - - override def dataType: DataType = IntegerType - - override def withNewMutableAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = - copy(mutableAggBufferOffset = newOffset) - - override def withNewInputAggBufferOffset(newOffset: Int): TypedImperativeAggregate[MaxValue] = - copy(inputAggBufferOffset = newOffset) - - override def serialize(buffer: MaxValue): Array[Byte] = { - null - } - - override def deserialize(storageFormat: Array[Byte]): MaxValue = { - null - } - } - private class MaxValue(var value: Int, var isValueSet: Boolean = false) }