Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion parquet-variant-compute/src/shred_variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ mod tests {
use crate::VariantArrayBuilder;
use arrow::array::{Array, Float64Array, Int64Array};
use arrow::datatypes::{DataType, Field, Fields};
use parquet_variant::{ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder};
use parquet_variant::{
ObjectBuilder, ReadOnlyMetadataBuilder, Variant, VariantBuilder, VariantDecimal4,
VariantDecimal8, VariantDecimal16,
};
use std::sync::Arc;

fn create_test_variant_array(values: Vec<Option<Variant<'_, '_>>>) -> VariantArray {
Expand Down Expand Up @@ -446,6 +449,73 @@ mod tests {
assert_eq!(typed_value_field.value(5), 3);
}

#[test]
fn test_decimal_shedding() {
// Test mixed scenarios in a single array
let input = create_test_variant_array(vec![
Some(Variant::from(4200i64)), // successful shred
Some(Variant::from("hello")), // failed shred (string)
None, // array-level null
Some(Variant::Null), // variant null
Some(Variant::from(VariantDecimal4::try_new(314, 2).unwrap())), // successful shred
Some(Variant::from(VariantDecimal8::try_new(271828, 3).unwrap())), // successful shred
Some(Variant::from(
VariantDecimal16::try_new(123456789012345678901234567890, 2).unwrap(),
)), // successful shred
]);

// Test Decimal32 target
let result_decimal32 = shred_variant(&input, &DataType::Decimal32(10, 2)).unwrap();
let typed_value_decimal32 = result_decimal32
.typed_value_field()
.unwrap()
.as_any()
.downcast_ref::<arrow::array::Decimal32Array>()
.unwrap();
assert_eq!(typed_value_decimal32.value(0), 42); // 42 with scale 2
assert!(typed_value_decimal32.is_null(1)); // string doesn't convert to decimal
assert!(typed_value_decimal32.is_null(2)); // array null
assert!(typed_value_decimal32.is_null(3)); // variant null
assert_eq!(typed_value_decimal32.value(4), 314); // 3.14 with scale 2
assert_eq!(typed_value_decimal32.value(5), 2718280); // 271.828 with scale 3
assert!(typed_value_decimal32.is_null(6)); // too large for Decimal32

// Test Decimal64 target
let result_decimal64 = shred_variant(&input, &DataType::Decimal64(20, 2)).unwrap();
let typed_value_decimal64 = result_decimal64
.typed_value_field()
.unwrap()
.as_any()
.downcast_ref::<arrow::array::Decimal64Array>()
.unwrap();
assert_eq!(typed_value_decimal64.value(0), 42); // 42 with
assert!(typed_value_decimal64.is_null(1)); // string doesn't convert to decimal
assert!(typed_value_decimal64.is_null(2)); // array null
assert!(typed_value_decimal64.is_null(3)); // variant null
assert_eq!(typed_value_decimal64.value(4), 314); // 3.
assert_eq!(typed_value_decimal64.value(5), 2718280); // 271.828 with scale 3
assert!(typed_value_decimal64.is_null(6)); // too large for Decimal64

// Test Decimal128 target
let result_decimal128 = shred_variant(&input, &DataType::Decimal128(38, 2)).unwrap();
let typed_value_decimal128 = result_decimal128
.typed_value_field()
.unwrap()
.as_any()
.downcast_ref::<arrow::array::Decimal128Array>()
.unwrap();
assert_eq!(typed_value_decimal128.value(0), 42); // 42 with
assert!(typed_value_decimal128.is_null(1)); // string doesn't convert to decimal
assert!(typed_value_decimal128.is_null(2)); // array null
assert!(typed_value_decimal128.is_null(3)); // variant null
assert_eq!(typed_value_decimal128.value(4), 314); // 3.
assert_eq!(typed_value_decimal128.value(5), 2718280); // 271.828 with scale 3
assert_eq!(
typed_value_decimal128.value(6),
123456789012345678901234567890
);
}

#[test]
fn test_primitive_different_target_types() {
let input = create_test_variant_array(vec![
Expand Down
23 changes: 23 additions & 0 deletions parquet-variant-compute/src/type_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,26 @@ macro_rules! decimal_to_variant_decimal {
}};
}
pub(crate) use decimal_to_variant_decimal;

/// Convert a `VariantDecimal` back to a decimal value with the target scale
macro_rules! variant_decimal_to_decimal {
($variant_decimal:expr, $target_scale:expr, $value_type:ty) => {{
let value = $variant_decimal.integer();
let variant_scale = $variant_decimal.scale();

let scale_factor = $target_scale as i32 - variant_scale as i32;

if scale_factor == 0 {
Some(value)
} else if scale_factor > 0 {
// Variant scale is greater than target scale, need to downscale
let divisor = <$value_type>::pow(10, scale_factor as u32);
<$value_type>::checked_div(value, divisor)
} else {
// Variant scale is less than target scale, need to upscale
let multiplier = <$value_type>::pow(10, (-scale_factor) as u32);
<$value_type>::checked_mul(value, multiplier)
}
}};
}
pub(crate) use variant_decimal_to_decimal;
217 changes: 159 additions & 58 deletions parquet-variant-compute/src/variant_to_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

use arrow::array::{ArrayRef, BinaryViewArray, NullBufferBuilder, PrimitiveBuilder};
use arrow::compute::CastOptions;
use arrow::datatypes::{self, ArrowPrimitiveType, DataType};
use arrow::datatypes::{self, ArrowPrimitiveType, DataType, DecimalType};
use arrow::error::{ArrowError, Result};
use parquet_variant::{Variant, VariantPath};

use crate::type_conversion::PrimitiveFromVariant;
use crate::type_conversion::{PrimitiveFromVariant, variant_decimal_to_decimal};
use crate::{VariantArray, VariantValueArrayBuilder};

use std::sync::Arc;
Expand All @@ -41,6 +41,9 @@ pub(crate) enum PrimitiveVariantToArrowRowBuilder<'a> {
Float16(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float16Type>),
Float32(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float32Type>),
Float64(VariantToPrimitiveArrowRowBuilder<'a, datatypes::Float64Type>),
Decimal32(VariantToDecimal32ArrowRowBuilder<'a>),
Decimal64(VariantToDecimal64ArrowRowBuilder<'a>),
Decimal128(VariantToDecimal128ArrowRowBuilder<'a>),
}

/// Builder for converting variant values into strongly typed Arrow arrays.
Expand Down Expand Up @@ -70,6 +73,9 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> {
Float16(b) => b.append_null(),
Float32(b) => b.append_null(),
Float64(b) => b.append_null(),
Decimal32(b) => b.append_null(),
Decimal64(b) => b.append_null(),
Decimal128(b) => b.append_null(),
}
}

Expand All @@ -87,6 +93,9 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> {
Float16(b) => b.append_value(value),
Float32(b) => b.append_value(value),
Float64(b) => b.append_value(value),
Decimal32(b) => b.append_value(value),
Decimal64(b) => b.append_value(value),
Decimal128(b) => b.append_value(value),
}
}

Expand All @@ -104,6 +113,9 @@ impl<'a> PrimitiveVariantToArrowRowBuilder<'a> {
Float16(b) => b.finish(),
Float32(b) => b.finish(),
Float64(b) => b.finish(),
Decimal32(b) => b.finish(),
Decimal64(b) => b.finish(),
Decimal128(b) => b.finish(),
}
}
}
Expand Down Expand Up @@ -145,62 +157,72 @@ pub(crate) fn make_primitive_variant_to_arrow_row_builder<'a>(
) -> Result<PrimitiveVariantToArrowRowBuilder<'a>> {
use PrimitiveVariantToArrowRowBuilder::*;

let builder = match data_type {
DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
_ if data_type.is_primitive() => {
return Err(ArrowError::NotYetImplemented(format!(
"Primitive data_type {data_type:?} not yet implemented"
)));
}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Not a primitive type: {data_type:?}"
)));
}
};
let builder =
match data_type {
DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new(
cast_options,
capacity,
)),
DataType::Decimal32(precision, scale) => Decimal32(
VariantToDecimal32ArrowRowBuilder::new(cast_options, capacity, *precision, *scale),
),
DataType::Decimal64(precision, scale) => Decimal64(
VariantToDecimal64ArrowRowBuilder::new(cast_options, capacity, *precision, *scale),
),
DataType::Decimal128(precision, scale) => Decimal128(
VariantToDecimal128ArrowRowBuilder::new(cast_options, capacity, *precision, *scale),
),
_ if data_type.is_primitive() => {
return Err(ArrowError::NotYetImplemented(format!(
"Primitive data_type {data_type:?} not yet implemented"
)));
}
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Not a primitive type: {data_type:?}"
)));
}
};
Ok(builder)
}

