Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ trait HiveTypeCoercion {
WidenTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanComparisons ::
BooleanEqualization ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
Expand Down Expand Up @@ -119,7 +119,7 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
val stringNaN = Literal("NaN")
private val stringNaN = Literal("NaN")

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
Expand Down Expand Up @@ -349,17 +349,17 @@ trait HiveTypeCoercion {
import scala.math.{max, min}

// Conversion rules for integer types into fixed-precision decimals
val intTypeToFixed: Map[DataType, DecimalType] = Map(
private val intTypeToFixed: Map[DataType, DecimalType] = Map(
ByteType -> DecimalType(3, 0),
ShortType -> DecimalType(5, 0),
IntegerType -> DecimalType(10, 0),
LongType -> DecimalType(20, 0)
)

def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType

// Conversion rules for float and double into fixed-precision decimals
val floatTypeToFixed: Map[DataType, DecimalType] = Map(
private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
FloatType -> DecimalType(7, 7),
DoubleType -> DecimalType(15, 15)
)
Expand Down Expand Up @@ -482,30 +482,66 @@ trait HiveTypeCoercion {
}

/**
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/
object BooleanComparisons extends Rule[LogicalPlan] {
val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_))
val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_))
object BooleanEqualization extends Rule[LogicalPlan] {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be a SQL internal Decimal type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably also add a test for this case since it seem broken ATM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This breaks for Decimal. We should probably compare Literal itself, instead of comparing the wrapped value for literals.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan can you fix this and add a test case? thanks.

private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0))

private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq(
Literal(trueValues.head), booleanExpr,
Literal(falseValues.head), Not(booleanExpr),
Literal(false)))
}

private def transform(booleanExpr: Expression, numericExpr: Expression) = {
If(Or(IsNull(booleanExpr), IsNull(numericExpr)),
Literal.create(null, BooleanType),
buildCaseKeyWhen(booleanExpr, numericExpr))
}

private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = {
CaseWhen(Seq(
And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true),
Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false),
buildCaseKeyWhen(booleanExpr, numericExpr)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boolean is already comparable, we don't need to cast it to byte.

))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

// Hive treats (true = 1) as true and (false = 0) as true.
case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l
case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r
case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l)
case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r)

// No need to change other EqualTo operators as that actually makes sense for boolean types.
case e: EqualTo => e
// No need to change the EqualNullSafe operators, too
case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison if p.left.dataType == BooleanType &&
p.right.dataType == BooleanType =>
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
// Hive treats (true = 1) as true and (false = 0) as true,
// all other cases are considered as false.

// We may simplify the expression if one side is literal numeric values
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => l
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(l)
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
if trueValues.contains(value) => r
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
if falseValues.contains(value) => Not(r)
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(l), l)
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(l), Not(l))
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(r), r)
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if its not worthwhile to have a custom extractor for equality checking. It seems there might be more cases where either == or <=> should match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need a custom extractor for equality checking in the future, but in this case, we handle EqualTo and EqualNullSafe differently.

if falseValues.contains(value) => And(IsNotNull(r), Not(r))

case EqualTo(l @ BooleanType(), r @ NumericType()) =>
transform(l , r)
case EqualTo(l @ NumericType(), r @ BooleanType()) =>
transform(r, l)
case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
transformNullSafe(l, r)
case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
transformNullSafe(r, l)
}
}

Expand Down Expand Up @@ -606,7 +642,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
Expand All @@ -625,6 +661,23 @@ trait HiveTypeCoercion {
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}

case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = ckw.branches.sliding(2, 2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case s => s
}.reduce(_ ++ _)
val transformedKey = if (ckw.key.dataType != commonType) {
Cast(ckw.key, commonType)
} else {
ckw.key
}
CaseKeyWhen(transformedKey, transformedBranches)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ trait CaseWhenLike extends Expression {

// both then and else val should be considered.
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
def valueTypesEqual: Boolean = valueTypes.distinct.size == 1

override def dataType: DataType = {
if (!resolved) {
Expand Down Expand Up @@ -442,7 +442,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
override def children: Seq[Expression] = key +: branches

override lazy val resolved: Boolean =
childrenResolved && valueTypesEqual
childrenResolved && valueTypesEqual &&
(key +: whenList).map(_.dataType).distinct.size == 1

/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.PlanTest

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._

class HiveTypeCoercionSuite extends PlanTest {
Expand Down Expand Up @@ -104,15 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest {
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}

private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
comparePlans(
rule(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}

test("coalesce casts") {
val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
comparePlans(
fac(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
ruleTest(
ruleTest(fac,
Coalesce(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
Expand All @@ -121,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
ruleTest(
ruleTest(fac,
Coalesce(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
Expand All @@ -131,4 +133,39 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
:: Nil))
}

test("type coercion for CaseKeyWhen") {
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
ruleTest(cwc,
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
// Will remove exception expectation in PR#6405
intercept[RuntimeException] {
ruleTest(cwc,
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
}
}

test("type coercion simplification for equal to") {
val be = new HiveTypeCoercion {}.BooleanEqualization
ruleTest(be,
EqualTo(Literal(true), Literal(1)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal(true), Literal(0)),
Not(Literal(true))
)
ruleTest(be,
EqualNullSafe(Literal(true), Literal(1)),
And(IsNotNull(Literal(true)), Literal(true))
)
ruleTest(be,
EqualNullSafe(Literal(true), Literal(0)),
And(IsNotNull(Literal(true)), Not(Literal(true)))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)

val literalNull = Literal.create(null, BooleanType)
val literalNull = Literal.create(null, IntegerType)
val literalInt = Literal(1)
val literalString = Literal("a")

Expand All @@ -871,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row)

checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}

test("complex type") {
Expand Down
36 changes: 28 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}

import org.apache.spark.sql.types._

/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect

class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Make sure the tables are loaded.
TestData

import org.apache.spark.sql.test.TestSQLContext.implicits._
val sqlCtx = TestSQLContext
val sqlContext = TestSQLContext
import sqlContext.implicits._

test("SPARK-6743: no columns from cache") {
Seq(
Expand Down Expand Up @@ -915,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}

val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
Expand Down Expand Up @@ -945,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}

val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
val df2 = createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
Expand All @@ -970,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}

val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
val df3 = createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")

checkAnswer(
Expand Down Expand Up @@ -1015,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
val personWithMeta = createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
Expand Down Expand Up @@ -1331,4 +1331,24 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {

checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}

test("SPARK-7952: fix the equality check between boolean and numeric types") {
withTempTable("t") {
// numeric field i, boolean field j, result of i = j, result of i <=> j
Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)](
(1, true, true, true),
(0, false, true, true),
(2, true, false, false),
(2, false, false, false),
(null, true, null, false),
(null, false, null, false),
(0, null, null, false),
(1, null, null, false),
(null, null, null, true)
).toDF("i", "b", "r1", "r2").registerTempTable("t")

checkAnswer(sql("select i = b from t"), sql("select r1 from t"))
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case can be simplified a bit, also please drop the temp table at last:

  val data = Seq(
    (1, true),
    (0, false),
    (2, true),
    (2, false),
    (null, true),
    (null, false),
    (0, null),
    (1, null),
    (null, null)
  ).map { case (i, b) =>
    (i.asInstanceOf[Integer], b.asInstanceOf[java.lang.Boolean])
  }

  data.toDF("i", "b").registerTempTable("t")

  try {
    // checkAnswer calls
  } finally {
    sqlCtx.dropTempTable("t")
  }

}