Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ object TypeCoercion {
object CaseWhenCoercion extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
case c: CaseWhen if c.childrenResolved && !c.areInputTypesForMergingEqual =>
val maybeCommonType = findWiderCommonType(c.inputTypesForMerging)
maybeCommonType.map { commonType =>
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
Expand Down Expand Up @@ -634,10 +634,10 @@ object TypeCoercion {
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
case i @ If(pred, left, right) if !i.areInputTypesForMergingEqual =>
findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
val newLeft = if (left.dataType.sameType(widestType)) left else Cast(left, widestType)
val newRight = if (right.dataType.sameType(widestType)) right else Cast(right, widestType)
If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
case If(Literal(null, NullType), left, right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -668,6 +668,57 @@ abstract class TernaryExpression extends Expression {
}
}

/**
* A trait resolving nullable, containsNull, valueContainsNull flags of the output date type.
* This logic is usually utilized by expressions combining data from multiple child expressions
* of non-primitive types (e.g. [[CaseWhen]]).
*/
trait ComplexTypeMergingExpression extends Expression {

/**
* A collection of data types used for resolution the output type of the expression. By default,
* data types of all child expressions. The collection must not be empty.
*/
@transient lazy val inputTypesForMerging: Seq[DataType] = children.map(_.dataType)

/**
* A method determining whether the input types are equal ignoring nullable, containsNull and
* valueContainsNull flags and thus convenient for resolution of the final data type.
*/
def areInputTypesForMergingEqual: Boolean = {
inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall {
case Seq(dt1, dt2) => dt1.sameType(dt2)
}
}

private def mergeTwoDataTypes(dt1: DataType, dt2: DataType): DataType = (dt1, dt2) match {
case (t1, t2) if t1 == t2 => t1
case (ArrayType(et1, cn1), ArrayType(et2, cn2)) =>
ArrayType(mergeTwoDataTypes(et1, et2), cn1 || cn2)
case (MapType(kt1, vt1, vcn1), MapType(kt2, vt2, vcn2)) =>
MapType(mergeTwoDataTypes(kt1, kt2), mergeTwoDataTypes(vt1, vt2), vcn1 || vcn2)
case (StructType(fields1), StructType(fields2)) =>
val newFields = fields1.zip(fields2).map { // t1.sameType(t2) == true => same length
case (f1, f2) if f1 == f2 => f1
case (StructField(name, fdt1, nl1, _), StructField(_, fdt2, nl2, _)) =>
StructField(name, mergeTwoDataTypes(fdt1, fdt2), nl1 || nl2)
}
StructType(newFields)
}

@transient override lazy val dataType: DataType = {
require(
inputTypesForMerging.nonEmpty,
"The collection of input data types must not be empty.")
require(
inputTypesForMerging.length <= 1 || inputTypesForMerging.sliding(2, 1).forall {
case Seq(dt1, dt2) => DataType.equalsIgnoreCaseAndNullability(dt1, dt2)
},
"All input types must be the same except nullable, containsNull, valueContainsNull flags.")
inputTypesForMerging.reduceLeft(mergeTwoDataTypes)
}
}

/**
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
* and Hive function wrappers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._

Expand All @@ -32,7 +32,11 @@ import org.apache.spark.sql.types._
""")
// scalastyle:on line.size.limit
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
extends Expression {
extends ComplexTypeMergingExpression {

@transient override lazy val inputTypesForMerging: Seq[DataType] = {
Seq(trueValue.dataType, falseValue.dataType)
}

override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
override def nullable: Boolean = trueValue.nullable || falseValue.nullable
Expand All @@ -42,16 +46,14 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
"type of predicate expression in If should be boolean, " +
s"not ${predicate.dataType.simpleString}")
} else if (!trueValue.dataType.sameType(falseValue.dataType)) {
} else if (!areInputTypesForMergingEqual) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def dataType: DataType = trueValue.dataType

override def eval(input: InternalRow): Any = {
if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) {
trueValue.eval(input)
Expand Down Expand Up @@ -117,27 +119,23 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
case class CaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression] = None)
extends Expression with Serializable {
extends ComplexTypeMergingExpression with Serializable {

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

// both then and else expressions should be considered.
def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)

def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
case Seq(dt1, dt2) => dt1.sameType(dt2)
@transient override lazy val inputTypesForMerging: Seq[DataType] = {
branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
}

override def dataType: DataType = branches.head._2.dataType

override def nullable: Boolean = {
// Result is nullable if any of the branch is nullable, or if the else value is nullable
branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
}

override def checkInputDataTypes(): TypeCheckResult = {
// Make sure all branch conditions are boolean types.
if (valueTypesEqual) {
if (areInputTypesForMergingEqual) {
// Make sure all branch conditions are boolean types.
if (branches.forall(_._1.dataType == BooleanType)) {
TypeCheckResult.TypeCheckSuccess
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,76 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true)
}

test("if/case when - null flags of non-primitive types") {
val arrayWithNulls = Literal.create(Seq("a", null, "b"), ArrayType(StringType, true))
val arrayWithoutNulls = Literal.create(Seq("c", "d"), ArrayType(StringType, false))
val structWithNulls = Literal.create(
create_row(null, null),
StructType(Seq(StructField("a", IntegerType, true), StructField("b", StringType, true))))
val structWithoutNulls = Literal.create(
create_row(1, "a"),
StructType(Seq(StructField("a", IntegerType, false), StructField("b", StringType, false))))
val mapWithNulls = Literal.create(Map(1 -> null), MapType(IntegerType, StringType, true))
val mapWithoutNulls = Literal.create(Map(1 -> "a"), MapType(IntegerType, StringType, false))

val arrayIf1 = If(Literal.FalseLiteral, arrayWithNulls, arrayWithoutNulls)
val arrayIf2 = If(Literal.FalseLiteral, arrayWithoutNulls, arrayWithNulls)
val arrayIf3 = If(Literal.TrueLiteral, arrayWithNulls, arrayWithoutNulls)
val arrayIf4 = If(Literal.TrueLiteral, arrayWithoutNulls, arrayWithNulls)
val structIf1 = If(Literal.FalseLiteral, structWithNulls, structWithoutNulls)
val structIf2 = If(Literal.FalseLiteral, structWithoutNulls, structWithNulls)
val structIf3 = If(Literal.TrueLiteral, structWithNulls, structWithoutNulls)
val structIf4 = If(Literal.TrueLiteral, structWithoutNulls, structWithNulls)
val mapIf1 = If(Literal.FalseLiteral, mapWithNulls, mapWithoutNulls)
val mapIf2 = If(Literal.FalseLiteral, mapWithoutNulls, mapWithNulls)
val mapIf3 = If(Literal.TrueLiteral, mapWithNulls, mapWithoutNulls)
val mapIf4 = If(Literal.TrueLiteral, mapWithoutNulls, mapWithNulls)

val arrayCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithNulls)), arrayWithoutNulls)
val arrayCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, arrayWithoutNulls)), arrayWithNulls)
val arrayCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithNulls)), arrayWithoutNulls)
val arrayCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, arrayWithoutNulls)), arrayWithNulls)
val structCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, structWithNulls)), structWithoutNulls)
val structCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, structWithoutNulls)), structWithNulls)
val structCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, structWithNulls)), structWithoutNulls)
val structCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, structWithoutNulls)), structWithNulls)
val mapCaseWhen1 = CaseWhen(Seq((Literal.FalseLiteral, mapWithNulls)), mapWithoutNulls)
val mapCaseWhen2 = CaseWhen(Seq((Literal.FalseLiteral, mapWithoutNulls)), mapWithNulls)
val mapCaseWhen3 = CaseWhen(Seq((Literal.TrueLiteral, mapWithNulls)), mapWithoutNulls)
val mapCaseWhen4 = CaseWhen(Seq((Literal.TrueLiteral, mapWithoutNulls)), mapWithNulls)

def checkResult(expectedType: DataType, expectedValue: Any, result: Expression): Unit = {
assert(expectedType == result.dataType)
checkEvaluation(result, expectedValue)
}

checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf1)
checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf2)
checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayIf3)
checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayIf4)
checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf1)
checkResult(structWithNulls.dataType, structWithNulls.value, structIf2)
checkResult(structWithNulls.dataType, structWithNulls.value, structIf3)
checkResult(structWithNulls.dataType, structWithoutNulls.value, structIf4)
checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf1)
checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf2)
checkResult(mapWithNulls.dataType, mapWithNulls.value, mapIf3)
checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapIf4)

checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen1)
checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen2)
checkResult(arrayWithNulls.dataType, arrayWithNulls.value, arrayCaseWhen3)
checkResult(arrayWithNulls.dataType, arrayWithoutNulls.value, arrayCaseWhen4)
checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen1)
checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen2)
checkResult(structWithNulls.dataType, structWithNulls.value, structCaseWhen3)
checkResult(structWithNulls.dataType, structWithoutNulls.value, structCaseWhen4)
checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen1)
checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen2)
checkResult(mapWithNulls.dataType, mapWithNulls.value, mapCaseWhen3)
checkResult(mapWithNulls.dataType, mapWithoutNulls.value, mapCaseWhen4)
}

test("case key when") {
val row = create_row(null, 1, 2, "a", "b", "c")
val c1 = 'a.int.at(0)
Expand Down
58 changes: 58 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2266,6 +2266,64 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
}

test("SPARK-24165: CaseWhen/If - nullability of nested types") {
val rows = new java.util.ArrayList[Row]()
rows.add(Row(true, ("x", 1), Seq("x", "y"), Map(0 -> "x")))
rows.add(Row(false, (null, 2), Seq(null, "z"), Map(0 -> null)))
val schema = StructType(Seq(
StructField("cond", BooleanType, true),
StructField("s", StructType(Seq(
StructField("val1", StringType, true),
StructField("val2", IntegerType, false)
)), false),
StructField("a", ArrayType(StringType, true)),
StructField("m", MapType(IntegerType, StringType, true))
))

val sourceDF = spark.createDataFrame(rows, schema)

val structWhenDF = sourceDF
.select(when('cond, struct(lit("a").as("val1"), lit(10).as("val2"))).otherwise('s) as "res")
.select('res.getField("val1"))
val arrayWhenDF = sourceDF
.select(when('cond, array(lit("a"), lit("b"))).otherwise('a) as "res")
.select('res.getItem(0))
val mapWhenDF = sourceDF
.select(when('cond, map(lit(0), lit("a"))).otherwise('m) as "res")
.select('res.getItem(0))

val structIfDF = sourceDF
.select(expr("if(cond, struct('a' as val1, 10 as val2), s)") as "res")
.select('res.getField("val1"))
val arrayIfDF = sourceDF
.select(expr("if(cond, array('a', 'b'), a)") as "res")
.select('res.getItem(0))
val mapIfDF = sourceDF
.select(expr("if(cond, map(0, 'a'), m)") as "res")
.select('res.getItem(0))

def checkResult(df: DataFrame, codegenExpected: Boolean): Unit = {
assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec] == codegenExpected)
checkAnswer(df, Seq(Row("a"), Row(null)))
}

// without codegen
checkResult(structWhenDF, false)
checkResult(arrayWhenDF, false)
checkResult(mapWhenDF, false)
checkResult(structIfDF, false)
checkResult(arrayIfDF, false)
checkResult(mapIfDF, false)

// with codegen
checkResult(structWhenDF.filter('cond.isNotNull), true)
checkResult(arrayWhenDF.filter('cond.isNotNull), true)
checkResult(mapWhenDF.filter('cond.isNotNull), true)
checkResult(structIfDF.filter('cond.isNotNull), true)
checkResult(arrayIfDF.filter('cond.isNotNull), true)
checkResult(mapIfDF.filter('cond.isNotNull), true)
}

test("Uuid expressions should produce same results at retries in the same DataFrame") {
val df = spark.range(1).select($"id", new Column(Uuid()))
checkAnswer(df, df.collect())
Expand Down