From c7c74cea150290492308f63967c563cfded39f81 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Thu, 6 Mar 2025 12:36:40 -0800 Subject: [PATCH 1/5] Initial prototype --- arrow-json/src/reader/boolean_array.rs | 12 +- arrow-json/src/reader/decimal_array.rs | 34 +-- arrow-json/src/reader/list_array.rs | 8 + arrow-json/src/reader/map_array.rs | 9 + arrow-json/src/reader/mod.rs | 257 ++++++++++++++++++++--- arrow-json/src/reader/null_array.rs | 13 +- arrow-json/src/reader/primitive_array.rs | 36 ++-- arrow-json/src/reader/string_array.rs | 6 +- arrow-json/src/reader/struct_array.rs | 12 ++ arrow-json/src/reader/timestamp_array.rs | 31 +-- 10 files changed, 339 insertions(+), 79 deletions(-) diff --git a/arrow-json/src/reader/boolean_array.rs b/arrow-json/src/reader/boolean_array.rs index 9094391cd7dd..4e933fd877be 100644 --- a/arrow-json/src/reader/boolean_array.rs +++ b/arrow-json/src/reader/boolean_array.rs @@ -24,7 +24,16 @@ use crate::reader::tape::{Tape, TapeElement}; use crate::reader::ArrayDecoder; #[derive(Default)] -pub struct BooleanArrayDecoder {} +pub struct BooleanArrayDecoder { + ignore_type_conflicts: bool, +} +impl BooleanArrayDecoder { + pub fn new(ignore_type_conflicts: bool) -> Self { + Self { + ignore_type_conflicts, + } + } +} impl ArrayDecoder for BooleanArrayDecoder { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { @@ -34,6 +43,7 @@ impl ArrayDecoder for BooleanArrayDecoder { TapeElement::Null => builder.append_null(), TapeElement::True => builder.append_value(true), TapeElement::False => builder.append_value(false), + _ if self.ignore_type_conflicts => builder.append_null(), _ => return Err(tape.error(*p, "boolean")), } } diff --git a/arrow-json/src/reader/decimal_array.rs b/arrow-json/src/reader/decimal_array.rs index d56afcfe807a..e3c4fb3c1827 100644 --- a/arrow-json/src/reader/decimal_array.rs +++ b/arrow-json/src/reader/decimal_array.rs @@ -30,15 +30,17 @@ use crate::reader::ArrayDecoder; pub struct DecimalArrayDecoder { precision: u8, scale: i8, + ignore_type_conflicts: bool, // Invariant and Send phantom: PhantomData D>, } impl DecimalArrayDecoder { - pub fn new(precision: u8, scale: i8) -> Self { + pub fn new(precision: u8, scale: i8, ignore_type_conflicts: bool) -> Self { Self { precision, scale, + ignore_type_conflicts, phantom: PhantomData, } } @@ -50,46 +52,52 @@ where { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut builder = PrimitiveBuilder::::with_capacity(pos.len()); + let append = if self.ignore_type_conflicts { + super::append_value_or_null + } else { + super::try_append_value + }; for p in pos { match tape.get(*p) { TapeElement::Null => builder.append_null(), TapeElement::String(idx) => { let s = tape.get_string(idx); - let value = parse_decimal::(s, self.precision, self.scale)?; - builder.append_value(value) + let value = parse_decimal::(s, self.precision, self.scale); + append(&mut builder, value)? } TapeElement::Number(idx) => { let s = tape.get_string(idx); - let value = parse_decimal::(s, self.precision, self.scale)?; - builder.append_value(value) + let value = parse_decimal::(s, self.precision, self.scale); + append(&mut builder, value)? } TapeElement::I64(high) => match tape.get(*p + 1) { TapeElement::I32(low) => { let val = (((high as i64) << 32) | (low as u32) as i64).to_string(); - let value = parse_decimal::(&val, self.precision, self.scale)?; - builder.append_value(value) + let value = parse_decimal::(&val, self.precision, self.scale); + append(&mut builder, value)? } _ => unreachable!(), }, TapeElement::I32(val) => { let s = val.to_string(); - let value = parse_decimal::(&s, self.precision, self.scale)?; - builder.append_value(value) + let value = parse_decimal::(&s, self.precision, self.scale); + append(&mut builder, value)? } TapeElement::F64(high) => match tape.get(*p + 1) { TapeElement::F32(low) => { let val = f64::from_bits(((high as u64) << 32) | low as u64).to_string(); - let value = parse_decimal::(&val, self.precision, self.scale)?; - builder.append_value(value) + let value = parse_decimal::(&val, self.precision, self.scale); + append(&mut builder, value)? } _ => unreachable!(), }, TapeElement::F32(val) => { let s = f32::from_bits(val).to_string(); - let value = parse_decimal::(&s, self.precision, self.scale)?; - builder.append_value(value) + let value = parse_decimal::(&s, self.precision, self.scale); + append(&mut builder, value)? } + _ if self.ignore_type_conflicts => builder.append_null(), _ => return Err(tape.error(*p, "decimal")), } } diff --git a/arrow-json/src/reader/list_array.rs b/arrow-json/src/reader/list_array.rs index 1a1dee6a23d4..7c05cffa6c76 100644 --- a/arrow-json/src/reader/list_array.rs +++ b/arrow-json/src/reader/list_array.rs @@ -29,6 +29,7 @@ pub struct ListArrayDecoder { data_type: DataType, decoder: Box, phantom: PhantomData, + ignore_type_conflicts: bool, is_nullable: bool, } @@ -37,6 +38,7 @@ impl ListArrayDecoder { data_type: DataType, coerce_primitive: bool, strict_mode: bool, + ignore_type_conflicts: bool, is_nullable: bool, struct_mode: StructMode, ) -> Result { @@ -49,6 +51,7 @@ impl ListArrayDecoder { field.data_type().clone(), coerce_primitive, strict_mode, + ignore_type_conflicts, field.is_nullable(), struct_mode, )?; @@ -57,6 +60,7 @@ impl ListArrayDecoder { data_type, decoder, phantom: Default::default(), + ignore_type_conflicts, is_nullable, }) } @@ -83,6 +87,10 @@ impl ArrayDecoder for ListArrayDecoder { nulls.append(false); *p + 1 } + (_, Some(nulls)) if self.ignore_type_conflicts => { + nulls.append(false); + *p + 1 + } _ => return Err(tape.error(*p, "[")), }; diff --git a/arrow-json/src/reader/map_array.rs b/arrow-json/src/reader/map_array.rs index ee78373a551e..b9b131ac134f 100644 --- a/arrow-json/src/reader/map_array.rs +++ b/arrow-json/src/reader/map_array.rs @@ -28,6 +28,7 @@ pub struct MapArrayDecoder { data_type: DataType, keys: Box, values: Box, + ignore_type_conflicts: bool, is_nullable: bool, } @@ -36,6 +37,7 @@ impl MapArrayDecoder { data_type: DataType, coerce_primitive: bool, strict_mode: bool, + ignore_type_conflicts: bool, is_nullable: bool, struct_mode: StructMode, ) -> Result { @@ -60,6 +62,7 @@ impl MapArrayDecoder { fields[0].data_type().clone(), coerce_primitive, strict_mode, + ignore_type_conflicts, fields[0].is_nullable(), struct_mode, )?; @@ -67,6 +70,7 @@ impl MapArrayDecoder { fields[1].data_type().clone(), coerce_primitive, strict_mode, + ignore_type_conflicts, fields[1].is_nullable(), struct_mode, )?; @@ -75,6 +79,7 @@ impl MapArrayDecoder { data_type, keys, values, + ignore_type_conflicts, is_nullable, }) } @@ -111,6 +116,10 @@ impl ArrayDecoder for MapArrayDecoder { nulls.append(false); p + 1 } + (_, Some(nulls)) if self.ignore_type_conflicts => { + nulls.append(false); + p + 1 + } _ => return Err(tape.error(p, "{")), }; diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 14a8f6809f70..bb8afe93f6fb 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -143,6 +143,7 @@ use serde::Serialize; use arrow_array::timezone::Tz; use arrow_array::types::*; use arrow_array::{downcast_integer, make_array, RecordBatch, RecordBatchReader, StructArray}; +use arrow_array::builder::PrimitiveBuilder; use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; @@ -168,7 +169,7 @@ mod schema; mod serializer; mod string_array; mod struct_array; -mod tape; +pub mod tape; mod timestamp_array; /// A builder for [`Reader`] and [`Decoder`] @@ -176,6 +177,7 @@ pub struct ReaderBuilder { batch_size: usize, coerce_primitive: bool, strict_mode: bool, + ignore_type_conflicts: bool, is_field: bool, struct_mode: StructMode, @@ -196,6 +198,7 @@ impl ReaderBuilder { batch_size: 1024, coerce_primitive: false, strict_mode: false, + ignore_type_conflicts: false, is_field: false, struct_mode: Default::default(), schema, @@ -237,6 +240,7 @@ impl ReaderBuilder { batch_size: 1024, coerce_primitive: false, strict_mode: false, + ignore_type_conflicts: false, is_field: true, struct_mode: Default::default(), schema: Arc::new(Schema::new([field.into()])), @@ -279,6 +283,15 @@ impl ReaderBuilder { } } + /// Sets if the decoder should produce NULL instead of returning an error if it encounters a + /// type conflict on a nullable column. + pub fn with_ignore_type_conflicts(self, ignore_type_conflicts: bool) -> Self { + Self { + ignore_type_conflicts, + ..self + } + } + /// Create a [`Reader`] with the provided [`BufRead`] pub fn build(self, reader: R) -> Result, ArrowError> { Ok(Reader { @@ -301,6 +314,7 @@ impl ReaderBuilder { data_type, self.coerce_primitive, self.strict_mode, + self.ignore_type_conflicts, nullable, self.struct_mode, )?; @@ -672,8 +686,11 @@ trait ArrayDecoder: Send { } macro_rules! primitive_decoder { - ($t:ty, $data_type:expr) => { - Ok(Box::new(PrimitiveArrayDecoder::<$t>::new($data_type))) + ($t:ty, $data_type:expr, $ignore_type_conflicts:expr) => { + Ok(Box::new(PrimitiveArrayDecoder::<$t>::new( + $data_type, + $ignore_type_conflicts, + ))) }; } @@ -681,69 +698,91 @@ fn make_decoder( data_type: DataType, coerce_primitive: bool, strict_mode: bool, + ignore_type_conflicts: bool, is_nullable: bool, struct_mode: StructMode, ) -> Result, ArrowError> { downcast_integer! { - data_type => (primitive_decoder, data_type), - DataType::Null => Ok(Box::::default()), - DataType::Float16 => primitive_decoder!(Float16Type, data_type), - DataType::Float32 => primitive_decoder!(Float32Type, data_type), - DataType::Float64 => primitive_decoder!(Float64Type, data_type), + data_type => (primitive_decoder, data_type, ignore_type_conflicts), + DataType::Null => Ok(Box::new(NullArrayDecoder::new(ignore_type_conflicts))), + DataType::Float16 => primitive_decoder!(Float16Type, data_type, ignore_type_conflicts), + DataType::Float32 => primitive_decoder!(Float32Type, data_type, ignore_type_conflicts), + DataType::Float64 => primitive_decoder!(Float64Type, data_type, ignore_type_conflicts), DataType::Timestamp(TimeUnit::Second, None) => { - Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Millisecond, None) => { - Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Microsecond, None) => { - Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Nanosecond, None) => { - Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, Utc, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Second, Some(ref tz)) => { let tz: Tz = tz.parse()?; - Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Millisecond, Some(ref tz)) => { let tz: Tz = tz.parse()?; - Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Microsecond, Some(ref tz)) => { let tz: Tz = tz.parse()?; - Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz, ignore_type_conflicts))) }, DataType::Timestamp(TimeUnit::Nanosecond, Some(ref tz)) => { let tz: Tz = tz.parse()?; - Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz))) + Ok(Box::new(TimestampArrayDecoder::::new(data_type, tz, ignore_type_conflicts))) }, - DataType::Date32 => primitive_decoder!(Date32Type, data_type), - DataType::Date64 => primitive_decoder!(Date64Type, data_type), - DataType::Time32(TimeUnit::Second) => primitive_decoder!(Time32SecondType, data_type), - DataType::Time32(TimeUnit::Millisecond) => primitive_decoder!(Time32MillisecondType, data_type), - DataType::Time64(TimeUnit::Microsecond) => primitive_decoder!(Time64MicrosecondType, data_type), - DataType::Time64(TimeUnit::Nanosecond) => primitive_decoder!(Time64NanosecondType, data_type), - DataType::Duration(TimeUnit::Nanosecond) => primitive_decoder!(DurationNanosecondType, data_type), - DataType::Duration(TimeUnit::Microsecond) => primitive_decoder!(DurationMicrosecondType, data_type), - DataType::Duration(TimeUnit::Millisecond) => primitive_decoder!(DurationMillisecondType, data_type), - DataType::Duration(TimeUnit::Second) => primitive_decoder!(DurationSecondType, data_type), - DataType::Decimal128(p, s) => Ok(Box::new(DecimalArrayDecoder::::new(p, s))), - DataType::Decimal256(p, s) => Ok(Box::new(DecimalArrayDecoder::::new(p, s))), - DataType::Boolean => Ok(Box::::default()), - DataType::Utf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), - DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive))), - DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), - DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::Date32 => primitive_decoder!(Date32Type, data_type, ignore_type_conflicts), + DataType::Date64 => primitive_decoder!(Date64Type, data_type, ignore_type_conflicts), + DataType::Time32(TimeUnit::Second) => primitive_decoder!(Time32SecondType, data_type, ignore_type_conflicts), + DataType::Time32(TimeUnit::Millisecond) => primitive_decoder!(Time32MillisecondType, data_type, ignore_type_conflicts), + DataType::Time64(TimeUnit::Microsecond) => primitive_decoder!(Time64MicrosecondType, data_type, ignore_type_conflicts), + DataType::Time64(TimeUnit::Nanosecond) => primitive_decoder!(Time64NanosecondType, data_type, ignore_type_conflicts), + DataType::Duration(TimeUnit::Nanosecond) => primitive_decoder!(DurationNanosecondType, data_type, ignore_type_conflicts), + DataType::Duration(TimeUnit::Microsecond) => primitive_decoder!(DurationMicrosecondType, data_type, ignore_type_conflicts), + DataType::Duration(TimeUnit::Millisecond) => primitive_decoder!(DurationMillisecondType, data_type, ignore_type_conflicts), + DataType::Duration(TimeUnit::Second) => primitive_decoder!(DurationSecondType, data_type, ignore_type_conflicts), + DataType::Decimal128(p, s) => Ok(Box::new(DecimalArrayDecoder::::new(p, s, ignore_type_conflicts))), + DataType::Decimal256(p, s) => Ok(Box::new(DecimalArrayDecoder::::new(p, s, ignore_type_conflicts))), + DataType::Boolean => Ok(Box::new(BooleanArrayDecoder::new(ignore_type_conflicts))), + DataType::Utf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive, ignore_type_conflicts))), + DataType::LargeUtf8 => Ok(Box::new(StringArrayDecoder::::new(coerce_primitive, ignore_type_conflicts))), + DataType::List(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, ignore_type_conflicts, is_nullable, struct_mode)?)), + DataType::LargeList(_) => Ok(Box::new(ListArrayDecoder::::new(data_type, coerce_primitive, strict_mode, ignore_type_conflicts, is_nullable, struct_mode)?)), + DataType::Struct(_) => Ok(Box::new(StructArrayDecoder::new(data_type, coerce_primitive, strict_mode, ignore_type_conflicts, is_nullable, struct_mode)?)), DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => { Err(ArrowError::JsonError(format!("{data_type} is not supported by JSON"))) } - DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, is_nullable, struct_mode)?)), + DataType::Map(_, _) => Ok(Box::new(MapArrayDecoder::new(data_type, coerce_primitive, strict_mode, ignore_type_conflicts, is_nullable, struct_mode)?)), d => Err(ArrowError::NotYetImplemented(format!("Support for {d} in JSON reader"))) } } +/// Attempts to append a value to the builder, if valid. Otherwise, returns the error. +fn try_append_value( + builder: &mut PrimitiveBuilder

