Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""
import math
import sys
import functools

if sys.version < "3":
from itertools import imap as map
Expand Down Expand Up @@ -1908,22 +1909,48 @@ def __call__(self, *cols):


@since(1.3)
def udf(f, returnType=StringType()):
def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).

.. note:: The user-defined functions must be deterministic. Due to optimization,
duplicate invocations may be eliminated or the function may even be invoked more times than
it is present in the query.

:param f: python function
:param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
:param f: python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object

>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
"""
return UserDefinedFunction(f, returnType)
>>> @udf
... def to_upper(s):
... if s is not None:
... return s.upper()
...
>>> @udf(returnType=IntegerType())
... def add_one(x):
... if x is not None:
... return x + 1
...
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")).show()
+----------+--------------+------------+
|slen(name)|to_upper(name)|add_one(age)|
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
return UserDefinedFunction(f, returnType)

# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type)
else:
return _udf(f=f, returnType=returnType)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
__all__ = [k for k, v in globals().items()
Expand Down
57 changes: 57 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,63 @@ def test_udf_shouldnt_accept_noncallable_object(self):
non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())

def test_udf_with_decorator(self):
from pyspark.sql.functions import lit, udf
from pyspark.sql.types import IntegerType, DoubleType

@udf(IntegerType())
def add_one(x):
if x is not None:
return x + 1

@udf(returnType=DoubleType())
def add_two(x):
if x is not None:
return float(x + 2)

@udf
def to_upper(x):
if x is not None:
return x.upper()

@udf()
def to_lower(x):
if x is not None:
return x.lower()

@udf
def substr(x, start, end):
if x is not None:
return x[start:end]

@udf("long")
def trunc(x):
return int(x)

@udf(returnType="double")
def as_double(x):
return float(x)

df = (
self.spark
.createDataFrame(
[(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
.select(
add_one("one"), add_two("one"),
to_upper("Foo"), to_lower("Foo"),
substr("foobar", lit(0), lit(3)),
trunc("float"), as_double("one")))

self.assertListEqual(
[tpe for _, tpe in df.dtypes],
["int", "double", "string", "string", "string", "bigint", "double"]
)

self.assertListEqual(
list(df.first()),
[2, 3.0, "FOO", "foo", "foo", 3, 1.0]
)

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down