Skip to content

Commit b6401ba

Browse files
committed
add type coercion for CaseKeyWhen and address comments
1 parent ebc8c61 commit b6401ba

File tree

5 files changed

+116
-46
lines changed

5 files changed

+116
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -485,15 +485,14 @@ trait HiveTypeCoercion {
485485
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
486486
*/
487487
object BooleanEqualization extends Rule[LogicalPlan] {
488-
val trueValue = Literal(new java.math.BigDecimal(1))
489-
val falseValue = Literal(new java.math.BigDecimal(0))
488+
val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
489+
val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
490490

491491
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
492-
CaseKeyWhen(Cast(numericExpr, DecimalType.Unlimited),
493-
Seq(
494-
trueValue, booleanExpr,
495-
falseValue, Not(booleanExpr),
496-
Literal(false)))
492+
CaseKeyWhen(numericExpr, Seq(
493+
Literal(trueValues.head), booleanExpr,
494+
Literal(falseValues.head), Not(booleanExpr),
495+
Literal(false)))
497496
}
498497

499498
private def transform(booleanExpr: Expression, numericExpr: Expression) = {
@@ -516,13 +515,32 @@ trait HiveTypeCoercion {
516515

517516
// Hive treats (true = 1) as true and (false = 0) as true,
518517
// all other cases are considered as false.
519-
case EqualTo(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] =>
520-
transform(l, r)
521-
case EqualTo(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] =>
518+
519+
// We may simplify the expression if one side is literal numeric values
520+
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
521+
if trueValues.contains(value) => l
522+
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
523+
if falseValues.contains(value) => Not(l)
524+
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
525+
if trueValues.contains(value) => r
526+
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
527+
if falseValues.contains(value) => Not(r)
528+
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
529+
if trueValues.contains(value) => And(IsNotNull(l), l)
530+
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
531+
if falseValues.contains(value) => Or(IsNull(l), l)
532+
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
533+
if trueValues.contains(value) => And(IsNotNull(r), r)
534+
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
535+
if falseValues.contains(value) => Or(IsNull(r), r)
536+
537+
case EqualTo(l @ BooleanType(), r @ NumericType()) =>
538+
transform(l , r)
539+
case EqualTo(l @ NumericType(), r @ BooleanType()) =>
522540
transform(r, l)
523-
case EqualNullSafe(l @ BooleanType(), r) if r.dataType.isInstanceOf[NumericType] =>
541+
case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
524542
transformNullSafe(l, r)
525-
case EqualNullSafe(l, r @ BooleanType()) if l.dataType.isInstanceOf[NumericType] =>
543+
case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
526544
transformNullSafe(r, l)
527545
}
528546
}
@@ -624,7 +642,7 @@ trait HiveTypeCoercion {
624642
import HiveTypeCoercion._
625643

626644
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
627-
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
645+
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
628646
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
629647
val commonType = cw.valueTypes.reduce { (v1, v2) =>
630648
findTightestCommonType(v1, v2).getOrElse(sys.error(
@@ -643,6 +661,23 @@ trait HiveTypeCoercion {
643661
case CaseKeyWhen(key, _) =>
644662
CaseKeyWhen(key, transformedBranches)
645663
}
664+
665+
case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
666+
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
667+
findTightestCommonType(v1, v2).getOrElse(sys.error(
668+
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
669+
}
670+
val transformedBranches = ckw.branches.sliding(2, 2).map {
671+
case Seq(when, then) if when.dataType != commonType =>
672+
Seq(Cast(when, commonType), then)
673+
case s => s
674+
}.reduce(_ ++ _)
675+
val transformedKey = if (ckw.key.dataType != commonType) {
676+
Cast(ckw.key, commonType)
677+
} else {
678+
ckw.key
679+
}
680+
CaseKeyWhen(transformedKey, transformedBranches)
646681
}
647682
}
648683

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression {
366366

367367
// both then and else val should be considered.
368368
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
369-
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
369+
def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
370370

371371
override def dataType: DataType = {
372372
if (!resolved) {
@@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
442442
override def children: Seq[Expression] = key +: branches
443443

444444
override lazy val resolved: Boolean =
445-
childrenResolved && valueTypesEqual
445+
childrenResolved && valueTypesEqual &&
446+
(key +: whenList).map(_.dataType).distinct.size == 1
446447

447448
/** Written in imperative fashion for performance considerations. */
448449
override def eval(input: Row): Any = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
2020
import org.apache.spark.sql.catalyst.plans.PlanTest
2121

2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
23+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
24+
import org.apache.spark.sql.catalyst.rules.Rule
2425
import org.apache.spark.sql.types._
2526

2627
class HiveTypeCoercionSuite extends PlanTest {
@@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest {
104105
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
105106
}
106107

108+
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
109+
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
110+
comparePlans(
111+
rule(Project(Seq(Alias(initial, "a")()), testRelation)),
112+
Project(Seq(Alias(transformed, "a")()), testRelation))
113+
}
114+
107115
test("coalesce casts") {
108116
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
109-
def ruleTest(initial: Expression, transformed: Expression) {
110-
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
111-
comparePlans(
112-
fac(Project(Seq(Alias(initial, "a")()), testRelation)),
113-
Project(Seq(Alias(transformed, "a")()), testRelation))
114-
}
115-
ruleTest(
117+
ruleTest(fac,
116118
Coalesce(Literal(1.0)
117119
:: Literal(1)
118120
:: Literal.create(1.0, FloatType)
@@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest {
121123
:: Cast(Literal(1), DoubleType)
122124
:: Cast(Literal.create(1.0, FloatType), DoubleType)
123125
:: Nil))
124-
ruleTest(
126+
ruleTest(fac,
125127
Coalesce(Literal(1L)
126128
:: Literal(1)
127129
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest {
131133
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
132134
:: Nil))
133135
}
136+
137+
test("type coercion for CaseKeyWhen") {
138+
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
139+
ruleTest(cwc,
140+
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
141+
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
142+
)
143+
// Will remove exception expectation in PR#6405
144+
intercept[RuntimeException] {
145+
ruleTest(cwc,
146+
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
147+
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
148+
)
149+
}
150+
}
151+
152+
test("type coercion simplification for equal to") {
153+
val be = new HiveTypeCoercion {}.BooleanEqualization
154+
ruleTest(be,
155+
EqualTo(Literal(true), Literal(1)),
156+
Literal(true)
157+
)
158+
ruleTest(be,
159+
EqualTo(Literal(true), Literal(0)),
160+
Not(Literal(true))
161+
)
162+
ruleTest(be,
163+
EqualNullSafe(Literal(true), Literal(1)),
164+
And(IsNotNull(Literal(true)), Literal(true))
165+
)
166+
ruleTest(be,
167+
EqualNullSafe(Literal(true), Literal(0)),
168+
Or(IsNull(Literal(true)), Literal(true))
169+
)
170+
}
134171
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
862862
val c5 = 'a.string.at(4)
863863
val c6 = 'a.string.at(5)
864864

865-
val literalNull = Literal.create(null, BooleanType)
865+
val literalNull = Literal.create(null, IntegerType)
866866
val literalInt = Literal(1)
867867
val literalString = Literal("a")
868868

@@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
871871
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
872872
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
873873
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
874-
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
874+
checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row)
875875

876876
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
877877
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
878-
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
879-
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
878+
checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
879+
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
880880
}
881881

882882
test("complex type") {

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,24 +1334,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
13341334

13351335
test("SPARK-7952: fix the equality check between boolean and numeric types") {
13361336
withTempTable("t") {
1337-
Seq(
1338-
(1, true),
1339-
(0, false),
1340-
(2, true),
1341-
(2, false),
1342-
(null, true),
1343-
(null, false),
1344-
(0, null),
1345-
(1, null),
1346-
(null, null)
1347-
).map { case (i, b) =>
1348-
(i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean])
1349-
}.toDF("i", "b").registerTempTable("t")
1350-
1351-
checkAnswer(sql("select i = b from t"),
1352-
Seq(true, true, false, false, null, null, null, null, null).map(Row(_)))
1353-
checkAnswer(sql("select i <=> b from t"),
1354-
Seq(true, true, false, false, false, false, false, false, true).map(Row(_)))
1337+
// numeric field i, boolean field j, result of i = j, result of i <=> j
1338+
Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
1339+
(1, true, true, true),
1340+
(0, false, true, true),
1341+
(2, true, false, false),
1342+
(2, false, false, false),
1343+
(null, true, null, false),
1344+
(null, false, null, false),
1345+
(0, null, null, false),
1346+
(1, null, null, false),
1347+
(null, null, null, true)
1348+
).toDF("i", "b", "r1", "r2").registerTempTable("t")
1349+
1350+
checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
1351+
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
13551352
}
13561353
}
13571354
}

0 commit comments

Comments
 (0)