Skip to content

Commit eee790f

Browse files
authored
feat: support Decimal256 for the abs function (#7904)
* feat: support Decimal256 for the abs function * Remove useless comment * use wrapping_abs
1 parent ae85a67 commit eee790f

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

datafusion/physical-expr/src/math_expressions.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
2020
use arrow::array::ArrayRef;
2121
use arrow::array::{
22-
BooleanArray, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array,
23-
Int64Array, Int8Array,
22+
BooleanArray, Decimal128Array, Decimal256Array, Float32Array, Float64Array,
23+
Int16Array, Int32Array, Int64Array, Int8Array,
2424
};
2525
use arrow::datatypes::DataType;
2626
use arrow::error::ArrowError;
@@ -701,6 +701,18 @@ macro_rules! make_try_abs_function {
701701
}};
702702
}
703703

704+
macro_rules! make_decimal_abs_function {
705+
($ARRAY_TYPE:ident) => {{
706+
|args: &[ArrayRef]| {
707+
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
708+
let res: $ARRAY_TYPE = array
709+
.unary(|x| x.wrapping_abs())
710+
.with_data_type(args[0].data_type().clone());
711+
Ok(Arc::new(res) as ArrayRef)
712+
}
713+
}};
714+
}
715+
704716
/// Abs SQL function
705717
/// Return different implementations based on input datatype to reduce branches during execution
706718
pub(super) fn create_abs_function(
@@ -723,15 +735,9 @@ pub(super) fn create_abs_function(
723735
| DataType::UInt32
724736
| DataType::UInt64 => Ok(|args: &[ArrayRef]| Ok(args[0].clone())),
725737

726-
// Decimal should keep the same precision and scale by using `with_data_type()`.
727-
// https://github.com/apache/arrow-rs/issues/4644
728-
DataType::Decimal128(_, _) => Ok(|args: &[ArrayRef]| {
729-
let array = downcast_arg!(&args[0], "abs arg", Decimal128Array);
730-
let res: Decimal128Array = array
731-
.unary(i128::abs)
732-
.with_data_type(args[0].data_type().clone());
733-
Ok(Arc::new(res) as ArrayRef)
734-
}),
738+
// Decimal types
739+
DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)),
740+
DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)),
735741

736742
other => not_impl_err!("Unsupported data type {other:?} for function abs"),
737743
}

datafusion/sqllogictest/test_files/math.slt

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ NaN NaN
395395

396396
# abs: return type
397397
query TT rowsort
398-
SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_float limit 1
398+
SELECT arrow_typeof(abs(c1)), arrow_typeof(abs(c2)) FROM test_nullable_float limit 1
399399
----
400400
Float32 Float64
401401

@@ -466,34 +466,48 @@ drop table test_non_nullable_float
466466

467467
statement ok
468468
CREATE TABLE test_nullable_decimal(
469-
c1 DECIMAL(10, 2),
470-
c2 DECIMAL(38, 10)
471-
) AS VALUES (0, 0), (NULL, NULL);
472-
473-
query RR
469+
c1 DECIMAL(10, 2), /* Decimal128 */
470+
c2 DECIMAL(38, 10), /* Decimal128 with max precision */
471+
c3 DECIMAL(40, 2), /* Decimal256 */
472+
c4 DECIMAL(76, 10) /* Decimal256 with max precision */
473+
) AS VALUES
474+
(0, 0, 0, 0),
475+
(NULL, NULL, NULL, NULL);
476+
477+
query RRRR
474478
INSERT into test_nullable_decimal values
475-
(-99999999.99, '-9999999999999999999999999999.9999999999'),
476-
(99999999.99, '9999999999999999999999999999.9999999999');
479+
(
480+
-99999999.99,
481+
'-9999999999999999999999999999.9999999999',
482+
'-99999999999999999999999999999999999999.99',
483+
'-999999999999999999999999999999999999999999999999999999999999999999.9999999999'
484+
),
485+
(
486+
99999999.99,
487+
'9999999999999999999999999999.9999999999',
488+
'99999999999999999999999999999999999999.99',
489+
'999999999999999999999999999999999999999999999999999999999999999999.9999999999'
490+
)
477491
----
478492
2
479493

480494

481-
query R rowsort
495+
query R
482496
SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NULL;
483497
----
484498
NULL
485499

486-
query R rowsort
500+
query R
487501
SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NULL;
488502
----
489503
NULL
490504

491-
query R rowsort
505+
query R
492506
SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NULL;
493507
----
494508
NULL
495509

496-
query R rowsort
510+
query R
497511
SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
498512
----
499513
0
@@ -507,19 +521,24 @@ query error DataFusion error: Arrow error: Divide by zero error
507521
SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
508522

509523
# abs: return type
510-
query TT rowsort
511-
SELECT arrow_typeof(c1), arrow_typeof(c2) FROM test_nullable_decimal limit 1
524+
query TTTT
525+
SELECT
526+
arrow_typeof(abs(c1)),
527+
arrow_typeof(abs(c2)),
528+
arrow_typeof(abs(c3)),
529+
arrow_typeof(abs(c4))
530+
FROM test_nullable_decimal limit 1
512531
----
513-
Decimal128(10, 2) Decimal128(38, 10)
532+
Decimal128(10, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10)
514533

515-
# abs: Decimal128
516-
query RR rowsort
517-
SELECT abs(c1), abs(c2) FROM test_nullable_decimal
534+
# abs: decimals
535+
query RRRR rowsort
536+
SELECT abs(c1), abs(c2), abs(c3), abs(c4) FROM test_nullable_decimal
518537
----
519-
0 0
520-
99999999.99 9999999999999999999999999999.9999999999
521-
99999999.99 9999999999999999999999999999.9999999999
522-
NULL NULL
538+
0 0 0 0
539+
99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999
540+
99999999.99 9999999999999999999999999999.9999999999 99999999999999999999999999999999999999.99 999999999999999999999999999999999999999999999999999999999999999999.9999999999
541+
NULL NULL NULL NULL
523542

524543
statement ok
525544
drop table test_nullable_decimal

0 commit comments

Comments
 (0)