, + value: Result, +) -> Result<(), ArrowError> { + builder.append_value(value?); + Ok(()) +} + +/// Attempts to append a value to the builder, if valid. Otherwise, appends NULL. +fn append_value_or_null( + builder: &mut PrimitiveBuilder

, + value: Result, +) -> Result<(), ArrowError> { + match value { + Ok(value) => builder.append_value(value), + Err(_) => builder.append_null(), + } + Ok(()) +} + #[cfg(test)] mod tests { use serde_json::json; @@ -2666,4 +2705,152 @@ mod tests { "Json error: whilst decoding field 'a': failed to parse \"a\" as Int32".to_owned() ); } + + #[test] + fn test_type_conflict_nulls() { + use arrow_array::{BooleanArray, Int32Array, MapArray, NullArray}; + use arrow_buffer::NullBuffer; + use arrow_schema::Fields; + let json = vec![ + json!({"null": null, "bool": true, "numeric": 1.234, "string": "hi", "array": [1, "hi", 3], "map": {"k": "value"}, "struct": {"a": 1}}), + json!({"bool": null, "numeric": true, "string": 1.234, "array": "hi", "map": [1, "hi", 3], "struct": {"k": "value"}, "null": {"a": 1}}), + json!({"numeric": null, "string": true, "array": 1.234, "map": "hi", "struct": [1, "hi", 3], "null": {"k": "value"}, "bool": {"a": 1}}), + json!({"string": null, "array": true, "map": 1.234, "struct": "hi", "null": [1, "hi", 3], "bool": {"k": "value"}, "numeric": {"a": 1}}), + json!({"array": null, "map": true, "struct": 1.234, "null": "hi", "bool": [1, "hi", 3], "numeric": {"k": "value"}, "string": {"a": 1}}), + json!({"map": null, "struct": true, "null": 1.234, "bool": "hi", "numeric": [1, "hi", 3], "string": {"k": "value"}, "array": {"a": 1}}), + json!({"struct": null, "null": true, "bool": 1.234, "numeric": "hi", "string": [1, "hi", 3], "array": {"k": "value"}, "map": {"a": 1}}), + ]; + let schema = Schema::new(vec![ + Field::new("null", DataType::Null, true), + Field::new("bool", DataType::Boolean, true), + Field::new("numeric", DataType::Decimal128(10, 3), true), + Field::new("string", DataType::Utf8, true), + Field::new( + "array", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Utf8, true), + ])), + false, // not nullable + )), + false, // not sorted + ), + true, // nullable + ), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, true)])), + true, + ), + ]); + let mut decoder = ReaderBuilder::new(Arc::new(schema)) + .with_ignore_type_conflicts(true) + .with_coerce_primitive(true) + .build_decoder() + .unwrap(); + decoder.serialize(&json).unwrap(); + let batch = decoder.flush().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 7); + assert_eq!(batch.num_columns(), 7); + + let _ = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + // NOTE: NullArray doesn't materialize any values (they're all NULL by definition) + + let bools = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(bools + .iter() + .eq([Some(true), None, None, None, None, None, None])); + + let numbers = batch.column(2).as_primitive::(); + assert!(numbers + .iter() + .eq([Some(1234), None, None, None, None, None, None])); + + let strings = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(strings.iter().eq([ + Some("hi"), + Some("1.234"), + Some("true"), + None, + None, + None, + None + ])); + + let arrays = batch + .column(4) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + arrays.nulls(), + Some(&NullBuffer::from( + &[true, false, false, false, false, false, false][..] + )) + ); + assert_eq!(arrays.offsets()[1], 3); + let array_values = arrays + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert!( + array_values.iter().eq([Some(1), None, Some(3)]), + "{array_values:?}" + ); + + let maps = batch.column(5).as_any().downcast_ref::().unwrap(); + assert_eq!( + maps.nulls(), + Some(&NullBuffer::from( + &[true, false, false, false, false, false, true][..] + )) + ); + let map_keys = maps.keys().as_any().downcast_ref::().unwrap(); + assert!(map_keys.iter().eq([Some("k"), Some("a")])); + let map_values = maps + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(map_values.iter().eq([Some("value"), Some("1")])); + + let structs = batch + .column(6) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + structs.nulls(), + Some(&NullBuffer::from( + &[true, true, false, false, false, false, false][..] + )) + ); + let struct_fields = structs + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(struct_fields.slice(0, 2).iter().eq([Some(1), None])); + } } diff --git a/arrow-json/src/reader/null_array.rs b/arrow-json/src/reader/null_array.rs index 4270045fb3c2..2b90da3389df 100644 --- a/arrow-json/src/reader/null_array.rs +++ b/arrow-json/src/reader/null_array.rs @@ -21,12 +21,21 @@ use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; #[derive(Default)] -pub struct NullArrayDecoder {} +pub struct NullArrayDecoder { + ignore_type_conflicts: bool, +} +impl NullArrayDecoder { + pub fn new(ignore_type_conflicts: bool) -> Self { + Self { + ignore_type_conflicts, + } + } +} impl ArrayDecoder for NullArrayDecoder { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { for p in pos { - if !matches!(tape.get(*p), TapeElement::Null) { + if !matches!(tape.get(*p), TapeElement::Null) && !self.ignore_type_conflicts { return Err(tape.error(*p, "null")); } } diff --git a/arrow-json/src/reader/primitive_array.rs b/arrow-json/src/reader/primitive_array.rs index 257c216cf5f6..8c54ce13f69b 100644 --- a/arrow-json/src/reader/primitive_array.rs +++ b/arrow-json/src/reader/primitive_array.rs @@ -75,14 +75,16 @@ impl ParseJsonNumber for f64 { pub struct PrimitiveArrayDecoder { data_type: DataType, + ignore_type_conflicts: bool, // Invariant and Send phantom: PhantomData P>, } impl PrimitiveArrayDecoder

