Skip to content

Commit 8e4d15f

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-13101][SQL] nullability of array type element should not fail analysis of encoder
nullability should only be considered as an optimization rather than part of the type system, so instead of failing analysis for mismatch nullability, we should pass analysis and add runtime null check. Author: Wenchen Fan <[email protected]> Closes #11035 from cloud-fan/ignore-nullability.
1 parent 06f0df6 commit 8e4d15f

File tree

7 files changed

+64
-104
lines changed

7 files changed

+64
-104
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ object JavaTypeInference {
292292
val setter = if (nullable) {
293293
constructor
294294
} else {
295-
AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
295+
AssertNotNull(constructor, Seq("currently no type path record in java"))
296296
}
297297
p.getWriteMethod.getName -> setter
298298
}.toMap

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {
249249

250250
case t if t <:< localTypeOf[Array[_]] =>
251251
val TypeRef(_, _, Seq(elementType)) = t
252+
253+
// TODO: add runtime null check for primitive array
252254
val primitiveMethod = elementType match {
253255
case t if t <:< definitions.IntTpe => Some("toIntArray")
254256
case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {
276278

277279
case t if t <:< localTypeOf[Seq[_]] =>
278280
val TypeRef(_, _, Seq(elementType)) = t
281+
val Schema(dataType, nullable) = schemaFor(elementType)
279282
val className = getClassNameFromType(elementType)
280283
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
281-
val arrayData =
282-
Invoke(
283-
MapObjects(
284-
p => constructorFor(elementType, Some(p), newTypePath),
285-
getPath,
286-
schemaFor(elementType).dataType),
287-
"array",
288-
ObjectType(classOf[Array[Any]]))
284+
285+
val mapFunction: Expression => Expression = p => {
286+
val converter = constructorFor(elementType, Some(p), newTypePath)
287+
if (nullable) {
288+
converter
289+
} else {
290+
AssertNotNull(converter, newTypePath)
291+
}
292+
}
293+
294+
val array = Invoke(
295+
MapObjects(mapFunction, getPath, dataType),
296+
"array",
297+
ObjectType(classOf[Array[Any]]))
289298

290299
StaticInvoke(
291300
scala.collection.mutable.WrappedArray.getClass,
292301
ObjectType(classOf[Seq[_]]),
293302
"make",
294-
arrayData :: Nil)
303+
array :: Nil)
295304

296305
case t if t <:< localTypeOf[Map[_, _]] =>
297306
// TODO: add walked type path for map
@@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
343352
newTypePath)
344353

345354
if (!nullable) {
346-
AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
355+
AssertNotNull(constructor, newTypePath)
347356
} else {
348357
constructor
349358
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
14261426
fail(child, DateType, walkedTypePath)
14271427
case (StringType, to: NumericType) =>
14281428
fail(child, to, walkedTypePath)
1429-
case _ => Cast(child, dataType)
1429+
case _ => Cast(child, dataType.asNullable)
14301430
}
14311431
}
14321432
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ object MapObjects {
365365
* to handle collection elements.
366366
* @param inputData An expression that when evaluted returns a collection object.
367367
*/
368-
case class MapObjects(
368+
case class MapObjects private(
369369
loopVar: LambdaVariable,
370370
lambdaFunction: Expression,
371371
inputData: Expression) extends Expression {
@@ -637,8 +637,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
637637
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
638638
* non-null `s`, `s.i` can't be null.
639639
*/
640-
case class AssertNotNull(
641-
child: Expression, parentType: String, fieldName: String, fieldType: String)
640+
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
642641
extends UnaryExpression {
643642

644643
override def dataType: DataType = child.dataType
@@ -651,19 +650,22 @@ case class AssertNotNull(
651650
override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = {
652651
val childGen = child.gen(ctx)
653652

653+
val errMsg = "Null value appeared in non-nullable field:" +
654+
walkedTypePath.mkString("\n", "\n", "\n") +
655+
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
656+
"please try to use scala.Option[_] or other nullable types " +
657+
"(e.g. java.lang.Integer instead of int/scala.Int)."
658+
val idx = ctx.references.length
659+
ctx.references += errMsg
660+
654661
ev.isNull = "false"
655662
ev.value = childGen.value
656663

657664
s"""
658665
${childGen.code}
659666

660667
if (${childGen.isNull}) {
661-
throw new RuntimeException(
662-
"Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
663-
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
664-
"please try to use scala.Option[_] or other nullable types " +
665-
"(e.g. java.lang.Integer instead of int/scala.Int)."
666-
);
668+
throw new RuntimeException((String) references[$idx]);
667669
}
668670
"""
669671
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala

Lines changed: 30 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag
2121

2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.dsl.expressions._
24-
import org.apache.spark.sql.catalyst.expressions._
2524
import org.apache.spark.sql.catalyst.plans.PlanTest
25+
import org.apache.spark.sql.catalyst.util.GenericArrayData
26+
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
2729

2830
case class StringLongClass(a: String, b: Long)
2931

@@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
3234
case class ComplexClass(a: Long, b: StringLongClass)
3335

3436
class EncoderResolutionSuite extends PlanTest {
37+
private val str = UTF8String.fromString("hello")
38+
3539
test("real type doesn't match encoder schema but they are compatible: product") {
3640
val encoder = ExpressionEncoder[StringLongClass]
37-
val cls = classOf[StringLongClass]
38-
3941

40-
{
41-
val attrs = Seq('a.string, 'b.int)
42-
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
43-
val expected: Expression = NewInstance(
44-
cls,
45-
Seq(
46-
toExternalString('a.string),
47-
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
48-
),
49-
ObjectType(cls),
50-
propagateNull = false)
51-
compareExpressions(fromRowExpr, expected)
52-
}
42+
// int type can be up cast to long type
43+
val attrs1 = Seq('a.string, 'b.int)
44+
encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
5345

54-
{
55-
val attrs = Seq('a.int, 'b.long)
56-
val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
57-
val expected = NewInstance(
58-
cls,
59-
Seq(
60-
toExternalString('a.int.cast(StringType)),
61-
AssertNotNull('b.long, cls.getName, "b", "Long")
62-
),
63-
ObjectType(cls),
64-
propagateNull = false)
65-
compareExpressions(fromRowExpr, expected)
66-
}
46+
// int type can be up cast to string type
47+
val attrs2 = Seq('a.int, 'b.long)
48+
encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
6749
}
6850

6951
test("real type doesn't match encoder schema but they are compatible: nested product") {
7052
val encoder = ExpressionEncoder[ComplexClass]
71-
val innerCls = classOf[StringLongClass]
72-
val cls = classOf[ComplexClass]
73-
7453
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
75-
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
76-
val expected: Expression = NewInstance(
77-
cls,
78-
Seq(
79-
AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
80-
If(
81-
'b.struct('a.int, 'b.long).isNull,
82-
Literal.create(null, ObjectType(innerCls)),
83-
NewInstance(
84-
innerCls,
85-
Seq(
86-
toExternalString(
87-
GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
88-
AssertNotNull(
89-
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
90-
innerCls.getName, "b", "Long")),
91-
ObjectType(innerCls),
92-
propagateNull = false)
93-
)),
94-
ObjectType(cls),
95-
propagateNull = false)
96-
compareExpressions(fromRowExpr, expected)
54+
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
9755
}
9856

9957
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
10058
val encoder = ExpressionEncoder.tuple(
10159
ExpressionEncoder[StringLongClass],
10260
ExpressionEncoder[Long])
103-
val cls = classOf[StringLongClass]
104-
10561
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
106-
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
107-
val expected: Expression = NewInstance(
108-
classOf[Tuple2[_, _]],
109-
Seq(
110-
NewInstance(
111-
cls,
112-
Seq(
113-
toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
114-
AssertNotNull(
115-
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
116-
cls.getName, "b", "Long")),
117-
ObjectType(cls),
118-
propagateNull = false),
119-
'b.int.cast(LongType)),
120-
ObjectType(classOf[Tuple2[_, _]]),
121-
propagateNull = false)
122-
compareExpressions(fromRowExpr, expected)
62+
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
63+
}
64+
65+
test("nullability of array type element should not fail analysis") {
66+
val encoder = ExpressionEncoder[Seq[Int]]
67+
val attrs = 'a.array(IntegerType) :: Nil
68+
69+
// It should pass analysis
70+
val bound = encoder.resolve(attrs, null).bind(attrs)
71+
72+
// If no null values appear, it should works fine
73+
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
74+
75+
// If there is null value, it should throw runtime exception
76+
val e = intercept[RuntimeException] {
77+
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
78+
}
79+
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
12380
}
12481

12582
test("the real number of fields doesn't match encoder schema: tuple encoder") {
@@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
166123
}
167124
}
168125

169-
private def toExternalString(e: Expression): Expression = {
170-
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
171-
}
172-
173126
test("throw exception if real type is not compatible with encoder schema") {
174127
val msg1 = intercept[AnalysisException] {
175128
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -850,9 +850,7 @@ public void testRuntimeNullabilityCheck() {
850850
}
851851

852852
nullabilityCheck.expect(RuntimeException.class);
853-
nullabilityCheck.expectMessage(
854-
"Null value appeared in non-nullable field " +
855-
"test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
853+
nullabilityCheck.expectMessage("Null value appeared in non-nullable field");
856854

857855
{
858856
Row row = new GenericRow(new Object[] {

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,9 +553,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
553553
buildDataset(Row(Row("hello", null))).collect()
554554
}.getMessage
555555

556-
assert(message.contains(
557-
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
558-
))
556+
assert(message.contains("Null value appeared in non-nullable field"))
559557
}
560558

561559
test("SPARK-12478: top level null field") {

0 commit comments

Comments
 (0)