Skip to content

Commit 36d235a

Browse files
committed
1. fix GetArrayItem nullablity issue.
2. move DataFrameSuite code into array.sql and ansi/array.sql 3. add numElements to exception message. 4. other code refine. Change-Id: Ieb322ed7b036fc3322fd3b814c8508bfef266378
1 parent 6729d7b commit 36d235a

File tree

12 files changed

+364
-135
lines changed

12 files changed

+364
-135
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,24 +1986,9 @@ case class ElementAt(
19861986
}
19871987
}
19881988

1989-
override def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
1990-
if (ordinal.foldable && !ordinal.nullable) {
1991-
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
1992-
child match {
1993-
case CreateArray(ar, _) =>
1994-
nullability(ar, intOrdinal)
1995-
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
1996-
nullability(elements, intOrdinal) || field.nullable
1997-
case _ =>
1998-
true
1999-
}
2000-
} else {
2001-
if (failOnError) arrayContainsNull else true
2002-
}
2003-
}
2004-
20051989
override def nullable: Boolean = left.dataType match {
2006-
case _: ArrayType => computeNullabilityFromArray(left, right)
1990+
case _: ArrayType =>
1991+
computeNullabilityFromArray(left, right, failOnError, nullability)
20071992
case _: MapType => true
20081993
}
20091994