{ - pub fn new(data_type: DataType) -> Self { + pub fn new(data_type: DataType, ignore_type_conflicts: bool) -> Self { Self { data_type, + ignore_type_conflicts, phantom: Default::default(), } } @@ -97,6 +99,11 @@ where let mut builder = PrimitiveBuilder::

::with_capacity(pos.len()).with_data_type(self.data_type.clone()); let d = &self.data_type; + let append = if self.ignore_type_conflicts { + super::append_value_or_null + } else { + super::try_append_value + }; for p in pos { match tape.get(*p) { @@ -105,38 +112,36 @@ where let s = tape.get_string(idx); let value = P::parse(s).ok_or_else(|| { ArrowError::JsonError(format!("failed to parse \"{s}\" as {d}",)) - })?; - - builder.append_value(value) + }); + append(&mut builder, value)? } TapeElement::Number(idx) => { let s = tape.get_string(idx); let value = ParseJsonNumber::parse(s.as_bytes()).ok_or_else(|| { ArrowError::JsonError(format!("failed to parse {s} as {d}",)) - })?; - - builder.append_value(value) + }); + append(&mut builder, value)? } TapeElement::F32(v) => { let v = f32::from_bits(v); let value = NumCast::from(v).ok_or_else(|| { ArrowError::JsonError(format!("failed to parse {v} as {d}",)) - })?; - builder.append_value(value) + }); + append(&mut builder, value)? } TapeElement::I32(v) => { let value = NumCast::from(v).ok_or_else(|| { ArrowError::JsonError(format!("failed to parse {v} as {d}",)) - })?; - builder.append_value(value) + }); + append(&mut builder, value)? } TapeElement::F64(high) => match tape.get(p + 1) { TapeElement::F32(low) => { let v = f64::from_bits(((high as u64) << 32) | low as u64); let value = NumCast::from(v).ok_or_else(|| { ArrowError::JsonError(format!("failed to parse {v} as {d}",)) - })?; - builder.append_value(value) + }); + append(&mut builder, value)? } _ => unreachable!(), }, @@ -145,11 +150,12 @@ where let v = ((high as i64) << 32) | (low as u32) as i64; let value = NumCast::from(v).ok_or_else(|| { ArrowError::JsonError(format!("failed to parse {v} as {d}",)) - })?; - builder.append_value(value) + }); + append(&mut builder, value)? } _ => unreachable!(), }, + _ if self.ignore_type_conflicts => builder.append_null(), _ => return Err(tape.error(*p, "primitive")), } } diff --git a/arrow-json/src/reader/string_array.rs b/arrow-json/src/reader/string_array.rs index 03d07ad8c8b3..476f7b2696c9 100644 --- a/arrow-json/src/reader/string_array.rs +++ b/arrow-json/src/reader/string_array.rs @@ -29,13 +29,15 @@ const FALSE: &str = "false"; pub struct StringArrayDecoder { coerce_primitive: bool, + ignore_type_conflicts: bool, phantom: PhantomData, } impl StringArrayDecoder { - pub fn new(coerce_primitive: bool) -> Self { + pub fn new(coerce_primitive: bool, ignore_type_conflicts: bool) -> Self { Self { coerce_primitive, + ignore_type_conflicts, phantom: Default::default(), } } @@ -70,6 +72,7 @@ impl ArrayDecoder for StringArrayDecoder { // An arbitrary estimate data_capacity += 10; } + _ if self.ignore_type_conflicts => {} _ => { return Err(tape.error(*p, "string")); } @@ -120,6 +123,7 @@ impl ArrayDecoder for StringArrayDecoder { } _ => unreachable!(), }, + _ if self.ignore_type_conflicts => builder.append_null(), _ => unreachable!(), } } diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index b9408df77a43..c168ad699efa 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -26,6 +26,7 @@ pub struct StructArrayDecoder { data_type: DataType, decoders: Vec>, strict_mode: bool, + ignore_type_conflicts: bool, is_nullable: bool, struct_mode: StructMode, } @@ -35,6 +36,7 @@ impl StructArrayDecoder { data_type: DataType, coerce_primitive: bool, strict_mode: bool, + ignore_type_conflicts: bool, is_nullable: bool, struct_mode: StructMode, ) -> Result { @@ -49,6 +51,7 @@ impl StructArrayDecoder { f.data_type().clone(), coerce_primitive, strict_mode, + ignore_type_conflicts, nullable, struct_mode, ) @@ -59,6 +62,7 @@ impl StructArrayDecoder { data_type, decoders, strict_mode, + ignore_type_conflicts, is_nullable, struct_mode, }) @@ -89,6 +93,10 @@ impl ArrayDecoder for StructArrayDecoder { nulls.append(false); continue; } + (_, Some(nulls)) if self.ignore_type_conflicts => { + nulls.append(false); + continue; + } (_, _) => return Err(tape.error(*p, "{")), }; @@ -129,6 +137,10 @@ impl ArrayDecoder for StructArrayDecoder { nulls.append(false); continue; } + (_, Some(nulls)) if self.ignore_type_conflicts => { + nulls.append(false); + continue; + } (_, _) => return Err(tape.error(*p, "[")), }; diff --git a/arrow-json/src/reader/timestamp_array.rs b/arrow-json/src/reader/timestamp_array.rs index ee9018702920..af955d53d114 100644 --- a/arrow-json/src/reader/timestamp_array.rs +++ b/arrow-json/src/reader/timestamp_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use chrono::TimeZone; +use chrono::{DateTime, TimeZone}; use std::marker::PhantomData; use arrow_array::builder::PrimitiveBuilder; @@ -32,15 +32,17 @@ use crate::reader::ArrayDecoder; pub struct TimestampArrayDecoder { data_type: DataType, timezone: Tz, + ignore_type_conflicts: bool, // Invariant and Send phantom: PhantomData P>, } impl TimestampArrayDecoder { - pub fn new(data_type: DataType, timezone: Tz) -> Self { + pub fn new(data_type: DataType, timezone: Tz, ignore_type_conflicts: bool) -> Self { Self { data_type, timezone, + ignore_type_conflicts, phantom: Default::default(), } } @@ -54,6 +56,11 @@ where fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut builder = PrimitiveBuilder::

::with_capacity(pos.len()).with_data_type(self.data_type.clone()); + let append = if self.ignore_type_conflicts { + super::append_value_or_null + } else { + super::try_append_value + }; for p in pos { match tape.get(*p) { @@ -65,20 +72,20 @@ where "failed to parse \"{s}\" as {}: {}", self.data_type, e )) - })?; + }); - let value = match P::UNIT { - TimeUnit::Second => date.timestamp(), - TimeUnit::Millisecond => date.timestamp_millis(), - TimeUnit::Microsecond => date.timestamp_micros(), + let date_to_value = |date: DateTime| match P::UNIT { + TimeUnit::Second => Ok(date.timestamp()), + TimeUnit::Millisecond => Ok(date.timestamp_millis()), + TimeUnit::Microsecond => Ok(date.timestamp_micros()), TimeUnit::Nanosecond => date.timestamp_nanos_opt().ok_or_else(|| { ArrowError::ParseError(format!( "{} would overflow 64-bit signed nanoseconds", date.to_rfc3339(), )) - })?, + }), }; - builder.append_value(value) + append(&mut builder, date.and_then(date_to_value))? } TapeElement::Number(idx) => { let s = tape.get_string(idx); @@ -90,9 +97,8 @@ where "failed to parse {s} as {}", self.data_type )) - })?; - - builder.append_value(value) + }); + append(&mut builder, value)? } TapeElement::I32(v) => builder.append_value(v as i64), TapeElement::I64(high) => match tape.get(p + 1) { @@ -101,6 +107,7 @@ where } _ => unreachable!(), }, + _ if self.ignore_type_conflicts => builder.append_null(), _ => return Err(tape.error(*p, "primitive")), } } From 5e87e5b4d13844ee7fcd76d41a457e3103f33115 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 12 Mar 2025 11:08:05 -0700 Subject: [PATCH 2/5] test coverage --- arrow-json/src/reader/decimal_array.rs | 5 + arrow-json/src/reader/mod.rs | 210 ++++++++++++++++++----- arrow-json/src/reader/primitive_array.rs | 5 + arrow-json/src/reader/timestamp_array.rs | 5 + 4 files changed, 186 insertions(+), 39 deletions(-) diff --git a/arrow-json/src/reader/decimal_array.rs b/arrow-json/src/reader/decimal_array.rs index e3c4fb3c1827..6eb3121908e7 100644 --- a/arrow-json/src/reader/decimal_array.rs +++ b/arrow-json/src/reader/decimal_array.rs @@ -52,6 +52,11 @@ where { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut builder = PrimitiveBuilder::::with_capacity(pos.len()); + + // Simplify call sites below by hoisting the branch out of the loop. Depending on compiler + // optimizations each call site will either be a predictable function pointer invocation or + // a predictable branch. Either way, the cost should be trivial compared to the expensive + // and unpredictably branchy string parse that immediately precedes each call. let append = if self.ignore_type_conflicts { super::append_value_or_null } else { diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index bb8afe93f6fb..64dc6d3d38c9 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -140,10 +140,10 @@ use std::sync::Arc; use chrono::Utc; use serde::Serialize; +use arrow_array::builder::PrimitiveBuilder; use arrow_array::timezone::Tz; use arrow_array::types::*; use arrow_array::{downcast_integer, make_array, RecordBatch, RecordBatchReader, StructArray}; -use arrow_array::builder::PrimitiveBuilder; use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; @@ -169,7 +169,7 @@ mod schema; mod serializer; mod string_array; mod struct_array; -pub mod tape; +mod tape; mod timestamp_array; /// A builder for [`Reader`] and [`Decoder`] @@ -283,8 +283,11 @@ impl ReaderBuilder { } } - /// Sets if the decoder should produce NULL instead of returning an error if it encounters a - /// type conflict on a nullable column. + /// Sets whether the decoder should produce NULL instead of returning an error if it encounters + /// a type conflict on a nullable column (effectively treating it as a non-existent column). + /// + /// NOTE: The inferred NULL on type conflict will still produce errors for non-nullable columns, + /// the same as any other NULL or missing value. pub fn with_ignore_type_conflicts(self, ignore_type_conflicts: bool) -> Self { Self { ignore_type_conflicts, @@ -785,16 +788,17 @@ fn append_value_or_null( #[cfg(test)] mod tests { - use serde_json::json; - use std::fs::File; - use std::io::{BufReader, Cursor, Seek}; - use arrow_array::cast::AsArray; - use arrow_array::{Array, BooleanArray, Float64Array, ListArray, StringArray}; - use arrow_buffer::{ArrowNativeType, Buffer}; + use arrow_array::{ + Array, BooleanArray, Float64Array, Int32Array, ListArray, MapArray, NullArray, StringArray, + }; + use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer}; use arrow_cast::display::{ArrayFormatter, FormatOptions}; use arrow_data::ArrayDataBuilder; use arrow_schema::{Field, Fields}; + use serde_json::json; + use std::fs::File; + use std::io::{BufReader, Cursor, Seek}; use super::*; @@ -2708,23 +2712,16 @@ mod tests { #[test] fn test_type_conflict_nulls() { - use arrow_array::{BooleanArray, Int32Array, MapArray, NullArray}; - use arrow_buffer::NullBuffer; - use arrow_schema::Fields; - let json = vec![ - json!({"null": null, "bool": true, "numeric": 1.234, "string": "hi", "array": [1, "hi", 3], "map": {"k": "value"}, "struct": {"a": 1}}), - json!({"bool": null, "numeric": true, "string": 1.234, "array": "hi", "map": [1, "hi", 3], "struct": {"k": "value"}, "null": {"a": 1}}), - json!({"numeric": null, "string": true, "array": 1.234, "map": "hi", "struct": [1, "hi", 3], "null": {"k": "value"}, "bool": {"a": 1}}), - json!({"string": null, "array": true, "map": 1.234, "struct": "hi", "null": [1, "hi", 3], "bool": {"k": "value"}, "numeric": {"a": 1}}), - json!({"array": null, "map": true, "struct": 1.234, "null": "hi", "bool": [1, "hi", 3], "numeric": {"k": "value"}, "string": {"a": 1}}), - json!({"map": null, "struct": true, "null": 1.234, "bool": "hi", "numeric": [1, "hi", 3], "string": {"k": "value"}, "array": {"a": 1}}), - json!({"struct": null, "null": true, "bool": 1.234, "numeric": "hi", "string": [1, "hi", 3], "array": {"k": "value"}, "map": {"a": 1}}), - ]; let schema = Schema::new(vec![ Field::new("null", DataType::Null, true), Field::new("bool", DataType::Boolean, true), Field::new("numeric", DataType::Decimal128(10, 3), true), Field::new("string", DataType::Utf8, true), + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Second, None), + true, + ), Field::new( "array", DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), @@ -2751,6 +2748,31 @@ mod tests { true, ), ]); + + // A compatible value for each schema field above, in schema order + let json_values = vec![ + json!(null), + json!(true), + json!(1.234), + json!("hi"), + json!("1970-01-01T00:00:00+02:00"), + json!([1, "ho", 3]), + json!({"k": "value"}), + json!({"a": 1}), + ]; + + // Create a set of JSON rows that rotates each value past every field + let json: Vec<_> = (0..json_values.len()) + .map(|i| { + let pairs = json_values[i..] + .iter() + .chain(json_values[..i].iter()) + .zip(&schema.fields) + .map(|(v, f)| (f.name().to_string(), v.clone())) + .collect(); + serde_json::Value::Object(pairs) + }) + .collect(); let mut decoder = ReaderBuilder::new(Arc::new(schema)) .with_ignore_type_conflicts(true) .with_coerce_primitive(true) @@ -2758,8 +2780,8 @@ mod tests { .unwrap(); decoder.serialize(&json).unwrap(); let batch = decoder.flush().unwrap().unwrap(); - assert_eq!(batch.num_rows(), 7); - assert_eq!(batch.num_columns(), 7); + assert_eq!(batch.num_rows(), 8); + assert_eq!(batch.num_columns(), 8); let _ = batch .column(0) @@ -2775,12 +2797,12 @@ mod tests { .unwrap(); assert!(bools .iter() - .eq([Some(true), None, None, None, None, None, None])); + .eq([Some(true), None, None, None, None, None, None, None])); let numbers = batch.column(2).as_primitive::(); assert!(numbers .iter() - .eq([Some(1234), None, None, None, None, None, None])); + .eq([Some(1234), None, None, None, None, None, None, None])); let strings = batch .column(3) @@ -2789,23 +2811,29 @@ mod tests { .unwrap(); assert!(strings.iter().eq([ Some("hi"), - Some("1.234"), - Some("true"), + Some("1970-01-01T00:00:00+02:00"), None, None, None, - None + None, + Some("true"), + Some("1.234"), ])); + let timestamps = batch.column(4).as_primitive::(); + assert!(timestamps + .iter() + .eq([Some(-7200), None, None, None, None, None, None, None,])); + let arrays = batch - .column(4) + .column(5) .as_any() .downcast_ref::() .unwrap(); assert_eq!( arrays.nulls(), Some(&NullBuffer::from( - &[true, false, false, false, false, false, false][..] + &[true, false, false, false, false, false, false, false][..] )) ); assert_eq!(arrays.offsets()[1], 3); @@ -2814,16 +2842,14 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert!( - array_values.iter().eq([Some(1), None, Some(3)]), - "{array_values:?}" - ); + assert!(array_values.iter().eq([Some(1), None, Some(3)])); - let maps = batch.column(5).as_any().downcast_ref::().unwrap(); + let maps = batch.column(6).as_any().downcast_ref::().unwrap(); assert_eq!( maps.nulls(), Some(&NullBuffer::from( - &[true, false, false, false, false, false, true][..] + // Both map and struct can parse + &[true, true, false, false, false, false, false, false][..] )) ); let map_keys = maps.keys().as_any().downcast_ref::().unwrap(); @@ -2836,14 +2862,15 @@ mod tests { assert!(map_values.iter().eq([Some("value"), Some("1")])); let structs = batch - .column(6) + .column(7) .as_any() .downcast_ref::() .unwrap(); assert_eq!( structs.nulls(), Some(&NullBuffer::from( - &[true, true, false, false, false, false, false][..] + // Both map and struct can parse + &[true, false, false, false, false, false, false, true][..] )) ); let struct_fields = structs @@ -2853,4 +2880,109 @@ mod tests { .unwrap(); assert!(struct_fields.slice(0, 2).iter().eq([Some(1), None])); } + + #[test] + fn test_type_conflict_non_nullable() { + let fields = [ + Field::new("bool", DataType::Boolean, false), + Field::new("numeric", DataType::Decimal128(10, 3), false), + Field::new("string", DataType::Utf8, false), + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Second, None), + false, + ), + Field::new( + "array", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new( + "map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Utf8, true), + ])), + false, // not nullable + )), + false, // not sorted + ), + false, // not nullable + ), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, true)])), + false, + ), + ]; + + // Every field above will have a type conflict with at least one of these values + let json_values = vec![json!(true), json!({"a": 1})]; + + for field in fields { + let mut decoder = ReaderBuilder::new_with_field(field) + .with_ignore_type_conflicts(true) + .build_decoder() + .unwrap(); + decoder.serialize(&json_values).unwrap(); + decoder + .flush() + .expect_err("type conflict on non-nullable type"); + } + } + + #[test] + fn test_ignore_type_conflicts_disabled() { + let fields = [ + Field::new("bool", DataType::Boolean, true), + Field::new("numeric", DataType::Decimal128(10, 3), true), + Field::new("string", DataType::Utf8, true), + Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Second, None), + true, + ), + Field::new( + "array", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Utf8, true), + ])), + false, // not nullable + )), + false, // not sorted + ), + true, // not nullable + ), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, true)])), + true, + ), + ]; + + // Every field above will have a type conflict with at least one of these values + let json_values = vec![json!(true), json!({"a": 1})]; + + for field in fields { + let mut decoder = ReaderBuilder::new_with_field(field) + .build_decoder() + .unwrap(); + decoder.serialize(&json_values).unwrap(); + decoder + .flush() + .expect_err("type conflict on non-nullable type"); + } + } } diff --git a/arrow-json/src/reader/primitive_array.rs b/arrow-json/src/reader/primitive_array.rs index 8c54ce13f69b..534f690ec771 100644 --- a/arrow-json/src/reader/primitive_array.rs +++ b/arrow-json/src/reader/primitive_array.rs @@ -99,6 +99,11 @@ where let mut builder = PrimitiveBuilder::

::with_capacity(pos.len()).with_data_type(self.data_type.clone()); let d = &self.data_type; + + // Simplify call sites below by hoisting the branch out of the loop. Depending on compiler + // optimizations each call site will either be a predictable function pointer invocation or + // a predictable branch. Either way, the cost should be trivial compared to the expensive + // and unpredictably branchy string parse that immediately precedes each call. let append = if self.ignore_type_conflicts { super::append_value_or_null } else { diff --git a/arrow-json/src/reader/timestamp_array.rs b/arrow-json/src/reader/timestamp_array.rs index af955d53d114..111c2aecea49 100644 --- a/arrow-json/src/reader/timestamp_array.rs +++ b/arrow-json/src/reader/timestamp_array.rs @@ -56,6 +56,11 @@ where fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut builder = PrimitiveBuilder::

::with_capacity(pos.len()).with_data_type(self.data_type.clone()); + + // Simplify call sites below by hoisting the branch out of the loop. Depending on compiler + // optimizations each call site will either be a predictable function pointer invocation or + // a predictable branch. Either way, the cost should be trivial compared to the expensive + // and unpredictably branchy string parse that immediately precedes each call. let append = if self.ignore_type_conflicts { super::append_value_or_null } else { From 9f9d7c583f9188fb00a7adad44ec5683020b82c8 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 12 Mar 2025 11:26:44 -0700 Subject: [PATCH 3/5] better null decoding --- arrow-json/src/reader/null_array.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/arrow-json/src/reader/null_array.rs b/arrow-json/src/reader/null_array.rs index 2b90da3389df..b931d7ec519f 100644 --- a/arrow-json/src/reader/null_array.rs +++ b/arrow-json/src/reader/null_array.rs @@ -34,9 +34,11 @@ impl NullArrayDecoder { impl ArrayDecoder for NullArrayDecoder { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { - for p in pos { - if !matches!(tape.get(*p), TapeElement::Null) && !self.ignore_type_conflicts { - return Err(tape.error(*p, "null")); + if !self.ignore_type_conflicts { + for p in pos { + if !matches!(tape.get(*p), TapeElement::Null) { + return Err(tape.error(*p, "null")); + } } } ArrayDataBuilder::new(DataType::Null).len(pos.len()).build() From 0936aa5a637e48a8d986f6b617ac1786d5bc9f37 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 12 Mar 2025 16:41:55 -0700 Subject: [PATCH 4/5] More efficient append helpers --- arrow-json/src/reader/decimal_array.rs | 48 ++++++++---------------- arrow-json/src/reader/mod.rs | 22 ----------- arrow-json/src/reader/primitive_array.rs | 18 +++++---- arrow-json/src/reader/timestamp_array.rs | 18 +++++---- 4 files changed, 36 insertions(+), 70 deletions(-) diff --git a/arrow-json/src/reader/decimal_array.rs b/arrow-json/src/reader/decimal_array.rs index 6eb3121908e7..83b12365e64a 100644 --- a/arrow-json/src/reader/decimal_array.rs +++ b/arrow-json/src/reader/decimal_array.rs @@ -53,55 +53,39 @@ where fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let mut builder = PrimitiveBuilder::::with_capacity(pos.len()); - // Simplify call sites below by hoisting the branch out of the loop. Depending on compiler - // optimizations each call site will either be a predictable function pointer invocation or - // a predictable branch. Either way, the cost should be trivial compared to the expensive - // and unpredictably branchy string parse that immediately precedes each call. - let append = if self.ignore_type_conflicts { - super::append_value_or_null - } else { - super::try_append_value + // Factor out this logic to simplify call sites below; the compiler will inline it, + // producing a highly predictable branch whose cost should be trivial compared to the + // expensive and unpredictably branchy string parse that immediately precedes each call. + let append = |builder: &mut PrimitiveBuilder, value: &str| { + match parse_decimal::(value, self.precision, self.scale) { + Ok(value) => builder.append_value(value), + Err(_) if self.ignore_type_conflicts => builder.append_null(), + Err(e) => return Err(e), + } + Ok(()) }; for p in pos { match tape.get(*p) { TapeElement::Null => builder.append_null(), - TapeElement::String(idx) => { - let s = tape.get_string(idx); - let value = parse_decimal::(s, self.precision, self.scale); - append(&mut builder, value)? - } - TapeElement::Number(idx) => { - let s = tape.get_string(idx); - let value = parse_decimal::(s, self.precision, self.scale); - append(&mut builder, value)? - } + TapeElement::String(idx) => append(&mut builder, tape.get_string(idx))?, + TapeElement::Number(idx) => append(&mut builder, tape.get_string(idx))?, TapeElement::I64(high) => match tape.get(*p + 1) { TapeElement::I32(low) => { let val = (((high as i64) << 32) | (low as u32) as i64).to_string(); - let value = parse_decimal::(&val, self.precision, self.scale); - append(&mut builder, value)? + append(&mut builder, &val)? } _ => unreachable!(), }, - TapeElement::I32(val) => { - let s = val.to_string(); - let value = parse_decimal::(&s, self.precision, self.scale); - append(&mut builder, value)? - } + TapeElement::I32(val) => append(&mut builder, &val.to_string())?, TapeElement::F64(high) => match tape.get(*p + 1) { TapeElement::F32(low) => { let val = f64::from_bits(((high as u64) << 32) | low as u64).to_string(); - let value = parse_decimal::(&val, self.precision, self.scale); - append(&mut builder, value)? + append(&mut builder, &val)? } _ => unreachable!(), }, - TapeElement::F32(val) => { - let s = f32::from_bits(val).to_string(); - let value = parse_decimal::(&s, self.precision, self.scale); - append(&mut builder, value)? - } + TapeElement::F32(val) => append(&mut builder, &f32::from_bits(val).to_string())?, _ if self.ignore_type_conflicts => builder.append_null(), _ => return Err(tape.error(*p, "decimal")), } diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 64dc6d3d38c9..46970ef5f90e 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -140,7 +140,6 @@ use std::sync::Arc; use chrono::Utc; use serde::Serialize; -use arrow_array::builder::PrimitiveBuilder; use arrow_array::timezone::Tz; use arrow_array::types::*; use arrow_array::{downcast_integer, make_array, RecordBatch, RecordBatchReader, StructArray}; @@ -765,27 +764,6 @@ fn make_decoder( } } -/// Attempts to append a value to the builder, if valid. Otherwise, returns the error. -fn try_append_value( - builder: &mut PrimitiveBuilder

, - value: Result, -) -> Result<(), ArrowError> { - builder.append_value(value?); - Ok(()) -} - -/// Attempts to append a value to the builder, if valid. Otherwise, appends NULL. -fn append_value_or_null( - builder: &mut PrimitiveBuilder

