@@ -76,7 +76,7 @@ trait HiveTypeCoercion {
7676 WidenTypes ::
7777 PromoteStrings ::
7878 DecimalPrecision ::
79- BooleanComparisons ::
79+ BooleanEqualization ::
8080 StringToIntegralCasts ::
8181 FunctionArgumentConversion ::
8282 CaseWhenCoercion ::
@@ -119,7 +119,7 @@ trait HiveTypeCoercion {
119119 * the appropriate numeric equivalent.
120120 */
121121 object ConvertNaNs extends Rule [LogicalPlan ] {
122- val stringNaN = Literal (" NaN" )
122+ private val stringNaN = Literal (" NaN" )
123123
124124 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
125125 case q : LogicalPlan => q transformExpressions {
@@ -349,17 +349,17 @@ trait HiveTypeCoercion {
349349 import scala .math .{max , min }
350350
351351 // Conversion rules for integer types into fixed-precision decimals
352- val intTypeToFixed : Map [DataType , DecimalType ] = Map (
352+ private val intTypeToFixed : Map [DataType , DecimalType ] = Map (
353353 ByteType -> DecimalType (3 , 0 ),
354354 ShortType -> DecimalType (5 , 0 ),
355355 IntegerType -> DecimalType (10 , 0 ),
356356 LongType -> DecimalType (20 , 0 )
357357 )
358358
359- def isFloat (t : DataType ): Boolean = t == FloatType || t == DoubleType
359+ private def isFloat (t : DataType ): Boolean = t == FloatType || t == DoubleType
360360
361361 // Conversion rules for float and double into fixed-precision decimals
362- val floatTypeToFixed : Map [DataType , DecimalType ] = Map (
362+ private val floatTypeToFixed : Map [DataType , DecimalType ] = Map (
363363 FloatType -> DecimalType (7 , 7 ),
364364 DoubleType -> DecimalType (15 , 15 )
365365 )
@@ -482,30 +482,66 @@ trait HiveTypeCoercion {
482482 }
483483
484484 /**
485- * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated .
485+ * Changes numeric values to booleans so that expressions like true = 1 can be evaluated .
486486 */
487- object BooleanComparisons extends Rule [LogicalPlan ] {
488- val trueValues = Seq (1 , 1L , 1 .toByte, 1 .toShort, new java.math.BigDecimal (1 )).map(Literal (_))
489- val falseValues = Seq (0 , 0L , 0 .toByte, 0 .toShort, new java.math.BigDecimal (0 )).map(Literal (_))
487+ object BooleanEqualization extends Rule [LogicalPlan ] {
488+ private val trueValues = Seq (1 .toByte, 1 .toShort, 1 , 1L , new java.math.BigDecimal (1 ))
489+ private val falseValues = Seq (0 .toByte, 0 .toShort, 0 , 0L , new java.math.BigDecimal (0 ))
490+
491+ private def buildCaseKeyWhen (booleanExpr : Expression , numericExpr : Expression ) = {
492+ CaseKeyWhen (numericExpr, Seq (
493+ Literal (trueValues.head), booleanExpr,
494+ Literal (falseValues.head), Not (booleanExpr),
495+ Literal (false )))
496+ }
497+
498+ private def transform (booleanExpr : Expression , numericExpr : Expression ) = {
499+ If (Or (IsNull (booleanExpr), IsNull (numericExpr)),
500+ Literal .create(null , BooleanType ),
501+ buildCaseKeyWhen(booleanExpr, numericExpr))
502+ }
503+
504+ private def transformNullSafe (booleanExpr : Expression , numericExpr : Expression ) = {
505+ CaseWhen (Seq (
506+ And (IsNull (booleanExpr), IsNull (numericExpr)), Literal (true ),
507+ Or (IsNull (booleanExpr), IsNull (numericExpr)), Literal (false ),
508+ buildCaseKeyWhen(booleanExpr, numericExpr)
509+ ))
510+ }
490511
491512 def apply (plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
492513 // Skip nodes who's children have not been resolved yet.
493514 case e if ! e.childrenResolved => e
494515
495- // Hive treats (true = 1) as true and (false = 0) as true.
496- case EqualTo (l @ BooleanType (), r) if trueValues.contains(r) => l
497- case EqualTo (l, r @ BooleanType ()) if trueValues.contains(l) => r
498- case EqualTo (l @ BooleanType (), r) if falseValues.contains(r) => Not (l)
499- case EqualTo (l, r @ BooleanType ()) if falseValues.contains(l) => Not (r)
500-
501- // No need to change other EqualTo operators as that actually makes sense for boolean types.
502- case e : EqualTo => e
503- // No need to change the EqualNullSafe operators, too
504- case e : EqualNullSafe => e
505- // Otherwise turn them to Byte types so that there exists and ordering.
506- case p : BinaryComparison if p.left.dataType == BooleanType &&
507- p.right.dataType == BooleanType =>
508- p.makeCopy(Array (Cast (p.left, ByteType ), Cast (p.right, ByteType )))
516+ // Hive treats (true = 1) as true and (false = 0) as true,
517+ // all other cases are considered as false.
518+
519+ // We may simplify the expression if one side is literal numeric values
520+ case EqualTo (l @ BooleanType (), Literal (value, _ : NumericType ))
521+ if trueValues.contains(value) => l
522+ case EqualTo (l @ BooleanType (), Literal (value, _ : NumericType ))
523+ if falseValues.contains(value) => Not (l)
524+ case EqualTo (Literal (value, _ : NumericType ), r @ BooleanType ())
525+ if trueValues.contains(value) => r
526+ case EqualTo (Literal (value, _ : NumericType ), r @ BooleanType ())
527+ if falseValues.contains(value) => Not (r)
528+ case EqualNullSafe (l @ BooleanType (), Literal (value, _ : NumericType ))
529+ if trueValues.contains(value) => And (IsNotNull (l), l)
530+ case EqualNullSafe (l @ BooleanType (), Literal (value, _ : NumericType ))
531+ if falseValues.contains(value) => And (IsNotNull (l), Not (l))
532+ case EqualNullSafe (Literal (value, _ : NumericType ), r @ BooleanType ())
533+ if trueValues.contains(value) => And (IsNotNull (r), r)
534+ case EqualNullSafe (Literal (value, _ : NumericType ), r @ BooleanType ())
535+ if falseValues.contains(value) => And (IsNotNull (r), Not (r))
536+
537+ case EqualTo (l @ BooleanType (), r @ NumericType ()) =>
538+ transform(l , r)
539+ case EqualTo (l @ NumericType (), r @ BooleanType ()) =>
540+ transform(r, l)
541+ case EqualNullSafe (l @ BooleanType (), r @ NumericType ()) =>
542+ transformNullSafe(l, r)
543+ case EqualNullSafe (l @ NumericType (), r @ BooleanType ()) =>
544+ transformNullSafe(r, l)
509545 }
510546 }
511547
@@ -606,7 +642,7 @@ trait HiveTypeCoercion {
606642 import HiveTypeCoercion ._
607643
608644 def apply (plan : LogicalPlan ): LogicalPlan = plan transformAllExpressions {
609- case cw : CaseWhenLike if ! cw.resolved && cw.childrenResolved && ! cw.valueTypesEqual =>
645+ case cw : CaseWhenLike if cw.childrenResolved && ! cw.valueTypesEqual =>
610646 logDebug(s " Input values for null casting ${cw.valueTypes.mkString(" ," )}" )
611647 val commonType = cw.valueTypes.reduce { (v1, v2) =>
612648 findTightestCommonType(v1, v2).getOrElse(sys.error(
@@ -625,6 +661,23 @@ trait HiveTypeCoercion {
625661 case CaseKeyWhen (key, _) =>
626662 CaseKeyWhen (key, transformedBranches)
627663 }
664+
665+ case ckw : CaseKeyWhen if ckw.childrenResolved && ! ckw.resolved =>
666+ val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
667+ findTightestCommonType(v1, v2).getOrElse(sys.error(
668+ s " Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2" ))
669+ }
670+ val transformedBranches = ckw.branches.sliding(2 , 2 ).map {
671+ case Seq (when, then ) if when.dataType != commonType =>
672+ Seq (Cast (when, commonType), then )
673+ case s => s
674+ }.reduce(_ ++ _)
675+ val transformedKey = if (ckw.key.dataType != commonType) {
676+ Cast (ckw.key, commonType)
677+ } else {
678+ ckw.key
679+ }
680+ CaseKeyWhen (transformedKey, transformedBranches)
628681 }
629682 }
630683
0 commit comments