@@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.expressions
1818
1919import java .time .ZoneId
2020import java .util .Comparator
21-
2221import scala .collection .mutable
2322import scala .reflect .ClassTag
24-
2523import org .apache .spark .sql .catalyst .InternalRow
2624import org .apache .spark .sql .catalyst .analysis .{TypeCheckResult , TypeCoercion }
2725import org .apache .spark .sql .catalyst .expressions .ArraySortLike .NullOrder
@@ -37,6 +35,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
3735import org .apache .spark .unsafe .types .{ByteArray , UTF8String }
3836import org .apache .spark .unsafe .types .CalendarInterval
3937import org .apache .spark .util .collection .OpenHashSet
38+ import scala .reflect .runtime .universe
4039
4140/**
4241 * Base trait for [[BinaryExpression ]]s with two arrays of the same element type and implicit
@@ -900,6 +899,88 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
900899 override def prettyName : String = " sort_array"
901900}
902901
902+ /**
903+ * Returns the median value as double of an array of numeric values.
904+ */
905+ @ ExpressionDescription (
906+ usage = """
907+ _FUNC_(array) - Returns the median value in the array, but only accepts arrays with numeric values.
908+ NULL elements are skipped and returns NULL if array is empty.""" ,
909+ examples = """
910+ Examples:
911+ > SELECT _FUNC_(array(1, 2, null, 3));
912+ 2
913+ """ , since = " 3.0.0" )
914+ case class ArrayMedian (child : Expression ) extends UnaryExpression with ImplicitCastInputTypes {
915+
916+ override def checkInputDataTypes (): TypeCheckResult = child.dataType match {
917+ case ArrayType (dt, _) => dt match {
918+ case _ : NumericType => TypeCheckResult .TypeCheckSuccess
919+ case _ => TypeCheckResult .TypeCheckFailure (
920+ s " $prettyName does not support arrays of type ${dt.catalogString} which is not numeric. " )
921+ }
922+ case _ =>
923+ TypeCheckResult .TypeCheckFailure (s " $prettyName only supports array input. " )
924+ }
925+
926+ override def inputTypes : Seq [AbstractDataType ] = Seq (ArrayType )
927+
928+ private def containsNulls : Boolean = child.dataType.asInstanceOf [ArrayType ].containsNull
929+
930+ private def assignArrayCodeGen (array : String , ctx : CodegenContext , c : String ): String = {
931+ val javaType = CodeGenerator .javaType(arrayType)
932+ val primitiveTypeName = CodeGenerator .primitiveTypeName(arrayType)
933+
934+ if (containsNulls) {
935+ val numElements = ctx.freshName(" numElements" )
936+ val tempArray = ctx.freshName(" tempArray" )
937+ val count = ctx.freshName(" count" )
938+ val i = ctx.freshName(" i" )
939+
940+ s """
941+ |int $numElements = $c.numElements();
942+ | $javaType[] $tempArray = new $javaType[ $numElements];
943+ |int $count = -1;
944+ |for (int $i = 0; $i < $numElements; $i++) {
945+ | if(! $c.isNullAt( $i)) {
946+ | $tempArray[++ $count] = $c.get $primitiveTypeName( $i);
947+ | }
948+ |}
949+ | $javaType[] $array = java.util.Arrays.copyOf( $tempArray, $count + 1);
950+ """ .stripMargin
951+ } else {
952+ s """
953+ | $javaType[] $array = $c.to ${primitiveTypeName}Array();
954+ """ .stripMargin
955+ }
956+ }
957+
958+ override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
959+ val size = ctx.freshName(" size" )
960+ val array = ctx.freshName(" array" )
961+
962+ nullSafeCodeGen(ctx, ev, c =>
963+ s """
964+ | ${assignArrayCodeGen(array, ctx, c)}
965+ |java.util.Arrays.sort( $array);
966+ |final int $size = $array.length;
967+ |if ( $size == 0) {
968+ | ${ev.isNull} = true;
969+ |} else if ( $size % 2 == 0) {
970+ | ${ev.value} = ( $array[ $size / 2] + $array[ $size / 2 - 1]) / 2d;
971+ |} else {
972+ | ${ev.value} = $array[ $size / 2] / 1d;
973+ |}
974+ """ .stripMargin)
975+ }
976+
977+ @ transient override val dataType : DataType = DoubleType
978+
979+ private val arrayType : DataType = child.dataType.asInstanceOf [ArrayType ].elementType
980+
981+ override def prettyName : String = " array_median"
982+ }
983+
903984
904985/**
905986 * Sorts the input array in ascending order according to the natural ordering of
0 commit comments