Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
CombineUnions,
// Constant folding and strength reduction
NullPropagation(conf),
ConstantPropagation,
FoldablePropagation,
OptimizeIn(conf),
ConstantFolding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,62 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
}

/**
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
* value in conjunctive [[Expression Expressions]]
* eg.
* {{{
* SELECT * FROM table WHERE i = 5 AND j = i + 3
* ==> SELECT * FROM table WHERE i = 5 AND j = 8
* }}}
*
* Approach used:
* - Start from AND operator as the root
* - Get all the children conjunctive predicates which are EqualTo / EqualNullSafe such that they
* don't have a `NOT` or `OR` operator in them
* - Populate a mapping of attribute => constant value by looking at all the equals predicates
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
* in the AND node.
*/
object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
private def containsNonConjunctionPredicates(expression: Expression): Boolean = expression.find {
case _: Not | _: Or => true
case _ => false
}.isDefined

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case f: Filter => f transformExpressionsUp {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was initially doing this for the entire logical plan but now switched to do only for filter operator.
Reason: Doing this for the entire logical plan will mess up with JOIN predicates. eg.

SELECT * FROM a JOIN b ON a.i = 1 AND b.i = a.i
=>
 SELECT * FROM a JOIN b ON a.i = 1 AND b.i = 1

.. the result is a cartesian product and Spark fails (asking to set a config). In case of OUTER JOINs, changing the join predicates might cause regression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am being myopic here but the result should be the same right? The only way this regresses is when we plan a CartesianProduct instead of an BroadcastNestedLoopJoin... I am fine with not optimizing this for now, it would be nice if these constraints are at least generated here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the result should be the same. I don't have any theoretical proof if doing this over joins will be safe so want to be cautious here ... any bad rules might lead to correctness bugs which is super bad for end users.

it would be nice if these constraints are at least generated here

Sorry I am not able to get you here and want to make sure if I am not ignoring your comment. Are you suggesting any changes over the existing version ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently infer is not null constraints up and down the plan. This could be easily extended to other constraints. Your PR has some overlap with this. However, lets focus on getting this merged first, and then we might take a stab at extending this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc @sameeragarwal

case and: And =>
val conjunctivePredicates =
splitConjunctivePredicates(and)
.filter(expr => expr.isInstanceOf[EqualTo] || expr.isInstanceOf[EqualNullSafe])
.filterNot(expr => containsNonConjunctionPredicates(expr))

val equalityPredicates = conjunctivePredicates.collect {
case e @ EqualTo(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualTo(left: Literal, right: AttributeReference) => ((right, left), e)
case e @ EqualNullSafe(left: AttributeReference, right: Literal) => ((left, right), e)
case e @ EqualNullSafe(left: Literal, right: AttributeReference) => ((right, left), e)
}

val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it's safe when we have both a = 1 and a = 2 at the same time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current impl will pick the last one (ie. a = 2) and propagate it. Given that its one of the equality predicates user provided, there is nothing wrong in propagating it. When the query is evaluated, it would return empty result given that a = 1 and a = 2 cannot be true at the same time.

scala> hc.sql(" SELECT * FROM table1 a WHERE a.j = 1 AND a.j = 2 AND a.k = (a.j + 3)").explain(true)

== Physical Plan ==
*Project [i#51, j#52, k#53]
+- *Filter ((((isnotnull(k#53) && isnotnull(j#52)) && (j#52 = 1)) && (j#52 = 2)) && (cast(k#53 as int) = 5))
   +- *FileScan orc default.table1[i#51,j#52,k#53] Batched: false, Format: ORC, Location: InMemoryFileIndex[file:/Users/tejasp/warehouse/table1], PartitionFilters: [], PushedFilters: [IsNotNull(k), IsNotNull(j), EqualTo(j,1), EqualTo(j,2)], ReadSchema: struct<i:int,j:int,k:string>


def replaceConstants(expression: Expression) = expression transform {
case a: AttributeReference =>
constantsMap.get(a) match {
case Some(literal) => literal
case None => a
}
}

and transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants(e)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for identity instead of equality? I think you are doing the latter. What will happen in the following example: select * from bla where (a = 1 or b = 2) and a = 1

Copy link
Contributor Author

@tejasapatil tejasapatil May 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the behavior with this PR. Seems reasonable because a = 1 has to be true so (a = 1 or b = 2) would always be true and can be eliminated.

scala> hc.sql(" select * from bla where (a = 1 or b = 2) and a = 1 ").explain(true)

== Physical Plan ==
*Project [a#34, b#35]
+- *Filter (isnotnull(a#34) && (a#34 = 1))
   +- *FileScan ....

case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants(e)
}
}
}
}

/**
* Reorder associative integral-type operators and fold all constants into one.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

/**
* Unit tests for constant propagation in expressions.
*/
class ConstantPropagationSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
EliminateSubqueryAliases) ::
Batch("ConstantPropagation", FixedPoint(10),
ColumnPruning,
ConstantPropagation,
ConstantFolding,
BooleanSimplification) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

private val columnA = 'a.int
private val columnB = 'b.int
private val columnC = 'c.int

test("basic test") {
val query = testRelation
.select(columnA)
.where(columnA === Add(columnB, Literal(1)) && columnB === Literal(10))

val correctAnswer =
testRelation
.select(columnA)
.where(columnA === Literal(11) && columnB === Literal(10)).analyze

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a negative test case like SELECT * FROM t WHERE a=1 and a=2 and b=a+3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


test("with combination of AND and OR predicates") {
val query = testRelation
.select(columnA)
.where(
columnA === Add(columnB, Literal(1)) &&
columnB === Literal(10) &&
(columnA === Add(columnC, Literal(3)) || columnB === columnC))
.analyze

val correctAnswer =
testRelation
.select(columnA)
.where(
columnA === Literal(11) &&
columnB === Literal(10) &&
(Literal(11) === Add(columnC, Literal(3)) || Literal(10) === columnC))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("equality predicates outside a `NOT` can be propagated within a `NOT`") {
val query = testRelation
.select(columnA)
.where(Not(columnA === Add(columnB, Literal(1))) && columnB === Literal(10))
.analyze

val correctAnswer =
testRelation
.select(columnA)
.where(Not(columnA === Literal(11)) && columnB === Literal(10))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("equality predicates inside a `NOT` should not be picked for propagation") {
val query = testRelation
.select(columnA)
.where(Not(columnB === Literal(10)) && columnA === Add(columnB, Literal(1)))
.analyze

comparePlans(Optimize.execute(query), query)
}

test("equality predicates outside a `OR` can be propagated within a `OR`") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(2) &&
(columnA === Add(columnB, Literal(3)) || columnB === Literal(9)))
.analyze

val correctAnswer = testRelation
.select(columnA)
.where(
columnA === Literal(2) &&
(Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9)))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("equality predicates inside a `OR` should not be picked for propagation") {
val query = testRelation
.select(columnA)
.where(
columnA === Add(columnB, Literal(2)) &&
(columnA === Add(columnB, Literal(3)) || columnB === Literal(9)))
.analyze

comparePlans(Optimize.execute(query), query)
}

test("equality operator not immediate child of root `AND` should not be used for propagation") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(0) &&
((columnB === columnA) === (columnB === Literal(0))))
.analyze

val correctAnswer = testRelation
.select(columnA)
.where(
columnA === Literal(0) &&
((columnB === Literal(0)) === (columnB === Literal(0))))
.analyze

comparePlans(Optimize.execute(query), correctAnswer)
}

test("conflicting equality predicates") {
val query = testRelation
.select(columnA)
.where(
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))

val correctAnswer = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5))

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
checkDataFilters(Set.empty)

// Only one file should be read.
checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions =>
checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 2")) { partitions =>
assert(partitions.size == 1, "when checking partitions")
assert(partitions.head.files.size == 1, "when checking files in partition 1")
assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
Expand All @@ -217,7 +217,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
checkDataFilters(Set.empty)

// Only one file should be read.
checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions =>
checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 2")) { partitions =>
assert(partitions.size == 1, "when checking partitions")
assert(partitions.head.files.size == 1, "when checking files in partition 1")
assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
Expand All @@ -235,13 +235,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
"p1=1/file1" -> 10,
"p1=2/file2" -> 10))

val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
val df1 = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
// Filter on data only are advisory so we have to reevaluate.
assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1"))
// Need to evalaute filters that are not pushed down.
assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2"))
assert(getPhysicalFilters(df1) contains resolve(df1, "c1 = 1"))
// Don't reevaluate partition only filters.
assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1")))
assert(!(getPhysicalFilters(df1) contains resolve(df1, "p1 = 1")))

val df2 = table.where("(p1 + c2) = 2 AND c1 = 1")
// Filter on data only are advisory so we have to reevaluate.
assert(getPhysicalFilters(df2) contains resolve(df2, "c1 = 1"))
// Need to evalaute filters that are not pushed down.
assert(getPhysicalFilters(df2) contains resolve(df2, "(p1 + c2) = 2"))
}

test("bucketed table") {
Expand Down