Skip to content

Commit 2125a1b

Browse files
author
Davies Liu
committed
fix bug in handling result (null)
1 parent c96b512 commit 2125a1b

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,24 +1029,27 @@ case class ScalaUDF(
10291029
// such as IntegerType, its javaType is `int` and the returned type of user-defined
10301030
// function is Object. Trying to convert an Object to `int` will cause casting exception.
10311031
val evalCode = evals.map(_.code).mkString
1032-
val funcArguments = converterTerms.zipWithIndex.map {
1033-
case (converter, i) =>
1034-
val eval = evals(i)
1035-
val dt = children(i).dataType
1036-
s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)})(${eval.value}))"
1037-
}.mkString(",")
1038-
val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
1039-
s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +
1040-
s".apply($funcTerm.apply($funcArguments));"
1032+
val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
1033+
val eval = evals(i)
1034+
val argTerm = ctx.freshName("arg")
1035+
val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
1036+
(convert, argTerm)
1037+
}.unzip
10411038

1042-
evalCode + s"""
1043-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
1044-
Boolean ${ev.isNull};
1039+
val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
1040+
s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
1041+
s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
10451042

1043+
s"""
1044+
$evalCode
1045+
${converters.mkString("\n")}
10461046
$callFunc
10471047

1048-
${ev.value} = $resultTerm;
1049-
${ev.isNull} = $resultTerm == null;
1048+
boolean ${ev.isNull} = $resultTerm == null;
1049+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
1050+
if (!${ev.isNull}) {
1051+
${ev.value} = $resultTerm;
1052+
}
10501053
"""
10511054
}
10521055

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,12 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
11441144

11451145
// passing null into the UDF that could handle it
11461146
val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
1147-
(i: java.lang.Integer) => if (i == null) -10 else i * 2
1147+
(i: java.lang.Integer) => if (i == null) -10 else null
11481148
}
1149-
checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil)
1149+
checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
11501150

1151-
sqlContext.udf.register("boxedUDF", (i: java.lang.Integer) => if (i == null) -10 else i * 2)
1152-
checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, -2) :: Nil)
1151+
sqlContext.udf.register("boxedUDF",
1152+
(i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
1153+
checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
11531154

11541155
val primitiveUDF = udf((i: Int) => i * 2)
11551156
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)

0 commit comments

Comments
 (0)