Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def _(col):
return _


def _create_binary_mathfunction(name, doc=""):
""" Create a binary mathfunction by name"""
def _(col1, col2):
sc = SparkContext._active_spark_context
# users might write ints for simplicity. This would throw an error on the JVM side.
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
col2._jc if isinstance(col2, Column) else float(col2))
return Column(jc)
_.__name__ = name
_.__doc__ = doc
return _


_functions = {
'lit': 'Creates a :class:`Column` of literal value.',
'col': 'Returns a :class:`Column` based on the given column name.',
Expand All @@ -63,6 +76,34 @@ def _(col):
'sqrt': 'Computes the square root of the specified float value.',
'abs': 'Computes the absolute value.',

# unary math functions
'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
'0.0 through pi.',
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
'-pi/2 through pi/2.',
'atan': 'Computes the tangent inverse of the given value.',
'cbrt': 'Computes the cube-root of the given value.',
'ceil': 'Computes the ceiling of the given value.',
'cos': 'Computes the cosine of the given value.',
'cosh': 'Computes the hyperbolic cosine of the given value.',
'exp': 'Computes the exponential of the given value.',
'expm1': 'Computes the exponential of the given value minus one.',
'floor': 'Computes the floor of the given value.',
'log': 'Computes the natural logarithm of the given value.',
'log10': 'Computes the logarithm of the given value in Base 10.',
'log1p': 'Computes the natural logarithm of the given value plus one.',
'rint': 'Returns the double value that is closest in value to the argument and' +
' is equal to a mathematical integer.',
'signum': 'Computes the signum of the given value.',
'sin': 'Computes the sine of the given value.',
'sinh': 'Computes the hyperbolic sine of the given value.',
'tan': 'Computes the tangent of the given value.',
'tanh': 'Computes the hyperbolic tangent of the given value.',
'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' +
'measured in degrees.',
'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
'measured in radians.',

'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
'first': 'Aggregate function: returns the first value in a group.',
Expand All @@ -74,10 +115,21 @@ def _(col):
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
}

# math functions that take two arguments as input
_binary_mathfunctions = {
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
'polar coordinates (r, theta).',
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
}

for _name, _doc in _functions.items():
globals()[_name] = _create_function(_name, _doc)
for _name, _doc in _binary_mathfunctions.items():
globals()[_name] = _create_binary_mathfunction(_name, _doc)
del _name, _doc
__all__ += _functions.keys()
__all__ += _binary_mathfunctions.keys()
__all__.sort()


Expand Down
101 changes: 0 additions & 101 deletions python/pyspark/sql/mathfunctions.py

This file was deleted.

2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_crosstab(self):

def test_math_functions(self):
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
from pyspark.sql import mathfunctions as functions
from pyspark.sql import functions
import math

def get_values(l):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1217,11 +1217,11 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
unaryMathFunctionEvaluation(Tanh, math.tanh)
}

test("toDeg") {
test("toDegrees") {
unaryMathFunctionEvaluation(ToDegrees, math.toDegrees)
}

test("toRad") {
test("toRadians") {
unaryMathFunctionEvaluation(ToRadians, math.toRadians)
}

Expand Down
Loading