@@ -186,8 +186,7 @@ class ObjectAggregationIterator(
186186 processRow,
187187 mergeAggregationBuffers,
188188 createNewAggregationBuffer(aggregateFunctions),
189- createNewAggregationBuffer(sortBasedMergeAggFunctions),
190- aggregateFunctions)
189+ sortBasedMergeAggFunctions)
191190
192191 while (inputRows.hasNext) {
193192 // NOTE: The input row is always UnsafeRow
@@ -217,10 +216,8 @@ class ObjectAggregationIterator(
217216 * aggregation buffers
218217 * @param makeEmptyAggregationBufferForSortBasedUpdateAggFunctions Creates an empty aggregation
219218 * buffer for update operation
220- * @param makeEmptyAggregationBufferForSortBasedMergeAggFunctions Creates an empty aggregation
221- * buffer for merge operation
222- * @param sortBasedUpdateAggFunctions aggregate functions needed to serialize the
223- * aggregation buffer
219+ * @param sortBasedMergeAggFunctions aggregate functions needed to serialize the
220+ * aggregation buffer
224221 *
225222 * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec ]].
226223 */
@@ -231,8 +228,7 @@ class SortBasedAggregator(
231228 processRow : (InternalRow , InternalRow ) => Unit ,
232229 mergeAggregationBuffers : (InternalRow , InternalRow ) => Unit ,
233230 makeEmptyAggregationBufferForSortBasedUpdateAggFunctions : => InternalRow ,
234- makeEmptyAggregationBufferForSortBasedMergeAggFunctions : => InternalRow ,
235- sortBasedUpdateAggFunctions : Array [AggregateFunction ]) {
231+ sortBasedMergeAggFunctions : Array [AggregateFunction ]) {
236232
237233 // external sorter to sort the input (grouping key + input row) with grouping key.
238234 private val inputSorter = createExternalSorterForInput()
@@ -241,10 +237,6 @@ class SortBasedAggregator(
241237 def addInput (groupingKey : UnsafeRow , inputRow : UnsafeRow ): Unit = {
242238 inputSorter.insertKV(groupingKey, inputRow)
243239 }
244- private def serializeBuffer (buffer : InternalRow ): Unit = {
245- sortBasedUpdateAggFunctions.collect { case f : TypedImperativeAggregate [_] => f }.foreach(
246- _.serializeAggregateBufferInPlace(buffer))
247- }
248240
249241 /**
250242 * Returns a destructive iterator of AggregationBufferEntry.
@@ -255,18 +247,16 @@ class SortBasedAggregator(
255247 val inputIterator = inputSorter.sortedIterator()
256248 var hasNextInput : Boolean = inputIterator.next()
257249 var hasNextAggBuffer : Boolean = initialAggBufferIterator.next()
258- private var updateResult : AggregationBufferEntry = _
259- private var finalResult : AggregationBufferEntry = _
250+ private var result : AggregationBufferEntry = _
260251 private var groupingKey : UnsafeRow = _
261252
262253 override def hasNext (): Boolean = {
263- updateResult != null || finalResult != null || findNextSortedGroup()
254+ result != null || findNextSortedGroup()
264255 }
265256
266257 override def next (): AggregationBufferEntry = {
267- val returnResult = finalResult
268- updateResult = null
269- finalResult = null
258+ val returnResult = result
259+ result = null
270260 returnResult
271261 }
272262
@@ -275,31 +265,25 @@ class SortBasedAggregator(
275265 if (hasNextInput || hasNextAggBuffer) {
276266 // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator
277267 groupingKey = findGroupingKey()
278- updateResult = new AggregationBufferEntry (
268+ result = new AggregationBufferEntry (
279269 groupingKey, makeEmptyAggregationBufferForSortBasedUpdateAggFunctions)
280- finalResult = new AggregationBufferEntry (
281- groupingKey, makeEmptyAggregationBufferForSortBasedMergeAggFunctions)
282270
283271 // Firstly, update the aggregation buffer with input rows.
284272 while (hasNextInput &&
285273 groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0 ) {
286- processRow(updateResult .aggregationBuffer, inputIterator.getValue)
274+ processRow(result .aggregationBuffer, inputIterator.getValue)
287275 hasNextInput = inputIterator.next()
288276 }
289277
290- // This step ensures that the contents of the updateResult aggregation buffer are
291- // merged with the finalResult aggregation buffer to maintain consistency
292- if (hasNextAggBuffer) {
293- serializeBuffer(updateResult.aggregationBuffer)
294- mergeAggregationBuffers(finalResult.aggregationBuffer, updateResult.aggregationBuffer)
295- }
296278 // Secondly, merge the aggregation buffer with existing aggregation buffers.
297279 // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should
298280 // be called after calling processRow.
281+ sortBasedMergeAggFunctions.collect { case f : ImperativeAggregate => f }.foreach(
282+ _.initialize(result.aggregationBuffer))
299283 while (hasNextAggBuffer &&
300284 groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0 ) {
301285 mergeAggregationBuffers(
302- finalResult .aggregationBuffer, initialAggBufferIterator.getValue)
286+ result .aggregationBuffer, initialAggBufferIterator.getValue)
303287 hasNextAggBuffer = initialAggBufferIterator.next()
304288 }
305289
0 commit comments