Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,25 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>

trait ConstraintHelper {

/**
* Infers an additional set of constraints from a given set of constraints.
*/
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferred = inferEqualityConstraints(constraints)
var lastInequalityInferred = Set.empty[Expression]
do {
lastInequalityInferred = inferInequalityConstraints(constraints ++ inferred)
inferred ++= lastInequalityInferred
} while (lastInequalityInferred.nonEmpty)
Comment on lines +64 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you hit a infinite loop with non deterministic filters? As they are never semantically equal to any other expression (including themselves). I hit that problem in #29650, where I was also working on constraint inference , but from EqualNullSafe.

inferred
}

/**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
def inferEqualityConstraints(constraints: Set[Expression]): Set[Expression] = {
var inferredConstraints = Set.empty[Expression]
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
Expand All @@ -78,6 +91,72 @@ trait ConstraintHelper {
inferredConstraints -- constraints
}

/**
* Infers an additional set of constraints from a given set of inequality constraints.
* For e.g., if an operator has constraints of the form (`a > b`, `b > 5`), this returns an
* additional constraint of the form `a > 5`.
*/
def inferInequalityConstraints(constraints: Set[Expression]): Set[Expression] = {
val binaryComparisons = constraints.filter {
case _: GreaterThan => true
case _: GreaterThanOrEqual => true
case _: LessThan => true
case _: LessThanOrEqual => true
case _: EqualTo => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EqualTo should not be needed here, as the inferEqualityConstraints should cover all cases including it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inferEqualityConstraints can not handle all cases, such as constraint with cast.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example: cast(a as double) > cast(b as double) and cast(b as double) = 1

case _ => false
}

val greaterThans = binaryComparisons.map {
case EqualTo(l, r) if l.foldable => EqualTo(r, l)
case LessThan(l, r) => GreaterThan(r, l)
case LessThanOrEqual(l, r) => GreaterThanOrEqual(r, l)
case other => other
}

val lessThans = binaryComparisons.map {
case EqualTo(l, r) if l.foldable => EqualTo(r, l)
case GreaterThan(l, r) => LessThan(r, l)
case GreaterThanOrEqual(l, r) => LessThanOrEqual(r, l)
case other => other
}
Comment on lines +116 to +121
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this duplicate the greaterThans block?
Here you have a < b < c and in the other block you have c > b > a

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. for example:
a > b and 5 > a. we can not infer anything. but we can infer that b < 5 after rewriting a > b and 5 > a as b < a and a < 5.

Copy link
Contributor

@tanelk tanelk Sep 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because of the foldable check? Without it, it should be inferable.


var inferredConstraints = Set.empty[Expression]
greaterThans.foreach {
case op @ BinaryComparison(source: Attribute, destination: Expression)
if destination.foldable =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the foldability is not needed here. The new constraints do not have to only involve constants, but also any attribute.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid generating too many constraints. For example: a > b > c > 1. The expected inferred constraints are: a > 1 and b > 1. a > c is useless.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a and c are in tihe same side of a join, then it can be pushed down.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How to push down a > c if both a and c are not foldable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I used a wrong word. I meant pushed through the join into one of the sides.

inferredConstraints ++= (greaterThans - op).map {
case GreaterThan(l, r) if r.semanticEquals(source) =>
GreaterThan(l, destination)
case GreaterThanOrEqual(l, r)
if r.semanticEquals(source) && op.isInstanceOf[GreaterThan] =>
GreaterThan(l, destination)
case GreaterThanOrEqual(l, r) if r.semanticEquals(source) =>
GreaterThanOrEqual(l, destination)
case other => other
}
case _ => // No inference
}

lessThans.foreach {
case op @ BinaryComparison(source: Attribute, destination: Expression)
if destination.foldable =>
inferredConstraints ++= (lessThans - op).map {
case LessThan(l, r) if r.semanticEquals(source) =>
LessThan(l, destination)
case LessThanOrEqual(l, r)
if r.semanticEquals(source) && op.isInstanceOf[LessThan] =>
LessThan(l, destination)
case LessThanOrEqual(l, r) if r.semanticEquals(source) =>
LessThanOrEqual(l, destination)
case other => other
}
case _ => // No inference
}

(inferredConstraints -- constraints -- greaterThans -- lessThans)
.filterNot(i => constraints.exists(_.semanticEquals(i)))
}

