Skip to content

Commit cac6506

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-19727][SQL][FOLLOWUP] Fix for round function that modifies original column
## What changes were proposed in this pull request? This is a followup of #17075 , to fix the bug in codegen path. ## How was this patch tested? new regression test Author: Wenchen Fan <[email protected]> Closes #19576 from cloud-fan/bug. (cherry picked from commit 7fdacbc) Signed-off-by: gatorsmile <[email protected]>
1 parent cb54f29 commit cac6506

File tree

7 files changed

+36
-26
lines changed

7 files changed

+36
-26
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ object CatalystTypeConverters {
310310
case d: JavaBigInteger => Decimal(d)
311311
case d: Decimal => d
312312
}
313-
decimal.toPrecision(dataType.precision, dataType.scale).orNull
313+
decimal.toPrecision(dataType.precision, dataType.scale)
314314
}
315315
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
316316
if (catalystValue == null) null

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
387387
/**
388388
* Create new `Decimal` with precision and scale given in `decimalType` (if any),
389389
* returning null if it overflows or creating a new `value` and returning it if successful.
390-
*
391390
*/
392391
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
393-
value.toPrecision(decimalType.precision, decimalType.scale).orNull
392+
value.toPrecision(decimalType.precision, decimalType.scale)
394393

395394

396395
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
8585
override def nullable: Boolean = true
8686

8787
override def nullSafeEval(input: Any): Any =
88-
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
88+
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)
8989

9090
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
9191
nullSafeCodeGen(ctx, ev, eval => {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
10301030
dataType match {
10311031
case DecimalType.Fixed(_, s) =>
10321032
val decimal = input1.asInstanceOf[Decimal]
1033-
decimal.toPrecision(decimal.precision, s, mode).orNull
1033+
decimal.toPrecision(decimal.precision, s, mode)
10341034
case ByteType =>
10351035
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
10361036
case ShortType =>
@@ -1062,12 +1062,8 @@ abstract class RoundBase(child: Expression, scale: Expression,
10621062
val evaluationCode = dataType match {
10631063
case DecimalType.Fixed(_, s) =>
10641064
s"""
1065-
if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
1066-
java.math.BigDecimal.${modeStr})) {
1067-
${ev.value} = ${ce.value};
1068-
} else {
1069-
${ev.isNull} = true;
1070-
}"""
1065+
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
1066+
${ev.isNull} = ${ev.value} == null;"""
10711067
case ByteType =>
10721068
if (_scale < 0) {
10731069
s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -234,31 +234,28 @@ final class Decimal extends Ordered[Decimal] with Serializable {
234234
changePrecision(precision, scale, ROUND_HALF_UP)
235235
}
236236

237-
def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
238-
case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
239-
case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
240-
}
241-
242237
/**
243238
* Create new `Decimal` with given precision and scale.
244239
*
245-
* @return `Some(decimal)` if successful or `None` if overflow would occur
240+
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
246241
*/
247242
private[sql] def toPrecision(
248243
precision: Int,
249244
scale: Int,
250-
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
245+
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
251246
val copy = clone()
252-
if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
247+
if (copy.changePrecision(precision, scale, roundMode)) copy else null
253248
}
254249

255250
/**
256251
* Update precision and scale while keeping our value the same, and return true if successful.
257252
*
258253
* @return true if successful, false if overflow would occur
259254
*/
260-
private[sql] def changePrecision(precision: Int, scale: Int,
261-
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
255+
private[sql] def changePrecision(
256+
precision: Int,
257+
scale: Int,
258+
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
262259
// fast path for UnsafeProjection
263260
if (precision == this.precision && scale == this.scale) {
264261
return true
@@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable {
393390

394391
def floor: Decimal = if (scale == 0) this else {
395392
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
396-
toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
397-
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
393+
val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
394+
if (res == null) {
395+
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
396+
}
397+
res
398398
}
399399

400400
def ceil: Decimal = if (scale == 0) this else {
401401
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
402-
toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
403-
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
402+
val res = toPrecision(newPrecision, 0, ROUND_CEILING)
403+
if (res == null) {
404+
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
405+
}
406+
res
404407
}
405408
}
406409

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
213213
assert(d.changePrecision(10, 0, mode))
214214
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
215215

216-
val copy = d.toPrecision(10, 0, mode).orNull
216+
val copy = d.toPrecision(10, 0, mode)
217217
assert(copy !== null)
218218
assert(d.ne(copy))
219219
assert(d === copy)

sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
258258
)
259259
}
260260

261+
test("round/bround with table columns") {
262+
withTable("t") {
263+
Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t")
264+
checkAnswer(
265+
sql("select i, round(i) from t"),
266+
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
267+
checkAnswer(
268+
sql("select i, bround(i) from t"),
269+
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
270+
}
271+
}
272+
261273
test("exp") {
262274
testOneToOneMathFunction(exp, math.exp)
263275
}

0 commit comments

Comments
 (0)