Skip to content

Commit 34c7001

Browse files
committed
Support filter at lower data type
1 parent 048a0ec commit 34c7001

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,46 @@ trait ConstraintHelper {
6868
val candidateConstraints = binaryComparisons - eq
6969
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
7070
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
71-
case eq @ EqualTo(l: Cast, r: Attribute) =>
72-
inferredConstraints ++= replaceConstraints(binaryComparisons - eq, r, l)
73-
case eq @ EqualTo(l: Attribute, r: Cast) =>
74-
inferredConstraints ++= replaceConstraints(binaryComparisons - eq, l, r)
71+
case eq @ EqualTo(l @ Cast(lc: Attribute, _, tz), r: Attribute) =>
72+
val candidateConstraints = binaryComparisons - eq
73+
val bridge = Cast(r, lc.dataType, tz)
74+
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
75+
inferredConstraints ++= replaceConstraints(candidateConstraints, lc, bridge)
76+
case eq @ EqualTo(l: Attribute, r @ Cast(rc: Attribute, _, tz)) =>
77+
val candidateConstraints = binaryComparisons - eq
78+
val bridge = Cast(l, rc.dataType, tz)
79+
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
80+
inferredConstraints ++= replaceConstraints(candidateConstraints, rc, bridge)
7581
case _ => // No inference
7682
}
7783
inferredConstraints -- constraints
7884
}
7985

86+
private def replaceConstraint(
87+
constraint: Expression,
88+
source: Expression,
89+
destination: Expression): Expression = constraint transform {
90+
case e: Expression if e.semanticEquals(source) => destination
91+
}
92+
8093
private def replaceConstraints(
8194
constraints: Set[Expression],
8295
source: Expression,
83-
destination: Expression): Set[Expression] = constraints.map(_ transform {
84-
case e: Expression if e.semanticEquals(source) => destination
85-
})
96+
dest: Expression): Set[Expression] = {
97+
constraints.map {
98+
case b @ BinaryComparison(left, right) =>
99+
(replaceConstraint(left, source, dest), replaceConstraint(right, source, dest)) match {
100+
case (Cast(Cast(child, _, _), dt, _), replacedRight)
101+
if dt == child.dataType && child.dataType == replacedRight.dataType =>
102+
b.makeCopy(Array(child, replacedRight))
103+
case (replacedLeft, Cast(Cast(child, _, _), dt, _))
104+
if dt == child.dataType && child.dataType == replacedLeft.dataType =>
105+
b.makeCopy(Array(replacedLeft, child))
106+
case (replacedLeft, replacedRight) =>
107+
b.makeCopy(Array(replacedLeft, replacedRight))
108+
}
109+
}
110+
}
86111

87112
/**
88113
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules._
2626
import org.apache.spark.sql.internal.SQLConf
27-
import org.apache.spark.sql.types.LongType
27+
import org.apache.spark.sql.types.{IntegerType, LongType}
2828

2929
class InferFiltersFromConstraintsSuite extends PlanTest {
3030

@@ -47,8 +47,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
4747
y: LogicalPlan,
4848
expectedLeft: LogicalPlan,
4949
expectedRight: LogicalPlan,
50-
joinType: JoinType) = {
51-
val condition = Some("x.a".attr === "y.a".attr)
50+
joinType: JoinType,
51+
condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = {
5252
val originalQuery = x.join(y, joinType, condition).analyze
5353
val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze
5454
val optimized = Optimize.execute(originalQuery)
@@ -265,7 +265,22 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
265265
testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter)
266266
}
267267

268-
test("Constraints should be inferred from cast equality constraint") {
268+
test("Constraints should be inferred from cast equality constraint(filter at lower data type)") {
269+
val testRelation1 = LocalRelation('a.int)
270+
val testRelation2 = LocalRelation('b.long)
271+
val originalLeft = testRelation1.where('a === 1).subquery('left)
272+
val originalRight = testRelation2.subquery('right)
273+
274+
val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left)
275+
val right = testRelation2.where(IsNotNull('b) && 'b.cast(IntegerType) === 1).subquery('right)
276+
277+
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
278+
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
279+
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
280+
}
281+
}
282+
283+
test("Constraints should be inferred from cast equality constraint(filter at higher data type)") {
269284
val testRelation1 = LocalRelation('a.int)
270285
val testRelation2 = LocalRelation('b.long)
271286
val originalLeft = testRelation1.subquery('left)
@@ -276,9 +291,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
276291

277292
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
278293
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
279-
val optimized = Optimize.execute(originalLeft.join(originalRight, Inner, condition).analyze)
280-
val correctAnswer = left.join(right, Inner, condition).analyze
281-
comparePlans(optimized, correctAnswer)
294+
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
282295
}
283296
}
284297
}

0 commit comments

Comments
 (0)