From 0a37b9611e542e57183c66c7b713f1bf3bc415a7 Mon Sep 17 00:00:00 2001 From: Jane Wang Date: Wed, 26 Apr 2017 11:28:44 -0700 Subject: [PATCH] Add array_unique UDF --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 39 +++++++++++++++++++ .../spark/sql/catalyst/util/ArrayData.scala | 4 ++ .../CollectionExpressionsSuite.scala | 14 +++++++ 4 files changed, 58 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e1d83a86f99dc..b33a3e3de6eb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -361,6 +361,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayUnique]("array_unique"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c863ba434120d..2ee5dd9c7fd10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -287,3 +287,42 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Returns an array with all duplicate elements removed from input array + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns an array with all duplicate elements removed from input array.", + extended = + """ + Examples: + > SELECT _FUNC_(array(1, 2, 2, 3, 4, 3, 6)); + [1,2,3,4,6] + """) +case class ArrayUnique(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = false + + override def nullSafeEval(array: Any): Any = { + + val elementType = child.dataType.asInstanceOf[ArrayType].elementType + val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) + new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val elementType = child.dataType.asInstanceOf[ArrayType].elementType + val dataTypeClass = elementType.getClass.getName.stripSuffix("$") + val arrayDataClass = classOf[GenericArrayData].getName.stripSuffix("$") + s"${ev.value} = new ${arrayDataClass}(($c).distinct(new ${dataTypeClass}()));" + }) + } + + override def prettyName: String = "array_unique" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41d639f3..8b50765ab1f69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -163,4 +163,8 @@ abstract class ArrayData extends SpecializedGetters with Serializable { i += 1 } } + + def distinct(elementType: DataType): Array[AnyRef] = { + toObjectArray(elementType).distinct + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..62bcfa24c7c98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array Unique") { + val a0 = Literal.create(Seq(2, 1, 3, 4, 1, 2, 6), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a", "c", "b", "a", "d"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + + checkEvaluation(ArrayUnique(a0), Seq(2, 1, 3, 4, 6)) + checkEvaluation(ArrayUnique(a1), Seq()) + checkEvaluation(ArrayUnique(a2), Seq("b", "a", "c", "d")) + checkEvaluation(ArrayUnique(a3), Seq("b", null, "a")) + checkEvaluation(ArrayUnique(a4), Seq(null)) + } }