Skip to content

Commit a0e46a0

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-7952][SPARK-7984][SQL] equality check between boolean type and numeric type is broken.
The origin code has several problems: * `true <=> 1` will return false as we didn't set a rule to handle it. * `true = a` where `a` is not `Literal` and its value is 1, will return false as we only handle literal values. Author: Wenchen Fan <[email protected]> Closes #6505 from cloud-fan/tmp1 and squashes the following commits: 77f0f39 [Wenchen Fan] minor fix b6401ba [Wenchen Fan] add type coercion for CaseKeyWhen and address comments ebc8c61 [Wenchen Fan] use SQLTestUtils and If 625973c [Wenchen Fan] improve 9ba2130 [Wenchen Fan] address comments fc0d741 [Wenchen Fan] fix style 2846a04 [Wenchen Fan] fix 7952
1 parent 91777a1 commit a0e46a0

File tree

5 files changed

+158
-47
lines changed

5 files changed

+158
-47
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ trait HiveTypeCoercion {
7676
WidenTypes ::
7777
PromoteStrings ::
7878
DecimalPrecision ::
79-
BooleanComparisons ::
79+
BooleanEqualization ::
8080
StringToIntegralCasts ::
8181
FunctionArgumentConversion ::
8282
CaseWhenCoercion ::
@@ -119,7 +119,7 @@ trait HiveTypeCoercion {
119119
* the appropriate numeric equivalent.
120120
*/
121121
object ConvertNaNs extends Rule[LogicalPlan] {
122-
val stringNaN = Literal("NaN")
122+
private val stringNaN = Literal("NaN")
123123

124124
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
125125
case q: LogicalPlan => q transformExpressions {
@@ -349,17 +349,17 @@ trait HiveTypeCoercion {
349349
import scala.math.{max, min}
350350

351351
// Conversion rules for integer types into fixed-precision decimals
352-
val intTypeToFixed: Map[DataType, DecimalType] = Map(
352+
private val intTypeToFixed: Map[DataType, DecimalType] = Map(
353353
ByteType -> DecimalType(3, 0),
354354
ShortType -> DecimalType(5, 0),
355355
IntegerType -> DecimalType(10, 0),
356356
LongType -> DecimalType(20, 0)
357357
)
358358

359-
def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
359+
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
360360

361361
// Conversion rules for float and double into fixed-precision decimals
362-
val floatTypeToFixed: Map[DataType, DecimalType] = Map(
362+
private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
363363
FloatType -> DecimalType(7, 7),
364364
DoubleType -> DecimalType(15, 15)
365365
)
@@ -482,30 +482,66 @@ trait HiveTypeCoercion {
482482
}
483483

484484
/**
485-
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
485+
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
486486
*/
487-
object BooleanComparisons extends Rule[LogicalPlan] {
488-
val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_))
489-
val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_))
487+
object BooleanEqualization extends Rule[LogicalPlan] {
488+
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
489+
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))
490+
491+
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
492+
CaseKeyWhen(numericExpr, Seq(
493+
Literal(trueValues.head), booleanExpr,
494+
Literal(falseValues.head), Not(booleanExpr),
495+
Literal(false)))
496+
}
497+
498+
private def transform(booleanExpr: Expression, numericExpr: Expression) = {
499+
If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
500+
Literal.create(null, BooleanType),
501+
buildCaseKeyWhen(booleanExpr, numericExpr))
502+
}
503+
504+
private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
505+
CaseWhen(Seq(
506+
And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
507+
Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
508+
buildCaseKeyWhen(booleanExpr, numericExpr)
509+
))
510+
}
490511

491512
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
492513
// Skip nodes who's children have not been resolved yet.
493514
case e if !e.childrenResolved => e
494515

495-
// Hive treats (true = 1) as true and (false = 0) as true.
496-
case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
497-
case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
498-
case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
499-
case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)
500-
501-
// No need to change other EqualTo operators as that actually makes sense for boolean types.
502-
case e: EqualTo => e
503-
// No need to change the EqualNullSafe operators, too
504-
case e: EqualNullSafe => e
505-
// Otherwise turn them to Byte types so that there exists and ordering.
506-
case p: BinaryComparison if p.left.dataType == BooleanType &&
507-
p.right.dataType == BooleanType =>
508-
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
516+
// Hive treats (true = 1) as true and (false = 0) as true,
517+
// all other cases are considered as false.
518+
519+
// We may simplify the expression if one side is literal numeric values
520+
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
521+
if trueValues.contains(value) => l
522+
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
523+
if falseValues.contains(value) => Not(l)
524+
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
525+
if trueValues.contains(value) => r
526+
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
527+
if falseValues.contains(value) => Not(r)
528+
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
529+
if trueValues.contains(value) => And(IsNotNull(l), l)
530+
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
531+
if falseValues.contains(value) => And(IsNotNull(l), Not(l))
532+
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
533+
if trueValues.contains(value) => And(IsNotNull(r), r)
534+
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
535+
if falseValues.contains(value) => And(IsNotNull(r), Not(r))
536+
537+
case EqualTo(l @ BooleanType(), r @ NumericType()) =>
538+
transform(l , r)
539+
case EqualTo(l @ NumericType(), r @ BooleanType()) =>
540+
transform(r, l)
541+
case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
542+
transformNullSafe(l, r)
543+
case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
544+
transformNullSafe(r, l)
509545
}
510546
}
511547

@@ -606,7 +642,7 @@ trait HiveTypeCoercion {
606642
import HiveTypeCoercion._
607643

608644
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
609-
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
645+
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
610646
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
611647
val commonType = cw.valueTypes.reduce { (v1, v2) =>
612648
findTightestCommonType(v1, v2).getOrElse(sys.error(
@@ -625,6 +661,23 @@ trait HiveTypeCoercion {
625661
case CaseKeyWhen(key, _) =>
626662
CaseKeyWhen(key, transformedBranches)
627663
}
664+
665+
case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
666+
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
667+
findTightestCommonType(v1, v2).getOrElse(sys.error(
668+
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
669+
}
670+
val transformedBranches = ckw.branches.sliding(2, 2).map {
671+
case Seq(when, then) if when.dataType != commonType =>
672+
Seq(Cast(when, commonType), then)
673+
case s => s
674+
}.reduce(_ ++ _)
675+
val transformedKey = if (ckw.key.dataType != commonType) {
676+
Cast(ckw.key, commonType)
677+
} else {
678+
ckw.key
679+
}
680+
CaseKeyWhen(transformedKey, transformedBranches)
628681
}
629682
}
630683

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression {
366366

367367
// both then and else val should be considered.
368368
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
369-
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
369+
def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
370370

371371
override def dataType: DataType = {
372372
if (!resolved) {
@@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
442442
override def children: Seq[Expression] = key +: branches
443443

444444
override lazy val resolved: Boolean =
445-
childrenResolved && valueTypesEqual
445+
childrenResolved && valueTypesEqual &&
446+
(key +: whenList).map(_.dataType).distinct.size == 1
446447

447448
/** Written in imperative fashion for performance considerations. */
448449
override def eval(input: Row): Any = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
2020
import org.apache.spark.sql.catalyst.plans.PlanTest
2121

2222
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
23+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
24+
import org.apache.spark.sql.catalyst.rules.Rule
2425
import org.apache.spark.sql.types._
2526

2627
class HiveTypeCoercionSuite extends PlanTest {
@@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest {
104105
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
105106
}
106107

108+
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
109+
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
110+
comparePlans(
111+
rule(Project(Seq(Alias(initial, "a")()), testRelation)),
112+
Project(Seq(Alias(transformed, "a")()), testRelation))
113+
}
114+
107115
test("coalesce casts") {
108116
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
109-
def ruleTest(initial: Expression, transformed: Expression) {
110-
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
111-
comparePlans(
112-
fac(Project(Seq(Alias(initial, "a")()), testRelation)),
113-
Project(Seq(Alias(transformed, "a")()), testRelation))
114-
}
115-
ruleTest(
117+
ruleTest(fac,
116118
Coalesce(Literal(1.0)
117119
:: Literal(1)
118120
:: Literal.create(1.0, FloatType)
@@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest {
121123
:: Cast(Literal(1), DoubleType)
122124
:: Cast(Literal.create(1.0, FloatType), DoubleType)
123125
:: Nil))
124-
ruleTest(
126+
ruleTest(fac,
125127
Coalesce(Literal(1L)
126128
:: Literal(1)
127129
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest {
131133
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
132134
:: Nil))
133135
}
136+
137+
test("type coercion for CaseKeyWhen") {
138+
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
139+
ruleTest(cwc,
140+
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
141+
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
142+
)
143+
// Will remove exception expectation in PR#6405
144+
intercept[RuntimeException] {
145+
ruleTest(cwc,
146+
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
147+
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
148+
)
149+
}
150+
}
151+
152+
test("type coercion simplification for equal to") {
153+
val be = new HiveTypeCoercion {}.BooleanEqualization
154+
ruleTest(be,
155+
EqualTo(Literal(true), Literal(1)),
156+
Literal(true)
157+
)
158+
ruleTest(be,
159+
EqualTo(Literal(true), Literal(0)),
160+
Not(Literal(true))
161+
)
162+
ruleTest(be,
163+
EqualNullSafe(Literal(true), Literal(1)),
164+
And(IsNotNull(Literal(true)), Literal(true))
165+
)
166+
ruleTest(be,
167+
EqualNullSafe(Literal(true), Literal(0)),
168+
And(IsNotNull(Literal(true)), Not(Literal(true)))
169+
)
170+
}
134171
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
862862
val c5 = 'a.string.at(4)
863863
val c6 = 'a.string.at(5)
864864

865-
val literalNull = Literal.create(null, BooleanType)
865+
val literalNull = Literal.create(null, IntegerType)
866866
val literalInt = Literal(1)
867867
val literalString = Literal("a")
868868

@@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
871871
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
872872
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
873873
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
874-
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
874+
checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row)
875875

876876
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
877877
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
878-
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
879-
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
878+
checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
879+
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
880880
}
881881

882882
test("complex type") {

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ import org.apache.spark.sql.catalyst.errors.DialectException
2424
import org.apache.spark.sql.execution.GeneratedAggregate
2525
import org.apache.spark.sql.functions._
2626
import org.apache.spark.sql.TestData._
27-
import org.apache.spark.sql.test.TestSQLContext
27+
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
2828
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
2929

3030
import org.apache.spark.sql.types._
3131

3232
/** A SQL Dialect for testing purpose, and it can not be nested type */
3333
class MyDialect extends DefaultParserDialect
3434

35-
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
35+
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
3636
// Make sure the tables are loaded.
3737
TestData
3838

39-
import org.apache.spark.sql.test.TestSQLContext.implicits._
40-
val sqlCtx = TestSQLContext
39+
val sqlContext = TestSQLContext
40+
import sqlContext.implicits._
4141

4242
test("SPARK-6743: no columns from cache") {
4343
Seq(
@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
915915
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
916916
}
917917

918-
val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
918+
val df1 = createDataFrame(rowRDD1, schema1)
919919
df1.registerTempTable("applySchema1")
920920
checkAnswer(
921921
sql("SELECT * FROM applySchema1"),
@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
945945
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
946946
}
947947

948-
val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
948+
val df2 = createDataFrame(rowRDD2, schema2)
949949
df2.registerTempTable("applySchema2")
950950
checkAnswer(
951951
sql("SELECT * FROM applySchema2"),
@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
970970
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
971971
}
972972

973-
val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
973+
val df3 = createDataFrame(rowRDD3, schema2)
974974
df3.registerTempTable("applySchema3")
975975

976976
checkAnswer(
@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
10151015
.build()
10161016
val schemaWithMeta = new StructType(Array(
10171017
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
1018-
val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
1018+
val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
10191019
def validateMetadata(rdd: DataFrame): Unit = {
10201020
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
10211021
}
@@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
13311331

13321332
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
13331333
}
1334+
1335+
test("SPARK-7952: fix the equality check between boolean and numeric types") {
1336+
withTempTable("t") {
1337+
// numeric field i, boolean field j, result of i = j, result of i <=> j
1338+
Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
1339+
(1, true, true, true),
1340+
(0, false, true, true),
1341+
(2, true, false, false),
1342+
(2, false, false, false),
1343+
(null, true, null, false),
1344+
(null, false, null, false),
1345+
(0, null, null, false),
1346+
(1, null, null, false),
1347+
(null, null, null, true)
1348+
).toDF("i", "b", "r1", "r2").registerTempTable("t")
1349+
1350+
checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
1351+
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
1352+
}
1353+
}
13341354
}

0 commit comments

Comments
 (0)