Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
if (functionsWithDistinct.nonEmpty) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.AggUtils.planAggregateWithoutPartial(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.AggUtils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
resultExpressions,
planLater(child))
} else {
if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) {
sys.error("Distinct columns cannot exist in Aggregate operator containing " +
"aggregate functions which don't support partial aggregation.")
}
aggregate.AggUtils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
Expand Down

Large diffs are not rendered by default.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
Expand All @@ -41,7 +42,11 @@ case class HashAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends AggregateExec with CodegenSupport {
extends UnaryExecNode with CodegenSupport {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))

Expand All @@ -55,6 +60,21 @@ case class HashAggregateExec(
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

// This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash
// map and/or the sort-based aggregation once it has processed a given number of input rows.
private val testFallbackStartsAt: Option[(Int, Int)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.Utils

Expand All @@ -37,11 +38,30 @@ case class SortAggregateExec(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends AggregateExec {
extends UnaryExecNode {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
}

override def producedAttributes: AttributeSet =
AttributeSet(aggregateAttributes) ++
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
AttributeSet(aggregateBufferAttributes)

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.aggregate.PartialAggregate
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -153,31 +151,18 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
assert(requiredChildDistributions.length == operator.children.length)
assert(requiredChildOrderings.length == operator.children.length)
var children: Seq[SparkPlan] = operator.children
assert(requiredChildDistributions.length == children.length)
assert(requiredChildOrderings.length == children.length)

def createShuffleExchange(dist: Distribution, child: SparkPlan) =
ShuffleExchange(createPartitioning(dist, defaultNumPreShufflePartitions), child)

var (parent, children) = operator match {
case PartialAggregate(childDist) if !operator.outputPartitioning.satisfies(childDist) =>
// If an aggregation needs a shuffle and support partial aggregations, a map-side partial
// aggregation and a shuffle are added as children.
val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
(mergeAgg, createShuffleExchange(
requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
case _ =>
// Ensure that the operator's children satisfy their output distribution requirements:
val childrenWithDist = operator.children.zip(requiredChildDistributions)
val newChildren = childrenWithDist.map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
createShuffleExchange(distribution, child)
}
(operator, newChildren)
// Ensure that the operator's children satisfy their output distribution requirements:
children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
Expand Down Expand Up @@ -270,7 +255,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}

parent.withNewChildren(children)
operator.withNewChildren(children)
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1248,17 +1248,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}

/**
* Verifies that there is a single Aggregation for `df`
* Verifies that there is no Exchange between the Aggregations for `df`
*/
private def verifyNonExchangingSingleAgg(df: DataFrame) = {
private def verifyNonExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
case agg: HashAggregateExec =>
atFirstAgg = !atFirstAgg
case _ =>
if (atFirstAgg) {
fail("Should not have back to back Aggregates")
fail("Should not have operators between the two aggregations")
}
atFirstAgg = true
case _ =>
}
}

Expand Down Expand Up @@ -1292,10 +1292,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Group by the column we are distributed by. This should generate a plan with no exchange
// between the aggregates
val df3 = testData.repartition($"key").groupBy("key").count()
verifyNonExchangingSingleAgg(df3)
verifyNonExchangingSingleAgg(testData.repartition($"key", $"value")
verifyNonExchangingAgg(df3)
verifyNonExchangingAgg(testData.repartition($"key", $"value")
.groupBy("key", "value").count())
verifyNonExchangingSingleAgg(testData.repartition($"key").groupBy("key", "value").count())

// Grouping by just the first distributeBy expr, need to exchange.
verifyExchangingAgg(testData.repartition($"key", $"value")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
Expand All @@ -38,84 +37,36 @@ class PlannerSuite extends SharedSQLContext {

setupTestData()

private def testPartialAggregationPlan(query: LogicalPlan): Seq[SparkPlan] = {
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val planner = spark.sessionState.planner
import planner._
val ensureRequirements = EnsureRequirements(spark.sessionState.conf)
val planned = Aggregation(query).headOption.map(ensureRequirements(_))
.getOrElse(fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
planned.collect { case n if n.nodeName contains "Aggregate" => n }
val plannedOption = Aggregation(query).headOption
val planned =
plannedOption.getOrElse(
fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }

// For the new aggregation code path, there will be four aggregate operator for
// distinct aggregations.
assert(
aggregations.size == 2 || aggregations.size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("count is partially aggregated") {
val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
assert(testPartialAggregationPlan(query).size == 2,
s"The plan of query $query does not have partial aggregations.")
testPartialAggregationPlan(query)
}

test("count distinct is partially aggregated") {
val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
testPartialAggregationPlan(query)
// For the new aggregation code path, there will be four aggregate operator for distinct
// aggregations.
assert(testPartialAggregationPlan(query).size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("mixed aggregates are partially aggregated") {
val query =
testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
// For the new aggregation code path, there will be four aggregate operator for distinct
// aggregations.
assert(testPartialAggregationPlan(query).size == 4,
s"The plan of query $query does not have partial aggregations.")
}

test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
withTempView("testSortBasedPartialAggregation") {
val schema = StructType(
StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
spark.createDataFrame(rowRDD, schema)
.createOrReplaceTempView("testSortBasedPartialAggregation")

// This test assumes a query below uses sort-based aggregations
val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
.queryExecution.executedPlan
// This line extracts both SortAggregate and Sort operators
val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
assert(extractedOps.size == 4 && aggOps.size == 2,
s"The plan $planned does not have correct sort-based partial aggregate pairs.")
}
}

test("non-partial aggregation for aggregates") {
withTempView("testNonPartialAggregation") {
val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)
val row = Row.fromSeq(Seq.fill(1)(null))
val rowRDD = sparkContext.parallelize(row :: Nil)
spark.createDataFrame(rowRDD, schema).repartition($"value")
.createOrReplaceTempView("testNonPartialAggregation")

val planned1 = sql("SELECT SUM(value) FROM testNonPartialAggregation GROUP BY value")
.queryExecution.executedPlan

// If input data are already partitioned and the same columns are used in grouping keys and
// aggregation values, no partial aggregation exist in query plans.
val aggOps1 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps1.size == 1, s"The plan $planned1 has partial aggregations.")

val planned2 = sql(
"""
|SELECT t.value, SUM(DISTINCT t.value)
|FROM (SELECT * FROM testNonPartialAggregation ORDER BY value) t
|GROUP BY t.value
""".stripMargin).queryExecution.executedPlan

val aggOps2 = planned1.collect { case n if n.nodeName contains "Aggregate" => n }
assert(aggOps2.size == 1, s"The plan $planned2 has partial aggregations.")
}
testPartialAggregationPlan(query)
}

test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
Expand Down