-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-7952][SPARK-7984][SQL] equality check between boolean type and numeric type is broken. #6505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2846a04
fc0d741
9ba2130
625973c
ebc8c61
b6401ba
77f0f39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,7 +76,7 @@ trait HiveTypeCoercion { | |
| WidenTypes :: | ||
| PromoteStrings :: | ||
| DecimalPrecision :: | ||
| BooleanComparisons :: | ||
| BooleanEqualization :: | ||
| StringToIntegralCasts :: | ||
| FunctionArgumentConversion :: | ||
| CaseWhenCoercion :: | ||
|
|
@@ -119,7 +119,7 @@ trait HiveTypeCoercion { | |
| * the appropriate numeric equivalent. | ||
| */ | ||
| object ConvertNaNs extends Rule[LogicalPlan] { | ||
| val stringNaN = Literal("NaN") | ||
| private val stringNaN = Literal("NaN") | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
| case q: LogicalPlan => q transformExpressions { | ||
|
|
@@ -349,17 +349,17 @@ trait HiveTypeCoercion { | |
| import scala.math.{max, min} | ||
|
|
||
| // Conversion rules for integer types into fixed-precision decimals | ||
| val intTypeToFixed: Map[DataType, DecimalType] = Map( | ||
| private val intTypeToFixed: Map[DataType, DecimalType] = Map( | ||
| ByteType -> DecimalType(3, 0), | ||
| ShortType -> DecimalType(5, 0), | ||
| IntegerType -> DecimalType(10, 0), | ||
| LongType -> DecimalType(20, 0) | ||
| ) | ||
|
|
||
| def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType | ||
| private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType | ||
|
|
||
| // Conversion rules for float and double into fixed-precision decimals | ||
| val floatTypeToFixed: Map[DataType, DecimalType] = Map( | ||
| private val floatTypeToFixed: Map[DataType, DecimalType] = Map( | ||
| FloatType -> DecimalType(7, 7), | ||
| DoubleType -> DecimalType(15, 15) | ||
| ) | ||
|
|
@@ -482,30 +482,66 @@ trait HiveTypeCoercion { | |
| } | ||
|
|
||
| /** | ||
| * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. | ||
| * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. | ||
| */ | ||
| object BooleanComparisons extends Rule[LogicalPlan] { | ||
| val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) | ||
| val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) | ||
| object BooleanEqualization extends Rule[LogicalPlan] { | ||
| private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) | ||
| private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) | ||
|
|
||
| private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { | ||
| CaseKeyWhen(numericExpr, Seq( | ||
| Literal(trueValues.head), booleanExpr, | ||
| Literal(falseValues.head), Not(booleanExpr), | ||
| Literal(false))) | ||
| } | ||
|
|
||
| private def transform(booleanExpr: Expression, numericExpr: Expression) = { | ||
| If(Or(IsNull(booleanExpr), IsNull(numericExpr)), | ||
| Literal.create(null, BooleanType), | ||
| buildCaseKeyWhen(booleanExpr, numericExpr)) | ||
| } | ||
|
|
||
| private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { | ||
| CaseWhen(Seq( | ||
| And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), | ||
| Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), | ||
| buildCaseKeyWhen(booleanExpr, numericExpr) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. boolean is already comparable, we don't need to cast it to byte. |
||
| )) | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
| // Skip nodes who's children have not been resolved yet. | ||
| case e if !e.childrenResolved => e | ||
|
|
||
| // Hive treats (true = 1) as true and (false = 0) as true. | ||
| case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l | ||
| case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r | ||
| case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) | ||
| case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) | ||
|
|
||
| // No need to change other EqualTo operators as that actually makes sense for boolean types. | ||
| case e: EqualTo => e | ||
| // No need to change the EqualNullSafe operators, too | ||
| case e: EqualNullSafe => e | ||
| // Otherwise turn them to Byte types so that there exists and ordering. | ||
| case p: BinaryComparison if p.left.dataType == BooleanType && | ||
| p.right.dataType == BooleanType => | ||
| p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) | ||
| // Hive treats (true = 1) as true and (false = 0) as true, | ||
| // all other cases are considered as false. | ||
|
|
||
| // We may simplify the expression if one side is literal numeric values | ||
| case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) | ||
| if trueValues.contains(value) => l | ||
| case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) | ||
| if falseValues.contains(value) => Not(l) | ||
| case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) | ||
| if trueValues.contains(value) => r | ||
| case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) | ||
| if falseValues.contains(value) => Not(r) | ||
| case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) | ||
| if trueValues.contains(value) => And(IsNotNull(l), l) | ||
| case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) | ||
| if falseValues.contains(value) => And(IsNotNull(l), Not(l)) | ||
| case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) | ||
| if trueValues.contains(value) => And(IsNotNull(r), r) | ||
| case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if its not worthwhile to have a custom extractor for equality checking. It seems there might be more cases where either There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may need a custom extractor for equality checking in the future, but in this case, we handle |
||
| if falseValues.contains(value) => And(IsNotNull(r), Not(r)) | ||
|
|
||
| case EqualTo(l @ BooleanType(), r @ NumericType()) => | ||
| transform(l , r) | ||
| case EqualTo(l @ NumericType(), r @ BooleanType()) => | ||
| transform(r, l) | ||
| case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => | ||
| transformNullSafe(l, r) | ||
| case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => | ||
| transformNullSafe(r, l) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -606,7 +642,7 @@ trait HiveTypeCoercion { | |
| import HiveTypeCoercion._ | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
| case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => | ||
| case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual => | ||
| logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") | ||
| val commonType = cw.valueTypes.reduce { (v1, v2) => | ||
| findTightestCommonType(v1, v2).getOrElse(sys.error( | ||
|
|
@@ -625,6 +661,23 @@ trait HiveTypeCoercion { | |
| case CaseKeyWhen(key, _) => | ||
| CaseKeyWhen(key, transformedBranches) | ||
| } | ||
|
|
||
| case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => | ||
| val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => | ||
| findTightestCommonType(v1, v2).getOrElse(sys.error( | ||
| s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) | ||
| } | ||
| val transformedBranches = ckw.branches.sliding(2, 2).map { | ||
| case Seq(when, then) if when.dataType != commonType => | ||
| Seq(Cast(when, commonType), then) | ||
| case s => s | ||
| }.reduce(_ ++ _) | ||
| val transformedKey = if (ckw.key.dataType != commonType) { | ||
| Cast(ckw.key, commonType) | ||
| } else { | ||
| ckw.key | ||
| } | ||
| CaseKeyWhen(transformedKey, transformedBranches) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,20 +24,20 @@ import org.apache.spark.sql.catalyst.errors.DialectException | |
| import org.apache.spark.sql.execution.GeneratedAggregate | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.TestData._ | ||
| import org.apache.spark.sql.test.TestSQLContext | ||
| import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} | ||
| import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} | ||
|
|
||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** A SQL Dialect for testing purpose, and it can not be nested type */ | ||
| class MyDialect extends DefaultParserDialect | ||
|
|
||
| class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | ||
| class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { | ||
| // Make sure the tables are loaded. | ||
| TestData | ||
|
|
||
| import org.apache.spark.sql.test.TestSQLContext.implicits._ | ||
| val sqlCtx = TestSQLContext | ||
| val sqlContext = TestSQLContext | ||
| import sqlContext.implicits._ | ||
|
|
||
| test("SPARK-6743: no columns from cache") { | ||
| Seq( | ||
|
|
@@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | |
| Row(values(0).toInt, values(1), values(2).toBoolean, v4) | ||
| } | ||
|
|
||
| val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) | ||
| val df1 = createDataFrame(rowRDD1, schema1) | ||
| df1.registerTempTable("applySchema1") | ||
| checkAnswer( | ||
| sql("SELECT * FROM applySchema1"), | ||
|
|
@@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | |
| Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) | ||
| } | ||
|
|
||
| val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) | ||
| val df2 = createDataFrame(rowRDD2, schema2) | ||
| df2.registerTempTable("applySchema2") | ||
| checkAnswer( | ||
| sql("SELECT * FROM applySchema2"), | ||
|
|
@@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | |
| Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) | ||
| } | ||
|
|
||
| val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) | ||
| val df3 = createDataFrame(rowRDD3, schema2) | ||
| df3.registerTempTable("applySchema3") | ||
|
|
||
| checkAnswer( | ||
|
|
@@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | |
| .build() | ||
| val schemaWithMeta = new StructType(Array( | ||
| schema("id"), schema("name").copy(metadata = metadata), schema("age"))) | ||
| val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) | ||
| val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) | ||
| def validateMetadata(rdd: DataFrame): Unit = { | ||
| assert(rdd.schema("name").metadata.getString(docKey) == docValue) | ||
| } | ||
|
|
@@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | |
|
|
||
| checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) | ||
| } | ||
|
|
||
| test("SPARK-7952: fix the equality check between boolean and numeric types") { | ||
| withTempTable("t") { | ||
| // numeric field i, boolean field j, result of i = j, result of i <=> j | ||
| Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( | ||
| (1, true, true, true), | ||
| (0, false, true, true), | ||
| (2, true, false, false), | ||
| (2, false, false, false), | ||
| (null, true, null, false), | ||
| (null, false, null, false), | ||
| (0, null, null, false), | ||
| (1, null, null, false), | ||
| (null, null, null, true) | ||
| ).toDF("i", "b", "r1", "r2").registerTempTable("t") | ||
|
|
||
| checkAnswer(sql("select i = b from t"), sql("select r1 from t")) | ||
| checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test case can be simplified a bit, also please drop the temp table at last: val data = Seq(
(1, true),
(0, false),
(2, true),
(2, false),
(null, true),
(null, false),
(0, null),
(1, null),
(null, null)
).map { case (i, b) =>
(i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean])
}
data.toDF("i", "b").registerTempTable("t")
try {
// checkAnswer calls
} finally {
sqlCtx.dropTempTable("t")
} |
||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be a SQL internal Decimal type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably also add a test for this case since it seem broken ATM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This breaks for Decimal. We should probably compare Literal itself, instead of comparing the wrapped value for literals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan can you fix this and add a test case? thanks.