From b77a4d638d29d8d9c939a653d6e79e8edc5203fd Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sat, 1 Oct 2016 12:33:30 +0800 Subject: [PATCH 01/13] modify function inferAdditionalConstraints to avoid producing non-converging set of constraints --- .../spark/sql/catalyst/plans/QueryPlan.scala | 20 ++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 41 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0fb6e7d2e795a..fc3e3dbef0a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -74,20 +74,36 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * additional constraint of the form `b = 5` */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + // Collect alias from expressions to avoid producing non-converging set of constraints + // for recursive functions. + // For more details, infer https://issues.apache.org/jira/browse/SPARK-17733 + val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => (a.toAttribute, a.child) + }) + var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) => r + case a: Attribute if a.semanticEquals(l) && !isRecursiveDeduction(a, r, aliasMap) => r }) inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) => l + case a: Attribute if a.semanticEquals(r) && !isRecursiveDeduction(l, a, aliasMap) => l }) case _ => // No inference } inferredConstraints -- constraints } + private def isRecursiveDeduction( + left: Attribute, + right: Attribute, + aliasMap: AttributeMap[Expression]): Boolean = { + val leftExpression = aliasMap.getOrElse(left, left) + val rightExpression = aliasMap.getOrElse(right, right) + leftExpression.containsChild(rightExpression) || rightExpression.containsChild(leftExpression) + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0ee8c959eeb4d..43826ca18caa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2678,4 +2678,45 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") { + withTempView("tmpv") { + spark.range(10).toDF("a").createTempView("tmpv") + + // Just ensure the following query will successfully execute complete. + assert(sql( + """ + |SELECT + | * + |FROM ( + | SELECT + | COALESCE(t1.a, t2.a) AS int_col, + | t1.a, + | t2.a AS b + | FROM tmpv t1 + | CROSS JOIN tmpv t2 + |) t1 + |INNER JOIN tmpv t2 + |ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b)) + """.stripMargin).count() > 0 + ) + + //sql("CREATE TEMPORARY VIEW foo(a) AS VALUES (CAST(-993 AS BIGINT))") + + /*sql( + """ + |SELECT + |* + |FROM ( + | SELECT + | COALESCE(t1.a, t2.a) AS int_col, + | t1.a, + | t2.a AS b + | FROM foo t1 + | CROSS JOIN foo t2 + |) t1 + |INNER JOIN foo t2 ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b)) + """.stripMargin).collect()*/ + } + } } From ebba4468c6707e35656d122f99ec074f2785195e Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sat, 1 Oct 2016 13:04:52 +0800 Subject: [PATCH 02/13] remove commented code. --- .../org/apache/spark/sql/SQLQuerySuite.scala | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 43826ca18caa9..9e92ef2879715 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2700,23 +2700,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b)) """.stripMargin).count() > 0 ) - - //sql("CREATE TEMPORARY VIEW foo(a) AS VALUES (CAST(-993 AS BIGINT))") - - /*sql( - """ - |SELECT - |* - |FROM ( - | SELECT - | COALESCE(t1.a, t2.a) AS int_col, - | t1.a, - | t2.a AS b - | FROM foo t1 - | CROSS JOIN foo t2 - |) t1 - |INNER JOIN foo t2 ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b)) - """.stripMargin).collect()*/ } } } From 7d9e2b0580f274df68d3de4c867e1527924d1b24 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sat, 1 Oct 2016 16:29:25 +0800 Subject: [PATCH 03/13] add new testcase. --- .../InferFiltersFromConstraintsSuite.scala | 32 +++++++++++++++++-- .../spark/sql/catalyst/plans/PlanTest.scala | 25 +++++++++++++-- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e7fdd5a6202b6..6fc031085d907 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -27,9 +27,11 @@ import org.apache.spark.sql.catalyst.rules._ class InferFiltersFromConstraintsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) :: - Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) :: - Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil + val batches = Batch("InferFilters", FixedPoint(100), InferFiltersFromConstraints) :: + Batch("PredicatePushdown", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate) :: + Batch("CombineFilters", FixedPoint(100), CombineFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -120,4 +122,28 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("don't generate constraints for recursive functions") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + val t3 = testRelation.subquery('t3) + + val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1.where(IsNotNull('a) && 'a === Coalesce(Seq('a, 'b)) + && IsNotNull('b) && 'b === Coalesce(Seq('a, 'b)) + && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b) + .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2.where(IsNotNull('a)), Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6310f0c2bc0ed..64e268703bf5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ /** @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) * etc., will all now be equivalent. * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. */ private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And), child) case sample: Sample => sample.copy(seed = 0L)(true) + case join @ Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) } } + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizePlan(normalizeExprIds(plan1)) From 3b93209b7192461a1390e205f88d51bbee780ff0 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sat, 1 Oct 2016 16:32:02 +0800 Subject: [PATCH 04/13] remove unused code. --- .../catalyst/optimizer/InferFiltersFromConstraintsSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 6fc031085d907..70c81f8706e30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -126,7 +126,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("don't generate constraints for recursive functions") { val t1 = testRelation.subquery('t1) val t2 = testRelation.subquery('t2) - val t3 = testRelation.subquery('t3) val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2, Inner, From 5b25fce87bb1c33bee4437936ed49e008db6a427 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 2 Oct 2016 00:49:06 +0800 Subject: [PATCH 05/13] another approach and more testcases. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 18 +++--- .../InferFiltersFromConstraintsSuite.scala | 56 ++++++++++++++++--- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index fc3e3dbef0a68..d464281e0ceb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -76,6 +76,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { // Collect alias from expressions to avoid producing non-converging set of constraints // for recursive functions. + // + // Don't apply transform on constraints if the attribute used to replace is an alias, + // because then both `QueryPlan.inferAdditionalConstraints` and + // `UnaryNode.getAliasedConstraints` applies and may produce a non-converging set of + // constraints. // For more details, infer https://issues.apache.org/jira/browse/SPARK-17733 val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect { case a: Alias => (a.toAttribute, a.child) @@ -85,25 +90,16 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) && !isRecursiveDeduction(a, r, aliasMap) => r + case a: Attribute if a.semanticEquals(l) && !aliasMap.contains(r) => r }) inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) && !isRecursiveDeduction(l, a, aliasMap) => l + case a: Attribute if a.semanticEquals(r) && !aliasMap.contains(l) => l }) case _ => // No inference } inferredConstraints -- constraints } - private def isRecursiveDeduction( - left: Attribute, - right: Attribute, - aliasMap: AttributeMap[Expression]): Boolean = { - val leftExpression = aliasMap.getOrElse(left, left) - val rightExpression = aliasMap.getOrElse(right, right) - leftExpression.containsChild(rightExpression) || rightExpression.containsChild(leftExpression) - } - /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 70c81f8706e30..30111ad74da5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -27,11 +27,12 @@ import org.apache.spark.sql.catalyst.rules._ class InferFiltersFromConstraintsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("InferFilters", FixedPoint(100), InferFiltersFromConstraints) :: - Batch("PredicatePushdown", FixedPoint(100), + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), PushPredicateThroughJoin, - PushDownPredicate) :: - Batch("CombineFilters", FixedPoint(100), CombineFilters) :: Nil + PushDownPredicate, + InferFiltersFromConstraints, + CombineFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -123,7 +124,41 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("don't generate constraints for recursive functions") { + test("inner join with alias: alias contains multiple attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val currectAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) + &&'a === Coalesce(Seq('a, 'b))) + .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2.where(IsNotNull('a)), Inner, + Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, currectAnswer) + } + + test("inner join with alias: alias contains single attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val currectAnswer = t1.where(IsNotNull('a) && IsNotNull('b) + && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + .select('a, 'b.as('d)).as("t") + .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, + Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, currectAnswer) + } + + test("inner join with alias: don't generate constraints for recursive functions") { val t1 = testRelation.subquery('t1) val t2 = testRelation.subquery('t2) @@ -133,11 +168,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest { && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze - val correctAnswer = t1.where(IsNotNull('a) && 'a === Coalesce(Seq('a, 'b)) - && IsNotNull('b) && 'b === Coalesce(Seq('a, 'b)) - && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b) + val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) + && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") - .join(t2.where(IsNotNull('a)), Inner, + .join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) From e5912f86c94ff4d6303c9d0a9b80a30d30b99e3d Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 2 Oct 2016 01:30:21 +0800 Subject: [PATCH 06/13] simplify infered filters. --- .../optimizer/InferFiltersFromConstraintsSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 30111ad74da5e..942ba0c289c9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -32,7 +32,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushPredicateThroughJoin, PushDownPredicate, InferFiltersFromConstraints, - CombineFilters) :: Nil + CombineFilters, + SimplifyBinaryComparison, + BooleanSimplification) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -169,13 +171,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { && "t.int_col".attr === "t2.a".attr)) .analyze val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner, + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) From 9639c71862d1e7783bc3ca4d750d68e7aa35be92 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 2 Oct 2016 09:51:38 +0800 Subject: [PATCH 07/13] revert testcase change. --- .../optimizer/InferFiltersFromConstraintsSuite.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 942ba0c289c9a..30111ad74da5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -32,9 +32,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PushPredicateThroughJoin, PushDownPredicate, InferFiltersFromConstraints, - CombineFilters, - SimplifyBinaryComparison, - BooleanSimplification) :: Nil + CombineFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -171,13 +169,13 @@ class InferFiltersFromConstraintsSuite extends PlanTest { && "t.int_col".attr === "t2.a".attr)) .analyze val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))), Inner, + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) From 1558d4c2f9190691239e9b27e9517714c2af2bcc Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 9 Oct 2016 14:17:52 +0800 Subject: [PATCH 08/13] change alias collection structure. --- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d464281e0ceb2..64c9086f2b674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -81,19 +81,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // because then both `QueryPlan.inferAdditionalConstraints` and // `UnaryNode.getAliasedConstraints` applies and may produce a non-converging set of // constraints. - // For more details, infer https://issues.apache.org/jira/browse/SPARK-17733 - val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect { - case a: Alias => (a.toAttribute, a.child) - }) + // For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733 + val aliasSet = AttributeSet((expressions ++ children.flatMap(_.expressions)) + .filter(_.isInstanceOf[Alias])) var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) && !aliasMap.contains(r) => r + case a: Attribute if a.semanticEquals(l) && !aliasSet.contains(r) => r }) inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) && !aliasMap.contains(l) => l + case a: Attribute if a.semanticEquals(r) && !aliasSet.contains(l) => l }) case _ => // No inference } From 388443d2886d09fec6a25b8400c6eb9631373135 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Sun, 16 Oct 2016 00:54:25 +0800 Subject: [PATCH 09/13] bugfix --- .../spark/sql/catalyst/plans/QueryPlan.scala | 5 +++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 +++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64c9086f2b674..a97c47bfe0306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -82,8 +82,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // `UnaryNode.getAliasedConstraints` applies and may produce a non-converging set of // constraints. // For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733 - val aliasSet = AttributeSet((expressions ++ children.flatMap(_.expressions)) - .filter(_.isInstanceOf[Alias])) + val aliasSet = AttributeSet((expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => a.toAttribute + }) var inferredConstraints = Set.empty[Expression] constraints.foreach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e92ef2879715..845d743b00e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext -import java.sql.{Date, Timestamp} +import java.sql.Timestamp + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{AccumulatorSuite, SparkException} -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -2684,7 +2685,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { spark.range(10).toDF("a").createTempView("tmpv") // Just ensure the following query will successfully execute complete. - assert(sql( + val query = """ |SELECT | * @@ -2698,8 +2699,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |) t1 |INNER JOIN tmpv t2 |ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b)) - """.stripMargin).count() > 0 - ) + """.stripMargin + + eventually(timeout(60 seconds)) { + assert(sql(query).count() > 0) + } } } } From 52ef1d8a0bb39f6d5b1833c0c7d2815d69074b60 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 19 Oct 2016 20:08:47 +0800 Subject: [PATCH 10/13] update rule for recursive constraints. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 86 ++++++++++++++++--- .../InferFiltersFromConstraintsSuite.scala | 11 ++- 2 files changed, 84 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a97c47bfe0306..991d2c0b94b63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -77,29 +77,95 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT // Collect alias from expressions to avoid producing non-converging set of constraints // for recursive functions. // - // Don't apply transform on constraints if the attribute used to replace is an alias, - // because then both `QueryPlan.inferAdditionalConstraints` and - // `UnaryNode.getAliasedConstraints` applies and may produce a non-converging set of - // constraints. + // Don't apply transform on constraints if the replacement will cause an recursive deduction, + // when that happens a non-converging set of constraints will be created and finally throw + // OOM Exception. // For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733 - val aliasSet = AttributeSet((expressions ++ children.flatMap(_.expressions)).collect { - case a: Alias => a.toAttribute + val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => (a.toAttribute, a.child) }) + val equalExprSets = generateEqualExpressionSets(constraints, aliasMap) + var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) && !aliasSet.contains(r) => r + val candidateConstraints = constraints - eq + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(l) && + !isRecursiveDeduction(r, aliasMap, equalExprSets) => r }) - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) && !aliasSet.contains(l) => l + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(r) && + !isRecursiveDeduction(l, aliasMap, equalExprSets) => l }) case _ => // No inference } inferredConstraints -- constraints } + /* + * Generate a sequence of expression sets from constraints, where each set stores an equivalence + * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following + * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal + * to an selected attribute. + */ + private def generateEqualExpressionSets( + constraints: Set[Expression], + aliasMap: AttributeMap[Expression]): Seq[Set[Expression]] = { + var equalExprSets = Seq.empty[Set[Expression]] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + // Transform [[Alias]] to its child. + val left = aliasMap.getOrElse(l, l) + val right = aliasMap.getOrElse(r, r) + // Get the expression set for equivalence class of expressions. + val leftEqualSet = getEqualExprSet(left, equalExprSets).getOrElse(Set.empty[Expression]) + val rightEqualSet = getEqualExprSet(right, equalExprSets).getOrElse(Set.empty[Expression]) + if (!leftEqualSet.isEmpty && !rightEqualSet.isEmpty) { + // Combine the two sets. + equalExprSets = equalExprSets.diff(leftEqualSet :: rightEqualSet :: Nil) :+ + (leftEqualSet ++ rightEqualSet) + } else if (!leftEqualSet.isEmpty) { // && rightEqualSet.isEmpty + // Update equivalence class of `left` expression. + equalExprSets = equalExprSets.diff(leftEqualSet :: Nil) :+ (leftEqualSet + right) + } else if (!rightEqualSet.isEmpty) { // && leftEqualSet.isEmpty + // Update equivalence class of `right` expression. + equalExprSets = equalExprSets.diff(rightEqualSet :: Nil) :+ (rightEqualSet + left) + } else { // leftEqualSet.isEmpty && rightEqualSet.isEmpty + // Create new equivalence class since both expression don't present in any classes. + equalExprSets = equalExprSets :+ Set(left, right) + } + case _ => // Skip + } + + equalExprSets + } + + /* + * Get all expressions equivalent to the selected expression. + */ + private def getEqualExprSet( + expr: Expression, + equalExprSets: Seq[Set[Expression]]): Option[Set[Expression]] = + equalExprSets.filter(_.contains(expr)).headOption + + /* + * Check whether replace by an [[Attribute]] will cause an recursive deduction. Generally it + * has an form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is an function. + * Here we first get all expressions equal to `attr` and then check whether at least one of them + * is child of the referenced expression. + */ + private def isRecursiveDeduction( + attr: Attribute, + aliasMap: AttributeMap[Expression], + equalExprSets: Seq[Set[Expression]]): Boolean = { + val expr = aliasMap.getOrElse(attr, attr) + getEqualExprSet(expr, equalExprSets).getOrElse(Set.empty[Expression]).exists { e => + expr.children.exists(_.semanticEquals(e)) + } + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 30111ad74da5e..3eca0b205d6b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -170,15 +170,20 @@ class InferFiltersFromConstraintsSuite extends PlanTest { .analyze val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) + && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) + && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner, + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr)) + && "t.int_col".attr === "t2.a".attr + && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr)) .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) From 909d2cde961bc61ff5d6a20b135803fb7eee3eb9 Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Wed, 19 Oct 2016 20:26:01 +0800 Subject: [PATCH 11/13] add testcase --- .../optimizer/InferFiltersFromConstraintsSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 3eca0b205d6b4..c1f5076fe3448 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -188,4 +188,15 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("generate correct filters for alias that don't produce recursive constraints") { + val t1 = testRelation.subquery('t1) + + val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze + val correctAnswer = + t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b)) + .select('a.as('x), 'b.as('y)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } From 905eaa12925fe8638df01a738b977e2f32a5fc9c Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 25 Oct 2016 15:07:55 +0800 Subject: [PATCH 12/13] improve structure and styling. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 84 +++++++++---------- .../InferFiltersFromConstraintsSuite.scala | 36 ++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 27 ------ 3 files changed, 61 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 991d2c0b94b63..45ee2964d4db0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -68,24 +68,22 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case _ => Seq.empty[Attribute] } + // Collect aliases from expressions, so we may avoid producing recursive constraints. + private lazy val aliasMap = AttributeMap( + (expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => (a.toAttribute, a.child) + }) + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5` + * additional constraint of the form `b = 5`. + * + * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` + * as they are often useless and can lead to a non-converging set of constraints. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { - // Collect alias from expressions to avoid producing non-converging set of constraints - // for recursive functions. - // - // Don't apply transform on constraints if the replacement will cause an recursive deduction, - // when that happens a non-converging set of constraints will be created and finally throw - // OOM Exception. - // For more details, refer to https://issues.apache.org/jira/browse/SPARK-17733 - val aliasMap = AttributeMap((expressions ++ children.flatMap(_.expressions)).collect { - case a: Alias => (a.toAttribute, a.child) - }) - - val equalExprSets = generateEqualExpressionSets(constraints, aliasMap) + val constraintClasses = generateEquivalentConstraintClasses(constraints) var inferredConstraints = Set.empty[Expression] constraints.foreach { @@ -93,11 +91,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT val candidateConstraints = constraints - eq inferredConstraints ++= candidateConstraints.map(_ transform { case a: Attribute if a.semanticEquals(l) && - !isRecursiveDeduction(r, aliasMap, equalExprSets) => r + !isRecursiveDeduction(r, constraintClasses) => r }) inferredConstraints ++= candidateConstraints.map(_ transform { case a: Attribute if a.semanticEquals(r) && - !isRecursiveDeduction(l, aliasMap, equalExprSets) => l + !isRecursiveDeduction(l, constraintClasses) => l }) case _ => // No inference } @@ -110,58 +108,60 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal * to an selected attribute. */ - private def generateEqualExpressionSets( - constraints: Set[Expression], - aliasMap: AttributeMap[Expression]): Seq[Set[Expression]] = { - var equalExprSets = Seq.empty[Set[Expression]] + private def generateEquivalentConstraintClasses( + constraints: Set[Expression]): Seq[Set[Expression]] = { + var constraintClasses = Seq.empty[Set[Expression]] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => // Transform [[Alias]] to its child. val left = aliasMap.getOrElse(l, l) val right = aliasMap.getOrElse(r, r) - // Get the expression set for equivalence class of expressions. - val leftEqualSet = getEqualExprSet(left, equalExprSets).getOrElse(Set.empty[Expression]) - val rightEqualSet = getEqualExprSet(right, equalExprSets).getOrElse(Set.empty[Expression]) - if (!leftEqualSet.isEmpty && !rightEqualSet.isEmpty) { + // Get the expression set for an equivalence constraint class. + val leftConstraintClass = getConstraintClass(left, constraintClasses) + val rightConstraintClass = getConstraintClass(right, constraintClasses) + if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { // Combine the two sets. - equalExprSets = equalExprSets.diff(leftEqualSet :: rightEqualSet :: Nil) :+ - (leftEqualSet ++ rightEqualSet) - } else if (!leftEqualSet.isEmpty) { // && rightEqualSet.isEmpty + constraintClasses = constraintClasses + .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ + (leftConstraintClass ++ rightConstraintClass) + } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty // Update equivalence class of `left` expression. - equalExprSets = equalExprSets.diff(leftEqualSet :: Nil) :+ (leftEqualSet + right) - } else if (!rightEqualSet.isEmpty) { // && leftEqualSet.isEmpty + constraintClasses = constraintClasses + .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) + } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty // Update equivalence class of `right` expression. - equalExprSets = equalExprSets.diff(rightEqualSet :: Nil) :+ (rightEqualSet + left) - } else { // leftEqualSet.isEmpty && rightEqualSet.isEmpty - // Create new equivalence class since both expression don't present in any classes. - equalExprSets = equalExprSets :+ Set(left, right) + constraintClasses = constraintClasses + .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) + } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty + // Create new equivalence constraint class since neither expression presents + // in any classes. + constraintClasses = constraintClasses :+ Set(left, right) } case _ => // Skip } - equalExprSets + constraintClasses } /* * Get all expressions equivalent to the selected expression. */ - private def getEqualExprSet( + private def getConstraintClass( expr: Expression, - equalExprSets: Seq[Set[Expression]]): Option[Set[Expression]] = - equalExprSets.filter(_.contains(expr)).headOption + constraintClasses: Seq[Set[Expression]]): Set[Expression] = + constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) /* - * Check whether replace by an [[Attribute]] will cause an recursive deduction. Generally it - * has an form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is an function. + * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it + * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. * Here we first get all expressions equal to `attr` and then check whether at least one of them - * is child of the referenced expression. + * is a child of the referenced expression. */ private def isRecursiveDeduction( attr: Attribute, - aliasMap: AttributeMap[Expression], - equalExprSets: Seq[Set[Expression]]): Boolean = { + constraintClasses: Seq[Set[Expression]]): Boolean = { val expr = aliasMap.getOrElse(attr, attr) - getEqualExprSet(expr, equalExprSets).getOrElse(Set.empty[Expression]).exists { e => + getConstraintClass(expr, constraintClasses).exists { e => expr.children.exists(_.semanticEquals(e)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index c1f5076fe3448..9f57f66a2ea20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -131,14 +131,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze - val currectAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) - &&'a === Coalesce(Seq('a, 'b))) + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))) .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2.where(IsNotNull('a)), Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, currectAnswer) + comparePlans(optimized, correctAnswer) } test("inner join with alias: alias contains single attributes") { @@ -148,14 +148,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val originalQuery = t1.select('a, 'b.as('d)).as("t") .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze - val currectAnswer = t1.where(IsNotNull('a) && IsNotNull('b) - && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b) .select('a, 'b.as('d)).as("t") .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) - comparePlans(optimized, currectAnswer) + comparePlans(optimized, correctAnswer) } test("inner join with alias: don't generate constraints for recursive functions") { @@ -168,18 +168,20 @@ class InferFiltersFromConstraintsSuite extends PlanTest { && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze - val correctAnswer = t1.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a - && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) - && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) - && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) - && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) + && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) + && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b + && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) + && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") - .join(t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a - && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, + .join(t2 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 845d743b00e9e..5fb62ab508443 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2679,31 +2679,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-17733 InferFiltersFromConstraints rule never terminates for query") { - withTempView("tmpv") { - spark.range(10).toDF("a").createTempView("tmpv") - - // Just ensure the following query will successfully execute complete. - val query = - """ - |SELECT - | * - |FROM ( - | SELECT - | COALESCE(t1.a, t2.a) AS int_col, - | t1.a, - | t2.a AS b - | FROM tmpv t1 - | CROSS JOIN tmpv t2 - |) t1 - |INNER JOIN tmpv t2 - |ON (((t2.a) = (t1.a)) AND ((t2.a) = (t1.int_col))) AND ((t2.a) = (t1.b)) - """.stripMargin - - eventually(timeout(60 seconds)) { - assert(sql(query).count() > 0) - } - } - } } From 45308d50de27fb4b6c2c9ee88e55ba808ef5ef6f Mon Sep 17 00:00:00 2001 From: jiangxingbo Date: Tue, 25 Oct 2016 15:14:47 +0800 Subject: [PATCH 13/13] remove unused import. --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5fb62ab508443..302bd6c80f3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,10 +21,6 @@ import java.io.File import java.math.MathContext import java.sql.Timestamp -import scala.concurrent.duration._ - -import org.scalatest.concurrent.Eventually._ - import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate