Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,6 +1999,20 @@ def array_remove(col, element):
return Column(sc._jvm.functions.array_remove(_to_java_column(col), element))


@since(2.4)
def array_distinct(col):
"""
Collection function: removes duplicate values from the array.
:param col: name of column or expression

>>> df = spark.createDataFrame([([1, 2, 3, 2],), ([4, 5, 5, 4],)], ['data'])
>>> df.select(array_distinct(df.data)).collect()
[Row(array_distinct(data)=[1, 2, 3]), Row(array_distinct(data)=[4, 5])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.array_distinct(_to_java_column(col)))


@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ object FunctionRegistry {
expression[Flatten]("flatten"),
expression[ArrayRepeat]("array_repeat"),
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
CreateStruct.registryEntry,

// mask functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
import org.apache.spark.util.collection.OpenHashSet

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
Expand Down Expand Up @@ -2355,3 +2356,281 @@ case class ArrayRemove(left: Expression, right: Expression)

override def prettyName: String = "array_remove"
}

/**
* Removes duplicate values from the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Removes duplicate values from the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, null, 3));
[1,2,3,null]
""", since = "2.4.0")
case class ArrayDistinct(child: Expression)
extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def dataType: DataType = child.dataType

@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f: TypeCheckResult.TypeCheckFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
}
}

@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

override def nullSafeEval(array: Any): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementTypeSupportEquals) {
new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
} else {
var foundNullElement = false
var pos = 0
for (i <- 0 until data.length) {
if (data(i) == null) {
if (!foundNullElement) {
foundNullElement = true
pos = pos + 1
}
} else {
var j = 0
var done = false
while (j <= i && !done) {
if (data(j) != null && ordering.equiv(data(j), data(i))) {
done = true
}
j = j + 1
}
if (i == j - 1) {
pos = pos + 1
}
}
}
new GenericArrayData(data.slice(0, pos))
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (array) => {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
val getValue1 = CodeGenerator.getValue(array, elementType, i)
val getValue2 = CodeGenerator.getValue(array, elementType, j)
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val hs = ctx.freshName("hs")
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
if (elementTypeSupportEquals) {
s"""
|int $sizeOfDistinctArray = 0;
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $array.numElements(); $i ++) {
| if ($array.isNullAt($i)) {
| $foundNullElement = true;
| } else {
| $hs.add($getValue1);
| }
|}
|$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
""".stripMargin
} else {
s"""
|int $sizeOfDistinctArray = 0;
|boolean $foundNullElement = false;
|for (int $i = 0; $i < $array.numElements(); $i ++) {
| if ($array.isNullAt($i)) {
| if (!($foundNullElement)) {
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
| $foundNullElement = true;
| }
| } else {
| int $j;
| for ($j = 0; $j < $i; $j ++) {
| if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) {
| break;
| }
| }
| if ($i == $j) {
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
| }
| }
|}
|
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
""".stripMargin
}
})
}

private def setNull(
isPrimitive: Boolean,
foundNullElement: String,
distinctArray: String,
pos: String): String = {
val setNullValue =
if (!isPrimitive) {
s"$distinctArray[$pos] = null";
} else {
s"$distinctArray.setNullAt($pos)";
}

s"""
|if (!($foundNullElement)) {
| $setNullValue;
| $pos = $pos + 1;
| $foundNullElement = true;
|}
""".stripMargin
}

private def setNotNullValue(isPrimitive: Boolean,
distinctArray: String,
pos: String,
getValue1: String,
primitiveValueTypeName: String): String = {
if (!isPrimitive) {
s"$distinctArray[$pos] = $getValue1";
} else {
s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)";
}
}