Expand Down Expand Up @@ -383,3 +405,82 @@ impl VariantToBinaryVariantArrowRowBuilder {
Ok(ArrayRef::from(variant_array))
}
}

macro_rules! define_variant_decimal_builder {
($name:ident, $arrow_type:ty, $variant_method:ident, $native:ty) => {
pub(crate) struct $name<'a> {
builder: PrimitiveBuilder<$arrow_type>,
cast_options: &'a CastOptions<'a>,
precision: u8,
scale: i8,
}

impl<'a> $name<'a> {
fn new(
cast_options: &'a CastOptions<'a>,
capacity: usize,
precision: u8,
scale: i8,
) -> Self {
let builder = PrimitiveBuilder::<$arrow_type>::with_capacity(capacity)
.with_data_type(<$arrow_type>::TYPE_CONSTRUCTOR(precision, scale));
Self {
builder,
cast_options,
precision,
scale,
}
}

fn append_null(&mut self) -> Result<()> {
self.builder.append_null();
Ok(())
}

fn append_value(&mut self, value: &Variant<'_, '_>) -> Result<bool> {
if let Some(decimal) = value.$variant_method() {
dbg!(&decimal);
dbg!(&self.scale);
let value = variant_decimal_to_decimal!(decimal, self.scale, $native);
if let Some(v) = value {

self.builder.append_value(v);
return Ok(true);
}
}
if !self.cast_options.safe {
return Err(ArrowError::CastError(format!(
"Failed to extract decimal of type {:?} from variant {:?} at path VariantPath([])",
<$arrow_type>::TYPE_CONSTRUCTOR(self.precision, self.scale),
value
)));
}
self.builder.append_null();
Ok(false)
}

fn finish(mut self) -> Result<ArrayRef> {
Ok(Arc::new(self.builder.finish()))
}
}
};
}

define_variant_decimal_builder!(
VariantToDecimal32ArrowRowBuilder,
datatypes::Decimal32Type,
as_decimal4,
i32
);
define_variant_decimal_builder!(
VariantToDecimal64ArrowRowBuilder,
datatypes::Decimal64Type,
as_decimal8,
i64
);
define_variant_decimal_builder!(
VariantToDecimal128ArrowRowBuilder,
datatypes::Decimal128Type,
as_decimal16,
i128
);
Loading