diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6fc154f8debcf..b3751a3de63b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -363,6 +363,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayIntersect]("array_intersect"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c863ba434120d..7dc6295d48055 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.catalyst.expressions +import java.util._ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow @@ -287,3 +288,126 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +@ExpressionDescription( + usage = "_FUNC_(array, array, ...) - Returns intersection of multiple arrays.", + extended = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4), array(0, 1, 3)); + array(1) + """) +case class ArrayIntersect(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override def checkInputDataTypes(): TypeCheckResult = { + val types = children.map(_.dataType) + types.foreach { t => + if (!t.isInstanceOf[NullType] && !t.isInstanceOf[ArrayType]) { + return TypeCheckResult.TypeCheckFailure( + s"input to $prettyName should be an array type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) + } + } + + TypeCheckResult.TypeCheckSuccess + } + + override def dataType: DataType = { + children.headOption.map(_.dataType).getOrElse(NullType) + } + + override def nullable: Boolean = children.exists(_.nullable) + + override def eval(input: InternalRow): Any = { + if (nullable) { + null + } else { + val arrays = children.map(_.eval(input).asInstanceOf[ArrayData].array) + var results = arrays.head + arrays.tail.foreach { + array => results = results.filter(elem => array.contains(elem)) + } + new GenericArrayData(results) + } + } + + private def doGenJavaArray( + ctx: CodegenContext, + arrayCodeType: (ExprCode, DataType)): (String, String) = { + val objArrayName = ctx.freshName("array") + val tmpIndex = ctx.freshName("index") + + val (ev, arrayDataType) = arrayCodeType + val elemDataType = arrayDataType.asInstanceOf[ArrayType].elementType + val boxedJavaDataType = ctx.boxedType(elemDataType) + val getValueCode = ctx.getValue(ev.value, elemDataType, tmpIndex) + + (objArrayName, + s""" + ${ev.code} + ${boxedJavaDataType}[] ${objArrayName} = new ${boxedJavaDataType}[${ev.value}.numElements()]; + for (int ${tmpIndex}=0; ${tmpIndex}<${ev.value}.numElements(); ${tmpIndex}++) { + ${objArrayName}[${tmpIndex}] = ${getValueCode}; + } + """.stripMargin + ) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arraysCode = children.map(e => ( + e.genCode(ctx), + e.dataType)) + + val arrayDataName = ctx.freshName("arrayData") + val resultsArrayListName = ctx.freshName("resultArrayList") + + val genericArrayClass = classOf[GenericArrayData].getName + val arrayListClass = classOf[ArrayList[Any]].getName + val listClass = classOf[List[Any]].getName + val arraysClass = classOf[Arrays].getName + + if (nullable) { + ev.copy(code = s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + """) + } else { + val (resultsName, genResultsCode) = doGenJavaArray(ctx, arraysCode.head) + val setupResultsCode = + s""" + ${arrayListClass} ${resultsArrayListName} = new ${arrayListClass}(); + ${genResultsCode} + ${resultsArrayListName}.addAll(${arraysClass}.asList(${resultsName})); + """.stripMargin + + val intersectArraysCode = arraysCode.tail.map { + arrayCode => { + val tmpListName = ctx.freshName("array") + val (arrayTmpName, genArrayTmpCode) = doGenJavaArray(ctx, arrayCode) + s""" + ${genArrayTmpCode} + ${listClass} ${tmpListName} = ${arraysClass}.asList(${arrayTmpName}); + ${resultsArrayListName}.retainAll(${tmpListName}); + """.stripMargin + } + } + + val resultsAsArrayDataCode = + s""" + final ArrayData ${arrayDataName} = new ${genericArrayClass}(${resultsArrayListName}); + """.stripMargin + + ev.copy( + code = setupResultsCode + + ctx.splitExpressions( + intersectArraysCode, "apply", + ("InternalRow", ctx.INPUT_ROW) :: (arrayListClass, resultsArrayListName) :: Nil) + + resultsAsArrayDataCode, + value = arrayDataName, + isNull = "false") + } + } + + override def prettyName: String = "array_intersect" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 020687e4b3a27..e36bf311ecccc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -105,4 +105,117 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array intersects") { + val a0 = Literal.create(1, IntegerType) + val a1 = Literal.create(2, IntegerType) + val a2 = Literal.create(3, IntegerType) + val a3 = Literal.create(4, IntegerType) + + val b0 = Literal.create(1L, LongType) + val b2 = Literal.create(3L, LongType) + + val c0 = Literal.create(1.0, DoubleType) + val d0 = Literal.create("1", StringType) + + val nullLiteral = Literal.create(null) + + checkEvaluation(ArrayIntersect(Seq(nullLiteral)), null) + + checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq()))), Seq()) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1)), CreateArray(Seq(a2)), CreateArray(Seq(a3)))), Seq()) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1)), CreateArray(Seq(a0)))), Seq(1)) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1)), CreateArray(Seq(a0)), CreateArray(Seq(a0)))), Seq(1)) + + checkEvaluation(ArrayIntersect(Seq(ArrayIntersect(Seq(CreateArray( + Seq(a0, a1)), CreateArray(Seq(a2)))), CreateArray(Seq(a0)))), Seq()) + + checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq(a0, a1)), + CreateArray(Seq(Cast(b0, IntegerType), Cast(b2, IntegerType))))), Seq(1)) + + checkEvaluation(ArrayIntersect( + Seq(CreateArray(Seq(a0, a0, a1, a3)), CreateArray(Seq(a0, a2, a3, a3)), + CreateArray(Seq(a0, a1, a3)))), Seq(1, 1, 4)) + + checkEvaluation(ArrayIntersect(Seq(nullLiteral, CreateArray(Seq(a0, a2, a3, a3)), + CreateArray(Seq(a0, a1, a3)))), null) + + checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq(a0, a1)), CreateArray(Seq(a0)))), Seq(1)) + + checkEvaluation(If(LessThan(Rand(0L), c0), a0, a0), 1) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a0, a1, a3)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2, a3, a3)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2, a3)))), Seq(1, 1, 4)) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0))))), Seq(1)) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))), Seq()) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1)), CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0))))), Seq(1)) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1)), CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))), Seq()) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1)), CreateArray(Seq(a0)))), Seq(1)) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1, a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1, a2)), + CreateArray(Seq(a0, a1, a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2)))), + Seq(1, 3)) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(d0, Cast(a1, StringType), Cast(a2, StringType))), + CreateArray(Seq(a1, a2)))), + Seq()) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(a0, a1, a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a1, a2)), + CreateArray(Seq(a0, a1, a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0), Cast(a2, StringType))))), + Seq()) + + checkEvaluation(ArrayIntersect(Seq( + nullLiteral, + CreateArray(Seq(a0, a1, a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))), + null) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), nullLiteral, nullLiteral))), + CreateArray(Seq(If(LessThan(Rand(0L), c0), a0, a0), a2)), + CreateArray(Seq(If(LessThan(Rand(0L), c0), d0, d0))))), + Seq()) + + checkEvaluation(ArrayIntersect(Seq(CreateArray(Seq(a0, a1)), nullLiteral)), null) + + checkEvaluation(ArrayIntersect(Seq(nullLiteral, CreateArray(Seq(a0, a1)))), null) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), nullLiteral, nullLiteral))), + nullLiteral, nullLiteral, CreateArray(Seq(a0)))), null) + + checkEvaluation(ArrayIntersect(Seq( + CreateArray(Seq(If(LessThan(Rand(0L), c0), nullLiteral, nullLiteral))), + CreateArray(Seq(a0, a1)), + CreateArray(Seq(a0)))), + Seq()) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index b6399edb68dd6..4275defb4ac1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -163,6 +163,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { Alias(expression, s"Optimized($expression)2")() :: Nil), expression) + plan.initialize(0) val unsafeRow = plan(inputRow) val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"