, - value: Result, -) -> Result<(), ArrowError> { - match value { - Ok(value) => builder.append_value(value), - Err(_) => builder.append_null(), - } - Ok(()) -} - #[cfg(test)] mod tests { use arrow_array::cast::AsArray; diff --git a/arrow-json/src/reader/primitive_array.rs b/arrow-json/src/reader/primitive_array.rs index 534f690ec771..afbb46dbbc11 100644 --- a/arrow-json/src/reader/primitive_array.rs +++ b/arrow-json/src/reader/primitive_array.rs @@ -100,14 +100,16 @@ where PrimitiveBuilder::

::with_capacity(pos.len()).with_data_type(self.data_type.clone()); let d = &self.data_type; - // Simplify call sites below by hoisting the branch out of the loop. Depending on compiler - // optimizations each call site will either be a predictable function pointer invocation or - // a predictable branch. Either way, the cost should be trivial compared to the expensive - // and unpredictably branchy string parse that immediately precedes each call. - let append = if self.ignore_type_conflicts { - super::append_value_or_null - } else { - super::try_append_value + // Factor out this logic to simplify call sites below; the compiler will inline it, + // producing a highly predictable branch whose cost should be trivial compared to the + // expensive and unpredictably branchy string parse that immediately precedes each call. + let append = |builder: &mut PrimitiveBuilder

