Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ case class Logarithm(left: Expression, right: Expression)
case class Round(child: Expression, scale: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

import BigDecimal.RoundingMode.HALF_UP
import BigDecimal.RoundingMode.HALF_EVEN

def this(child: Expression) = this(child, Literal(0))

Expand Down Expand Up @@ -727,26 +727,26 @@ case class Round(child: Expression, scale: Expression)
val decimal = input1.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_EVEN).toByte
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_EVEN).toShort
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_EVEN).toInt
case LongType =>
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_EVEN).toLong
case FloatType =>
val f = input1.asInstanceOf[Float]
if (f.isNaN || f.isInfinite) {
f
} else {
BigDecimal(f).setScale(_scale, HALF_UP).toFloat
BigDecimal(f).setScale(_scale, HALF_EVEN).toFloat
}
case DoubleType =>
val d = input1.asInstanceOf[Double]
if (d.isNaN || d.isInfinite) {
d
} else {
BigDecimal(d).setScale(_scale, HALF_UP).toDouble
BigDecimal(d).setScale(_scale, HALF_EVEN).toDouble
}
}
}
Expand All @@ -766,31 +766,31 @@ case class Round(child: Expression, scale: Expression)
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_EVEN).byteValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case ShortType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_EVEN).shortValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case IntegerType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_EVEN).intValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case LongType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_EVEN).longValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
Expand All @@ -808,24 +808,54 @@ case class Round(child: Expression, scale: Expression)
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_EVEN).floatValue();
}"""
}
case DoubleType => // if child eval to NaN or Infinity, just return it.
// The logic for rounding half-integers to even values is exemplified by the following
// table:
//
// x | x rounded to half-even | x * 2 | (x rounded to half-even) * 2 | (x * 2) & 3
// ----------------------------------------------------------------------------------------
// -4.5 | -4 | -9 | -8 | 3
// -3.5 | -4 | -7 | -8 | 1
// -2.5 | -2 | -5 | -6 | 3
// -1.5 | -2 | -3 | -6 | 1
// -0.5 | 0 | -1 | 0 | 3
// 0.5 | 0 | 1 | 0 | 1
// 1.5 | 2 | 3 | 4 | 3
// 2.5 | 2 | 5 | 4 | 1
// 3.5 | 4 | 7 | 8 | 3
// 4.5 | 4 | 9 | 8 | 1
//
// Therefore, looking at the last three columns above, if x has the form of "<integer>.5",
// then
// (x rounded to half-even) * 2 = (x * 2) + ((x * 2) & 3) - 2

if (_scale == 0) {
s"""
if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})){
${ev.value} = ${ce.value};
} else {
${ev.value} = Math.round(${ce.value});
double timesTwo = ${ce.value} * 2;
long timesTwoRounded = Math.round(timesTwo);
if (timesTwo == timesTwoRounded) {
if ((timesTwoRounded & 1) == 0) {
${ev.value} = timesTwoRounded >> 1;
} else {
${ev.value} = (timesTwoRounded + (timesTwoRounded & 3) - 2) >> 1;
}
} else {
${ev.value} = Math.round(${ce.value});
}
}"""
} else {
s"""
if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})){
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_EVEN).doubleValue();
}"""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Seq.fill[Short](7)(31415)

val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300,
314159270) ++ Seq.fill(7)(314159265)
314159260) ++ Seq.fill(7)(314159265)

val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L,
31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++
Expand Down