private def setValueForFastEval(
isPrimitive: Boolean,
hs: String,
distinctArray: String,
pos: String,
getValue1: String,
primitiveValueTypeName: String): String = {
val setValue = setNotNullValue(isPrimitive,
distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|if (!($hs.contains($getValue1))) {
| $hs.add($getValue1);
| $setValue;
| $pos = $pos + 1;
|}
""".stripMargin
}

private def setValueForBruteForceEval(
isPrimitive: Boolean,
i: String,
j: String,
inputArray: String,
distinctArray: String,
pos: String,
getValue1: String,
isEqual: String,
primitiveValueTypeName: String): String = {
val setValue = setNotNullValue(isPrimitive,
distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|int $j;
|for ($j = 0; $j < $i; $j ++) {
| if (!$inputArray.isNullAt($j) && $isEqual) {
| break;
| }
|}
|if ($i == $j) {
| $setValue;
| $pos = $pos + 1;
|}
""".stripMargin
}

def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
size: String): String = {
val distinctArray = ctx.freshName("distinctArray")
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val pos = ctx.freshName("pos")
val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
val foundNullElement = ctx.freshName("foundNullElement")
val hs = ctx.freshName("hs")
val openHashSet = classOf[OpenHashSet[_]].getName
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayClass = classOf[GenericArrayData].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
val setNullForNonPrimitive =
setNull(false, foundNullElement, distinctArray, pos)
if (elementTypeSupportEquals) {
val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "")
s"""
|int $pos = 0;
|Object[] $distinctArray = new Object[$size];
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForNonPrimitive;
| } else {
| $setValueForFast;
| }
|}
|${ev.value} = new $arrayClass($distinctArray);
""".stripMargin
} else {
val setValueForBruteForce = setValueForBruteForceEval(
false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "")
s"""
|int $pos = 0;
|Object[] $distinctArray = new Object[$size];
|boolean $foundNullElement = false;
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForNonPrimitive;
| } else {
| $setValueForBruteForce;
| }
|}
|${ev.value} = new $arrayClass($distinctArray);
""".stripMargin
}
} else {
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos)
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
val setValueForFast =
setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
|int $pos = 0;
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForPrimitive;
| } else {
| $setValueForFast;
| }
|}
|${ev.value} = $distinctArray;
""".stripMargin
}
}

override def prettyName: String = "array_distinct"
}
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,49 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayRemove(c1, dataToRemove2), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayRemove(c2, dataToRemove2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}

test("Array Distinct") {
val a0 = Literal.create(Seq(2, 1, 2, 3, 4, 4, 5), ArrayType(IntegerType))
val a1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a", "a", "c", "b"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a", null, "a", null), ArrayType(StringType))
val a4 = Literal.create(Seq(null, null, null), ArrayType(NullType))
val a5 = Literal.create(Seq(true, false, false, true), ArrayType(BooleanType))
val a6 = Literal.create(Seq(1.123, 0.1234, 1.121, 1.123, 1.1230, 1.121, 0.1234),
ArrayType(DoubleType))
val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f),
ArrayType(FloatType))

checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5))
checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer])
checkEvaluation(new ArrayDistinct(a2), Seq("b", "a", "c"))
checkEvaluation(new ArrayDistinct(a3), Seq("b", null, "a"))
checkEvaluation(new ArrayDistinct(a4), Seq(null))
checkEvaluation(new ArrayDistinct(a5), Seq(true, false))
checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121))
checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f))
Copy link
Member

@kiszk kiszk May 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add test cases with complex types (e.g. Array[Binary] or others)? See #21361.
cc @ueshin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do. Thanks!


// complex data types
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2),
Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), null, Array[Byte](1, 2),
null, Array[Byte](5, 6), null), ArrayType(BinaryType))

checkEvaluation(ArrayDistinct(b0), Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)))
checkEvaluation(ArrayDistinct(b1), Seq[Array[Byte]](Array[Byte](2, 1), null))
checkEvaluation(ArrayDistinct(b2), Seq[Array[Byte]](Array[Byte](5, 6), null,
Array[Byte](1, 2)))

val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](1, 2),
Seq[Int](3, 4), Seq[Int](1, 2)), ArrayType(ArrayType(IntegerType)))
val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))
val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null),
ArrayType(ArrayType(IntegerType)))
checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)))
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
}
}
7 changes: 7 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3189,6 +3189,13 @@ object functions {
ArrayRemove(column.expr, Literal(element))
}

/**
* Removes duplicate values from the array.
* @group collection_funcs
* @since 2.4.0
*/
def array_distinct(e: Column): Column = withExpr { ArrayDistinct(e.expr) }

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Loading