, value| { + match value { + Ok(value) => builder.append_value(value), + Err(_) if self.ignore_type_conflicts => builder.append_null(), + Err(e) => return Err(e), + }; + Ok(()) }; for p in pos { diff --git a/arrow-json/src/reader/timestamp_array.rs b/arrow-json/src/reader/timestamp_array.rs index 111c2aecea49..b9298f9c23f4 100644 --- a/arrow-json/src/reader/timestamp_array.rs +++ b/arrow-json/src/reader/timestamp_array.rs @@ -57,14 +57,16 @@ where let mut builder = PrimitiveBuilder::

::with_capacity(pos.len()).with_data_type(self.data_type.clone()); - // Simplify call sites below by hoisting the branch out of the loop. Depending on compiler - // optimizations each call site will either be a predictable function pointer invocation or - // a predictable branch. Either way, the cost should be trivial compared to the expensive - // and unpredictably branchy string parse that immediately precedes each call. - let append = if self.ignore_type_conflicts { - super::append_value_or_null - } else { - super::try_append_value + // Factor out this logic to simplify call sites below; the compiler will inline it, + // producing a highly predictable branch whose cost should be trivial compared to the + // expensive and unpredictably branchy string parse that immediately precedes each call. + let append = |builder: &mut PrimitiveBuilder

, value| { + match value { + Ok(value) => builder.append_value(value), + Err(_) if self.ignore_type_conflicts => builder.append_null(), + Err(e) => return Err(e), + }; + Ok(()) }; for p in pos { From bb4e8d4d6a3953dc2ebe80e3f91b6f085636721c Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Thu, 13 Mar 2025 09:16:25 -0700 Subject: [PATCH 5/5] missing unit tests --- arrow-json/src/reader/mod.rs | 134 +++++++++++++++++++++++++++-------- 1 file changed, 104 insertions(+), 30 deletions(-) diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 3ef303ba4e8f..1938bd7e818d 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -2836,8 +2836,10 @@ mod tests { let schema = Schema::new(vec![ Field::new("null", DataType::Null, true), Field::new("bool", DataType::Boolean, true), + Field::new("primitive", DataType::Int32, true), Field::new("numeric", DataType::Decimal128(10, 3), true), Field::new("string", DataType::Utf8, true), + Field::new("string_view", DataType::Utf8View, true), Field::new( "timestamp", DataType::Timestamp(TimeUnit::Second, None), @@ -2874,8 +2876,10 @@ mod tests { let json_values = vec![ json!(null), json!(true), + json!(42), json!(1.234), json!("hi"), + json!("ho"), json!("1970-01-01T00:00:00+02:00"), json!([1, "ho", 3]), json!({"k": "value"}), @@ -2901,60 +2905,125 @@ mod tests { .unwrap(); decoder.serialize(&json).unwrap(); let batch = decoder.flush().unwrap().unwrap(); - assert_eq!(batch.num_rows(), 8); - assert_eq!(batch.num_columns(), 8); + assert_eq!(batch.num_rows(), 10); + assert_eq!(batch.num_columns(), 10); + // NOTE: NullArray doesn't materialize any values (they're all NULL by definition) let _ = batch .column(0) .as_any() .downcast_ref::() .unwrap(); - // NOTE: NullArray doesn't materialize any values (they're all NULL by definition) - let bools = batch + assert!(batch .column(1) .as_any() .downcast_ref::() - .unwrap(); - assert!(bools + .unwrap() .iter() - .eq([Some(true), None, None, None, None, None, None, None])); + .eq([ + Some(true), + None, + None, + None, + None, + None, + None, + None, + None, + None + ])); - let numbers = batch.column(2).as_primitive::(); - assert!(numbers - .iter() - .eq([Some(1234), None, None, None, None, None, None, None])); + assert!(batch.column(2).as_primitive::().iter().eq([ + Some(42), + Some(1), + None, + None, + None, + None, + None, + None, + None, + None + ])); - let strings = batch - .column(3) - .as_any() - .downcast_ref::() - .unwrap(); - assert!(strings.iter().eq([ - Some("hi"), - Some("1970-01-01T00:00:00+02:00"), + assert!(batch.column(3).as_primitive::().iter().eq([ + Some(1234), + None, + None, None, None, None, None, - Some("true"), - Some("1.234"), + None, + None, + Some(42000) ])); - let timestamps = batch.column(4).as_primitive::(); - assert!(timestamps + assert!(batch + .column(4) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .eq([ + Some("hi"), + Some("ho"), + Some("1970-01-01T00:00:00+02:00"), + None, + None, + None, + None, + Some("true"), + Some("42"), + Some("1.234"), + ])); + + assert!(batch + .column(5) + .as_any() + .downcast_ref::() + .unwrap() .iter() - .eq([Some(-7200), None, None, None, None, None, None, None,])); + .eq([ + Some("ho"), + Some("1970-01-01T00:00:00+02:00"), + None, + None, + None, + None, + Some("true"), + Some("42"), + Some("1.234"), + Some("hi"), + ])); + + assert!(batch + .column(6) + .as_primitive::() + .iter() + .eq([ + Some(-7200), + None, + None, + None, + None, + None, + Some(42), + None, + None, + None, + ])); let arrays = batch - .column(5) + .column(7) .as_any() .downcast_ref::() .unwrap(); assert_eq!( arrays.nulls(), Some(&NullBuffer::from( - &[true, false, false, false, false, false, false, false][..] + &[true, false, false, false, false, false, false, false, false, false][..] )) ); assert_eq!(arrays.offsets()[1], 3); @@ -2965,12 +3034,12 @@ mod tests { .unwrap(); assert!(array_values.iter().eq([Some(1), None, Some(3)])); - let maps = batch.column(6).as_any().downcast_ref::().unwrap(); + let maps = batch.column(8).as_any().downcast_ref::().unwrap(); assert_eq!( maps.nulls(), Some(&NullBuffer::from( // Both map and struct can parse - &[true, true, false, false, false, false, false, false][..] + &[true, true, false, false, false, false, false, false, false, false][..] )) ); let map_keys = maps.keys().as_any().downcast_ref::().unwrap(); @@ -2983,7 +3052,7 @@ mod tests { assert!(map_values.iter().eq([Some("value"), Some("1")])); let structs = batch - .column(7) + .column(9) .as_any() .downcast_ref::() .unwrap(); @@ -2991,7 +3060,7 @@ mod tests { structs.nulls(), Some(&NullBuffer::from( // Both map and struct can parse - &[true, false, false, false, false, false, false, true][..] + &[true, false, false, false, false, false, false, false, false, true][..] )) ); let struct_fields = structs @@ -3006,8 +3075,10 @@ mod tests { fn test_type_conflict_non_nullable() { let fields = [ Field::new("bool", DataType::Boolean, false), + Field::new("primitive", DataType::Int32, false), Field::new("numeric", DataType::Decimal128(10, 3), false), Field::new("string", DataType::Utf8, false), + Field::new("string_view", DataType::Utf8View, false), Field::new( "timestamp", DataType::Timestamp(TimeUnit::Second, None), @@ -3058,9 +3129,12 @@ mod tests { #[test] fn test_ignore_type_conflicts_disabled() { let fields = [ + Field::new("null", DataType::Null, true), Field::new("bool", DataType::Boolean, true), + Field::new("primitive", DataType::Int32, true), Field::new("numeric", DataType::Decimal128(10, 3), true), Field::new("string", DataType::Utf8, true), + Field::new("string_view", DataType::Utf8View, true), Field::new( "timestamp", DataType::Timestamp(TimeUnit::Second, None),