Skip to content

Commit cf7a7bb

Browse files
committed
[SPARK-7358] Move DataFrame mathfunctions into functions
1 parent fec7b29 commit cf7a7bb

File tree

7 files changed

+543
-488
lines changed

7 files changed

+543
-488
lines changed

python/pyspark/sql/functions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def _(col):
5151
return _
5252

5353

54+
def _create_binary_mathfunction(name, doc=""):
55+
""" Create a binary mathfunction by name"""
56+
def _(col1, col2):
57+
sc = SparkContext._active_spark_context
58+
# users might write ints for simplicity. This would throw an error on the JVM side.
59+
if type(col1) is int:
60+
col1 = col1 * 1.0
61+
if type(col2) is int:
62+
col2 = col2 * 1.0
63+
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else col1,
64+
col2._jc if isinstance(col2, Column) else col2)
65+
return Column(jc)
66+
_.__name__ = name
67+
_.__doc__ = doc
68+
return _
69+
70+
5471
_functions = {
5572
'lit': 'Creates a :class:`Column` of literal value.',
5673
'col': 'Returns a :class:`Column` based on the given column name.',
@@ -63,6 +80,34 @@ def _(col):
6380
'sqrt': 'Computes the square root of the specified float value.',
6481
'abs': 'Computes the absolute value.',
6582

83+
# unary math functions
84+
'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
85+
'0.0 through pi.',
86+
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
87+
'-pi/2 through pi/2.',
88+
'atan': 'Computes the tangent inverse of the given value.',
89+
'cbrt': 'Computes the cube-root of the given value.',
90+
'ceil': 'Computes the ceiling of the given value.',
91+
'cos': 'Computes the cosine of the given value.',
92+
'cosh': 'Computes the hyperbolic cosine of the given value.',
93+
'exp': 'Computes the exponential of the given value.',
94+
'expm1': 'Computes the exponential of the given value minus one.',
95+
'floor': 'Computes the floor of the given value.',
96+
'log': 'Computes the natural logarithm of the given value.',
97+
'log10': 'Computes the logarithm of the given value in Base 10.',
98+
'log1p': 'Computes the natural logarithm of the given value plus one.',
99+
'rint': 'Returns the double value that is closest in value to the argument and' +
100+
' is equal to a mathematical integer.',
101+
'signum': 'Computes the signum of the given value.',
102+
'sin': 'Computes the sine of the given value.',
103+
'sinh': 'Computes the hyperbolic sine of the given value.',
104+
'tan': 'Computes the tangent of the given value.',
105+
'tanh': 'Computes the hyperbolic tangent of the given value.',
106+
'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' +
107+
'measured in degrees.',
108+
'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
109+
'measured in radians.',
110+
66111
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
67112
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
68113
'first': 'Aggregate function: returns the first value in a group.',
@@ -74,10 +119,21 @@ def _(col):
74119
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
75120
}
76121

122+
# math functions that take two arguments as input
123+
_binary_mathfunctions = {
124+
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
125+
'polar coordinates (r, theta).',
126+
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
127+
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
128+
}
129+
77130
for _name, _doc in _functions.items():
78131
globals()[_name] = _create_function(_name, _doc)
132+
for _name, _doc in _binary_mathfunctions.items():
133+
globals()[_name] = _create_binary_mathfunction(_name, _doc)
79134
del _name, _doc
80135
__all__ += _functions.keys()
136+
__all__ += _binary_mathfunctions.keys()
81137
__all__.sort()
82138

83139

python/pyspark/sql/mathfunctions.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

python/pyspark/sql/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def test_crosstab(self):
416416

417417
def test_math_functions(self):
418418
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
419-
from pyspark.sql import mathfunctions as functions
419+
from pyspark.sql import functions
420420
import math
421421

422422
def get_values(l):

0 commit comments

Comments
 (0)