Skip to content

Commit a7af520

Browse files
BryanCutlerhagerf
authored andcommitted
[SPARK-29367][DOC] Add compatibility note for Arrow 0.15.0 to SQL guide
Add documentation to SQL programming guide to use PyArrow >= 0.15.0 with current versions of Spark. Arrow 0.15.0 introduced a change in format which requires an environment variable to maintain compatibility. No Ran pandas_udfs tests using PyArrow 0.15.0 with environment variable set. Closes apache#26045 from BryanCutler/arrow-document-legacy-IPC-fix-SPARK-29367. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 2b3c379 commit a7af520

File tree

5 files changed

+140
-2
lines changed

5 files changed

+140
-2
lines changed

docs/sql-pyspark-pandas-with-arrow.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,20 @@ Note that a standard UDF (non-Pandas) will load timestamp data as Python datetim
219219
different than a Pandas timestamp. It is recommended to use Pandas time series functionality when
220220
working with timestamps in `pandas_udf`s to get the best performance, see
221221
[here](https://pandas.pydata.org/pandas-docs/stable/timeseries.html) for details.
222+
223+
### Compatibiliy Setting for PyArrow >= 0.15.0 and Spark 2.3.x, 2.4.x
224+
225+
Since Arrow 0.15.0, a change in the binary IPC format requires an environment variable to be
226+
compatible with previous versions of Arrow <= 0.14.1. This is only necessary to do for PySpark
227+
users with versions 2.3.x and 2.4.x that have manually upgraded PyArrow to 0.15.0. The following
228+
can be added to `conf/spark-env.sh` to use the legacy Arrow IPC format:
229+
230+
```
231+
ARROW_PRE_0_15_IPC_FORMAT=1
232+
```
233+
234+
This will instruct PyArrow >= 0.15.0 to use the legacy IPC format with the older Arrow Java that
235+
is in Spark 2.3.x and 2.4.x. Not setting this environment variable will lead to a similar error as
236+
described in [SPARK-29367](https://issues.apache.org/jira/browse/SPARK-29367) when running
237+
`pandas_udf`s or `toPandas()` with Arrow enabled. More information about the Arrow IPC change can
238+
be read on the Arrow 0.15.0 release [blog](http://arrow.apache.org/blog/2019/10/06/0.15.0-release/#columnar-streaming-protocol-change-since-0140).

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ object FunctionRegistry {
446446
expression[Shuffle]("shuffle"),
447447
expression[ArrayMin]("array_min"),
448448
expression[ArrayMax]("array_max"),
449+
expression[ArrayMedian]("array_median"),
449450
expression[Reverse]("reverse"),
450451
expression[Concat]("concat"),
451452
expression[Flatten]("flatten"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@ package org.apache.spark.sql.catalyst.expressions
1818

1919
import java.time.ZoneId
2020
import java.util.Comparator
21-
2221
import scala.collection.mutable
2322
import scala.reflect.ClassTag
24-
2523
import org.apache.spark.sql.catalyst.InternalRow
2624
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2725
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
@@ -37,6 +35,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
3735
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
3836
import org.apache.spark.unsafe.types.CalendarInterval
3937
import org.apache.spark.util.collection.OpenHashSet
38+
import scala.reflect.runtime.universe
4039

4140
/**
4241
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
@@ -900,6 +899,88 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
900899
override def prettyName: String = "sort_array"
901900
}
902901

902+
/**
903+
* Returns the median value as double of an array of numeric values.
904+
*/
905+
@ExpressionDescription(
906+
usage = """
907+
_FUNC_(array) - Returns the median value in the array, but only accepts arrays with numeric values.
908+
NULL elements are skipped and returns NULL if array is empty.""",
909+
examples = """
910+
Examples:
911+
> SELECT _FUNC_(array(1, 2, null, 3));
912+
2
913+
""", since = "3.0.0")
914+
case class ArrayMedian(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
915+
916+
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
917+
case ArrayType(dt, _) => dt match {
918+
case _: NumericType => TypeCheckResult.TypeCheckSuccess
919+
case _ => TypeCheckResult.TypeCheckFailure(
920+
s"$prettyName does not support arrays of type ${dt.catalogString} which is not numeric.")
921+
}
922+
case _ =>
923+
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
924+
}
925+
926+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
927+
928+
private def containsNulls: Boolean = child.dataType.asInstanceOf[ArrayType].containsNull
929+
930+
private def assignArrayCodeGen(array: String, ctx: CodegenContext, c: String): String = {
931+
val javaType = CodeGenerator.javaType(arrayType)
932+
val primitiveTypeName = CodeGenerator.primitiveTypeName(arrayType)
933+
934+
if(containsNulls) {
935+
val numElements = ctx.freshName("numElements")
936+
val tempArray = ctx.freshName("tempArray")
937+
val count = ctx.freshName("count")
938+
val i = ctx.freshName("i")
939+
940+
s"""
941+
|int $numElements = $c.numElements();
942+
|$javaType[] $tempArray = new $javaType[$numElements];
943+
|int $count = -1;
944+
|for (int $i = 0; $i < $numElements; $i++) {
945+
| if(!$c.isNullAt($i)) {
946+
| $tempArray[++$count] = $c.get$primitiveTypeName($i);
947+
| }
948+
|}
949+
|$javaType[] $array = java.util.Arrays.copyOf($tempArray, $count + 1);
950+
""".stripMargin
951+
} else {
952+
s"""
953+
|$javaType[] $array = $c.to${primitiveTypeName}Array();
954+
""".stripMargin
955+
}
956+
}
957+
958+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
959+
val size = ctx.freshName("size")
960+
val array = ctx.freshName("array")
961+
962+
nullSafeCodeGen(ctx, ev, c =>
963+
s"""
964+
|${assignArrayCodeGen(array, ctx, c)}
965+
|java.util.Arrays.sort($array);
966+
|final int $size = $array.length;
967+
|if ($size == 0) {
968+
| ${ev.isNull} = true;
969+
|} else if ($size % 2 == 0) {
970+
| ${ev.value} = ($array[$size / 2] + $array[$size / 2 - 1]) / 2d;
971+
|} else {
972+
| ${ev.value} = $array[$size / 2] / 1d;
973+
|}
974+
""".stripMargin)
975+
}
976+
977+
@transient override val dataType: DataType = DoubleType
978+
979+
private val arrayType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
980+
981+
override def prettyName: String = "array_median"
982+
}
983+
903984

904985
/**
905986
* Sorts the input array in ascending order according to the natural ordering of

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3902,6 +3902,15 @@ object functions {
39023902
*/
39033903
def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) }
39043904

3905+
/**
3906+
* Returns the median value in the array. Array must contain numeric values and no nulls.
3907+
* Returns null for empty arrays.
3908+
*
3909+
* @group collection_funcs
3910+
* @since 2.4.0
3911+
*/
3912+
def array_median(e: Column): Column = withExpr { ArrayMedian(e.expr) }
3913+
39053914
/**
39063915
* Returns the maximum value in the array.
39073916
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
884884
checkAnswer(df.selectExpr("array_min(a)"), answer)
885885
}
886886

887+
test("array_median function") {
888+
val doubles = Seq(
889+
Seq(1.0, 3.0, 2.0).map(Option.apply),
890+
Seq(Some(1.2), Some(-100.0), Some(2.5), Option.empty[Double]),
891+
Seq(6.0, 2.0, 3.0, 5.0, 4.0, 1.0).map(Option.apply),
892+
Seq.empty[Option[Double]]
893+
).toDF("a")
894+
895+
val answerDoubles = Seq(Row(2.0), Row(1.2), Row(3.5), Row(null))
896+
897+
val ints = Seq(
898+
Seq(1, 3, 2),
899+
Seq(1, -100, 2)
900+
).toDF("a")
901+
902+
val longs = Seq(
903+
Seq(1L, 3L, 2L),
904+
Seq(1L, -100L, 2L)
905+
).toDF("a")
906+
907+
val answerLongAndInt = Seq(Row(2.0), Row(1.0))
908+
909+
checkAnswer(doubles.select(array_median(doubles("a"))), answerDoubles)
910+
checkAnswer(doubles.selectExpr("array_median(a)"), answerDoubles)
911+
checkAnswer(ints.select(array_median(ints("a"))), answerLongAndInt)
912+
checkAnswer(ints.selectExpr("array_median(a)"), answerLongAndInt)
913+
checkAnswer(longs.select(array_median(longs("a"))), answerLongAndInt)
914+
checkAnswer(longs.selectExpr("array_median(a)"), answerLongAndInt)
915+
}
916+
887917
test("array_max function") {
888918
val df = Seq(
889919
Seq[Option[Int]](Some(1), Some(3), Some(2)),

0 commit comments

Comments
 (0)