private def replaceConstraints(
constraints: Set[Expression],
source: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,126 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
condition)
}
}

test("Constraints inferred from inequality constraints: basic") {
Seq(('a < 'b && 'b < 3, 'a < 'b && 'b < 3 && 'a < 3), // a < b && b < 3 => a < 3
('a < 'b && 'b <= 3, 'a < 'b && 'b <= 3 && 'a < 3), // a < b && b <= 3 => a < 3
('a < 'b && 'b === 3, 'a < 'b && 'b === 3 && 'a < 3), // a < b && b = 3 => a < 3
('a <= 'b && 'b < 3, 'a <= 'b && 'b < 3 && 'a < 3), // a <= b && b < 3 => a < 3
('a <= 'b && 'b <= 3, 'a <= 'b && 'b <= 3 && 'a <= 3), // a <= b && b <= 3 => a <= 3
('a <= 'b && 'b === 3, 'a <= 'b && 'b === 3 && 'a <= 3), // a <= b && b = 3 => a <= 3
('a > 'b && 'b > 3, 'a > 'b && 'b > 3 && 'a > 3), // a > b && b > 3 => a > 3
('a > 'b && 'b >= 3, 'a > 'b && 'b >= 3 && 'a > 3), // a > b && b >= 3 => a > 3
('a > 'b && 'b === 3, 'a > 'b && 'b === 3 && 'a > 3), // a > b && b = 3 => a > 3
('a >= 'b && 'b > 3, 'a >= 'b && 'b > 3 && 'a > 3), // a >= b && b > 3 => a > 3
('a >= 'b && 'b >= 3, 'a >= 'b && 'b >= 3 && 'a >= 3), // a >= b && b >= 3 => a >= 3
('a >= 'b && 'b === 3, 'a >= 'b && 'b === 3 && 'a >= 3) // a >= b && b = 3 => a >= 3
).foreach {
case (filter, inferred) =>
val original = testRelation.where(filter)
val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred)
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
}
}

test("Constraints inferred from inequality constraints: join") {
Seq(("left.b".attr < "right.b".attr, 'b < 1, 'b < 1),
("left.b".attr < "right.b".attr, 'b === 1, 'b < 1),
("left.b".attr < "right.b".attr, 'b <= 1, 'b < 1),
("left.b".attr <= "right.b".attr, 'b <= 1, 'b <= 1),
("left.b".attr <= "right.b".attr, 'b === 1, 'b <= 1),
("left.b".attr > "right.b".attr, 'b > 1, 'b > 1),
("left.b".attr > "right.b".attr, 'b === 1, 'b > 1),
("left.b".attr > "right.b".attr, 'b >= 1, 'b > 1),
("left.b".attr >= "right.b".attr, 'b >= 1, 'b >= 1),
("left.b".attr >= "right.b".attr, 'b === 1, 'b >= 1)
).foreach {
case (cond, filter, inferred) =>
val originalLeft = testRelation.subquery('left)
val originalRight = testRelation.where(filter).subquery('right)

val left = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred).subquery('left)
val right = testRelation.where(IsNotNull('a) && IsNotNull('b) && filter).subquery('right)
val condition = Some("left.a".attr === "right.a".attr && cond)
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
}
}

test("Constraints inferred from inequality constraints with cast") {
Seq(('a < 'b && 'b < 3L, 'a.cast(LongType) < 'b && 'b < 3L && 'a.cast(LongType) < 3L),
('a < 'b && 'b <= 3L, 'a.cast(LongType) < 'b && 'b <= 3L && 'a.cast(LongType) < 3L),
('a < 'b && 'b === 3L, 'a.cast(LongType) < 'b && 'b === 3L && 'a.cast(LongType) < 3L),
('a <= 'b && 'b < 3L, 'a.cast(LongType) <= 'b && 'b < 3L && 'a.cast(LongType) < 3L),
('a <= 'b && 'b <= 3L, 'a.cast(LongType) <= 'b && 'b <= 3L && 'a.cast(LongType) <= 3L),
('a <= 'b && 'b === 3L, 'a.cast(LongType) <= 'b && 'b === 3L && 'a.cast(LongType) <= 3L),
('a < 'b && 'b < 3, 'a.cast(LongType) < 'b && 'b < Literal(3).cast(LongType)
&& 'a.cast(LongType) < Literal(3).cast(LongType)),
('a > 'b && 'b > 3L, 'a.cast(LongType) > 'b && 'b > 3L && 'a.cast(LongType) > 3L),
('a > 'b && 'b >= 3L, 'a.cast(LongType) > 'b && 'b >= 3L && 'a.cast(LongType) > 3L),
('a > 'b && 'b === 3L, 'a.cast(LongType) > 'b && 'b === 3L && 'a.cast(LongType) > 3L),
('a >= 'b && 'b > 3L, 'a.cast(LongType) >= 'b && 'b > 3L && 'a.cast(LongType) > 3L),
('a >= 'b && 'b >= 3L, 'a.cast(LongType) >= 'b && 'b >= 3L && 'a.cast(LongType) >= 3L),
('a >= 'b && 'b === 3L, 'a.cast(LongType) >= 'b && 'b === 3L && 'a.cast(LongType) >= 3L),
('a > 'b && 'b > 3, 'a.cast(LongType) > 'b && 'b > Literal(3).cast(LongType)
&& 'a.cast(LongType) > Literal(3).cast(LongType))
).foreach {
case (filter, inferred) =>
val testRelation = LocalRelation('a.int, 'b.long)
val original = testRelation.where(filter)
val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && inferred)
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
}
}

test("Constraints inferred from inequality attributes: case1") {
val condition = Some("x.a".attr > "y.a".attr)
val optimizedLeft = testRelation.where(IsNotNull('a) && 'a === 1).as("x")
val optimizedRight = testRelation.where('a < 1 && IsNotNull('a) ).as("y")
val correct = optimizedLeft.join(optimizedRight, Inner, condition)

Seq(Literal(1) === 'a, 'a === Literal(1)).foreach { filter =>
val original = testRelation.where(filter).as("x").join(testRelation.as("y"), Inner, condition)
comparePlans(Optimize.execute(original.analyze), correct.analyze)
}
}

test("Constraints inferred from inequality attributes: case2") {
val original = testRelation.where('a < 'b && 'b < 'c && 'c < 5)
val optimized = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c)
&& 'a < 'b && 'b < 'c && 'a < 5 && 'b < 5 && 'c < 5)
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
}

test("Constraints inferred from inequality attributes: case3") {
val left = testRelation.where('b >= 3 && 'b <= 13).as("x")
val right = testRelation.as("y")

val optimizedLeft = testRelation.where(IsNotNull('a) && IsNotNull('b)
&& 'b >= 3 && 'b <= 13).as("x")
val optimizedRight = testRelation.where(IsNotNull('a) && IsNotNull('b) && IsNotNull('c)
&& 'c > 3 && 'b <= 13).as("y")
val condition = Some("x.a".attr === "y.a".attr
&& "x.b".attr >= "y.b".attr && "x.b".attr < "y.c".attr)
val original = left.join(right, Inner, condition)
val optimized = optimizedLeft.join(optimizedRight, Inner, condition)
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
}

test("Constraints inferred from inequality attributes: case4") {
val testRelation1 = LocalRelation('a.long, 'b.long, 'c.long).as("x")
val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int).as("y")

// y.b < 13 inferred from y.b < x.b && x.b <= 13
val left = testRelation1.where('b <= 13L).as("x")
val right = testRelation2.as("y")

val optimizedLeft = testRelation1.where(IsNotNull('a) && IsNotNull('b) && 'b <= 13L).as("x")
val optimizedRight = testRelation2.where(IsNotNull('a) && IsNotNull('b)
&& 'b.cast(LongType) < 13L).as("y")

val condition = Some("x.a".attr === "y.a".attr && "y.b".attr < "x.b".attr)
val original = left.join(right, Inner, condition)
val optimized = optimizedLeft.join(optimizedRight, Inner, condition)
comparePlans(Optimize.execute(original.analyze), optimized.analyze)
}
}