Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ trait HiveTypeCoercion {
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanEqualization ::
BooleanEquality ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
Expand Down Expand Up @@ -445,10 +445,10 @@ trait HiveTypeCoercion {
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = DecimalType(max(p1, p2), max(s1, s2))
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2)
case b @ BinaryComparison(e1 @ DecimalType.Expression(_, _), e2)
if e2.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2))
case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _))
case b @ BinaryComparison(e1, e2 @ DecimalType.Expression(_, _))
if e1.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited)))

Expand Down Expand Up @@ -479,9 +479,9 @@ trait HiveTypeCoercion {
/**
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
object BooleanEqualization extends Rule[LogicalPlan] {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
object BooleanEquality extends Rule[LogicalPlan] {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1))
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0))

private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq(
Expand Down Expand Up @@ -512,22 +512,22 @@ trait HiveTypeCoercion {
// all other cases are considered as false.

// We may simplify the expression if one side is literal numeric values
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => left
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(left)
case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
if trueValues.contains(value) => right
case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
if falseValues.contains(value) => Not(right)
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(left), left)
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(left), Not(left))
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(right), right)
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
if falseValues.contains(value) => And(IsNotNull(right), Not(right))
case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only care about boolean type and literal type here, not left and right, so I use bool instead.

if trueValues.contains(value) => bool
case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(bool)
case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
if trueValues.contains(value) => bool
case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
if falseValues.contains(value) => Not(bool)
case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(bool), bool)
case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(bool), bool)
case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))

case EqualTo(left @ BooleanType(), right @ NumericType()) =>
transform(left , right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class HiveTypeCoercionSuite extends PlanTest {
}

test("type coercion simplification for equal to") {
val be = new HiveTypeCoercion {}.BooleanEqualization
val be = new HiveTypeCoercion {}.BooleanEquality

ruleTest(be,
EqualTo(Literal(true), Literal(1)),
Literal(true)
Expand All @@ -164,5 +165,26 @@ class HiveTypeCoercionSuite extends PlanTest {
EqualNullSafe(Literal(true), Literal(0)),
And(IsNotNull(Literal(true)), Not(Literal(true)))
)

ruleTest(be,
EqualTo(Literal(true), Literal(1L)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal(BigDecimal(0)), Literal(true)),
Not(Literal(true))
)
ruleTest(be,
EqualTo(Literal(Decimal(1)), Literal(true)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
Literal(true)
)
}
}