@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919import java .util .Comparator
2020
2121import scala .collection .mutable
22+
2223import org .apache .spark .sql .catalyst .InternalRow
2324import org .apache .spark .sql .catalyst .analysis .{TypeCheckResult , TypeCoercion }
2425import org .apache .spark .sql .catalyst .expressions .ArraySortLike .NullOrder
@@ -2386,22 +2387,20 @@ case class ArrayDistinct(child: Expression)
23862387 val i = ctx.freshName(" i" )
23872388 val j = ctx.freshName(" j" )
23882389 val hs = ctx.freshName(" hs" )
2390+ val foundNullElement = ctx.freshName(" foundNullElement" )
23892391 val distinctArrayLen = ctx.freshName(" distinctArrayLen" )
23902392 val getValue = CodeGenerator .getValue(array, elementType, i)
23912393 val openHashSet = classOf [OpenHashSet [_]].getName
23922394 val classTag = s " scala.reflect.ClassTag $$ .MODULE $$ .Object() "
23932395 s """
23942396 |int $distinctArrayLen = 0;
2397+ |boolean $foundNullElement = false;
23952398 | $openHashSet $hs = new $openHashSet( $classTag);
23962399 |for (int $i = 0; $i < $array.numElements(); $i++) {
23972400 | if ( $array.isNullAt( $i)) {
2398- | int $j;
2399- | for ( $j = 0; $j < $i; $j ++) {
2400- | if ( $array.isNullAt( $j))
2401- | break;
2402- | }
2403- | if ( $i == $j) {
2401+ | if (!( $foundNullElement)) {
24042402 | $distinctArrayLen = $distinctArrayLen + 1;
2403+ | $foundNullElement = true;
24052404 | }
24062405 | }
24072406 | else {
@@ -2427,6 +2426,7 @@ case class ArrayDistinct(child: Expression)
24272426 val i = ctx.freshName(" i" )
24282427 val j = ctx.freshName(" j" )
24292428 val pos = ctx.freshName(" pos" )
2429+ val foundNullElement = ctx.freshName(" foundNullElement" )
24302430 val genericArrayData = classOf [GenericArrayData ].getName
24312431 val getValue = CodeGenerator .getValue(inputArray, elementType, i)
24322432
@@ -2436,17 +2436,14 @@ case class ArrayDistinct(child: Expression)
24362436 s """
24372437 |Object[] $distinctArr = new Object[ $newArraySize];
24382438 |int $pos = 0;
2439+ |boolean $foundNullElement = false;
24392440 | $openHashSet $hs = new $openHashSet( $classTag);
24402441 |for (int $i = 0; $i < $inputArray.numElements(); $i++) {
24412442 | if ( $inputArray.isNullAt( $i)) {
2442- | int $j;
2443- | for ( $j = 0; $j < $i; $j ++) {
2444- | if ( $inputArray.isNullAt( $j))
2445- | break;
2446- | }
2447- | if ( $i == $j) {
2443+ | if (!( $foundNullElement)) {
24482444 | $distinctArr[ $pos] = null;
24492445 | $pos = $pos + 1;
2446+ | $foundNullElement = true;
24502447 | }
24512448 | }
24522449 | else {
@@ -2465,17 +2462,14 @@ case class ArrayDistinct(child: Expression)
24652462 s """
24662463 | ${ctx.createUnsafeArray(distinctArr, newArraySize, elementType, s " $prettyName failed. " )}
24672464 |int $pos = 0;
2465+ |boolean $foundNullElement = false;
24682466 | $openHashSet $hs = new $openHashSet( $classTag);
24692467 |for (int $i = 0; $i < $inputArray.numElements(); $i++) {
24702468 | if ( $inputArray.isNullAt( $i)) {
2471- | int $j;
2472- | for ( $j = 0; $j < $i; $j ++) {
2473- | if ( $inputArray.isNullAt( $j))
2474- | break;
2475- | }
2476- | if ( $i == $j) {
2469+ | if (!( $foundNullElement)) {
24772470 | $distinctArr.setNullAt( $pos);
24782471 | $pos = $pos + 1;
2472+ | $foundNullElement = true;
24792473 | }
24802474 | }
24812475 | else {
0 commit comments