Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions

import java.{lang => jl}

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -52,28 +54,29 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param f The math function.
* @param name The short name of the function
*/
abstract class UnaryMathExpression(f: Double => Double, name: String)
abstract class UnaryMathExpression[T <: NumericType](
f: Double => Double, name: String)(implicit ttag: TypeTag[T])
extends UnaryExpression with Serializable with ImplicitCastInputTypes {

override def inputTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override val dataType: NumericType = NumericType.toType(ttag)
override def nullable: Boolean = true
override def toString: String = s"$name($child)"

protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Double])
dataType.cast(f(input.asInstanceOf[Double]))
}

// name of function in java.lang.Math
def funcName: String = name.toLowerCase

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
defineCodeGen(ctx, ev, c => s"(${dataType.typeName})java.lang.Math.${funcName}($c)")
}
}

abstract class UnaryLogExpression(f: Double => Double, name: String)
extends UnaryMathExpression(f, name) {
extends UnaryMathExpression[DoubleType](f, name) {

// values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
protected val yAsymptote: Double = 0.0
Expand Down Expand Up @@ -144,19 +147,19 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////

case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
case class Acos(child: Expression) extends UnaryMathExpression[DoubleType](math.acos, "ACOS")

case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
case class Asin(child: Expression) extends UnaryMathExpression[DoubleType](math.asin, "ASIN")

case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
case class Atan(child: Expression) extends UnaryMathExpression[DoubleType](math.atan, "ATAN")

case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
case class Cbrt(child: Expression) extends UnaryMathExpression[DoubleType](math.cbrt, "CBRT")

case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL")
case class Ceil(child: Expression) extends UnaryMathExpression[LongType](math.ceil, "CEIL")

case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
case class Cos(child: Expression) extends UnaryMathExpression[DoubleType](math.cos, "COS")

case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
case class Cosh(child: Expression) extends UnaryMathExpression[DoubleType](math.cosh, "COSH")

/**
* Convert a num from one base to another
Expand Down Expand Up @@ -191,11 +194,11 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
}
}

case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
case class Exp(child: Expression) extends UnaryMathExpression[DoubleType](math.exp, "EXP")

case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
case class Expm1(child: Expression) extends UnaryMathExpression[DoubleType](math.expm1, "EXPM1")

case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR")
case class Floor(child: Expression) extends UnaryMathExpression[LongType](math.floor, "FLOOR")

object Factorial {

Expand Down Expand Up @@ -283,27 +286,29 @@ case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1
protected override val yAsymptote: Double = -1.0
}

case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
case class Rint(child: Expression) extends UnaryMathExpression[DoubleType](math.rint, "ROUND") {
override def funcName: String = "rint"
}

case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
case class Signum(child: Expression) extends UnaryMathExpression[DoubleType](math.signum, "SIGNUM")

case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
case class Sin(child: Expression) extends UnaryMathExpression[DoubleType](math.sin, "SIN")

case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
case class Sinh(child: Expression) extends UnaryMathExpression[DoubleType](math.sinh, "SINH")

case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
case class Sqrt(child: Expression) extends UnaryMathExpression[DoubleType](math.sqrt, "SQRT")

case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
case class Tan(child: Expression) extends UnaryMathExpression[DoubleType](math.tan, "TAN")

case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
case class Tanh(child: Expression) extends UnaryMathExpression[DoubleType](math.tanh, "TANH")

case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
case class ToDegrees(child: Expression)
extends UnaryMathExpression[DoubleType](math.toDegrees, "DEGREES") {
override def funcName: String = "toDegrees"
}

case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
case class ToRadians(child: Expression)
extends UnaryMathExpression[DoubleType](math.toRadians, "RADIANS") {
override def funcName: String = "toRadians"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.types

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror}
import scala.reflect.runtime.universe._

import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.Expression
Expand Down Expand Up @@ -147,6 +147,18 @@ abstract class NumericType extends AtomicType {
// desugared by the compiler into an argument to the objects constructor. This means there is no
// longer an no argument constructor and thus the JVM cannot serialize the object anymore.
private[sql] val numeric: Numeric[InternalType]

def cast(d: Double): InternalType

def cast(f: Float): InternalType

def cast(l: Long): InternalType

def cast(i: Int): InternalType

def cast(s: Short): InternalType

def cast(b: Byte): InternalType
}


Expand All @@ -165,6 +177,24 @@ private[sql] object NumericType extends AbstractDataType {
override private[sql] def simpleString: String = "numeric"

override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]

def toType(ttag: TypeTag[_]): NumericType = {
ttag match {
case Byte => ByteType
case Short => ShortType
case Int => IntegerType
case Long => LongType
case Float => FloatType
case Double => DoubleType
}
}

val Byte : TypeTag[ByteType] = typeTag[ByteType]
val Short : TypeTag[ShortType] = typeTag[ShortType]
val Int : TypeTag[IntegerType] = typeTag[IntegerType]
val Long : TypeTag[LongType] = typeTag[LongType]
val Float : TypeTag[FloatType] = typeTag[FloatType]
val Double : TypeTag[DoubleType] = typeTag[DoubleType]
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class ByteType private() extends IntegralType {
override def simpleString: String = "tinyint"

private[spark] override def asNullable: ByteType = this

override def cast(d: Double): InternalType = d.toByte

override def cast(f: Float): InternalType = f.toByte

override def cast(l: Long): InternalType = l.toByte

override def cast(i: Int): InternalType = i.toByte

override def cast(s: Short): InternalType = s.toByte

override def cast(b: Byte): InternalType = b
}

case object ByteType extends ByteType
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
override def simpleString: String = s"decimal($precision,$scale)"

private[spark] override def asNullable: DecimalType = this

override def cast(d: Double): InternalType = throw new UnsupportedOperationException

override def cast(f: Float): InternalType = throw new UnsupportedOperationException

override def cast(l: Long): InternalType = throw new UnsupportedOperationException

override def cast(i: Int): InternalType = throw new UnsupportedOperationException

override def cast(s: Short): InternalType = throw new UnsupportedOperationException

override def cast(b: Byte): InternalType = throw new UnsupportedOperationException
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ class DoubleType private() extends FractionalType {
override def defaultSize: Int = 8

private[spark] override def asNullable: DoubleType = this

override def cast(d: InternalType): InternalType = d

override def cast(f: Float): InternalType = f.toDouble

override def cast(l: Long): InternalType = l.toDouble

override def cast(i: Int): InternalType = i.toDouble

override def cast(s: Short): InternalType = s.toDouble

override def cast(b: Byte): InternalType = b.toDouble
}

case object DoubleType extends DoubleType
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ class FloatType private() extends FractionalType {
override def defaultSize: Int = 4

private[spark] override def asNullable: FloatType = this

override def cast(d: Double): InternalType = d.toFloat

override def cast(f: Float): InternalType = f

override def cast(l: Long): InternalType = l.toFloat

override def cast(i: Int): InternalType = i.toFloat

override def cast(s: Short): InternalType = s.toFloat

override def cast(b: Byte): InternalType = b.toFloat
}

case object FloatType extends FloatType
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ class IntegerType private() extends IntegralType {
override def simpleString: String = "int"

private[spark] override def asNullable: IntegerType = this

override def cast(d: Double): InternalType = d.toInt

override def cast(f: Float): InternalType = f.toInt

override def cast(l: Long): InternalType = l.toInt

override def cast(i: Int): InternalType = i

override def cast(s: Short): InternalType = s.toInt

override def cast(b: Byte): InternalType = b.toInt
}

case object IntegerType extends IntegerType
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ class LongType private() extends IntegralType {
override def simpleString: String = "bigint"

private[spark] override def asNullable: LongType = this

override def cast(d: Double): InternalType = d.toLong

override def cast(f: Float): InternalType = f.toLong

override def cast(l: Long): InternalType = l

override def cast(i: Int): InternalType = i.toLong

override def cast(s: Short): InternalType = s.toLong

override def cast(b: Byte): InternalType = b.toLong

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ class ShortType private() extends IntegralType {
override def simpleString: String = "smallint"

private[spark] override def asNullable: ShortType = this

override def cast(d: Double): InternalType = d.toShort

override def cast(f: Float): InternalType = f.toShort

override def cast(l: Long): InternalType = l.toShort

override def cast(i: Int): InternalType = i.toShort

override def cast(s: Short): InternalType = s

override def cast(b: Byte): InternalType = b.toShort
}

case object ShortType extends ShortType
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("ceil") {
testUnary(Ceil, math.ceil)
testUnary(Ceil, (x: Double) => math.ceil(x).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
}

test("floor") {
testUnary(Floor, math.floor)
testUnary(Floor, (x: Double) => math.floor(x).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {

private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T](
c: Column => Column,
f: T => T): Unit = {
f: T => Any): Unit = {
checkAnswer(
doubleData.select(c('a)),
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T])))
Expand Down Expand Up @@ -165,7 +165,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
}

test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil)
testOneToOneMathFunction(ceil, (x: Double) => math.ceil(x).toLong)
checkAnswer(
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0.0, 1.0, 2.0))
Expand All @@ -184,7 +184,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
}

test("floor") {
testOneToOneMathFunction(floor, math.floor)
testOneToOneMathFunction(floor, (x: Double) => math.floor(x).toLong)
}

test("factorial") {
Expand Down