diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 1f4ff9c4b184e..6084028a5c810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -45,7 +45,7 @@ abstract class Collect extends ImperativeAggregate { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def supportsPartial: Boolean = false + override def forceSortAggregate: Boolean = true override def aggBufferAttributes: Seq[AttributeReference] = Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 504cea52797de..ee8217f939e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -174,10 +174,10 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def inputAggBufferAttributes: Seq[AttributeReference] /** - * Indicates if this function supports partial aggregation. - * Currently Hive UDAF is the only one that doesn't support partial aggregation. + * Indicates if this function needs to aggregate values group-by-group in a single step. + * If true, we must always use a `SortAggregateExec` operator without partial aggregates. */ - def supportsPartial: Boolean = true + def forceSortAggregate: Boolean = false /** * Result of the aggregate function when the input is empty. This is currently only used for the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index c0b453dccf5e9..6d0383826b38c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -431,7 +431,7 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) override def dataType: DataType = IntegerType override def nullable: Boolean = true - override def supportsPartial: Boolean = false + override def forceSortAggregate: Boolean = true override lazy val mergeExpressions = throw new UnsupportedOperationException("Window Functions do not support merging.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8e2f2ed4f86b9..61e9d439ec34d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -261,7 +261,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (aggregateExpressions.map(_.aggregateFunction).exists(_.forceSortAggregate)) { if (functionsWithDistinct.nonEmpty) { sys.error("Distinct columns cannot exist in Aggregate operator containing " + "aggregate functions which don't support partial aggregation.") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index c53675694f620..59b1ff5dcff61 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -336,7 +336,7 @@ private[hive] case class HiveUDAFFunction( override def nullable: Boolean = true - override def supportsPartial: Boolean = false + override def forceSortAggregate: Boolean = true override lazy val dataType: DataType = inspectorToDataType(returnInspector)