@@ -2016,7 +2001,8 @@ case class ElementAt(
20162001
val index = ordinal.asInstanceOf[Int]
20172002
if (array.numElements() < math.abs(index)) {
20182003
if (failOnError) {
2019-
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
2004+
throw new ArrayIndexOutOfBoundsException(
2005+
s"Invalid index: $index, numElements: ${array.numElements()}")
20202006
} else {
20212007
null
20222008
}
@@ -2055,7 +2041,10 @@ case class ElementAt(
20552041
}
20562042

20572043
val failOnErrorBranch = if (failOnError) {
2058-
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
2044+
s"""throw new ArrayIndexOutOfBoundsException(
2045+
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
2046+
|);
2047+
""".stripMargin
20592048
} else {
20602049
s"${ev.isNull} = true;"
20612050
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,25 @@ case class GetArrayItem(
240240

241241
override def left: Expression = child
242242
override def right: Expression = ordinal
243-
override def nullable: Boolean = computeNullabilityFromArray(left, right)
243+
override def nullable: Boolean =
244+
computeNullabilityFromArray(left, right, failOnError, nullability)
244245
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
245246

247+
private def nullability(elements: Seq[Expression], ordinal: Int): Boolean = {
248+
if (ordinal >= 0 && ordinal < elements.length) {
249+
elements(ordinal).nullable
250+
} else {
251+
!failOnError
252+
}
253+
}
254+
246255
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
247256
val baseValue = value.asInstanceOf[ArrayData]
248257
val index = ordinal.asInstanceOf[Number].intValue()
249258
if (index >= baseValue.numElements() || index < 0) {
250259
if (failOnError) {
251-
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
260+
throw new ArrayIndexOutOfBoundsException(
261+
s"Invalid index: $index, numElements: ${baseValue.numElements()}")
252262
} else {
253263
null
254264
}
@@ -272,7 +282,10 @@ case class GetArrayItem(
272282
}
273283

274284
val failOnErrorBranch = if (failOnError) {
275-
s"""throw new ArrayIndexOutOfBoundsException("Invalid index: " + $index);""".stripMargin
285+
s"""throw new ArrayIndexOutOfBoundsException(
286+
| "Invalid index: " + $index + ", numElements: " + $eval1.numElements()
287+
|);
288+
""".stripMargin
276289
} else {
277290
s"${ev.isNull} = true;"
278291
}
@@ -295,20 +308,24 @@ case class GetArrayItem(
295308
trait GetArrayItemUtil {
296309

297310
/** `Null` is returned for invalid ordinals. */
298-
protected def computeNullabilityFromArray(child: Expression, ordinal: Expression): Boolean = {
311+
protected def computeNullabilityFromArray(
312+
child: Expression,
313+
ordinal: Expression,
314+
failOnError: Boolean,
315+
nullability: (Seq[Expression], Int) => Boolean): Boolean = {
316+
val arrayContainsNull = child.dataType.asInstanceOf[ArrayType].containsNull
299317
if (ordinal.foldable && !ordinal.nullable) {
300318
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
301319
child match {
302-
case CreateArray(ar, _) if intOrdinal < ar.length =>
303-
ar(intOrdinal).nullable
304-
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
305-
if intOrdinal < elements.length =>
306-
elements(intOrdinal).nullable || field.nullable
320+
case CreateArray(ar, _) =>
321+
nullability(ar, intOrdinal)
322+
case GetArrayStructFields(CreateArray(elements, _), field, _, _, _) =>
323+
nullability(elements, intOrdinal) || field.nullable
307324
case _ =>
308325
true
309326
}
310327
} else {
311-
true
328+
if (failOnError) arrayContainsNull else true
312329
}
313330
}
314331
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ case class Elt(
286286
val index = indexObj.asInstanceOf[Int]
287287
if (index <= 0 || index > inputExprs.length) {
288288
if (failOnError) {
289-
throw new ArrayIndexOutOfBoundsException(s"Invalid index: $index")
289+
throw new ArrayIndexOutOfBoundsException(
290+
s"Invalid index: $index, numElements: ${inputExprs.length}")
290291
} else {
291292
null
292293
}
@@ -340,7 +341,8 @@ case class Elt(
340341
val failOnErrorBranch = if (failOnError) {
341342
s"""
342343
|if (!$indexMatched) {
343-
| throw new ArrayIndexOutOfBoundsException("Invalid index: " + ${index.value});
344+
| throw new ArrayIndexOutOfBoundsException(
345+
| "Invalid index: " + ${index.value} + ", numElements: " + ${inputExprs.length});
344346
|}
345347
""".stripMargin
346348
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2147,7 +2147,7 @@ object SQLConf {
21472147
"throw an exception at runtime if the inputs to a SQL operator/function are invalid, " +
21482148
"e.g. overflow in arithmetic operations, out-of-range index when accessing array elements. " +
21492149
"2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in " +
2150-
"the SQL parser. 3. Spark will returns null for null input for function `size`.")
2150+
"the SQL parser. 3. Spark will return NULL for null input for function `size`.")
21512151
.version("3.0.0")
21522152
.booleanConf
21532153
.createWithDefault(false)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,21 +1888,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
18881888
Seq(Date.valueOf("2018-01-01")))
18891889
}
18901890

1891-
test("SPARK-33391: element_at ArrayIndexOutOfBoundsException") {
1891+
test("SPARK-33386: element_at ArrayIndexOutOfBoundsException") {
18921892
Seq(true, false).foreach { ansiEnabled =>
18931893
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
18941894
val array = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
18951895
var expr: Expression = ElementAt(array, Literal(5))
18961896
if (ansiEnabled) {
1897-
val errMsg = "Invalid index: 5"
1897+
val errMsg = "Invalid index: 5, numElements: 3"
18981898
checkExceptionInExpression[Exception](expr, errMsg)
18991899
} else {
19001900
checkEvaluation(expr, null)
19011901
}
19021902

19031903
expr = ElementAt(array, Literal(-5))
19041904
if (ansiEnabled) {
1905-
val errMsg = "Invalid index: -5"
1905+
val errMsg = "Invalid index: -5, numElements: 3"
19061906
checkExceptionInExpression[Exception](expr, errMsg)
19071907
} else {
19081908
checkEvaluation(expr, null)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
6262
checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1))
6363
}
6464

65-
test("SPARK-33391: GetArrayItem ArrayIndexOutOfBoundsException") {
65+
test("SPARK-33386: GetArrayItem ArrayIndexOutOfBoundsException") {
6666
Seq(true, false).foreach { ansiEnabled =>
6767
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
6868
val array = Literal.create(Seq("a", "b"), ArrayType(StringType))
6969

7070
if (ansiEnabled) {
7171
checkExceptionInExpression[Exception](
7272
GetArrayItem(array, Literal(5)),
73-
"Invalid index: 5"
73+
"Invalid index: 5, numElements: 2"
7474
)
7575

7676
checkExceptionInExpression[Exception](
7777
GetArrayItem(array, Literal(-1)),
78-
"Invalid index: -1"
78+
"Invalid index: -1, numElements: 2"
7979
)
8080
} else {
8181
checkEvaluation(GetArrayItem(array, Literal(5)), null)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -969,28 +969,28 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
969969
Sentences(Literal("\"quote"), Literal("\"quote"), Literal("\"quote")) :: Nil)
970970
}
971971

972-
test("SPARK-33391: elt ArrayIndexOutOfBoundsException") {
972+
test("SPARK-33386: elt ArrayIndexOutOfBoundsException") {
973973
Seq(true, false).foreach { ansiEnabled =>
974974
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
975975
var expr: Expression = Elt(Seq(Literal(4), Literal("123"), Literal("456")))
976976
if (ansiEnabled) {
977-
val errMsg = "Invalid index: 4"
977+
val errMsg = "Invalid index: 4, numElements: 2"
978978
checkExceptionInExpression[Exception](expr, errMsg)
979979
} else {
980980
checkEvaluation(expr, null)
981981
}
982982

983983
expr = Elt(Seq(Literal(0), Literal("123"), Literal("456")))
984984
if (ansiEnabled) {
985-
val errMsg = "Invalid index: 0"
985+
val errMsg = "Invalid index: 0, numElements: 2"
986986
checkExceptionInExpression[Exception](expr, errMsg)
987987
} else {
988988
checkEvaluation(expr, null)
989989
}
990990

991991
expr = Elt(Seq(Literal(-1), Literal("123"), Literal("456")))
992992
if (ansiEnabled) {
993-
val errMsg = "Invalid index: -1"
993+
val errMsg = "Invalid index: -1, numElements: 2"
994994
checkExceptionInExpression[Exception](expr, errMsg)
995995
} else {
996996
checkEvaluation(expr, null)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
--IMPORT array.sql

sql/core/src/test/resources/sql-tests/inputs/array.sql

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,15 @@ select
9090
size(date_array),
9191
size(timestamp_array)
9292
from primitive_arrays;
93+
94+
-- index out of range for array elements
95+
select element_at(array(1, 2, 3), 5);
96+
select element_at(array(1, 2, 3), -5);
97+
select element_at(array(1, 2, 3), 0);
98+
99+
select elt(4, '123', '456');
100+
select elt(0, '123', '456');
101+
select elt(-1, '123', '456');
102+
103+
select array(1, 2, 3)[5];
104+
select array(1, 2, 3)[-1];

0 commit comments

Comments
 (0)