Skip to content

Commit f43a7f9

Browse files
committed
Add interpretedOrdering to ArrayType.
1 parent f5f074d commit f43a7f9

File tree

4 files changed

+85
-69
lines changed

4 files changed

+85
-69
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
6868
private lazy val lt: Comparator[Any] = {
6969
val ordering = base.dataType match {
7070
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
71+
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
7172
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
7273
}
7374

@@ -90,6 +91,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
9091
private lazy val gt: Comparator[Any] = {
9192
val ordering = base.dataType match {
9293
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
94+
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
9395
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
9496
}
9597

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala

Lines changed: 27 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.util.ArrayData
2221
import org.apache.spark.sql.types._
2322

2423

@@ -30,76 +29,39 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow
3029
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
3130
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
3231

33-
private def compareValue(
34-
left: Any,
35-
right: Any,
36-
dataType: DataType,
37-
direction: SortDirection): Int = {
38-
if (left == null && right == null) {
39-
return 0
40-
} else if (left == null) {
41-
return if (direction == Ascending) -1 else 1
42-
} else if (right == null) {
43-
return if (direction == Ascending) 1 else -1
44-
} else {
45-
dataType match {
46-
case dt: AtomicType if direction == Ascending =>
47-
return dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
48-
case dt: AtomicType if direction == Descending =>
49-
return dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
50-
case s: StructType if direction == Ascending =>
51-
return s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
52-
case s: StructType if direction == Descending =>
53-
return s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
54-
case a: ArrayType =>
55-
val leftArray = left.asInstanceOf[ArrayData]
56-
val rightArray = right.asInstanceOf[ArrayData]
57-
val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
58-
var i = 0
59-
while (i < minLength) {
60-
val isNullLeft = leftArray.isNullAt(i)
61-
val isNullRight = rightArray.isNullAt(i)
62-
if (isNullLeft && isNullRight) {
63-
// Do nothing.
64-
} else if (isNullLeft) {
65-
return if (direction == Ascending) -1 else 1
66-
} else if (isNullRight) {
67-
return if (direction == Ascending) 1 else -1
68-
} else {
69-
val comp =
70-
compareValue(
71-
leftArray.get(i, a.elementType),
72-
rightArray.get(i, a.elementType),
73-
a.elementType,
74-
direction)
75-
if (comp != 0) {
76-
return comp
77-
}
78-
}
79-
i += 1
80-
}
81-
if (leftArray.numElements() < rightArray.numElements()) {
82-
return if (direction == Ascending) -1 else 1
83-
} else if (leftArray.numElements() > rightArray.numElements()) {
84-
return if (direction == Ascending) 1 else -1
85-
} else {
86-
return 0
87-
}
88-
case other =>
89-
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
90-
}
91-
}
92-
}
93-
9432
def compare(a: InternalRow, b: InternalRow): Int = {
9533
var i = 0
9634
while (i < ordering.size) {
9735
val order = ordering(i)
9836
val left = order.child.eval(a)
9937
val right = order.child.eval(b)
100-
val comparison = compareValue(left, right, order.dataType, order.direction)
101-
if (comparison != 0) {
102-
return comparison
38+
39+
if (left == null && right == null) {
40+
// Both null, continue looking.
41+
} else if (left == null) {
42+
return if (order.direction == Ascending) -1 else 1
43+
} else if (right == null) {
44+
return if (order.direction == Ascending) 1 else -1
45+
} else {
46+
val comparison = order.dataType match {
47+
case dt: AtomicType if order.direction == Ascending =>
48+
dt.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
49+
case dt: AtomicType if order.direction == Descending =>
50+
dt.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
51+
case a: ArrayType if order.direction == Ascending =>
52+
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
53+
case a: ArrayType if order.direction == Descending =>
54+
a.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
55+
case s: StructType if order.direction == Ascending =>
56+
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
57+
case s: StructType if order.direction == Descending =>
58+
s.interpretedOrdering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
59+
case other =>
60+
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
61+
}
62+
if (comparison != 0) {
63+
return comparison
64+
}
10365
}
10466
i += 1
10567
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.sql.types
1919

20+
import org.apache.spark.sql.catalyst.util.ArrayData
2021
import org.json4s.JsonDSL._
2122

2223
import org.apache.spark.annotation.DeveloperApi
2324

25+
import scala.math.Ordering
26+
2427

2528
object ArrayType extends AbstractDataType {
2629
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
@@ -81,4 +84,49 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
8184
override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
8285
f(this) || elementType.existsRecursively(f)
8386
}
87+
88+
@transient
89+
private[sql] lazy val interpretedOrdering: Ordering[ArrayData] = new Ordering[ArrayData] {
90+
private[this] val elementOrdering: Ordering[Any] = elementType match {
91+
case dt: AtomicType => dt.ordering.asInstanceOf[Ordering[Any]]
92+
case a : ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
93+
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
94+
case other =>
95+
throw new IllegalArgumentException(s"Type $other does not support ordered operations")
96+
}
97+
98+
def compare(x: ArrayData, y: ArrayData): Int = {
99+
val leftArray = x
100+
val rightArray = y
101+
val minLength = scala.math.min(leftArray.numElements(), rightArray.numElements())
102+
var i = 0
103+
while (i < minLength) {
104+
val isNullLeft = leftArray.isNullAt(i)
105+
val isNullRight = rightArray.isNullAt(i)
106+
if (isNullLeft && isNullRight) {
107+
// Do nothing.
108+
} else if (isNullLeft) {
109+
return -1
110+
} else if (isNullRight) {
111+
return 1
112+
} else {
113+
val comp =
114+
elementOrdering.compare(
115+
leftArray.get(i, elementType),
116+
rightArray.get(i, elementType))
117+
if (comp != 0) {
118+
return comp
119+
}
120+
}
121+
i += 1
122+
}
123+
if (leftArray.numElements() < rightArray.numElements()) {
124+
return -1
125+
} else if (leftArray.numElements() > rightArray.numElements()) {
126+
return 1
127+
} else {
128+
return 0
129+
}
130+
}
131+
}
84132
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
308308
Row(null, null))
309309
)
310310

311-
val df2 = Seq((Array[Array[Int]](Array(2)), "x")).toDF("a", "b")
312-
assert(intercept[AnalysisException] {
313-
df2.selectExpr("sort_array(a)").collect()
314-
}.getMessage().contains("does not support sorting array of type array<int>"))
311+
val df2 = Seq((Array[Array[Int]](Array(2), Array(1), Array(2, 4), null), "x")).toDF("a", "b")
312+
checkAnswer(
313+
df2.selectExpr("sort_array(a, true)", "sort_array(a, false)"),
314+
Seq(
315+
Row(
316+
Seq[Seq[Int]](null, Seq(1), Seq(2), Seq(2, 4)),
317+
Seq[Seq[Int]](Seq(2, 4), Seq(2), Seq(1), null)))
318+
)
315319

316320
val df3 = Seq(("xxx", "x")).toDF("a", "b")
317321
assert(intercept[AnalysisException] {

0 commit comments

Comments
 (0)