Skip to content

Commit 6a5ed71

Browse files
author
pgandhi
committed
[SPARK-27207] : Changing design to use one buffer but initializing for different aggregate functions
1 parent 088cbc6 commit 6a5ed71

File tree

3 files changed

+16
-33
lines changed

3 files changed

+16
-33
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,7 @@ class ObjectAggregationIterator(
186186
processRow,
187187
mergeAggregationBuffers,
188188
createNewAggregationBuffer(aggregateFunctions),
189-
createNewAggregationBuffer(sortBasedMergeAggFunctions),
190-
aggregateFunctions)
189+
sortBasedMergeAggFunctions)
191190

192191
while (inputRows.hasNext) {
193192
// NOTE: The input row is always UnsafeRow
@@ -217,10 +216,8 @@ class ObjectAggregationIterator(
217216
* aggregation buffers
218217
* @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation
219218
* buffer for update operation
220-
* @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation
221-
* buffer for merge operation
222-
* @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the
223-
* aggregation buffer
219+
* @param sortBasedMergeAggFunctions aggregate functions needed to serialize the
220+
* aggregation buffer
224221
*
225222
* @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]].
226223
*/
@@ -231,8 +228,7 @@ class SortBasedAggregator(
231228
processRow: (InternalRow, InternalRow) => Unit,
232229
mergeAggregationBuffers: (InternalRow, InternalRow) => Unit,
233230
makeEmptyAggregationBufferForSortBasedUpdateAggFunctions: => InternalRow,
234-
makeEmptyAggregationBufferForSortBasedMergeAggFunctions: => InternalRow,
235-
sortBasedUpdateAggFunctions: Array[AggregateFunction]) {
231+
sortBasedMergeAggFunctions: Array[AggregateFunction]) {
236232

237233
// external sorter to sort the input (grouping key + input row) with grouping key.
238234
private val inputSorter = createExternalSorterForInput()
@@ -241,10 +237,6 @@ class SortBasedAggregator(
241237
def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = {
242238
inputSorter.insertKV(groupingKey, inputRow)
243239
}
244-
private def serializeBuffer(buffer: InternalRow): Unit = {
245-
sortBasedUpdateAggFunctions.collect { case f: TypedImperativeAggregate[_] => f }.foreach(
246-
_.serializeAggregateBufferInPlace(buffer))
247-
}
248240

249241
/**
250242
* Returns a destructive iterator of AggregationBufferEntry.
@@ -255,18 +247,16 @@ class SortBasedAggregator(
255247
val inputIterator = inputSorter.sortedIterator()
256248
var hasNextInput: Boolean = inputIterator.next()
257249
var hasNextAggBuffer: Boolean = initialAggBufferIterator.next()
258-
private var updateResult: AggregationBufferEntry = _
259-
private var finalResult: AggregationBufferEntry = _
250+
private var result: AggregationBufferEntry = _
260251
private var groupingKey: UnsafeRow = _
261252

262253
override def hasNext(): Boolean = {
263-
updateResult != null || finalResult != null || findNextSortedGroup()
254+
result != null || findNextSortedGroup()
264255
}
265256

266257
override def next(): AggregationBufferEntry = {
267-
val returnResult = finalResult
268-
updateResult = null
269-
finalResult = null
258+
val returnResult = result
259+
result = null
270260
returnResult
271261
}
272262

@@ -275,31 +265,25 @@ class SortBasedAggregator(
275265
if (hasNextInput || hasNextAggBuffer) {
276266
// Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
277267
groupingKey = findGroupingKey()
278-
updateResult = new AggregationBufferEntry(
268+
result = new AggregationBufferEntry(
279269
groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions)
280-
finalResult = new AggregationBufferEntry(
281-
groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions)
282270

283271
// Firstly, update the aggregation buffer with input rows.
284272
while (hasNextInput &&
285273
groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) {
286-
processRow(updateResult.aggregationBuffer, inputIterator.getValue)
274+
processRow(result.aggregationBuffer, inputIterator.getValue)
287275
hasNextInput = inputIterator.next()
288276
}
289277

290-
// This step ensures that the contents of the updateResult aggregation buffer are
291-
// merged with the finalResult aggregation buffer to maintain consistency
292-
if (hasNextAggBuffer) {
293-
serializeBuffer(updateResult.aggregationBuffer)
294-
mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer)
295-
}
296278
// Secondly, merge the aggregation buffer with existing aggregation buffers.
297279
// NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
298280
// be called after calling processRow.
281+
sortBasedMergeAggFunctions.collect { case f: ImperativeAggregate => f }.foreach(
282+
_.initialize(result.aggregationBuffer))
299283
while (hasNextAggBuffer &&
300284
groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) {
301285
mergeAggregationBuffers(
302-
finalResult.aggregationBuffer, initialAggBufferIterator.getValue)
286+
result.aggregationBuffer, initialAggBufferIterator.getValue)
303287
hasNextAggBuffer = initialAggBufferIterator.next()
304288
}
305289

sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte
8080
updateInputRow,
8181
mergeAggBuffer,
8282
createNewAggregationBuffer,
83-
createNewAggregationBuffer,
84-
sortBasedUpdateAggFunctions = new Array[AggregateFunction](5))
83+
sortBasedMergeAggFunctions = new Array[AggregateFunction](5))
8584

8685
(5000 to 100000).foreach { _ =>
8786
randomKV(inputRow, group)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
135135
assert(aggs.length == 2)
136136

137137
checkAnswer(df, Seq(
138-
Row(0, Row(3, 1)),
139-
Row(1, Row(3, 1))
138+
Row(0, Row(2, 1)),
139+
Row(1, Row(2, 0))
140140
))
141141
}
142142
}

0 commit comments

Comments
 (0)