Skip to content

Commit 193b5b2

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-40387][SQL] Improve the implementation of Spark Decimal
### What changes were proposed in this pull request? This PR used to improve the implementation of Spark `Decimal`. The improvement points are as follows: 1. Use `toJavaBigDecimal` instead of `toBigDecimal.bigDecimal` 2. Extract `longVal / POW_10(_scale)` as a new method `def actualLongVal: Long` 3. Remove `BIG_DEC_ZERO` and use `decimalVal.signum` to judge whether or not equals zero. 4. Use `<` instead of `compare`. 5. Correct some code style. ### Why are the changes needed? Improve the implementation of Spark Decimal ### Does this PR introduce _any_ user-facing change? 'No'. Just update the inner implementation. ### How was this patch tested? N/A Closes apache#37830 from beliefer/SPARK-40387. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 5496d99 commit 193b5b2

File tree

1 file changed

+13
-15
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/types

1 file changed

+13
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
204204
if (decimalVal.ne(null)) {
205205
decimalVal.toBigInt
206206
} else {
207-
BigInt(toLong)
207+
BigInt(actualLongVal)
208208
}
209209
}
210210

211211
def toJavaBigInteger: java.math.BigInteger = {
212212
if (decimalVal.ne(null)) {
213213
decimalVal.underlying().toBigInteger()
214214
} else {
215-
java.math.BigInteger.valueOf(toLong)
215+
java.math.BigInteger.valueOf(actualLongVal)
216216
}
217217
}
218218

@@ -226,7 +226,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
226226

227227
override def toString: String = toBigDecimal.toString()
228228

229-
def toPlainString: String = toBigDecimal.bigDecimal.toPlainString
229+
def toPlainString: String = toJavaBigDecimal.toPlainString
230230

231231
def toDebugString: String = {
232232
if (decimalVal.ne(null)) {
@@ -240,9 +240,11 @@ final class Decimal extends Ordered[Decimal] with Serializable {
240240

241241
def toFloat: Float = toBigDecimal.floatValue
242242

243+
private def actualLongVal: Long = longVal / POW_10(_scale)
244+
243245
def toLong: Long = {
244246
if (decimalVal.eq(null)) {
245-
longVal / POW_10(_scale)
247+
actualLongVal
246248
} else {
247249
decimalVal.longValue
248250
}
@@ -278,7 +280,6 @@ final class Decimal extends Ordered[Decimal] with Serializable {
278280
private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int)
279281
(f1: Long => T) (f2: Double => T): T = {
280282
if (decimalVal.eq(null)) {
281-
val actualLongVal = longVal / POW_10(_scale)
282283
val numericVal = f1(actualLongVal)
283284
if (actualLongVal == numericVal) {
284285
numericVal
@@ -303,7 +304,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
303304
*/
304305
private[sql] def roundToLong(): Long = {
305306
if (decimalVal.eq(null)) {
306-
longVal / POW_10(_scale)
307+
actualLongVal
307308
} else {
308309
try {
309310
// We cannot store Long.MAX_VALUE as a Double without losing precision.
@@ -455,7 +456,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
455456

456457
override def hashCode(): Int = toBigDecimal.hashCode()
457458

458-
def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
459+
def isZero: Boolean = if (decimalVal.ne(null)) decimalVal.signum == 0 else longVal == 0
459460

460461
// We should follow DecimalPrecision promote if use longVal for add and subtract:
461462
// Operation Result Precision Result Scale
@@ -466,15 +467,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
466467
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
467468
Decimal(longVal + that.longVal, Math.max(precision, that.precision) + 1, scale)
468469
} else {
469-
Decimal(toBigDecimal.bigDecimal.add(that.toBigDecimal.bigDecimal))
470+
Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal))
470471
}
471472
}
472473

473474
def - (that: Decimal): Decimal = {
474475
if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) {
475476
Decimal(longVal - that.longVal, Math.max(precision, that.precision) + 1, scale)
476477
} else {
477-
Decimal(toBigDecimal.bigDecimal.subtract(that.toBigDecimal.bigDecimal))
478+
Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal))
478479
}
479480
}
480481

@@ -504,7 +505,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
504505
}
505506
}
506507

507-
def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
508+
def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this
508509

509510
def floor: Decimal = if (scale == 0) this else {
510511
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
@@ -532,8 +533,6 @@ object Decimal {
532533

533534
val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
534535

535-
private val BIG_DEC_ZERO = BigDecimal(0)
536-
537536
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
538537

539538
private[sql] val ZERO = Decimal(0)
@@ -575,9 +574,8 @@ object Decimal {
575574
}
576575
}
577576

578-
private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = {
579-
bigDecimal.precision - bigDecimal.scale
580-
}
577+
private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int =
578+
bigDecimal.precision - bigDecimal.scale
581579

582580
private def stringToJavaBigDecimal(str: UTF8String): JavaBigDecimal = {
583581
// According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`.

0 commit comments

Comments
 (0)