Skip to content

Commit bea8845

Browse files
committed
Merge pull request #148 from mbautin/csd-1.6_SPARK-12213
Backport: [SPARK-12213][SQL] use multiple partitions for single distinct query
2 parents edc1192 + b6db2cc commit bea8845

File tree

10 files changed

+422
-990
lines changed

10 files changed

+422
-990
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst
1919

2020
private[spark] trait CatalystConf {
2121
def caseSensitiveAnalysis: Boolean
22-
23-
protected[spark] def specializeSingleDistinctAggPlanning: Boolean
2422
}
2523

2624
/**
@@ -31,13 +29,8 @@ object EmptyConf extends CatalystConf {
3129
override def caseSensitiveAnalysis: Boolean = {
3230
throw new UnsupportedOperationException
3331
}
34-
35-
protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = {
36-
throw new UnsupportedOperationException
37-
}
3832
}
3933

4034
/** A CatalystConf that can be used for local testing. */
4135
case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf {
42-
protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true
4336
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
123123
.filter(_.isDistinct)
124124
.groupBy(_.aggregateFunction.children.toSet)
125125

126-
val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) {
127-
// When the flag is set to specialize single distinct agg planning,
128-
// we will rely on our Aggregation strategy to handle queries with a single
129-
// distinct column.
130-
distinctAggGroups.size > 1
131-
} else {
132-
distinctAggGroups.size >= 1
133-
}
134-
if (shouldRewrite) {
126+
// Aggregation strategy can handle the query with single distinct
127+
if (distinctAggGroups.size > 1) {
135128
// Create the attributes for the grouping id and the group by clause.
136129
val gid = new AttributeReference("gid", IntegerType, false)()
137130
val groupByMap = a.groupingExpressions.collect {

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -449,18 +449,6 @@ private[spark] object SQLConf {
449449
doc = "When true, we could use `datasource`.`path` as table in SQL query"
450450
)
451451

452-
val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING =
453-
booleanConf("spark.sql.specializeSingleDistinctAggPlanning",
454-
defaultValue = Some(false),
455-
isPublic = false,
456-
doc = "When true, if a query only has a single distinct column and it has " +
457-
"grouping expressions, we will use our planner rule to handle this distinct " +
458-
"column (other cases are handled by DistinctAggregationRewriter). " +
459-
"When false, we will always use DistinctAggregationRewriter to plan " +
460-
"aggregation queries with DISTINCT keyword. This is an internal flag that is " +
461-
"used to benchmark the performance impact of using DistinctAggregationRewriter to " +
462-
"plan aggregation queries with a single distinct column.")
463-
464452
object Deprecated {
465453
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
466454
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
@@ -579,9 +567,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
579567

580568
private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)
581569

582-
protected[spark] override def specializeSingleDistinctAggPlanning: Boolean =
583-
getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING)
584-
585570
/** ********************** SQLConf functionality methods ************ */
586571

587572
/** Set Spark SQL configuration properties. */

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala

Lines changed: 132 additions & 285 deletions
Large diffs are not rendered by default.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
2929
case class SortBasedAggregate(
3030
requiredChildDistributionExpressions: Option[Seq[Expression]],
3131
groupingExpressions: Seq[NamedExpression],
32-
nonCompleteAggregateExpressions: Seq[AggregateExpression],
33-
nonCompleteAggregateAttributes: Seq[Attribute],
34-
completeAggregateExpressions: Seq[AggregateExpression],
35-
completeAggregateAttributes: Seq[Attribute],
32+
aggregateExpressions: Seq[AggregateExpression],
33+
aggregateAttributes: Seq[Attribute],
3634
initialInputBufferOffset: Int,
3735
resultExpressions: Seq[NamedExpression],
3836
child: SparkPlan)
@@ -42,10 +40,8 @@ case class SortBasedAggregate(
4240
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
4341
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
4442

45-
override def outputsUnsafeRows: Boolean = false
46-
43+
override def outputsUnsafeRows: Boolean = true
4744
override def canProcessUnsafeRows: Boolean = false
48-
4945
override def canProcessSafeRows: Boolean = true
5046

5147
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -76,31 +72,24 @@ case class SortBasedAggregate(
7672
if (!hasInput && groupingExpressions.nonEmpty) {
7773
// This is a grouped aggregate and the input iterator is empty,
7874
// so return an empty iterator.
79-
Iterator[InternalRow]()
75+
Iterator[UnsafeRow]()
8076
} else {
81-
val groupingKeyProjection =
82-
UnsafeProjection.create(groupingExpressions, child.output)
83-
8477
val outputIter = new SortBasedAggregationIterator(
85-
groupingKeyProjection,
86-
groupingExpressions.map(_.toAttribute),
78+
groupingExpressions,
8779
child.output,
8880
iter,
89-
nonCompleteAggregateExpressions,
90-
nonCompleteAggregateAttributes,
91-
completeAggregateExpressions,
92-
completeAggregateAttributes,
81+
aggregateExpressions,
82+
aggregateAttributes,
9383
initialInputBufferOffset,
9484
resultExpressions,
9585
newMutableProjection,
96-
outputsUnsafeRows,
9786
numInputRows,
9887
numOutputRows)
9988
if (!hasInput && groupingExpressions.isEmpty) {
10089
// There is no input and there is no grouping expressions.
10190
// We need to output a single row as the output.
10291
numOutputRows += 1
103-
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
92+
Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
10493
} else {
10594
outputIter
10695
}
@@ -109,7 +98,7 @@ case class SortBasedAggregate(
10998
}
11099

111100
override def simpleString: String = {
112-
val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
101+
val allAggregateExpressions = aggregateExpressions
113102

114103
val keyString = groupingExpressions.mkString("[", ",", "]")
115104
val functionString = allAggregateExpressions.mkString("[", ",", "]")

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,37 +24,34 @@ import org.apache.spark.sql.execution.metric.LongSQLMetric
2424

2525
/**
2626
* An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
27-
* sorted by values of [[groupingKeyAttributes]].
27+
* sorted by values of [[groupingExpressions]].
2828
*/
2929
class SortBasedAggregationIterator(
30-
groupingKeyProjection: InternalRow => InternalRow,
31-
groupingKeyAttributes: Seq[Attribute],
30+
groupingExpressions: Seq[NamedExpression],
3231
valueAttributes: Seq[Attribute],
3332
inputIterator: Iterator[InternalRow],
34-
nonCompleteAggregateExpressions: Seq[AggregateExpression],
35-
nonCompleteAggregateAttributes: Seq[Attribute],
36-
completeAggregateExpressions: Seq[AggregateExpression],
37-
completeAggregateAttributes: Seq[Attribute],
33+
aggregateExpressions: Seq[AggregateExpression],
34+
aggregateAttributes: Seq[Attribute],
3835
initialInputBufferOffset: Int,
3936
resultExpressions: Seq[NamedExpression],
4037
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
41-
outputsUnsafeRows: Boolean,
4238
numInputRows: LongSQLMetric,
4339
numOutputRows: LongSQLMetric)
4440
extends AggregationIterator(
45-
groupingKeyAttributes,
41+
groupingExpressions,
4642
valueAttributes,
47-
nonCompleteAggregateExpressions,
48-
nonCompleteAggregateAttributes,
49-
completeAggregateExpressions,
50-
completeAggregateAttributes,
43+
aggregateExpressions,
44+
aggregateAttributes,
5145
initialInputBufferOffset,
5246
resultExpressions,
53-
newMutableProjection,
54-
outputsUnsafeRows) {
55-
56-
override protected def newBuffer: MutableRow = {
57-
val bufferSchema = allAggregateFunctions.flatMap(_.aggBufferAttributes)
47+
newMutableProjection) {
48+
49+
/**
50+
* Creates a new aggregation buffer and initializes buffer values
51+
* for all aggregate functions.
52+
*/
53+
private def newBuffer: MutableRow = {
54+
val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
5855
val bufferRowSize: Int = bufferSchema.length
5956

6057
val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
@@ -76,10 +73,10 @@ class SortBasedAggregationIterator(
7673
///////////////////////////////////////////////////////////////////////////
7774

7875
// The partition key of the current partition.
79-
private[this] var currentGroupingKey: InternalRow = _
76+
private[this] var currentGroupingKey: UnsafeRow = _
8077

8178
// The partition key of next partition.
82-
private[this] var nextGroupingKey: InternalRow = _
79+
private[this] var nextGroupingKey: UnsafeRow = _
8380

8481
// The first row of next partition.
8582
private[this] var firstRowInNextGroup: InternalRow = _
@@ -94,7 +91,7 @@ class SortBasedAggregationIterator(
9491
if (inputIterator.hasNext) {
9592
initializeBuffer(sortBasedAggregationBuffer)
9693
val inputRow = inputIterator.next()
97-
nextGroupingKey = groupingKeyProjection(inputRow).copy()
94+
nextGroupingKey = groupingProjection(inputRow).copy()
9895
firstRowInNextGroup = inputRow.copy()
9996
numInputRows += 1
10097
sortedInputHasNewGroup = true
@@ -120,7 +117,7 @@ class SortBasedAggregationIterator(
120117
while (!findNextPartition && inputIterator.hasNext) {
121118
// Get the grouping key.
122119
val currentRow = inputIterator.next()
123-
val groupingKey = groupingKeyProjection(currentRow)
120+
val groupingKey = groupingProjection(currentRow)
124121
numInputRows += 1
125122

126123
// Check if the current row belongs the current input row.
@@ -146,7 +143,7 @@ class SortBasedAggregationIterator(
146143

147144
override final def hasNext: Boolean = sortedInputHasNewGroup
148145

149-
override final def next(): InternalRow = {
146+
override final def next(): UnsafeRow = {
150147
if (hasNext) {
151148
// Process the current group.
152149
processCurrentSortedGroup()
@@ -162,8 +159,8 @@ class SortBasedAggregationIterator(
162159
}
163160
}
164161

165-
def outputForEmptyGroupingKeyWithoutInput(): InternalRow = {
162+
def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
166163
initializeBuffer(sortBasedAggregationBuffer)
167-
generateOutput(new GenericInternalRow(0), sortBasedAggregationBuffer)
164+
generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer)
168165
}
169166
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,18 @@ import org.apache.spark.sql.types.StructType
3030
case class TungstenAggregate(
3131
requiredChildDistributionExpressions: Option[Seq[Expression]],
3232
groupingExpressions: Seq[NamedExpression],
33-
nonCompleteAggregateExpressions: Seq[AggregateExpression],
34-
nonCompleteAggregateAttributes: Seq[Attribute],
35-
completeAggregateExpressions: Seq[AggregateExpression],
36-
completeAggregateAttributes: Seq[Attribute],
33+
aggregateExpressions: Seq[AggregateExpression],
34+
aggregateAttributes: Seq[Attribute],
3735
initialInputBufferOffset: Int,
3836
resultExpressions: Seq[NamedExpression],
3937
child: SparkPlan)
4038
extends UnaryNode {
4139

4240
private[this] val aggregateBufferAttributes = {
43-
(nonCompleteAggregateExpressions ++ completeAggregateExpressions)
44-
.flatMap(_.aggregateFunction.aggBufferAttributes)
41+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
4542
}
4643

47-
require(TungstenAggregate.supportsAggregate(groupingExpressions, aggregateBufferAttributes))
44+
require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes))
4845

4946
override private[sql] lazy val metrics = Map(
5047
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
@@ -53,9 +50,7 @@ case class TungstenAggregate(
5350
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
5451

5552
override def outputsUnsafeRows: Boolean = true
56-
5753
override def canProcessUnsafeRows: Boolean = true
58-
5954
override def canProcessSafeRows: Boolean = true
6055

6156
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
@@ -94,10 +89,8 @@ case class TungstenAggregate(
9489
val aggregationIterator =
9590
new TungstenAggregationIterator(
9691
groupingExpressions,
97-
nonCompleteAggregateExpressions,
98-
nonCompleteAggregateAttributes,
99-
completeAggregateExpressions,
100-
completeAggregateAttributes,
92+
aggregateExpressions,
93+
aggregateAttributes,
10194
initialInputBufferOffset,
10295
resultExpressions,
10396
newMutableProjection,
@@ -119,7 +112,7 @@ case class TungstenAggregate(
119112
}
120113

121114
override def simpleString: String = {
122-
val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
115+
val allAggregateExpressions = aggregateExpressions
123116

124117
testFallbackStartsAt match {
125118
case None =>
@@ -135,9 +128,7 @@ case class TungstenAggregate(
135128
}
136129

137130
object TungstenAggregate {
138-
def supportsAggregate(
139-
groupingExpressions: Seq[Expression],
140-
aggregateBufferAttributes: Seq[Attribute]): Boolean = {
131+
def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
141132
val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
142133
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
143134
}

0 commit comments

Comments
 (0)