Skip to content

Commit d127ac4

Browse files
committed
address comments(3)
1 parent 1219626 commit d127ac4

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919
import java.util.Comparator
2020

2121
import scala.collection.mutable
22+
2223
import org.apache.spark.sql.catalyst.InternalRow
2324
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2425
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
@@ -2386,22 +2387,20 @@ case class ArrayDistinct(child: Expression)
23862387
val i = ctx.freshName("i")
23872388
val j = ctx.freshName("j")
23882389
val hs = ctx.freshName("hs")
2390+
val foundNullElement = ctx.freshName("foundNullElement")
23892391
val distinctArrayLen = ctx.freshName("distinctArrayLen")
23902392
val getValue = CodeGenerator.getValue(array, elementType, i)
23912393
val openHashSet = classOf[OpenHashSet[_]].getName
23922394
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
23932395
s"""
23942396
|int $distinctArrayLen = 0;
2397+
|boolean $foundNullElement = false;
23952398
|$openHashSet $hs = new $openHashSet($classTag);
23962399
|for (int $i = 0; $i < $array.numElements(); $i++) {
23972400
| if ($array.isNullAt($i)) {
2398-
| int $j;
2399-
| for ($j = 0; $j < $i; $j ++) {
2400-
| if ($array.isNullAt($j))
2401-
| break;
2402-
| }
2403-
| if ($i == $j) {
2401+
| if (!($foundNullElement)) {
24042402
| $distinctArrayLen = $distinctArrayLen + 1;
2403+
| $foundNullElement = true;
24052404
| }
24062405
| }
24072406
| else {
@@ -2427,6 +2426,7 @@ case class ArrayDistinct(child: Expression)
24272426
val i = ctx.freshName("i")
24282427
val j = ctx.freshName("j")
24292428
val pos = ctx.freshName("pos")
2429+
val foundNullElement = ctx.freshName("foundNullElement")
24302430
val genericArrayData = classOf[GenericArrayData].getName
24312431
val getValue = CodeGenerator.getValue(inputArray, elementType, i)
24322432

@@ -2436,17 +2436,14 @@ case class ArrayDistinct(child: Expression)
24362436
s"""
24372437
|Object[] $distinctArr = new Object[$newArraySize];
24382438
|int $pos = 0;
2439+
|boolean $foundNullElement = false;
24392440
|$openHashSet $hs = new $openHashSet($classTag);
24402441
|for (int $i = 0; $i < $inputArray.numElements(); $i++) {
24412442
| if ($inputArray.isNullAt($i)) {
2442-
| int $j;
2443-
| for ($j = 0; $j < $i; $j ++) {
2444-
| if ($inputArray.isNullAt($j))
2445-
| break;
2446-
| }
2447-
| if ($i == $j) {
2443+
| if (!($foundNullElement)) {
24482444
| $distinctArr[$pos] = null;
24492445
| $pos = $pos + 1;
2446+
| $foundNullElement = true;
24502447
| }
24512448
| }
24522449
| else {
@@ -2465,17 +2462,14 @@ case class ArrayDistinct(child: Expression)
24652462
s"""
24662463
|${ctx.createUnsafeArray(distinctArr, newArraySize, elementType, s" $prettyName failed.")}
24672464
|int $pos = 0;
2465+
|boolean $foundNullElement = false;
24682466
|$openHashSet $hs = new $openHashSet($classTag);
24692467
|for (int $i = 0; $i < $inputArray.numElements(); $i++) {
24702468
| if ($inputArray.isNullAt($i)) {
2471-
| int $j;
2472-
| for ($j = 0; $j < $i; $j ++) {
2473-
| if ($inputArray.isNullAt($j))
2474-
| break;
2475-
| }
2476-
| if ($i == $j) {
2469+
| if (!($foundNullElement)) {
24772470
| $distinctArr.setNullAt($pos);
24782471
| $pos = $pos + 1;
2472+
| $foundNullElement = true;
24792473
| }
24802474
| }
24812475
| else {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,13 +771,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
771771
val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType))
772772
val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
773773
val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
774-
val a3 = Literal.create(Seq("b", null, "a", "a"), ArrayType(StringType))
775-
val a4 = Literal.create(Seq(null, null), ArrayType(NullType))
774+
val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType))
775+
val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType))
776+
val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType))
777+
val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234),
778+
ArrayType(DoubleType))
779+
val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f),
780+
ArrayType(FloatType))
776781

777782
checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
778783
checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
779784
checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c"))
780785
checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a"))
781786
checkEvaluation(new ArrayDistinct(a4), Seq(null))
787+
checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
788+
checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
789+
checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
782790
}
783791
}

0 commit comments

Comments
 (0)