Skip to content

Commit da5685b

Browse files
holdenkBryanCutler
authored andcommitted
[SPARK-23672][PYTHON] Document support for nested return types in scalar with arrow udfs
## What changes were proposed in this pull request? Clarify docstring for Scalar functions ## How was this patch tested? Adds a unit test showing use similar to wordcount, there's existing unit test for array of floats as well. Closes #20908 from holdenk/SPARK-23672-document-support-for-nested-return-types-in-scalar-with-arrow-udfs. Authored-by: Holden Karau <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 12e3e9f commit da5685b

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2720,9 +2720,10 @@ def pandas_udf(f=None, returnType=None, functionType=None):
27202720
1. SCALAR
27212721
27222722
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
2723-
The returnType should be a primitive data type, e.g., :class:`DoubleType`.
27242723
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
27252724
2725+
:class:`MapType`, :class:`StructType` are currently not supported as output types.
2726+
27262727
Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
27272728
:meth:`pyspark.sql.DataFrame.select`.
27282729

python/pyspark/sql/tests.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4443,6 +4443,7 @@ def test_timestamp_dst(self):
44434443
not _have_pandas or not _have_pyarrow,
44444444
_pandas_requirement_message or _pyarrow_requirement_message)
44454445
class PandasUDFTests(ReusedSQLTestCase):
4446+
44464447
def test_pandas_udf_basic(self):
44474448
from pyspark.rdd import PythonEvalType
44484449
from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4658,6 +4659,24 @@ def random_udf(v):
46584659
random_udf = random_udf.asNondeterministic()
46594660
return random_udf
46604661

4662+
def test_pandas_udf_tokenize(self):
4663+
from pyspark.sql.functions import pandas_udf
4664+
tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')),
4665+
ArrayType(StringType()))
4666+
self.assertEqual(tokenize.returnType, ArrayType(StringType()))
4667+
df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
4668+
result = df.select(tokenize("vals").alias("hi"))
4669+
self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect())
4670+
4671+
def test_pandas_udf_nested_arrays(self):
4672+
from pyspark.sql.functions import pandas_udf
4673+
tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]),
4674+
ArrayType(ArrayType(StringType())))
4675+
self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
4676+
df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
4677+
result = df.select(tokenize("vals").alias("hi"))
4678+
self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect())
4679+
46614680
def test_vectorized_udf_basic(self):
46624681
from pyspark.sql.functions import pandas_udf, col, array
46634682
df = self.spark.range(10).select(

0 commit comments

Comments
 (0)