diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index e47852a5e3cd..cf4d3df35b28 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -590,6 +590,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::parse::parse_decimal; #[test] fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> { @@ -598,7 +599,20 @@ mod tests { 0_i128 ); assert_eq!( + parse_decimal::("0", 38, 0)?, + parse_string_to_decimal_native::("0", 0)?, + "value is {}", + 0_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("0", 5)?, + 0_i128 + ); + assert_eq!( + parse_decimal::("0", 38, 5)?, parse_string_to_decimal_native::("0", 5)?, + "value is {}", 0_i128 ); @@ -607,7 +621,20 @@ mod tests { 123_i128 ); assert_eq!( + parse_decimal::("123", 38, 0)?, + parse_string_to_decimal_native::("123", 0)?, + "value is {}", + 123_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("123", 5)?, + 12300000_i128 + ); + assert_eq!( + parse_decimal::("123", 38, 5)?, parse_string_to_decimal_native::("123", 5)?, + "value is {}", 12300000_i128 ); @@ -616,7 +643,20 @@ mod tests { 123_i128 ); assert_eq!( + parse_decimal::("123.45", 38, 0)?, + parse_string_to_decimal_native::("123.45", 0)?, + "value is {}", + 123_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("123.45", 5)?, + 12345000_i128 + ); + assert_eq!( + parse_decimal::("123.45", 38, 5)?, parse_string_to_decimal_native::("123.45", 5)?, + "value is {}", 12345000_i128 ); @@ -625,7 +665,20 @@ mod tests { 123_i128 ); assert_eq!( + parse_decimal::("123.4567891", 38, 0)?, + parse_string_to_decimal_native::("123.4567891", 0)?, + "value is {}", + 123_i128 + ); + + assert_eq!( + parse_string_to_decimal_native::("123.4567891", 5)?, + 12345679_i128 + ); + assert_eq!( + parse_decimal::("123.4567891", 38, 5)?, parse_string_to_decimal_native::("123.4567891", 5)?, + "value is {}", 12345679_i128 ); Ok(()) diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index 55834ad92a01..8f27a4e1232e 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -850,7 +850,18 @@ fn parse_e_notation( } if exp < 0 { - result = result.div_wrapping(base.pow_wrapping(-exp as _)); + let result_with_scale = result.div_wrapping(base.pow_wrapping(-exp as _)); + let result_with_one_scale_up = + result.div_wrapping(base.pow_wrapping(-exp.add_wrapping(1) as _)); + let rounding_digit = + result_with_one_scale_up.sub_wrapping(result_with_scale.mul_wrapping(base)); + //rounding digit is the next digit after result with scale, it helps in rounding to nearest integer + // with scale 1 rounding digit for 247e-2 is 7, hence result is 2.5, whereas rounding digit for 244e-2 is 4, hence result is 2.4 + if rounding_digit >= T::Native::usize_as(5) { + result = result_with_scale.add_wrapping(T::Native::usize_as(1)); + } else { + result = result_with_scale; + } } else { result = result.mul_wrapping(base.pow_wrapping(exp as _)); } @@ -866,8 +877,9 @@ pub fn parse_decimal( scale: i8, ) -> Result { let mut result = T::Native::usize_as(0); - let mut fractionals: i8 = 0; - let mut digits: u8 = 0; + let mut fractionals: i16 = 0; + let mut digits: u16 = 0; + let mut rounding_digit = -1; // to store digit after the scale for rounding let base = T::Native::usize_as(10); let bs = s.as_bytes(); @@ -897,6 +909,13 @@ pub fn parse_decimal( // Ignore leading zeros. continue; } + if fractionals == scale as i16 && scale != 0 { + // Capture the rounding digit once + if rounding_digit < 0 { + rounding_digit = (b - b'0') as i8; + } + continue; + } digits += 1; result = result.mul_wrapping(base); result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); @@ -909,8 +928,8 @@ pub fn parse_decimal( if *b == b'e' || *b == b'E' { result = parse_e_notation::( s, - digits as u16, - fractionals as i16, + digits, + fractionals, result, point_index, precision as u16, @@ -925,11 +944,17 @@ pub fn parse_decimal( "can't parse the string value {s} to decimal" ))); } - if fractionals == scale && scale != 0 { + if fractionals == scale as i16 { + // Capture the rounding digit once + if rounding_digit < 0 { + rounding_digit = (b - b'0') as i8; + } // We have processed all the digits that we need. All that // is left is to validate that the rest of the string contains // valid digits. - continue; + if scale != 0 { + continue; + } } fractionals += 1; digits += 1; @@ -951,8 +976,8 @@ pub fn parse_decimal( b'e' | b'E' => { result = parse_e_notation::( s, - digits as u16, - fractionals as i16, + digits, + fractionals, result, index, precision as u16, @@ -972,20 +997,28 @@ pub fn parse_decimal( } if !is_e_notation { - if fractionals < scale { - let exp = scale - fractionals; - if exp as u8 + digits > precision { + if fractionals < scale as i16 { + let exp = scale as i16 - fractionals; + if exp + digits as i16 > precision as i16 { return Err(ArrowError::ParseError(format!( "parse decimal overflow ({s})" ))); } let mul = base.pow_wrapping(exp as _); result = result.mul_wrapping(mul); - } else if digits > precision { + } else if digits > precision as u16 { return Err(ArrowError::ParseError(format!( "parse decimal overflow ({s})" ))); } + if scale == 0 { + result = result.div_wrapping(base.pow_wrapping(fractionals as u32)) + } + //rounding digit is the next digit after result with scale, it is used to do rounding to nearest integer + // with scale 1 rounding digit for 2.47 is 7, hence result is 2.5, whereas rounding digit for 2.44 is 4,hence result is 2.4 + if rounding_digit >= 5 { + result = result.add_wrapping(T::Native::usize_as(1)); + } } Ok(if negative { @@ -2564,6 +2597,18 @@ mod tests { assert_eq!(i256::from_i128(i), result_256.unwrap()); } + let tests_with_varying_scale = [ + ("123.4567891", 12345679_i128, 5), + ("123.4567891", 123_i128, 0), + ("123.45", 12345000_i128, 5), + ("-2.5", -3_i128, 0), + ("-2.49", -2_i128, 0), + ]; + for (str, e, scale) in tests_with_varying_scale { + let result_128_a = parse_decimal::(str, 20, scale); + assert_eq!(result_128_a.unwrap(), e); + } + let e_notation_tests = [ ("1.23e3", "1230.0", 2), ("5.6714e+2", "567.14", 4), @@ -2599,6 +2644,9 @@ mod tests { ("000001.1034567002e0", "000001.1034567002", 3), ("1.234e16", "12340000000000000", 0), ("123.4e16", "1234000000000000000", 0), + ("4e+5", "400000", 4), + ("4e7", "40000000", 2), + ("1265E-4", ".1265", 3), ]; for (e, d, scale) in e_notation_tests { let result_128_e = parse_decimal::(e, 20, scale); @@ -2608,6 +2656,7 @@ mod tests { let result_256_d = parse_decimal::(d, 20, scale); assert_eq!(result_256_e.unwrap(), result_256_d.unwrap()); } + let can_not_parse_tests = [ "123,123", ".", @@ -2780,6 +2829,55 @@ mod tests { } } + #[test] + fn test_parse_decimal_rounding() { + let test_rounding_for_e_notation_varying_scale = [ + ("1.2345e4", "12345", 2), + ("12345e-5", "0.12", 2), + ("12345E-5", "0.123", 3), + ("12345e-5", "0.1235", 4), + ("1265E-4", ".127", 3), + ("12.345e3", "12345.000", 3), + ("1.2345e4", "12345", 0), + ("1.2345e3", "1235", 0), + ("1.23e-3", "0", 0), + ("123e-2", "1", 0), + ("-1e-15", "-0.0000000000", 10), + ("1e-15", "0.0000000000", 10), + ("1e15", "1000000000000000", 2), + ]; + + for (e, d, scale) in test_rounding_for_e_notation_varying_scale { + let result_128_e = parse_decimal::(e, 38, scale); + let result_128_d = parse_decimal::(d, 38, scale); + assert_eq!(result_128_e.unwrap(), result_128_d.unwrap()); + let result_256_e = parse_decimal::(e, 38, scale); + let result_256_d = parse_decimal::(d, 38, scale); + assert_eq!(result_256_e.unwrap(), result_256_d.unwrap()); + } + + let edge_tests_256_error = [ + (&f64::INFINITY.to_string(), 0), + (&f64::NEG_INFINITY.to_string(), 0), + ]; + for (s, scale) in edge_tests_256_error { + let result = parse_decimal::(s, 76, scale); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result.unwrap_err().to_string() + ); + } + + let edge_tests_256_overflow = [(&f64::MIN.to_string(), 0), (&f64::MAX.to_string(), 0)]; + for (s, scale) in edge_tests_256_overflow { + let result = parse_decimal::(s, 76, scale); + assert_eq!( + format!("Parser error: parse decimal overflow ({s})"), + result.unwrap_err().to_string() + ); + } + } + #[test] fn test_parse_empty() { assert_eq!(Int32Type::parse(""), None); diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index e3ab013a57c1..1b63fd293117 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -1286,7 +1286,7 @@ mod tests { assert_eq!("53.002666", lat.value_as_string(1)); assert_eq!("52.412811", lat.value_as_string(2)); assert_eq!("51.481583", lat.value_as_string(3)); - assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("12.123457", lat.value_as_string(4)); assert_eq!("50.760000", lat.value_as_string(5)); assert_eq!("0.123000", lat.value_as_string(6)); assert_eq!("123.000000", lat.value_as_string(7)); diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 14a8f6809f70..17c3ec012d7e 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -1181,7 +1181,7 @@ mod tests { assert!(col1.is_null(5)); assert_eq!( col1.values(), - &[100, 200, 204, 1103420, 0, 0].map(T::Native::usize_as) + &[100, 200, 205, 1103420, 0, 0].map(T::Native::usize_as) ); let col2 = batches[0].column(1).as_primitive::(); @@ -1201,7 +1201,7 @@ mod tests { assert!(col3.is_null(5)); assert_eq!( col3.values(), - &[3830, 12345, 0, 0, 0, 0].map(T::Native::usize_as) + &[3830, 12346, 0, 0, 0, 0].map(T::Native::usize_as) ); }