From b73c1eaf0d4cec5dfbc043600d7a8710d12bfa95 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Fri, 3 Oct 2025 16:22:16 +0300 Subject: [PATCH 1/2] feat: Support Decimal32, 64, 128 in variant to arrow --- parquet-variant-compute/src/shred_variant.rs | 72 +++++- .../src/type_conversion.rs | 23 ++ .../src/variant_to_arrow.rs | 217 +++++++++++++----- 3 files changed, 253 insertions(+), 59 deletions(-) diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index 4c5a3c3f8a45..281723e5b42d 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -328,7 +328,10 @@ mod tests { use crate::VariantArrayBuilder; use arrow::array::{Array, Float64Array, Int64Array}; use arrow::datatypes::{DataType, Field, Fields}; - use parquet_variant::{ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder}; + use parquet_variant::{ + ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder, VariantDecimal16, + VariantDecimal4, VariantDecimal8, + }; use std::sync::Arc; fn create_test_variant_array(values: Vec>>) -> VariantArray { @@ -446,6 +449,73 @@ mod tests { assert_eq!(typed_value_field.value(5), 3); } + #[test] + fn test_decimal_shedding() { + // Test mixed scenarios in a single array + let input = create_test_variant_array(vec![ + Some(Variant::from(4200i64)), // successful shred + Some(Variant::from("hello")), // failed shred (string) + None, // array-level null + Some(Variant::Null), // variant null + Some(Variant::from(VariantDecimal4::try_new(314, 2).unwrap())), // successful shred + Some(Variant::from(VariantDecimal8::try_new(271828, 3).unwrap())), // successful shred + Some(Variant::from( + VariantDecimal16::try_new(123456789012345678901234567890, 2).unwrap(), + )), // successful shred + ]); + + // Test Decimal32 target + let result_decimal32 = shred_variant(&input, &DataType::Decimal32(10, 2)).unwrap(); + let typed_value_decimal32 = result_decimal32 + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(typed_value_decimal32.value(0), 42); // 42 with scale 2 + assert!(typed_value_decimal32.is_null(1)); // string doesn't convert to decimal + assert!(typed_value_decimal32.is_null(2)); // array null + assert!(typed_value_decimal32.is_null(3)); // variant null + assert_eq!(typed_value_decimal32.value(4), 314); // 3.14 with scale 2 + assert_eq!(typed_value_decimal32.value(5), 2718280); // 271.828 with scale 3 + assert!(typed_value_decimal32.is_null(6)); // too large for Decimal32 + + // Test Decimal64 target + let result_decimal64 = shred_variant(&input, &DataType::Decimal64(20, 2)).unwrap(); + let typed_value_decimal64 = result_decimal64 + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(typed_value_decimal64.value(0), 42); // 42 with + assert!(typed_value_decimal64.is_null(1)); // string doesn't convert to decimal + assert!(typed_value_decimal64.is_null(2)); // array null + assert!(typed_value_decimal64.is_null(3)); // variant null + assert_eq!(typed_value_decimal64.value(4), 314); // 3. + assert_eq!(typed_value_decimal64.value(5), 2718280); // 271.828 with scale 3 + assert!(typed_value_decimal64.is_null(6)); // too large for Decimal64 + + // Test Decimal128 target + let result_decimal128 = shred_variant(&input, &DataType::Decimal128(38, 2)).unwrap(); + let typed_value_decimal128 = result_decimal128 + .typed_value_field() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(typed_value_decimal128.value(0), 42); // 42 with + assert!(typed_value_decimal128.is_null(1)); // string doesn't convert to decimal + assert!(typed_value_decimal128.is_null(2)); // array null + assert!(typed_value_decimal128.is_null(3)); // variant null + assert_eq!(typed_value_decimal128.value(4), 314); // 3. + assert_eq!(typed_value_decimal128.value(5), 2718280); // 271.828 with scale 3 + assert_eq!( + typed_value_decimal128.value(6), + 123456789012345678901234567890 + ); + } + #[test] fn test_primitive_different_target_types() { let input = create_test_variant_array(vec![ diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 5dda1855297a..905491c58b9d 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -121,3 +121,26 @@ macro_rules! decimal_to_variant_decimal { }}; } pub(crate) use decimal_to_variant_decimal; + +/// Convert a `VariantDecimal` back to a decimal value with the target scale +macro_rules! variant_decimal_to_decimal { + ($variant_decimal:expr, $target_scale:expr, $value_type:ty) => {{ + let value = $variant_decimal.integer(); + let variant_scale = $variant_decimal.scale(); + + let scale_factor = $target_scale as i32 - variant_scale as i32; + + if scale_factor == 0 { + Some(value) + } else if scale_factor > 0 { + // Variant scale is greater than target scale, need to downscale + let divisor = <$value_type>::pow(10, scale_factor as u32); + <$value_type>::checked_div(value, divisor) + } else { + // Variant scale is less than target scale, need to upscale + let multiplier = <$value_type>::pow(10, (-scale_factor) as u32); + <$value_type>::checked_mul(value, multiplier) + } + }}; +} +pub(crate) use variant_decimal_to_decimal; diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index 50249aa63d20..a63a24de14ad 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -17,11 +17,11 @@ use arrow::array::{ArrayRef, BinaryViewArray, NullBufferBuilder, PrimitiveBuilder}; use arrow::compute::CastOptions; -use arrow::datatypes::{self, ArrowPrimitiveType, DataType}; +use arrow::datatypes::{self, ArrowPrimitiveType, DataType, DecimalType}; use arrow::error::{ArrowError, Result}; use parquet_variant::{Variant, VariantPath}; -use crate::type_conversion::PrimitiveFromVariant; +use crate::type_conversion::{variant_decimal_to_decimal, PrimitiveFromVariant}; use crate::{VariantArray, VariantValueArrayBuilder}; use std::sync::Arc; @@ -41,6 +41,9 @@ pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> { Float16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float16Type>), Float32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float32Type>), Float64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float64Type>), + Decimal32(VariantToDecimal32ArrowRowBuilder<'a>), + Decimal64(VariantToDecimal64ArrowRowBuilder<'a>), + Decimal128(VariantToDecimal128ArrowRowBuilder<'a>), } /// Builder for converting variant values into strongly typed Arrow arrays. @@ -70,6 +73,9 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { Float16(b) => b.append_null(), Float32(b) => b.append_null(), Float64(b) => b.append_null(), + Decimal32(b) => b.append_null(), + Decimal64(b) => b.append_null(), + Decimal128(b) => b.append_null(), } } @@ -87,6 +93,9 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { Float16(b) => b.append_value(value), Float32(b) => b.append_value(value), Float64(b) => b.append_value(value), + Decimal32(b) => b.append_value(value), + Decimal64(b) => b.append_value(value), + Decimal128(b) => b.append_value(value), } } @@ -104,6 +113,9 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> { Float16(b) => b.finish(), Float32(b) => b.finish(), Float64(b) => b.finish(), + Decimal32(b) => b.finish(), + Decimal64(b) => b.finish(), + Decimal128(b) => b.finish(), } } } @@ -145,62 +157,72 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>( ) -> Result> { use PrimitiveVariantToArrowRowBuilder::*; - let builder = match data_type { - DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new( - cast_options, - capacity, - )), - _ if data_type.is_primitive() => { - return Err(ArrowError::NotYetImplemented(format!( - "Primitive data_type {data_type:?} not yet implemented" - ))); - } - _ => { - return Err(ArrowError::InvalidArgumentError(format!( - "Not a primitive type: {data_type:?}" - ))); - } - }; + let builder = + match data_type { + DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new( + cast_options, + capacity, + )), + DataType::Decimal32(precision, scale) => Decimal32( + VariantToDecimal32ArrowRowBuilder::new(cast_options, capacity, *precision, *scale), + ), + DataType::Decimal64(precision, scale) => Decimal64( + VariantToDecimal64ArrowRowBuilder::new(cast_options, capacity, *precision, *scale), + ), + DataType::Decimal128(precision, scale) => Decimal128( + VariantToDecimal128ArrowRowBuilder::new(cast_options, capacity, *precision, *scale), + ), + _ if data_type.is_primitive() => { + return Err(ArrowError::NotYetImplemented(format!( + "Primitive data_type {data_type:?} not yet implemented" + ))); + } + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Not a primitive type: {data_type:?}" + ))); + } + }; Ok(builder) } @@ -383,3 +405,82 @@ impl VariantToBinaryVariantArrowRowBuilder { Ok(ArrayRef::from(variant_array)) } } + +macro_rules! define_variant_decimal_builder { + ($name:ident, $arrow_type:ty, $variant_method:ident, $native:ty) => { + pub(crate) struct $name<'a> { + builder: PrimitiveBuilder<$arrow_type>, + cast_options: &'a CastOptions<'a>, + precision: u8, + scale: i8, + } + + impl<'a> $name<'a> { + fn new( + cast_options: &'a CastOptions<'a>, + capacity: usize, + precision: u8, + scale: i8, + ) -> Self { + let builder = PrimitiveBuilder::<$arrow_type>::with_capacity(capacity) + .with_data_type(<$arrow_type>::TYPE_CONSTRUCTOR(precision, scale)); + Self { + builder, + cast_options, + precision, + scale, + } + } + + fn append_null(&mut self) -> Result<()> { + self.builder.append_null(); + Ok(()) + } + + fn append_value(&mut self, value: &Variant<'_, '_>) -> Result { + if let Some(decimal) = value.$variant_method() { + dbg!(&decimal); + dbg!(&self.scale); + let value = variant_decimal_to_decimal!(decimal, self.scale, $native); + if let Some(v) = value { + + self.builder.append_value(v); + return Ok(true); + } + } + if !self.cast_options.safe { + return Err(ArrowError::CastError(format!( + "Failed to extract decimal of type {:?} from variant {:?} at path VariantPath([])", + <$arrow_type>::TYPE_CONSTRUCTOR(self.precision, self.scale), + value + ))); + } + self.builder.append_null(); + Ok(false) + } + + fn finish(mut self) -> Result { + Ok(Arc::new(self.builder.finish())) + } + } + }; +} + +define_variant_decimal_builder!( + VariantToDecimal32ArrowRowBuilder, + datatypes::Decimal32Type, + as_decimal4, + i32 +); +define_variant_decimal_builder!( + VariantToDecimal64ArrowRowBuilder, + datatypes::Decimal64Type, + as_decimal8, + i64 +); +define_variant_decimal_builder!( + VariantToDecimal128ArrowRowBuilder, + datatypes::Decimal128Type, + as_decimal16, + i128 +); From df0387f0d09f9d1686deb644628105e1ed65dd3d Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Fri, 3 Oct 2025 17:40:31 +0300 Subject: [PATCH 2/2] chore --- parquet-variant-compute/src/shred_variant.rs | 4 ++-- parquet-variant-compute/src/variant_to_arrow.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/parquet-variant-compute/src/shred_variant.rs b/parquet-variant-compute/src/shred_variant.rs index 281723e5b42d..9987e9376700 100644 --- a/parquet-variant-compute/src/shred_variant.rs +++ b/parquet-variant-compute/src/shred_variant.rs @@ -329,8 +329,8 @@ mod tests { use arrow::array::{Array, Float64Array, Int64Array}; use arrow::datatypes::{DataType, Field, Fields}; use parquet_variant::{ - ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder, VariantDecimal16, - VariantDecimal4, VariantDecimal8, + ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder, VariantDecimal4, + VariantDecimal8, VariantDecimal16, }; use std::sync::Arc; diff --git a/parquet-variant-compute/src/variant_to_arrow.rs b/parquet-variant-compute/src/variant_to_arrow.rs index a63a24de14ad..8d9886c826b7 100644 --- a/parquet-variant-compute/src/variant_to_arrow.rs +++ b/parquet-variant-compute/src/variant_to_arrow.rs @@ -21,7 +21,7 @@ use arrow::datatypes::{self, ArrowPrimitiveType, DataType, DecimalType}; use arrow::error::{ArrowError, Result}; use parquet_variant::{Variant, VariantPath}; -use crate::type_conversion::{variant_decimal_to_decimal, PrimitiveFromVariant}; +use crate::type_conversion::{PrimitiveFromVariant, variant_decimal_to_decimal}; use crate::{VariantArray, VariantValueArrayBuilder}; use std::sync::Arc;