From 7d4e6fd2f8ae14109271ac348aaa87661c770dfc Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 28 Dec 2024 12:44:19 -0600 Subject: [PATCH 01/38] Added basic support for arrow -> avro codec along with beginnings of avro writer. --- arrow-avro/src/codec.rs | 228 ++++++++++++++++++++++++- arrow-avro/src/lib.rs | 1 + arrow-avro/src/schema.rs | 22 +-- arrow-avro/src/writer/mod.rs | 15 ++ arrow-avro/src/writer/schema.rs | 288 ++++++++++++++++++++++++++++++++ arrow-avro/src/writer/vlq.rs | 98 +++++++++++ 6 files changed, 640 insertions(+), 12 deletions(-) create mode 100644 arrow-avro/src/writer/mod.rs create mode 100644 arrow-avro/src/writer/schema.rs create mode 100644 arrow-avro/src/writer/vlq.rs diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 2ac1ad038bd7..25a790fa476a 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,10 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use crate::schema::{ + Attributes, ComplexType, PrimitiveType, Schema, TypeName, Array, Fixed, Map, Record, + Field as AvroFieldDef +}; use arrow_schema::{ ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, }; +use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray, RecordBatch}; use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -45,6 +49,25 @@ pub struct AvroDataType { } impl AvroDataType { + + /// Create a new AvroDataType with the given parts. + /// This helps you construct it from outside `codec.rs` without exposing internals. + pub fn new( + codec: Codec, + nullability: Option, + metadata: HashMap, + ) -> Self { + AvroDataType { + codec, + nullability, + metadata, + } + } + + pub fn from_codec(codec: Codec) -> Self { + Self::new(codec, None, Default::default()) + } + /// Returns an arrow [`Field`] with the given name pub fn field_with_name(&self, name: &str) -> Field { let d = self.codec.data_type(); @@ -58,6 +81,23 @@ impl AvroDataType { pub fn nullability(&self) -> Option { self.nullability } + + /// Convert this `AvroDataType`, which encapsulates an Arrow data type (`codec`) + /// plus nullability, back into an Avro `Schema<'a>`. + pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { + let inner_schema = self.codec.to_avro_schema(name); + + // If the field is nullable in Arrow, wrap Avro schema in a union: ["null", ]. + // Otherwise, return the schema as-is. + if let Some(_) = self.nullability { + Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + inner_schema, + ]) + } else { + inner_schema + } + } } /// A named [`AvroDataType`] @@ -157,6 +197,128 @@ impl Codec { Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), } } + + /// Convert this `Codec` variant to an Avro `Schema<'a>`. + /// More work needed to handle `decimal`, `enum`, `map`, etc. + pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { + match self { + Codec::Null => Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Codec::Boolean => Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)), + Codec::Int32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Codec::Int64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + Codec::Float32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)), + Codec::Float64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), + Codec::Binary => Schema::TypeName(TypeName::Primitive(PrimitiveType::Bytes)), + Codec::Utf8 => Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + + // date32 => Avro int + logicalType=date + Codec::Date32 => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Int), + attributes: Attributes { + logical_type: Some("date"), + additional: Default::default(), + }, + }), + + // time-millis => Avro int with logicalType=time-millis + Codec::TimeMillis => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Int), + attributes: Attributes { + logical_type: Some("time-millis"), + additional: Default::default(), + }, + }), + + // time-micros => Avro long with logicalType=time-micros + Codec::TimeMicros => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("time-micros"), + additional: Default::default(), + }, + }), + + // timestamp-millis => Avro long with logicalType=timestamp-millis + Codec::TimestampMillis(is_utc) => { + // TODO `is_utc` or store it in metadata + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("timestamp-millis"), + additional: Default::default(), + }, + }) + } + + // timestamp-micros => Avro long with logicalType=timestamp-micros + Codec::TimestampMicros(is_utc) => { + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("timestamp-micros"), + additional: Default::default(), + }, + }) + } + + Codec::Interval => { + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("duration"), + additional: Default::default(), + }, + }) + } + + Codec::Fixed(size) => { + // Convert Arrow FixedSizeBinary => Avro fixed with a known name & size + // TODO namespace/aliases. + Schema::Complex(ComplexType::Fixed(Fixed { + name, + namespace: None, // TODO namespace implementation + aliases: vec![], // TODO alias implementation + size: *size as usize, + attributes: Attributes::default(), + })) + } + + Codec::List(item_type) => { + // Avro array with "items" recursively derived + let items_schema = item_type.to_avro_schema("items"); + Schema::Complex(ComplexType::Array(Array { + items: Box::new(items_schema), + attributes: Attributes::default(), + })) + } + + Codec::Struct(fields) => { + // Avro record with nested fields + let record_fields = fields + .iter() + .map(|f| { + // For each `AvroField`, get its Avro schema + let child_schema = f.data_type().to_avro_schema(f.name()); + AvroFieldDef { + name: f.name(), // Avro field name + doc: None, + r#type: child_schema, + default: None, + } + }) + .collect(); + + Schema::Complex(ComplexType::Record(Record { + name, + namespace: None, // TODO follow up for namespace implementation + doc: None, + aliases: vec![], // TODO follow up for alias implementation + fields: record_fields, + attributes: Attributes::default(), + })) + } + } + } } impl From for Codec { @@ -327,3 +489,67 @@ fn make_data_type<'a>( } } } + + +/// Convert an Arrow `Field` into an `AvroField`. +pub(crate) fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { + // TODO advanced metadata logic here + let codec = arrow_type_to_codec(arrow_field.data_type()); + // Set nullability if the Arrow field is nullable + let nullability = if arrow_field.is_nullable() { + Some(Nullability::NullFirst) + } else { + None + }; + let avro_data_type = AvroDataType { + nullability, + metadata: arrow_field.metadata().clone(), + codec, + }; + AvroField { + name: arrow_field.name().clone(), + data_type: avro_data_type, + } +} + +/// Maps an Arrow `DataType` to a `Codec`: +fn arrow_type_to_codec(dt: &DataType) -> Codec { + use arrow_schema::DataType::*; + match dt { + Null => Codec::Null, + Boolean => Codec::Boolean, + Int8 | Int16 | Int32 => Codec::Int32, + Int64 => Codec::Int64, + Float32 => Codec::Float32, + Float64 => Codec::Float64, + Utf8 => Codec::Utf8, + Binary | LargeBinary => Codec::Binary, + Date32 => Codec::Date32, + Time32(TimeUnit::Millisecond) => Codec::TimeMillis, + Time64(TimeUnit::Microsecond) => Codec::TimeMicros, + Timestamp(TimeUnit::Millisecond, _) => Codec::TimestampMillis(true), + Timestamp(TimeUnit::Microsecond, _) => Codec::TimestampMicros(true), + FixedSizeBinary(n) => Codec::Fixed(*n as i32), + + List(field) => { + // Recursively create Codec for the child item + let child_codec = arrow_type_to_codec(field.data_type()); + Codec::List(Arc::new(AvroDataType { + nullability: None, + metadata: Default::default(), + codec: child_codec, + })) + } + Struct(child_fields) => { + let avro_fields: Vec = child_fields + .iter() + .map(|fref| arrow_field_to_avro_field(fref.as_ref())) + .collect(); + Codec::Struct(Arc::from(avro_fields)) + } + _ => { + // TODO handle more arrow types (e.g. decimal, map, union, etc.) + Codec::Utf8 + } + } +} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index d01d681b7af0..ef3bd082d0e8 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -29,6 +29,7 @@ mod schema; mod compression; mod codec; +mod writer; #[cfg(test)] mod test_util { diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index a9d91e47948b..4895d24d76e4 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -123,7 +123,7 @@ pub enum ComplexType<'a> { pub struct Record<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub doc: Option<&'a str>, @@ -144,7 +144,7 @@ pub struct Field<'a> { pub doc: Option<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub default: Option<&'a str>, } @@ -155,7 +155,7 @@ pub struct Field<'a> { pub struct Enum<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub doc: Option<&'a str>, @@ -163,7 +163,7 @@ pub struct Enum<'a> { pub aliases: Vec<&'a str>, #[serde(borrow)] pub symbols: Vec<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub default: Option<&'a str>, #[serde(flatten)] pub attributes: Attributes<'a>, @@ -198,7 +198,7 @@ pub struct Map<'a> { pub struct Fixed<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, @@ -237,7 +237,7 @@ mod tests { "logicalType":"timestamp-micros" }"#, ) - .unwrap(); + .unwrap(); let timestamp = Type { r#type: TypeName::Primitive(PrimitiveType::Long), @@ -260,7 +260,7 @@ mod tests { "scale":2 }"#, ) - .unwrap(); + .unwrap(); let decimal = ComplexType::Fixed(Fixed { name: "fixed", @@ -300,7 +300,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -333,7 +333,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -392,7 +392,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -453,7 +453,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs new file mode 100644 index 000000000000..afb623162d9a --- /dev/null +++ b/arrow-avro/src/writer/mod.rs @@ -0,0 +1,15 @@ +mod schema; +mod vlq; + +#[cfg(test)] +mod test { + use std::fs::File; + use std::io::BufWriter; + use arrow_array::RecordBatch; + + fn write_file(file: &str, batch: &RecordBatch) { + let file = File::open(file).unwrap(); + let mut writer = BufWriter::new(file); + + } +} \ No newline at end of file diff --git a/arrow-avro/src/writer/schema.rs b/arrow-avro/src/writer/schema.rs new file mode 100644 index 000000000000..c8cc5a7f9ec2 --- /dev/null +++ b/arrow-avro/src/writer/schema.rs @@ -0,0 +1,288 @@ +use std::collections::HashMap; +use std::sync::Arc; +use arrow_array::RecordBatch; +use crate::codec::{AvroDataType, AvroField, Codec}; +use crate::schema::Schema; + +fn record_batch_to_avro_schema<'a>( + batch: &'a RecordBatch, + record_name: &'a str, + top_level_data_type: &'a AvroDataType, +) -> Schema<'a> { + top_level_data_type.to_avro_schema(record_name) +} + +pub fn to_avro_json_schema( + batch: &RecordBatch, + record_name: &str, +) -> Result { + let avro_fields: Vec = batch + .schema() + .fields() + .iter() + .map(|arrow_field| crate::codec::arrow_field_to_avro_field(arrow_field)) + .collect(); + let top_level_data_type = AvroDataType::from_codec( + Codec::Struct(Arc::from(avro_fields)), + ); + let avro_schema = record_batch_to_avro_schema(batch, record_name, &top_level_data_type); + serde_json::to_string_pretty(&avro_schema) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Int32Array, StringArray, RecordBatch, ArrayRef, StructArray}; + use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; + use serde_json::{json, Value}; + use std::sync::Arc; + + #[test] + fn test_record_batch_to_avro_schema_basic() { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); + let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) + .expect("Failed to create RecordBatch"); + + // Convert the batch -> Avro `Schema` + let avro_schema = to_avro_json_schema(&batch, "MyTestRecord") + .expect("Failed to convert RecordBatch to Avro JSON schema");; + let actual_json: Value = serde_json::from_str(&avro_schema) + .expect("Invalid JSON returned by to_avro_json_schema"); + + let expected_json = json!({ + "type": "record", + "name": "MyTestRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "doc": null, + "type": "int" + }, + { + "name": "name", + "doc": null, + "type": ["null", "string"] + } + ] + }); + + // Compare the two JSON objects + assert_eq!( + actual_json, expected_json, + "Avro Schema JSON does not match expected" + ); + } + + #[test] + fn test_to_avro_json_schema_basic() { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("desc", DataType::Utf8, true), + ])); + + let col_id = Arc::new(Int32Array::from(vec![10, 20, 30])); + let col_desc = Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_desc]) + .expect("Failed to create RecordBatch"); + + let json_schema_string = to_avro_json_schema(&batch, "AnotherTestRecord") + .expect("Failed to convert RecordBatch to Avro JSON schema"); + + let actual_json: Value = serde_json::from_str(&json_schema_string) + .expect("Invalid JSON returned by to_avro_json_schema"); + + let expected_json = json!({ + "type": "record", + "name": "AnotherTestRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "type": "int", + "doc": null, + }, + { + "name": "desc", + "type": ["null", "string"], + "doc": null, + } + ] + }); + + assert_eq!( + actual_json, expected_json, + "JSON schema mismatch for to_avro_json_schema" + ); + } + + #[test] + fn test_to_avro_json_schema_single_nonnull_int() { + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id]) + .expect("Failed to create RecordBatch"); + + let avro_json_string = to_avro_json_schema(&batch, "MySingleIntRecord") + .expect("Failed to generate Avro JSON schema"); + + let actual_json: Value = serde_json::from_str(&avro_json_string) + .expect("Failed to parse Avro JSON schema"); + + let expected_json = json!({ + "type": "record", + "name": "MySingleIntRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "type": "int", + "doc": null, + } + ] + }); + + // Compare + assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); + } + + #[test] + fn test_to_avro_json_schema_two_fields_nullable_string() { + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); + let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) + .expect("Failed to create RecordBatch"); + + let avro_json_string = to_avro_json_schema(&batch, "MyRecord") + .expect("Failed to generate Avro JSON schema"); + + let actual_json: Value = serde_json::from_str(&avro_json_string) + .expect("Failed to parse Avro JSON schema"); + + let expected_json = json!({ + "type": "record", + "name": "MyRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "type": "int", + "doc": null, + }, + { + "name": "name", + "doc": null, + "type": [ + "null", + "string", + ] + } + ] + }); + + // Compare + assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); + } + + #[test] + fn test_to_avro_json_schema_nested_struct() { + let inner_fields = Fields::from(vec![ + Field::new("inner_int", DataType::Int32, false), + Field::new("inner_str", DataType::Utf8, true), + ]); + + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![ + Field::new("my_struct", DataType::Struct(inner_fields), true) + ])); + + let inner_int_col = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef; + let inner_str_col = Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef; + + let fields_arrays = vec![ + ( + Arc::new(Field::new("inner_int", DataType::Int32, false)), + inner_int_col, + ), + ( + Arc::new(Field::new("inner_str", DataType::Utf8, true)), + inner_str_col, + ), + ]; + + let struct_array = StructArray::from(fields_arrays); + + let batch = RecordBatch::try_new( + arrow_schema, + vec![Arc::new(struct_array)], + ) + .expect("Failed to create RecordBatch"); + + let avro_json_string = to_avro_json_schema(&batch, "NestedRecord") + .expect("Failed to generate Avro JSON schema"); + + let actual_json: Value = serde_json::from_str(&avro_json_string) + .expect("Failed to parse Avro JSON schema"); + + let expected_json = json!({ + "type": "record", + "name": "NestedRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "my_struct", + "doc": null, + "type": [ + "null", + { + "type": "record", + "name": "my_struct", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "inner_int", + "type": "int", + "doc": null, + }, + { + "name": "inner_str", + "doc": null, + "type": [ + "null", + "string", + ] + } + ] + } + ] + } + ] + }); + + // Compare + assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); + } +} diff --git a/arrow-avro/src/writer/vlq.rs b/arrow-avro/src/writer/vlq.rs new file mode 100644 index 000000000000..765e6687abaa --- /dev/null +++ b/arrow-avro/src/writer/vlq.rs @@ -0,0 +1,98 @@ +/// Encoder for zig-zag encoded variable length integers +/// +/// This complements the VLQ decoding logic used by Avro. Zig-zag encoding maps signed integers +/// to unsigned integers so that small magnitudes (both positive and negative) produce smaller varints. +/// After zig-zag encoding, values are encoded as a series of bytes where the lower 7 bits are data +/// and the high bit indicates if another byte follows. +/// +/// See also: +/// +/// +#[derive(Debug, Default)] +pub struct VLQEncoder; + +impl VLQEncoder { + /// Encode a signed 64-bit integer `value` into `output` using Avro's zig-zag varint encoding. + /// + /// Zig-zag encoding: + /// ```text + /// encoded = (value << 1) ^ (value >> 63) + /// ``` + /// + /// Then `encoded` is written as a variable-length integer (varint): + /// - Extract 7 bits at a time + /// - If more bits remain, set the MSB of the current byte to 1 and continue + /// - Otherwise, MSB is 0 and this is the last byte + pub fn long(&mut self, value: i64, output: &mut Vec) { + let zigzag = ((value << 1) ^ (value >> 63)) as u64; + self.encode_varint(zigzag, output); + } + + fn encode_varint(&self, mut val: u64, output: &mut Vec) { + while (val & !0x7F) != 0 { + output.push(((val & 0x7F) as u8) | 0x80); + val >>= 7; + } + output.push(val as u8); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn decode_varint(buf: &mut &[u8]) -> Option { + let mut value = 0_u64; + for i in 0..10 { + let b = buf.get(i).copied()?; + let lower_7 = (b & 0x7F) as u64; + value |= lower_7 << (7 * i); + if b & 0x80 == 0 { + *buf = &buf[i + 1..]; + return Some(value); + } + } + None // more than 10 bytes or not terminated properly + } + + fn decode_zigzag(val: u64) -> i64 { + ((val >> 1) as i64) ^ -((val & 1) as i64) + } + + fn decode_long(buf: &mut &[u8]) -> Option { + let val = decode_varint(buf)?; + Some(decode_zigzag(val)) + } + + fn round_trip(value: i64) { + let mut encoder = VLQEncoder::default(); + let mut buf = Vec::new(); + encoder.long(value, &mut buf); + + let mut slice = buf.as_slice(); + let decoded = decode_long(&mut slice).expect("Failed to decode value"); + assert_eq!(decoded, value, "Round-trip mismatch for value {}", value); + assert!(slice.is_empty(), "Not all bytes consumed"); + } + + #[test] + fn test_round_trip() { + round_trip(0); + round_trip(1); + round_trip(-1); + round_trip(12345678); + round_trip(-12345678); + round_trip(i64::MAX); + round_trip(i64::MIN); + } + + #[test] + fn test_random_values() { + use rand::Rng; + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let val: i64 = rng.gen(); + round_trip(val); + } + } +} From 36d56a9b36bef27f6a1fc67a8cb34ded459f36ca Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 04:18:32 -0600 Subject: [PATCH 02/38] Added codec support + tests for: 1. Namespaces 2. Enums 3. Maps 4. Decimals --- arrow-avro/src/codec.rs | 777 ++++++++++++++++++++++++++++---- arrow-avro/src/reader/record.rs | 3 + 2 files changed, 692 insertions(+), 88 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 25a790fa476a..aab38a45e444 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -17,11 +17,14 @@ use crate::schema::{ Attributes, ComplexType, PrimitiveType, Schema, TypeName, Array, Fixed, Map, Record, - Field as AvroFieldDef -}; -use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, + Field as AvroFieldDef, + Fixed as AvroFixed, + Enum as AvroEnum, + Map as AvroMap }; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, + SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE}; use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray, RecordBatch}; use std::borrow::Cow; use std::collections::HashMap; @@ -49,7 +52,6 @@ pub struct AvroDataType { } impl AvroDataType { - /// Create a new AvroDataType with the given parts. /// This helps you construct it from outside `codec.rs` without exposing internals. pub fn new( @@ -64,6 +66,7 @@ impl AvroDataType { } } + /// Create a new AvroDataType from a `Codec`, with default (no) nullability and empty metadata. pub fn from_codec(codec: Codec) -> Self { Self::new(codec, None, Default::default()) } @@ -74,30 +77,57 @@ impl AvroDataType { Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) } + /// Return a reference to the inner `Codec`. pub fn codec(&self) -> &Codec { &self.codec } + /// Return the nullability for this Avro type, if any. pub fn nullability(&self) -> Option { self.nullability } /// Convert this `AvroDataType`, which encapsulates an Arrow data type (`codec`) - /// plus nullability, back into an Avro `Schema<'a>`. + /// plus nullability and metadata, back into an Avro `Schema<'a>`. + /// + /// - If `metadata["namespace"]` is present, we'll store it in the resulting schema for named types + /// (record, enum, fixed). pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { let inner_schema = self.codec.to_avro_schema(name); - // If the field is nullable in Arrow, wrap Avro schema in a union: ["null", ]. - // Otherwise, return the schema as-is. if let Some(_) = self.nullability { Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - inner_schema, + maybe_add_namespace(inner_schema, self), ]) } else { - inner_schema + maybe_add_namespace(inner_schema, self) + } + } +} + +/// If this is a named complex type (Record, Enum, Fixed), attach `namespace` +/// from `dt.metadata["namespace"]` if present. Otherwise, return as-is. +fn maybe_add_namespace<'a>(mut schema: Schema<'a>, dt: &'a AvroDataType) -> Schema<'a> { + let ns = dt.metadata.get("namespace"); + if let Some(ns_str) = ns { + if let Schema::Complex(ref mut c) = schema { + match c { + ComplexType::Record(r) => { + r.namespace = Some(ns_str); + } + ComplexType::Enum(e) => { + e.namespace = Some(ns_str); + } + ComplexType::Fixed(f) => { + f.namespace = Some(ns_str); + } + // Arrays and Maps do not have a namespace field, so do nothing + _ => {} + } } } + schema } /// A named [`AvroDataType`] @@ -118,6 +148,7 @@ impl AvroField { &self.data_type } + /// Returns the name of this field pub fn name(&self) -> &str { &self.name } @@ -167,9 +198,14 @@ pub enum Codec { List(Arc), Struct(Arc<[AvroField]>), Interval, + /// In Arrow, use Dictionary(Int32, Utf8) for Enum. + Enum(Vec), + Map(Arc), + Decimal(usize, Option, Option), } impl Codec { + /// Convert this to an Arrow `DataType` fn data_type(&self) -> DataType { match self { Self::Null => DataType::Null, @@ -195,11 +231,50 @@ impl Codec { DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) } Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + Self::Enum(_symbols) => { + // Produce a Dictionary type with index = Int32, value = Utf8 + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ) + } + Self::Map(values) => { + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + values.field_with_name("value"), + ]) + ), + false, + )), + false, + ) + } + Self::Decimal(precision, scale, size) => match size { + Some(s) if *s > 16 && *s <= 32 => { + DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + }, + Some(s) if *s <= 16 => { + DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + }, + _ => { + // Infer based on precision when size is None + if *precision <= DECIMAL128_MAX_PRECISION as usize + && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize + { + DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + } else { + DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + } + } + }, } } /// Convert this `Codec` variant to an Avro `Schema<'a>`. - /// More work needed to handle `decimal`, `enum`, `map`, etc. pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { match self { Codec::Null => Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -210,7 +285,6 @@ impl Codec { Codec::Float64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), Codec::Binary => Schema::TypeName(TypeName::Primitive(PrimitiveType::Bytes)), Codec::Utf8 => Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), - // date32 => Avro int + logicalType=date Codec::Date32 => Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Int), @@ -219,7 +293,6 @@ impl Codec { additional: Default::default(), }, }), - // time-millis => Avro int with logicalType=time-millis Codec::TimeMillis => Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Int), @@ -228,7 +301,6 @@ impl Codec { additional: Default::default(), }, }), - // time-micros => Avro long with logicalType=time-micros Codec::TimeMicros => Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), @@ -237,52 +309,53 @@ impl Codec { additional: Default::default(), }, }), - - // timestamp-millis => Avro long with logicalType=timestamp-millis + // timestamp-millis => Avro long with logicalType=timestamp-millis or local-timestamp-millis Codec::TimestampMillis(is_utc) => { - // TODO `is_utc` or store it in metadata + let lt = if *is_utc { + Some("timestamp-millis") + } else { + Some("local-timestamp-millis") + }; Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), attributes: Attributes { - logical_type: Some("timestamp-millis"), + logical_type: lt, additional: Default::default(), }, }) } - - // timestamp-micros => Avro long with logicalType=timestamp-micros + // timestamp-micros => Avro long with logicalType=timestamp-micros or local-timestamp-micros Codec::TimestampMicros(is_utc) => { + let lt = if *is_utc { + Some("timestamp-micros") + } else { + Some("local-timestamp-micros") + }; Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), attributes: Attributes { - logical_type: Some("timestamp-micros"), - additional: Default::default(), - }, - }) - } - - Codec::Interval => { - Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Bytes), - attributes: Attributes { - logical_type: Some("duration"), + logical_type: lt, additional: Default::default(), }, }) } - + Codec::Interval => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("duration"), + additional: Default::default(), + }, + }), Codec::Fixed(size) => { - // Convert Arrow FixedSizeBinary => Avro fixed with a known name & size - // TODO namespace/aliases. + // Convert Arrow FixedSizeBinary => Avro fixed with name & size Schema::Complex(ComplexType::Fixed(Fixed { name, - namespace: None, // TODO namespace implementation - aliases: vec![], // TODO alias implementation + namespace: None, + aliases: vec![], size: *size as usize, attributes: Attributes::default(), })) } - Codec::List(item_type) => { // Avro array with "items" recursively derived let items_schema = item_type.to_avro_schema("items"); @@ -291,32 +364,80 @@ impl Codec { attributes: Attributes::default(), })) } - Codec::Struct(fields) => { // Avro record with nested fields let record_fields = fields .iter() .map(|f| { - // For each `AvroField`, get its Avro schema let child_schema = f.data_type().to_avro_schema(f.name()); AvroFieldDef { - name: f.name(), // Avro field name + name: f.name(), doc: None, r#type: child_schema, default: None, } }) .collect(); - Schema::Complex(ComplexType::Record(Record { name, - namespace: None, // TODO follow up for namespace implementation + namespace: None, doc: None, - aliases: vec![], // TODO follow up for alias implementation + aliases: vec![], fields: record_fields, attributes: Attributes::default(), })) } + Codec::Enum(symbols) => { + // If there's a namespace in metadata, we will apply it later in maybe_add_namespace. + Schema::Complex(ComplexType::Enum(AvroEnum { + name, + namespace: None, + doc: None, + aliases: vec![], + symbols: symbols.iter().map(|s| s.as_str()).collect(), + default: None, + attributes: Attributes::default(), + })) + } + Codec::Map(values) => { + let val_schema = values.to_avro_schema("values"); + Schema::Complex(ComplexType::Map(AvroMap { + values: Box::new(val_schema), + attributes: Attributes::default(), + })) + } + Codec::Decimal(precision, scale, size) => { + // If size is Some(n), produce Avro "fixed", else "bytes". + if let Some(n) = size { + // fixed with logicalType=decimal, plus precision/scale + Schema::Complex(ComplexType::Fixed(AvroFixed { + name, + namespace: None, + aliases: vec![], + size: *n, + attributes: Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::json!(*precision)), + ("scale", serde_json::json!(scale.unwrap_or(0))), + ("size", serde_json::json!(*n)), + ]), + }, + })) + } else { + // "type":"bytes", "logicalType":"decimal" + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::json!(*precision)), + ("scale", serde_json::json!(scale.unwrap_or(0))), + ]), + }, + }) + } + } } } } @@ -365,8 +486,6 @@ impl<'a> Resolver<'a> { /// /// `name`: is name used to refer to `schema` in its parent /// `namespace`: an optional qualifier used as part of a type hierarchy -/// -/// See [`Resolver`] for more information fn make_data_type<'a>( schema: &Schema<'a>, namespace: Option<&'a str>, @@ -380,7 +499,7 @@ fn make_data_type<'a>( }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), Schema::Union(f) => { - // Special case the common case of nullable primitives + // Special case the common case of nullable primitives or single-type let null = f .iter() .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); @@ -431,50 +550,132 @@ fn make_data_type<'a>( }) } ComplexType::Fixed(f) => { + // Possibly decimal with logicalType=decimal let size = f.size.try_into().map_err(|e| { ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) })?; + if let Some("decimal") = f.attributes.logical_type { + let precision = f + .attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .ok_or_else(|| { + ArrowError::ParseError("Decimal requires precision".to_string()) + })?; + let size_val = f + .attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .ok_or_else(|| { + ArrowError::ParseError("Decimal requires size".to_string()) + })?; + let scale = f + .attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .or_else(|| Some(0)); + + let field = AvroDataType { + nullability: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Decimal( + precision as usize, + Some(scale.unwrap_or(0) as usize), + Some(size_val as usize), + ), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } else { + let field = AvroDataType { + nullability: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Fixed(size), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + } + ComplexType::Enum(e) => { + let symbols = e.symbols.iter().map(|sym| sym.to_string()).collect::>(); let field = AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), + metadata: e.attributes.field_metadata(), + codec: Codec::Enum(symbols), + }; + resolver.register(e.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Map(m) => { + let values_data_type = make_data_type(m.values.as_ref(), namespace, resolver)?; + let field = AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(values_data_type)), }; - resolver.register(f.name, namespace, field.clone()); Ok(field) } - ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( - "Enum of {e:?} not currently supported" - ))), - ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( - "Map of {m:?} not currently supported" - ))), }, Schema::Type(t) => { + // Possibly decimal, or other logical types let mut field = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - // https://avro.apache.org/docs/1.11.1/specification/#logical-types match (t.attributes.logical_type, &mut field.codec) { (Some("decimal"), c @ Codec::Fixed(_)) => { - return Err(ArrowError::NotYetImplemented( - "Decimals are not currently supported".to_string(), - )) + *c = Codec::Decimal( + t.attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize, + Some( + t.attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + Some( + t.attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + ); + } + (Some("decimal"), c @ Codec::Binary) => { + *c = Codec::Decimal( + t.attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize, + Some( + t.attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + None, + ); } (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) - } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) - } + (Some("local-timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(false), + (Some("local-timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(false), (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, (Some(logical), _) => { - // Insert unrecognized logical type into metadata map + // Insert unrecognized logical type into metadata field.metadata.insert("logicalType".into(), logical.into()); } (None, _) => {} @@ -490,20 +691,20 @@ fn make_data_type<'a>( } } - /// Convert an Arrow `Field` into an `AvroField`. -pub(crate) fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { - // TODO advanced metadata logic here +pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { + // Basic metadata logic: + // If arrow_field.metadata().get("namespace") is present, we store it below in AvroDataType let codec = arrow_type_to_codec(arrow_field.data_type()); - // Set nullability if the Arrow field is nullable let nullability = if arrow_field.is_nullable() { Some(Nullability::NullFirst) } else { None }; + let mut metadata = arrow_field.metadata().clone(); let avro_data_type = AvroDataType { nullability, - metadata: arrow_field.metadata().clone(), + metadata, codec, }; AvroField { @@ -512,7 +713,7 @@ pub(crate) fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { } } -/// Maps an Arrow `DataType` to a `Codec`: +/// Maps an Arrow `DataType` to a `Codec`. fn arrow_type_to_codec(dt: &DataType) -> Codec { use arrow_schema::DataType::*; match dt { @@ -527,29 +728,429 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Date32 => Codec::Date32, Time32(TimeUnit::Millisecond) => Codec::TimeMillis, Time64(TimeUnit::Microsecond) => Codec::TimeMicros, - Timestamp(TimeUnit::Millisecond, _) => Codec::TimestampMillis(true), - Timestamp(TimeUnit::Microsecond, _) => Codec::TimestampMicros(true), - FixedSizeBinary(n) => Codec::Fixed(*n as i32), - - List(field) => { - // Recursively create Codec for the child item - let child_codec = arrow_type_to_codec(field.data_type()); - Codec::List(Arc::new(AvroDataType { - nullability: None, - metadata: Default::default(), - codec: child_codec, - })) + Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), + Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), + Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMillis(true) + } + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMicros(true) + } + FixedSizeBinary(n) => Codec::Fixed(*n), + Decimal128(prec, scale) => Codec::Decimal( + *prec as usize, + Some(*scale as usize), + Some(16), + ), + Decimal256(prec, scale) => Codec::Decimal( + *prec as usize, + Some(*scale as usize), + Some(32), + ),Dictionary(index_type, value_type) => { + let mut md = HashMap::new(); + md.insert("dictionary_index_type".to_string(), format!("{:?}", index_type)); + if matches!(value_type.as_ref(), Utf8 | LargeUtf8) { + let mut dt = AvroDataType::from_codec(Codec::Enum(vec![])); + dt.metadata.extend(md); + Codec::Enum(vec![]) + } else { + // fallback + Codec::Utf8 + } + } + // For map => "type":"map" => in Arrow: DataType::Map + Map(field, _keys_sorted) => { + if let Struct(child_fields) = field.data_type() { + let value_field = &child_fields[1]; // name="value" + let sub_codec = arrow_type_to_codec(value_field.data_type()); + Codec::Map(Arc::new(AvroDataType { + nullability: value_field.is_nullable().then(|| Nullability::NullFirst), + metadata: value_field.metadata().clone(), + codec: sub_codec, + })) + } else { + Codec::Map(Arc::new(AvroDataType::from_codec(Codec::Utf8))) + } } Struct(child_fields) => { let avro_fields: Vec = child_fields .iter() - .map(|fref| arrow_field_to_avro_field(fref.as_ref())) + .map(|f_ref| arrow_field_to_avro_field(f_ref.as_ref())) .collect(); Codec::Struct(Arc::from(avro_fields)) } - _ => { - // TODO handle more arrow types (e.g. decimal, map, union, etc.) - Codec::Utf8 + _ => Codec::Utf8, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{DataType, Field}; + use std::sync::Arc; + use serde_json::json; + + #[test] + fn test_decimal256_tuple_variant_fixed() { + // Arrow decimal(60,3) => Codec::Decimal(60,3,Some(32)) + let c = arrow_type_to_codec(&DataType::Decimal256(60, 3)); + match c { + Codec::Decimal(p, s, Some(32)) => { + assert_eq!(p, 60); + assert_eq!(s, Some(3)); + } + _ => panic!("Expected decimal(60,3,Some(32))"), + } + let avro_dt = AvroDataType::from_codec(c); + let avro_schema = avro_dt.to_avro_schema("FixedDec"); + let j = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "FixedDec", + "aliases": [], + "size": 32, + "logicalType": "decimal", + "precision": 60, + "scale": 3 + }); + assert_eq!(j, expected); + } + + #[test] + fn test_decimal128_tuple_variant_fixed() { + // Avro "fixed" => decimal(6,2,Some(4)) + // arrow => decimal(6,2) + let c = Codec::Decimal(6, Some(2), Some(4)); + let dt = c.data_type(); + match dt { + DataType::Decimal128(p, s) => { + assert_eq!(p, 6); + assert_eq!(s, 2); + } + _ => panic!("Expected decimal(6,2) arrow type"), + } + + // Convert back to Avro schema => "fixed" + let avro_dt = AvroDataType::from_codec(c); + let schema = avro_dt.to_avro_schema("FixedDec"); + let j = serde_json::to_value(&schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "FixedDec", + "aliases": [], + "size": 4, + "logicalType": "decimal", + "precision": 6, + "scale": 2, + }); + assert_eq!(j, expected); + } + + #[test] + fn test_decimal_size_decision() { + // Decimal128 (size <= 16) + let codec = Codec::Decimal(10, Some(3), Some(16)); + let dt = codec.data_type(); + match dt { + DataType::Decimal128(precision, scale) => { + assert_eq!(precision, 10); + assert_eq!(scale, 3); + } + _ => panic!("Expected Decimal128"), + } + + // Decimal256 (size > 16) + let codec = Codec::Decimal(18, Some(4), Some(32)); + let dt = codec.data_type(); + match dt { + DataType::Decimal256(precision, scale) => { + assert_eq!(precision, 18); + assert_eq!(scale, 4); + } + _ => panic!("Expected Decimal256"), + } + + // Default to Decimal128 (size not specified) + let codec = Codec::Decimal(8, Some(2), None); + let dt = codec.data_type(); + match dt { + DataType::Decimal128(precision, scale) => { + assert_eq!(precision, 8); + assert_eq!(scale, 2); + } + _ => panic!("Expected Decimal128"), + } + } + + #[test] + fn test_avro_data_type_new_and_from_codec() { + let dt1 = AvroDataType::new( + Codec::Int32, + Some(Nullability::NullFirst), + HashMap::from([("namespace".into(), "my.ns".into())]), + ); + + let actual_str = format!("{:?}", dt1.nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + + let actual_str2 = format!("{:?}", dt1.codec()); + let expected_str2 = format!("{:?}", &Codec::Int32); + assert_eq!(actual_str2, expected_str2); + assert_eq!(dt1.metadata.get("namespace"), Some(&"my.ns".to_string())); + + let dt2 = AvroDataType::from_codec(Codec::Float64); + let actual_str4 = format!("{:?}", dt2.codec()); + let expected_str4 = format!("{:?}", &Codec::Float64); + assert_eq!(actual_str4, expected_str4); + assert!(dt2.metadata.is_empty()); + } + + #[test] + fn test_avro_data_type_field_with_name() { + let dt = AvroDataType::new( + Codec::Binary, + None, + HashMap::from([("something".into(), "else".into())]), + ); + let f = dt.field_with_name("bin_col"); + assert_eq!(f.name(), "bin_col"); + assert_eq!(f.data_type(), &DataType::Binary); + assert!(!f.is_nullable()); + assert_eq!(f.metadata().get("something"), Some(&"else".to_string())); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_record() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example".to_string()); + let fields = Arc::from(vec![ + AvroField { + name: "id".to_string(), + data_type: AvroDataType::from_codec(Codec::Int32), + }, + AvroField { + name: "label".to_string(), + data_type: AvroDataType::new(Codec::Utf8, Some(Nullability::NullFirst), Default::default()), + } + ]); + let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); + let avro_schema = top_level.to_avro_schema("TopRecord"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + + let expected = json!({ + "type": "record", + "name": "TopRecord", + "namespace": "com.example", + "doc": null, + "logicalType": null, + "aliases": [], + "fields": [ + { "name": "id", "doc": null, "type": "int" }, + { "name": "label", "doc": null, "type": ["null","string"] } + ], + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_enum() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example.enum".to_string()); + + let enum_dt = AvroDataType::new( + Codec::Enum(vec!["A".to_string(), "B".to_string(), "C".to_string()]), + None, + meta, + ); + let avro_schema = enum_dt.to_avro_schema("MyEnum"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "enum", + "name": "MyEnum", + "logicalType": null, + "namespace": "com.example.enum", + "doc": null, + "aliases": [], + "symbols": ["A","B","C"] + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_fixed() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example.fixed".to_string()); + + let fixed_dt = AvroDataType::new(Codec::Fixed(8), None, meta); + + let avro_schema = fixed_dt.to_avro_schema("MyFixed"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + + let expected = json!({ + "type": "fixed", + "name": "MyFixed", + "logicalType": null, + "namespace": "com.example.fixed", + "aliases": [], + "size": 8 + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_field() { + let field_codec = AvroDataType::from_codec(Codec::Int64); + let avro_field = AvroField { + name: "long_col".to_string(), + data_type: field_codec.clone(), + }; + + assert_eq!(avro_field.name(), "long_col"); + + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::Int64); + assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); + + let arrow_field = avro_field.field(); + assert_eq!(arrow_field.name(), "long_col"); + assert_eq!(arrow_field.data_type(), &DataType::Int64); + assert!(!arrow_field.is_nullable()); + } + + #[test] + fn test_arrow_field_to_avro_field() { + let arrow_field = Field::new( + "test_meta", + DataType::Utf8, + true, + ) + .with_metadata(HashMap::from([ + ("namespace".to_string(), "arrow_meta_ns".to_string()) + ])); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert_eq!(avro_field.name(), "test_meta"); + + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::Utf8); + assert_eq!(actual_str, expected_str); + + let actual_str = format!("{:?}", avro_field.data_type().nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + + // Confirm we kept the metadata + assert_eq!( + avro_field.data_type().metadata.get("namespace"), + Some(&"arrow_meta_ns".to_string()) + ); + } + + #[test] + fn test_codec_struct() { + let fields = Arc::from(vec![ + AvroField { + name: "a".to_string(), + data_type: AvroDataType::from_codec(Codec::Boolean), + }, + AvroField { + name: "b".to_string(), + data_type: AvroDataType::from_codec(Codec::Float64), + }, + ]); + let codec = Codec::Struct(fields); + let dt = codec.data_type(); + match dt { + DataType::Struct(fields) => { + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), "a"); + assert_eq!(fields[0].data_type(), &DataType::Boolean); + assert_eq!(fields[1].name(), "b"); + assert_eq!(fields[1].data_type(), &DataType::Float64); + } + _ => panic!("Expected Struct data type"), + } + } + + #[test] + fn test_codec_fixedsizebinary() { + let codec = Codec::Fixed(12); + let dt = codec.data_type(); + match dt { + DataType::FixedSizeBinary(n) => assert_eq!(n, 12), + _ => panic!("Expected FixedSizeBinary(12)"), } } + + #[test] + fn test_utc_timestamp_millis() { + let arrow_field = Field::new( + "utc_ts_ms", + DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMillis(true)), + "Expected Codec::TimestampMillis(true), got: {:?}", + codec + ); + } + + #[test] + fn test_utc_timestamp_micros() { + let arrow_field = Field::new( + "utc_ts_us", + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMicros(true)), + "Expected Codec::TimestampMicros(true), got: {:?}", + codec + ); + } + + #[test] + fn test_local_timestamp_millis() { + let arrow_field = Field::new( + "local_ts_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMillis(false)), + "Expected Codec::TimestampMillis(false), got: {:?}", + codec + ); + } + + #[test] + fn test_local_timestamp_micros() { + let arrow_field = Field::new( + "local_ts_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ); + + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + + assert!( + matches!(codec, Codec::TimestampMicros(false)), + "Expected Codec::TimestampMicros(false), got: {:?}", + codec + ); + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 52a58cf63303..97ccc1032b76 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -144,6 +144,9 @@ impl Decoder { } Self::Record(arrow_fields.into(), encodings) } + _ => { + Self::Null(0) // TODO: Add decoders for Enum, Map, and Decimal + } }; Ok(match data_type.nullability() { From 36b4b734cea5d9a8b7c28044c76ee076988acebd Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 22:48:29 -0600 Subject: [PATCH 03/38] Added reader record decoder support for non-null Enum, Map, and Decimal types. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 47 +-- arrow-avro/src/reader/cursor.rs | 28 +- arrow-avro/src/reader/record.rs | 500 ++++++++++++++++++++++++++++++-- 3 files changed, 500 insertions(+), 75 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index aab38a45e444..01a2732e99bc 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -53,7 +53,6 @@ pub struct AvroDataType { impl AvroDataType { /// Create a new AvroDataType with the given parts. - /// This helps you construct it from outside `codec.rs` without exposing internals. pub fn new( codec: Codec, nullability: Option, @@ -261,7 +260,7 @@ impl Codec { DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) }, _ => { - // Infer based on precision when size is None + // Note: Infer based on precision when size is None if *precision <= DECIMAL128_MAX_PRECISION as usize && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize { @@ -409,7 +408,6 @@ impl Codec { Codec::Decimal(precision, scale, size) => { // If size is Some(n), produce Avro "fixed", else "bytes". if let Some(n) = size { - // fixed with logicalType=decimal, plus precision/scale Schema::Complex(ComplexType::Fixed(AvroFixed { name, namespace: None, @@ -532,7 +530,6 @@ fn make_data_type<'a>( }) }) .collect::>()?; - let field = AvroDataType { nullability: None, codec: Codec::Struct(fields), @@ -554,7 +551,6 @@ fn make_data_type<'a>( let size = f.size.try_into().map_err(|e| { ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) })?; - if let Some("decimal") = f.attributes.logical_type { let precision = f .attributes @@ -578,7 +574,6 @@ fn make_data_type<'a>( .get("scale") .and_then(|v| v.as_u64()) .or_else(|| Some(0)); - let field = AvroDataType { nullability: None, metadata: f.attributes.field_metadata(), @@ -624,7 +619,6 @@ fn make_data_type<'a>( // Possibly decimal, or other logical types let mut field = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - match (t.attributes.logical_type, &mut field.codec) { (Some("decimal"), c @ Codec::Fixed(_)) => { *c = Codec::Decimal( @@ -792,7 +786,6 @@ mod tests { #[test] fn test_decimal256_tuple_variant_fixed() { - // Arrow decimal(60,3) => Codec::Decimal(60,3,Some(32)) let c = arrow_type_to_codec(&DataType::Decimal256(60, 3)); match c { Codec::Decimal(p, s, Some(32)) => { @@ -818,8 +811,6 @@ mod tests { #[test] fn test_decimal128_tuple_variant_fixed() { - // Avro "fixed" => decimal(6,2,Some(4)) - // arrow => decimal(6,2) let c = Codec::Decimal(6, Some(2), Some(4)); let dt = c.data_type(); match dt { @@ -829,8 +820,6 @@ mod tests { } _ => panic!("Expected decimal(6,2) arrow type"), } - - // Convert back to Avro schema => "fixed" let avro_dt = AvroDataType::from_codec(c); let schema = avro_dt.to_avro_schema("FixedDec"); let j = serde_json::to_value(&schema).unwrap(); @@ -848,7 +837,6 @@ mod tests { #[test] fn test_decimal_size_decision() { - // Decimal128 (size <= 16) let codec = Codec::Decimal(10, Some(3), Some(16)); let dt = codec.data_type(); match dt { @@ -858,8 +846,6 @@ mod tests { } _ => panic!("Expected Decimal128"), } - - // Decimal256 (size > 16) let codec = Codec::Decimal(18, Some(4), Some(32)); let dt = codec.data_type(); match dt { @@ -869,8 +855,6 @@ mod tests { } _ => panic!("Expected Decimal256"), } - - // Default to Decimal128 (size not specified) let codec = Codec::Decimal(8, Some(2), None); let dt = codec.data_type(); match dt { @@ -889,16 +873,13 @@ mod tests { Some(Nullability::NullFirst), HashMap::from([("namespace".into(), "my.ns".into())]), ); - let actual_str = format!("{:?}", dt1.nullability()); let expected_str = format!("{:?}", Some(Nullability::NullFirst)); assert_eq!(actual_str, expected_str); - let actual_str2 = format!("{:?}", dt1.codec()); let expected_str2 = format!("{:?}", &Codec::Int32); assert_eq!(actual_str2, expected_str2); assert_eq!(dt1.metadata.get("namespace"), Some(&"my.ns".to_string())); - let dt2 = AvroDataType::from_codec(Codec::Float64); let actual_str4 = format!("{:?}", dt2.codec()); let expected_str4 = format!("{:?}", &Codec::Float64); @@ -937,7 +918,6 @@ mod tests { let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); let avro_schema = top_level.to_avro_schema("TopRecord"); let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ "type": "record", "name": "TopRecord", @@ -981,12 +961,9 @@ mod tests { fn test_avro_data_type_to_avro_schema_with_namespace_fixed() { let mut meta = HashMap::new(); meta.insert("namespace".to_string(), "com.example.fixed".to_string()); - let fixed_dt = AvroDataType::new(Codec::Fixed(8), None, meta); - let avro_schema = fixed_dt.to_avro_schema("MyFixed"); let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ "type": "fixed", "name": "MyFixed", @@ -1005,13 +982,10 @@ mod tests { name: "long_col".to_string(), data_type: field_codec.clone(), }; - assert_eq!(avro_field.name(), "long_col"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); let expected_str = format!("{:?}", &Codec::Int64); assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); - let arrow_field = avro_field.field(); assert_eq!(arrow_field.name(), "long_col"); assert_eq!(arrow_field.data_type(), &DataType::Int64); @@ -1024,22 +998,17 @@ mod tests { "test_meta", DataType::Utf8, true, - ) - .with_metadata(HashMap::from([ - ("namespace".to_string(), "arrow_meta_ns".to_string()) - ])); + ).with_metadata(HashMap::from([ + ("namespace".to_string(), "arrow_meta_ns".to_string()) + ])); let avro_field = arrow_field_to_avro_field(&arrow_field); assert_eq!(avro_field.name(), "test_meta"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); let expected_str = format!("{:?}", &Codec::Utf8); assert_eq!(actual_str, expected_str); - let actual_str = format!("{:?}", avro_field.data_type().nullability()); let expected_str = format!("{:?}", Some(Nullability::NullFirst)); assert_eq!(actual_str, expected_str); - - // Confirm we kept the metadata assert_eq!( avro_field.data_type().metadata.get("namespace"), Some(&"arrow_meta_ns".to_string()) @@ -1089,10 +1058,8 @@ mod tests { DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMillis(true)), "Expected Codec::TimestampMillis(true), got: {:?}", @@ -1107,10 +1074,8 @@ mod tests { DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMicros(true)), "Expected Codec::TimestampMicros(true), got: {:?}", @@ -1125,10 +1090,8 @@ mod tests { DataType::Timestamp(TimeUnit::Millisecond, None), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMillis(false)), "Expected Codec::TimestampMillis(false), got: {:?}", @@ -1143,10 +1106,8 @@ mod tests { DataType::Timestamp(TimeUnit::Microsecond, None), false, ); - let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); - assert!( matches!(codec, Codec::TimestampMicros(false)), "Expected Codec::TimestampMicros(false), got: {:?}", diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 4b6a5a4d65db..ba1d01f72d7e 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; @@ -65,27 +64,32 @@ impl<'a> AvroCursor<'a> { Ok(val) } + /// Decode a zig-zag encoded Avro int (32-bit). #[inline] pub(crate) fn get_int(&mut self) -> Result { let varint = self.read_vlq()?; let val: u32 = varint .try_into() .map_err(|_| ArrowError::ParseError("varint overflow".to_string()))?; + // Zig-zag decode Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } + /// Decode a zig-zag encoded Avro long (64-bit). #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; + // Zig-zag decode Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } + /// Read a variable-length byte array from Avro (where the length is stored as an Avro long). pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { let len: usize = self.get_long()?.try_into().map_err(|_| { ArrowError::ParseError("offset overflow reading avro bytes".to_string()) })?; - if (self.buf.len() < len) { + if self.buf.len() < len { return Err(ArrowError::ParseError( "Unexpected EOF reading bytes".to_string(), )); @@ -95,9 +99,10 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 32-bit float #[inline] pub(crate) fn get_float(&mut self) -> Result { - if (self.buf.len() < 4) { + if self.buf.len() < 4 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -107,15 +112,28 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 64-bit float #[inline] pub(crate) fn get_double(&mut self) -> Result { - if (self.buf.len() < 8) { + if self.buf.len() < 8 { return Err(ArrowError::ParseError( - "Unexpected EOF reading float".to_string(), + "Unexpected EOF reading double".to_string(), )); } let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); self.buf = &self.buf[8..]; Ok(ret) } + + /// Read exactly `n` bytes from the buffer (e.g. for Avro `fixed`). + pub(crate) fn get_fixed(&mut self, n: usize) -> Result<&'a [u8], ArrowError> { + if self.buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected EOF reading fixed".to_string(), + )); + } + let ret = &self.buf[..n]; + self.buf = &self.buf[n..]; + Ok(ret) + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 97ccc1032b76..4c57a3426bd6 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -23,12 +23,12 @@ use crate::schema::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; -use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, -}; +use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit}; use std::collections::HashMap; use std::io::Read; +use std::ptr::null; use std::sync::Arc; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; /// Decodes avro encoded data into [`RecordBatch`] pub struct RecordDecoder { @@ -94,9 +94,17 @@ enum Decoder { List(FieldRef, OffsetBufferBuilder, Box), Record(Fields, Vec), Nullable(Nullability, NullBufferBuilder, Box), + Enum(Vec, Vec), + Map(FieldRef, OffsetBufferBuilder, OffsetBufferBuilder, Vec, Box, usize), + Decimal(usize, usize, Option, Vec>), } impl Decoder { + /// Checks if the Decoder is nullable + fn is_nullable(&self) -> bool { + matches!(self, Decoder::Nullable(_, _, _)) + } + fn try_new(data_type: &AvroDataType) -> Result { let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); @@ -144,11 +152,39 @@ impl Decoder { } Self::Record(arrow_fields.into(), encodings) } - _ => { - Self::Null(0) // TODO: Add decoders for Enum, Map, and Decimal + Codec::Enum(symbols) => { + Decoder::Enum( + symbols.clone(), + Vec::with_capacity(DEFAULT_CAPACITY), + ) + } + Codec::Map(value_type) => { + let map_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(value_type.field_with_name("value")), + ])), + false, + )); + Decoder::Map( + map_field, + OffsetBufferBuilder::new(DEFAULT_CAPACITY), // key_offsets + OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets + Vec::with_capacity(DEFAULT_CAPACITY), // key_data + Box::new(Self::try_new(value_type)?), // values_decoder_inner + 0, // current_entry_count + ) + } + Codec::Decimal(precision, scale, size) => { + Decoder::Decimal( + *precision, + scale.unwrap_or(0), + *size, + Vec::with_capacity(DEFAULT_CAPACITY), + ) } }; - Ok(match data_type.nullability() { Some(nullability) => Self::Nullable( nullability, @@ -178,6 +214,23 @@ impl Decoder { } Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + Self::Enum(_, _) => { + // For Enum, appending a null is not straightforward. Handle accordingly if needed. + } + Self::Map( + _, + key_offsets, + map_offsets_builder, + key_data, + values_decoder_inner, + current_entry_count, + ) => { + key_offsets.push_length(0); + map_offsets_builder.push_length(*current_entry_count); + } + Self::Decimal(_, _, _, _) => { + // For Decimal, appending a null doesn't make sense as per current implementation + } } } @@ -218,59 +271,256 @@ impl Decoder { false => e.append_null(), } } + Self::Enum(symbols, indices) => { + // Encodes enum by writing its zero-based index as an int + let index = buf.get_int()?; + indices.push(index); + } + Self::Map( + field, + key_offsets, + map_offsets_builder, + key_data, + values_decoder_inner, + current_entry_count, + ) => { + let block_count = buf.get_long()?; + if block_count <= 0 { + // Push the current_entry_count without changes + map_offsets_builder.push_length(*current_entry_count); + } else { + let n = block_count as usize; + for _ in 0..n { + let key_bytes = buf.get_bytes()?; + key_offsets.push_length(key_bytes.len()); + key_data.extend_from_slice(key_bytes); + values_decoder_inner.decode(buf)?; + } + // Update the current_entry_count and push to map_offsets_builder + *current_entry_count += n; + map_offsets_builder.push_length(*current_entry_count); + } + } + Self::Decimal( + precision, + scale, + size, + data + ) => { + let raw = if let Some(fixed_len) = size { + // get_fixed used to get exactly fixed_len bytes + buf.get_fixed(*fixed_len)? + } else { + // get_bytes used for variable-length + buf.get_bytes()? + }; + data.push(raw.to_vec()); + } } Ok(()) } /// Flush decoded records to an [`ArrayRef`] fn flush(&mut self, nulls: Option) -> Result { - Ok(match self { - Self::Nullable(_, n, e) => e.flush(n.finish())?, - Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), - Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), - Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), + match self { + Self::Nullable(_, n, e) => e.flush(n.finish()), + Self::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), + Self::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), + Self::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Self::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Self::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), Self::TimeMillis(values) => { - Arc::new(flush_primitive::(values, nulls)) + Ok(Arc::new(flush_primitive::(values, nulls))) } Self::TimeMicros(values) => { - Arc::new(flush_primitive::(values, nulls)) + Ok(Arc::new(flush_primitive::(values, nulls))) } - Self::TimestampMillis(is_utc, values) => Arc::new( + Self::TimestampMillis(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::TimestampMicros(is_utc, values) => Arc::new( + )), + Self::TimestampMicros(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - + )), + Self::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Self::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), Self::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); - Arc::new(BinaryArray::new(offsets, values, nulls)) + Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } Self::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); - Arc::new(StringArray::new(offsets, values, nulls)) + Ok(Arc::new(StringArray::new(offsets, values, nulls))) } Self::List(field, offsets, values) => { let values = values.flush(None)?; let offsets = flush_offsets(offsets); - Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + Ok(Arc::new(ListArray::new(field.clone(), offsets, values, nulls))) } Self::Record(fields, encodings) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } - }) + Self::Enum(symbols, indices) => { + let dict_values = StringArray::from_iter_values(symbols.iter()); + let flushed_indices = flush_values(indices); // Vec + let indices_array: Int32Array = match nulls { + Some(buf) => { + let buffer = Buffer::from_slice_ref(&flushed_indices); + PrimitiveArray::::try_new(ScalarBuffer::from(buffer), Some(buf.clone()))? + }, + None => { + Int32Array::from_iter_values(flushed_indices) + } + }; + let dict_array = DictionaryArray::::try_new( + indices_array, + Arc::new(dict_values), + )?; + Ok(Arc::new(dict_array)) + } + Self::Map( + field, + key_offsets_builder, + map_offsets_builder, + key_data, + values_decoder_inner, + current_entry_count, + ) => { + let map_offsets = flush_offsets(map_offsets_builder); + let key_offsets = flush_offsets(key_offsets_builder); + let key_data = flush_values(key_data).into(); + let key_array = StringArray::new(key_offsets, key_data, None); + let val_array = values_decoder_inner.flush(None)?; + let is_nullable = matches!(**values_decoder_inner, Decoder::Nullable(_, _, _)); + let struct_fields = vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(ArrowField::new("value", val_array.data_type().clone(), is_nullable)), + ]; + let struct_array = StructArray::new( + Fields::from(struct_fields), + vec![Arc::new(key_array), val_array], + None, + ); + let map_array = MapArray::new(field.clone(), map_offsets.clone(), struct_array.clone(), nulls, false); + Ok(Arc::new(map_array)) + } + Self::Decimal( + precision, + scale, + size, + data, + ) => { + let mut array_builder = DecimalBuilder::new(*precision, *scale, *size)?; + for raw in data.drain(..) { + if let Some(s) = size { + if raw.len() < *s { + let extended = sign_extend(&raw, *s); + array_builder.append_bytes(&extended)?; + continue; + } + } + array_builder.append_bytes(&raw)?; + } + let arr = array_builder.finish()?; + Ok(Arc::new(arr)) + } + } + } +} + +/// Helper to build a field with a given type +fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(ArrowField::new(name, dt, nullable)) +} + +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut extended = Vec::with_capacity(target_len); + if sign_bit != 0 { // negative + extended.resize(target_len - raw.len(), 0xFF); + } else { // positive + extended.resize(target_len - raw.len(), 0x00); + } + extended.extend_from_slice(raw); + extended +} + +/// Extend raw bytes to 16 bytes (for Decimal128) +fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let extended = sign_extend(raw, 16); + if extended.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend bytes to 16 bytes: got {} bytes", + extended.len() + ))); + } + Ok(extended.try_into().unwrap()) +} + +/// Extend raw bytes to 32 bytes (for Decimal256) +fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let extended = sign_extend(raw, 32); + if extended.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend bytes to 32 bytes: got {} bytes", + extended.len() + ))); + } + Ok(extended.try_into().unwrap()) +} + +/// Trait for building decimal arrays +enum DecimalBuilder { + Decimal128(Decimal128Builder), + Decimal256(Decimal256Builder), +} + +impl DecimalBuilder { + + fn new(precision: usize, scale: usize, size: Option) -> Result { + match size { + Some(s) if s > 16 => { + // decimal256 + Ok(Self::Decimal256(Decimal256Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + } + _ => { + // decimal128 + Ok(Self::Decimal128(Decimal128Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + } + } + } + + fn append_bytes(&mut self, bytes: &[u8]) -> Result<(), ArrowError> { + match self { + DecimalBuilder::Decimal128(b) => { + let padded = extend_to_16_bytes(bytes)?; + let value = i128::from_be_bytes(padded); + b.append_value(value); + } + DecimalBuilder::Decimal256(b) => { + let padded = extend_to_32_bytes(bytes)?; + let value = i256::from_be_bytes(padded); + b.append_value(value); + } + } + Ok(()) + } + + fn finish(self) -> Result { + match self { + DecimalBuilder::Decimal128(mut b) => Ok(Arc::new(b.finish())), + DecimalBuilder::Decimal256(mut b) => Ok(Arc::new(b.finish())), + } } } @@ -293,3 +543,199 @@ fn flush_primitive( } const DEFAULT_CAPACITY: usize = 1024; + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ + Array, ArrayRef, Int32Array, MapArray, StringArray, StructArray, + Decimal128Array, Decimal256Array, DictionaryArray, + }; + use arrow_array::cast::AsArray; + use arrow_schema::{Field as ArrowField, DataType as ArrowDataType}; + + /// Helper functions for encoding test data + fn encode_avro_int(value: i32) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 31); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_long(value: i64) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 63); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_bytes(bytes: &[u8]) -> Vec { + let mut buf = encode_avro_long(bytes.len() as i64); + buf.extend_from_slice(bytes); + buf + } + + #[test] + fn test_enum_decoding() { + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + // Encode the indices [1, 0, 2] using zigzag encoding + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] + data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] + data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + assert_eq!(dict_arr.len(), 3); + let keys = dict_arr.keys(); + assert_eq!(keys.value(0), 1); + assert_eq!(keys.value(1), 0); + assert_eq!(keys.value(2), 2); + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_map_decoding_one_entry() { + let value_type = AvroDataType::from_codec(Codec::Utf8); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Avro encoding for a map: + // - block_count: 1 (number of entries) + // - keys: "hello" (5 bytes) + // - values: "world" (5 bytes) + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key = "hello" + data.extend_from_slice(&encode_avro_bytes(b"world")); // value = "world" + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // Verify 1 map + assert_eq!(map_arr.value_length(0), 1); // Verify 1 entry in the map + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 1); // Verify 1 entry in StructArray + let key = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key.value(0), "hello"); // Verify Key + assert_eq!(value.value(0), "world"); // Verify Value + } + + #[test] + fn test_map_decoding_empty() { + let value_type = AvroDataType::from_codec(Codec::Utf8); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Avro encoding for an empty map: + // - block_count: 0 (no entries) + let data = encode_avro_long(0); // block_count = 0 + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // Verify 1 map + assert_eq!(map_arr.value_length(0), 0); // Verify 0 entries in the map + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 0); // // Verify 0 entries StructArray + let key = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let value = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key.len(), 0); + assert_eq!(value.len(), 0); + } + + #[test] + fn test_decimal_decoding_fixed128() { + let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + // Row1: 123.45 => unscaled: 12345 => i128: 0x00000000000000000000000000003039 + // Row2: -1.23 => unscaled: -123 => i128: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF85 + let row1 = [ + 0x00, 0x00, 0x00, 0x00, // First 8 bytes + 0x00, 0x00, 0x00, 0x00, // Next 8 bytes + 0x00, 0x00, 0x00, 0x00, // Next 8 bytes + 0x00, 0x00, 0x30, 0x39, // Last 8 bytes: 0x3039 = 12345 + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, // First 8 bytes (two's complement) + 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes + 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes + 0xFF, 0xFF, 0xFF, 0x85, // Last 8 bytes: 0xFFFFFF85 = -123 + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + decoder.decode(&mut AvroCursor::new(&data[16..])).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 2); + assert_eq!(dec_arr.value_as_string(0), "123.45"); + assert_eq!(dec_arr.value_as_string(1), "-1.23"); + } + + #[test] + fn test_decimal_decoding_bytes() { + let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let unscaled_row1: i128 = 1234; // 123.4 + let unscaled_row2: i128 = -1234; // -123.4 + // Note: convert unscaled values to big-endian bytes + let bytes_row1 = unscaled_row1.to_be_bytes(); + let bytes_row2 = unscaled_row2.to_be_bytes(); + // Row1: 1234 => 0x04D2 (2 bytes) + // Row2: -1234 => two's complement of 0x04D2 = 0xFB2E (2 bytes) + let row1_bytes = &bytes_row1[14..16]; // Last 2 bytes + let row2_bytes = &bytes_row2[14..16]; // Last 2 bytes + let mut data = Vec::new(); + // Encode row1 + data.extend_from_slice(&encode_avro_long(2)); // Length=2 + data.extend_from_slice(row1_bytes); // 0x04, 0xD2 + // Encode row2 + data.extend_from_slice(&encode_avro_long(2)); // Length=2 + data.extend_from_slice(row2_bytes); // 0xFB, 0x2E + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 2); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(1), "-123.4"); + } +} From 9d0bf4cc7e490e34aa409d62e8a4d3eee0783e70 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 23:01:18 -0600 Subject: [PATCH 04/38] Added reader record decoder support for non-null Enum, Map, and Decimal types. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 01a2732e99bc..92274a167de6 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -740,7 +740,8 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { *prec as usize, Some(*scale as usize), Some(32), - ),Dictionary(index_type, value_type) => { + ), + Dictionary(index_type, value_type) => { let mut md = HashMap::new(); md.insert("dictionary_index_type".to_string(), format!("{:?}", index_type)); if matches!(value_type.as_ref(), Utf8 | LargeUtf8) { From 082581a4d98a6931718a8491ee524f6b86fc7dd5 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 30 Dec 2024 23:02:15 -0600 Subject: [PATCH 05/38] Added reader record decoder support for non-null Enum, Map, and Decimal types. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 92274a167de6..d58390a57bf2 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -748,15 +748,13 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { let mut dt = AvroDataType::from_codec(Codec::Enum(vec![])); dt.metadata.extend(md); Codec::Enum(vec![]) - } else { - // fallback + } else { // fallback Codec::Utf8 } } - // For map => "type":"map" => in Arrow: DataType::Map Map(field, _keys_sorted) => { if let Struct(child_fields) = field.data_type() { - let value_field = &child_fields[1]; // name="value" + let value_field = &child_fields[1]; let sub_codec = arrow_type_to_codec(value_field.data_type()); Codec::Map(Arc::new(AvroDataType { nullability: value_field.is_nullable().then(|| Nullability::NullFirst), From 6647b250054b6a9681c976e265d643ffcbc1d85c Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 12:39:55 -0600 Subject: [PATCH 06/38] Added null support Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 86 ++--- arrow-avro/src/reader/record.rs | 651 ++++++++++++++++++++++---------- 2 files changed, 493 insertions(+), 244 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index d58390a57bf2..4e57d4d186bc 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -29,6 +29,7 @@ use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray, RecordBatch}; use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; +use arrow_schema::DataType::*; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. @@ -207,43 +208,43 @@ impl Codec { /// Convert this to an Arrow `DataType` fn data_type(&self) -> DataType { match self { - Self::Null => DataType::Null, - Self::Boolean => DataType::Boolean, - Self::Int32 => DataType::Int32, - Self::Int64 => DataType::Int64, - Self::Float32 => DataType::Float32, - Self::Float64 => DataType::Float64, - Self::Binary => DataType::Binary, - Self::Utf8 => DataType::Utf8, - Self::Date32 => DataType::Date32, - Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), - Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + Self::Null => Null, + Self::Boolean => Boolean, + Self::Int32 => Int32, + Self::Int64 => Int64, + Self::Float32 => Float32, + Self::Float64 => Float64, + Self::Binary => Binary, + Self::Utf8 => Utf8, + Self::Date32 => Date32, + Self::TimeMillis => Time32(TimeUnit::Millisecond), + Self::TimeMicros => Time64(TimeUnit::Microsecond), Self::TimestampMillis(is_utc) => { - DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) } Self::TimestampMicros(is_utc) => { - DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } - Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), - Self::Fixed(size) => DataType::FixedSizeBinary(*size), + Self::Interval => Interval(IntervalUnit::MonthDayNano), + Self::Fixed(size) => FixedSizeBinary(*size), Self::List(f) => { - DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) + List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) } - Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + Self::Struct(f) => Struct(f.iter().map(|x| x.field()).collect()), Self::Enum(_symbols) => { // Produce a Dictionary type with index = Int32, value = Utf8 - DataType::Dictionary( + Dictionary( Box::new(DataType::Int32), Box::new(DataType::Utf8), ) } Self::Map(values) => { - DataType::Map( + Map( Arc::new(Field::new( "entries", - DataType::Struct( + Struct( Fields::from(vec![ - Field::new("key", DataType::Utf8, false), + Field::new("key", Utf8, false), values.field_with_name("value"), ]) ), @@ -254,19 +255,19 @@ impl Codec { } Self::Decimal(precision, scale, size) => match size { Some(s) if *s > 16 && *s <= 32 => { - DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + Decimal256(*precision as u8, scale.unwrap_or(0) as i8) }, Some(s) if *s <= 16 => { - DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) }, _ => { // Note: Infer based on precision when size is None if *precision <= DECIMAL128_MAX_PRECISION as usize && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize { - DataType::Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) } else { - DataType::Decimal256(*precision as u8, scale.unwrap_or(0) as i8) + Decimal256(*precision as u8, scale.unwrap_or(0) as i8) } } }, @@ -687,8 +688,6 @@ fn make_data_type<'a>( /// Convert an Arrow `Field` into an `AvroField`. pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { - // Basic metadata logic: - // If arrow_field.metadata().get("namespace") is present, we store it below in AvroDataType let codec = arrow_type_to_codec(arrow_field.data_type()); let nullability = if arrow_field.is_nullable() { Some(Nullability::NullFirst) @@ -709,7 +708,6 @@ pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { /// Maps an Arrow `DataType` to a `Codec`. fn arrow_type_to_codec(dt: &DataType) -> Codec { - use arrow_schema::DataType::*; match dt { Null => Codec::Null, Boolean => Codec::Boolean, @@ -742,13 +740,9 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Some(32), ), Dictionary(index_type, value_type) => { - let mut md = HashMap::new(); - md.insert("dictionary_index_type".to_string(), format!("{:?}", index_type)); - if matches!(value_type.as_ref(), Utf8 | LargeUtf8) { - let mut dt = AvroDataType::from_codec(Codec::Enum(vec![])); - dt.metadata.extend(md); + if let Utf8 = **value_type { Codec::Enum(vec![]) - } else { // fallback + } else { // Fallback to Utf8 Codec::Utf8 } } @@ -785,7 +779,7 @@ mod tests { #[test] fn test_decimal256_tuple_variant_fixed() { - let c = arrow_type_to_codec(&DataType::Decimal256(60, 3)); + let c = arrow_type_to_codec(&Decimal256(60, 3)); match c { Codec::Decimal(p, s, Some(32)) => { assert_eq!(p, 60); @@ -813,7 +807,7 @@ mod tests { let c = Codec::Decimal(6, Some(2), Some(4)); let dt = c.data_type(); match dt { - DataType::Decimal128(p, s) => { + Decimal128(p, s) => { assert_eq!(p, 6); assert_eq!(s, 2); } @@ -839,7 +833,7 @@ mod tests { let codec = Codec::Decimal(10, Some(3), Some(16)); let dt = codec.data_type(); match dt { - DataType::Decimal128(precision, scale) => { + Decimal128(precision, scale) => { assert_eq!(precision, 10); assert_eq!(scale, 3); } @@ -848,7 +842,7 @@ mod tests { let codec = Codec::Decimal(18, Some(4), Some(32)); let dt = codec.data_type(); match dt { - DataType::Decimal256(precision, scale) => { + Decimal256(precision, scale) => { assert_eq!(precision, 18); assert_eq!(scale, 4); } @@ -857,7 +851,7 @@ mod tests { let codec = Codec::Decimal(8, Some(2), None); let dt = codec.data_type(); match dt { - DataType::Decimal128(precision, scale) => { + Decimal128(precision, scale) => { assert_eq!(precision, 8); assert_eq!(scale, 2); } @@ -995,7 +989,7 @@ mod tests { fn test_arrow_field_to_avro_field() { let arrow_field = Field::new( "test_meta", - DataType::Utf8, + Utf8, true, ).with_metadata(HashMap::from([ ("namespace".to_string(), "arrow_meta_ns".to_string()) @@ -1029,7 +1023,7 @@ mod tests { let codec = Codec::Struct(fields); let dt = codec.data_type(); match dt { - DataType::Struct(fields) => { + Struct(fields) => { assert_eq!(fields.len(), 2); assert_eq!(fields[0].name(), "a"); assert_eq!(fields[0].data_type(), &DataType::Boolean); @@ -1045,7 +1039,7 @@ mod tests { let codec = Codec::Fixed(12); let dt = codec.data_type(); match dt { - DataType::FixedSizeBinary(n) => assert_eq!(n, 12), + FixedSizeBinary(n) => assert_eq!(n, 12), _ => panic!("Expected FixedSizeBinary(12)"), } } @@ -1054,7 +1048,7 @@ mod tests { fn test_utc_timestamp_millis() { let arrow_field = Field::new( "utc_ts_ms", - DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1070,7 +1064,7 @@ mod tests { fn test_utc_timestamp_micros() { let arrow_field = Field::new( "utc_ts_us", - DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1086,7 +1080,7 @@ mod tests { fn test_local_timestamp_millis() { let arrow_field = Field::new( "local_ts_ms", - DataType::Timestamp(TimeUnit::Millisecond, None), + Timestamp(TimeUnit::Millisecond, None), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1102,7 +1096,7 @@ mod tests { fn test_local_timestamp_micros() { let arrow_field = Field::new( "local_ts_us", - DataType::Timestamp(TimeUnit::Microsecond, None), + Timestamp(TimeUnit::Microsecond, None), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 4c57a3426bd6..a29f293107c0 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -23,7 +23,7 @@ use crate::schema::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; -use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit}; +use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION}; use std::collections::HashMap; use std::io::Read; use std::ptr::null; @@ -76,6 +76,7 @@ impl RecordDecoder { } } +/// Enum representing different decoders for various data types. #[derive(Debug)] enum Decoder { Null(usize), @@ -95,51 +96,59 @@ enum Decoder { Record(Fields, Vec), Nullable(Nullability, NullBufferBuilder, Box), Enum(Vec, Vec), - Map(FieldRef, OffsetBufferBuilder, OffsetBufferBuilder, Vec, Box, usize), - Decimal(usize, usize, Option, Vec>), + Map( + FieldRef, + OffsetBufferBuilder, // key_offsets + OffsetBufferBuilder, // map_offsets + Vec, // key_data + Box, // values_decoder_inner + usize, // current_entry_count + ), + Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - /// Checks if the Decoder is nullable + /// Checks if the Decoder is nullable. fn is_nullable(&self) -> bool { matches!(self, Decoder::Nullable(_, _, _)) } + /// Creates a new `Decoder` based on the provided `AvroDataType`. fn try_new(data_type: &AvroDataType) -> Result { let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); let decoder = match data_type.codec() { - Codec::Null => Self::Null(0), - Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), - Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Binary => Self::Binary( + Codec::Null => Decoder::Null(0), + Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + Codec::Int32 => Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Int64 => Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float32 => Decoder::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float64 => Decoder::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Binary => Decoder::Binary( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + Codec::Utf8 => Decoder::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Date32 => Decoder::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Decoder::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Decoder::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimestampMillis(is_utc) => { - Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Decoder::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::TimestampMicros(is_utc) => { - Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::Fixed(_) => return nyi("decoding fixed"), Codec::Interval => return nyi("decoding interval"), Codec::List(item) => { - let decoder = Self::try_new(item)?; - Self::List( + let decoder = Box::new(Self::try_new(item)?); + Decoder::List( Arc::new(item.field_with_name("item")), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + decoder, ) } Codec::Struct(fields) => { @@ -150,14 +159,9 @@ impl Decoder { arrow_fields.push(avro_field.field()); encodings.push(encoding); } - Self::Record(arrow_fields.into(), encodings) - } - Codec::Enum(symbols) => { - Decoder::Enum( - symbols.clone(), - Vec::with_capacity(DEFAULT_CAPACITY), - ) + Decoder::Record(arrow_fields.into(), encodings) } + Codec::Enum(symbols) => Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( "entries", @@ -170,54 +174,58 @@ impl Decoder { Decoder::Map( map_field, OffsetBufferBuilder::new(DEFAULT_CAPACITY), // key_offsets - OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets - Vec::with_capacity(DEFAULT_CAPACITY), // key_data - Box::new(Self::try_new(value_type)?), // values_decoder_inner + OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets + Vec::with_capacity(DEFAULT_CAPACITY), // key_data + Box::new(Self::try_new(value_type)?), // values_decoder_inner 0, // current_entry_count ) } Codec::Decimal(precision, scale, size) => { - Decoder::Decimal( - *precision, - scale.unwrap_or(0), - *size, - Vec::with_capacity(DEFAULT_CAPACITY), - ) + let builder = DecimalBuilder::new(*precision, *scale, *size)?; + Decoder::Decimal(*precision, *scale, *size, builder) } }; - Ok(match data_type.nullability() { - Some(nullability) => Self::Nullable( + + // Wrap the decoder in Nullable if necessary + match data_type.nullability() { + Some(nullability) => Ok(Decoder::Nullable( nullability, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(decoder), - ), - None => decoder, - }) + )), + None => Ok(decoder), + } } - /// Append a null record + /// Appends a null value to the decoder. fn append_null(&mut self) { match self { - Self::Null(count) => *count += 1, - Self::Boolean(b) => b.append(false), - Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), - Self::Int64(v) - | Self::TimeMicros(v) - | Self::TimestampMillis(_, v) - | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), - Self::List(_, offsets, e) => { + Decoder::Null(count) => *count += 1, + Decoder::Boolean(b) => b.append(false), + Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => v.push(0), + Decoder::Int64(v) + | Decoder::TimeMicros(v) + | Decoder::TimestampMillis(_, v) + | Decoder::TimestampMicros(_, v) => v.push(0), + Decoder::Float32(v) => v.push(0.0), + Decoder::Float64(v) => v.push(0.0), + Decoder::Binary(offsets, _) | Decoder::String(offsets, _) => { + offsets.push_length(0); + } + Decoder::List(_, offsets, e) => { offsets.push_length(0); e.append_null(); } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), - Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), - Self::Enum(_, _) => { - // For Enum, appending a null is not straightforward. Handle accordingly if needed. + Decoder::Record(_, encodings) => { + for encoding in encodings.iter_mut() { + encoding.append_null(); + } + } + Decoder::Enum(_, indices) => { + // Append a placeholder index for null entries + indices.push(0); } - Self::Map( + Decoder::Map( _, key_offsets, map_offsets_builder, @@ -228,55 +236,60 @@ impl Decoder { key_offsets.push_length(0); map_offsets_builder.push_length(*current_entry_count); } - Self::Decimal(_, _, _, _) => { - // For Decimal, appending a null doesn't make sense as per current implementation + Decoder::Decimal(_, _, _, builder) => { + builder.append_null(); } + Decoder::Nullable(_, _, _) => { /* Nulls are handled by the Nullable variant */ } } } - /// Decode a single record from `buf` + /// Decodes a single record from the provided buffer `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Self::Null(x) => *x += 1, - Self::Boolean(values) => values.append(buf.get_bool()?), - Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { - values.push(buf.get_int()?) - } - Self::Int64(values) - | Self::TimeMicros(values) - | Self::TimestampMillis(_, values) - | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), - Self::Float32(values) => values.push(buf.get_float()?), - Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) | Self::String(offsets, values) => { + Decoder::Null(x) => *x += 1, + Decoder::Boolean(values) => values.append(buf.get_bool()?), + Decoder::Int32(values) => values.push(buf.get_int()?), + Decoder::Date32(values) => values.push(buf.get_int()?), + Decoder::Int64(values) => values.push(buf.get_long()?), + Decoder::TimeMillis(values) => values.push(buf.get_int()?), + Decoder::TimeMicros(values) => values.push(buf.get_long()?), + Decoder::TimestampMillis(is_utc, values) => { + values.push(buf.get_long()?); + } + Decoder::TimestampMicros(is_utc, values) => { + values.push(buf.get_long()?); + } + Decoder::Float32(values) => values.push(buf.get_float()?), + Decoder::Float64(values) => values.push(buf.get_double()?), + Decoder::Binary(offsets, values) | Decoder::String(offsets, values) => { let data = buf.get_bytes()?; offsets.push_length(data.len()); values.extend_from_slice(data); } - Self::List(_, _, _) => { + Decoder::List(_, _, _) => { return Err(ArrowError::NotYetImplemented( "Decoding ListArray".to_string(), - )) + )); } - Self::Record(_, encodings) => { - for encoding in encodings { + Decoder::Record(fields, encodings) => { + for encoding in encodings.iter_mut() { encoding.decode(buf)?; } } - Self::Nullable(nullability, nulls, e) => { - let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); + Decoder::Nullable(_, nulls, e) => { + let is_valid = buf.get_bool()?; nulls.append(is_valid); match is_valid { true => e.decode(buf)?, false => e.append_null(), } } - Self::Enum(symbols, indices) => { - // Encodes enum by writing its zero-based index as an int + Decoder::Enum(symbols, indices) => { + // Enums are encoded as zero-based indices using zigzag encoding let index = buf.get_int()?; indices.push(index); } - Self::Map( + Decoder::Map( field, key_offsets, map_offsets_builder, @@ -301,83 +314,79 @@ impl Decoder { map_offsets_builder.push_length(*current_entry_count); } } - Self::Decimal( - precision, - scale, - size, - data - ) => { - let raw = if let Some(fixed_len) = size { - // get_fixed used to get exactly fixed_len bytes - buf.get_fixed(*fixed_len)? + Decoder::Decimal(_precision, _scale, _size, builder) => { + if let Some(size) = _size { + // Fixed-size decimal + let raw = buf.get_fixed(*size)?; + builder.append_bytes(raw)?; } else { - // get_bytes used for variable-length - buf.get_bytes()? - }; - data.push(raw.to_vec()); + // Variable-size decimal + let bytes = buf.get_bytes()?; + builder.append_bytes(bytes)?; + } } } Ok(()) } - /// Flush decoded records to an [`ArrayRef`] + /// Flushes decoded records to an [`ArrayRef`]. fn flush(&mut self, nulls: Option) -> Result { match self { - Self::Nullable(_, n, e) => e.flush(n.finish()), - Self::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), - Self::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), - Self::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::TimeMillis(values) => { + Decoder::Nullable(_, n, e) => e.flush(n.finish()), + Decoder::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), + Decoder::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), + Decoder::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::TimeMillis(values) => { Ok(Arc::new(flush_primitive::(values, nulls))) } - Self::TimeMicros(values) => { + Decoder::TimeMicros(values) => { Ok(Arc::new(flush_primitive::(values, nulls))) } - Self::TimestampMillis(is_utc, values) => Ok(Arc::new( + Decoder::TimestampMillis(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), )), - Self::TimestampMicros(is_utc, values) => Ok(Arc::new( + Decoder::TimestampMicros(is_utc, values) => Ok(Arc::new( flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), )), - Self::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Self::Binary(offsets, values) => { + Decoder::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), + Decoder::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } - Self::String(offsets, values) => { + Decoder::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } - Self::List(field, offsets, values) => { + Decoder::List(field, offsets, values) => { let values = values.flush(None)?; let offsets = flush_offsets(offsets); Ok(Arc::new(ListArray::new(field.clone(), offsets, values, nulls))) } - Self::Record(fields, encodings) => { + Decoder::Record(fields, encodings) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } - Self::Enum(symbols, indices) => { + Decoder::Enum(symbols, indices) => { let dict_values = StringArray::from_iter_values(symbols.iter()); - let flushed_indices = flush_values(indices); // Vec let indices_array: Int32Array = match nulls { Some(buf) => { - let buffer = Buffer::from_slice_ref(&flushed_indices); - PrimitiveArray::::try_new(ScalarBuffer::from(buffer), Some(buf.clone()))? - }, - None => { - Int32Array::from_iter_values(flushed_indices) + let buffer = arrow_buffer::Buffer::from_slice_ref(&indices); + PrimitiveArray::::try_new( + arrow_buffer::ScalarBuffer::from(buffer), + Some(buf.clone()), + )? } + None => Int32Array::from_iter_values(indices.iter().cloned()), }; let dict_array = DictionaryArray::::try_new( indices_array, @@ -385,7 +394,7 @@ impl Decoder { )?; Ok(Arc::new(dict_array)) } - Self::Map( + Decoder::Map( field, key_offsets_builder, map_offsets_builder, @@ -401,36 +410,37 @@ impl Decoder { let is_nullable = matches!(**values_decoder_inner, Decoder::Nullable(_, _, _)); let struct_fields = vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_array.data_type().clone(), is_nullable)), + Arc::new(ArrowField::new( + "value", + val_array.data_type().clone(), + is_nullable, + )), ]; let struct_array = StructArray::new( Fields::from(struct_fields), vec![Arc::new(key_array), val_array], None, ); - let map_array = MapArray::new(field.clone(), map_offsets.clone(), struct_array.clone(), nulls, false); + let map_array = MapArray::new( + field.clone(), + map_offsets.clone(), + struct_array.clone(), + nulls, + false, + ); Ok(Arc::new(map_array)) } - Self::Decimal( - precision, - scale, - size, - data, - ) => { - let mut array_builder = DecimalBuilder::new(*precision, *scale, *size)?; - for raw in data.drain(..) { - if let Some(s) = size { - if raw.len() < *s { - let extended = sign_extend(&raw, *s); - array_builder.append_bytes(&extended)?; - continue; - } - } - array_builder.append_bytes(&raw)?; - } - let arr = array_builder.finish()?; - Ok(Arc::new(arr)) + Decoder::Decimal(_precision, _scale, _size, builder) => { + let precision = *_precision; + let scale = _scale.unwrap_or(0); // Default scale if None + let size = _size.clone(); + let builder = std::mem::replace( + builder, + DecimalBuilder::new(precision, *_scale, *_size)?, + ); + Ok(builder.finish(nulls, precision, scale)?) // Pass precision and scale } + } } } @@ -440,22 +450,23 @@ fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { Arc::new(ArrowField::new(name, dt, nullable)) } +/// Extends raw bytes to the target length with sign extension. fn sign_extend(raw: &[u8], target_len: usize) -> Vec { if raw.is_empty() { return vec![0; target_len]; } let sign_bit = raw[0] & 0x80; let mut extended = Vec::with_capacity(target_len); - if sign_bit != 0 { // negative + if sign_bit != 0 { extended.resize(target_len - raw.len(), 0xFF); - } else { // positive + } else { extended.resize(target_len - raw.len(), 0x00); } extended.extend_from_slice(raw); extended } -/// Extend raw bytes to 16 bytes (for Decimal128) +/// Extends raw bytes to 16 bytes (for Decimal128). fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { let extended = sign_extend(raw, 16); if extended.len() != 16 { @@ -464,10 +475,12 @@ fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { extended.len() ))); } - Ok(extended.try_into().unwrap()) + let mut arr = [0u8; 16]; + arr.copy_from_slice(&extended); + Ok(arr) } -/// Extend raw bytes to 32 bytes (for Decimal256) +/// Extends raw bytes to 32 bytes (for Decimal256). fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { let extended = sign_extend(raw, 32); if extended.len() != 32 { @@ -476,30 +489,67 @@ fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { extended.len() ))); } - Ok(extended.try_into().unwrap()) + let mut arr = [0u8; 32]; + arr.copy_from_slice(&extended); + Ok(arr) } -/// Trait for building decimal arrays +/// Enum representing the builder for Decimal arrays. +#[derive(Debug)] enum DecimalBuilder { Decimal128(Decimal128Builder), Decimal256(Decimal256Builder), } impl DecimalBuilder { - - fn new(precision: usize, scale: usize, size: Option) -> Result { + /// Initializes a new `DecimalBuilder` based on precision, scale, and size. + fn new( + precision: usize, + scale: Option, + size: Option, + ) -> Result { match size { - Some(s) if s > 16 => { - // decimal256 - Ok(Self::Decimal256(Decimal256Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + Some(s) if s > 16 && s <= 32 => { + // Decimal256 + Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) } - _ => { - // decimal128 - Ok(Self::Decimal128(Decimal128Builder::new().with_precision_and_scale(precision as u8, scale as i8)?)) + Some(s) if s <= 16 => { + // Decimal128 + Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) } + None => { + // Infer based on precision + if precision <= DECIMAL128_MAX_PRECISION as usize { + Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) + } else if precision <= DECIMAL256_MAX_PRECISION as usize { + Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) + } else { + Err(ArrowError::ParseError(format!( + "Decimal precision {} exceeds maximum supported", + precision + ))) + } + } + _ => Err(ArrowError::ParseError(format!( + "Unsupported decimal size: {:?}", + size + ))), } } + /// Appends bytes to the decimal builder. fn append_bytes(&mut self, bytes: &[u8]) -> Result<(), ArrowError> { match self { DecimalBuilder::Decimal128(b) => { @@ -516,10 +566,46 @@ impl DecimalBuilder { Ok(()) } - fn finish(self) -> Result { + /// Appends a null value to the decimal builder by appending placeholder bytes. + fn append_null(&mut self) -> Result<(), ArrowError> { match self { - DecimalBuilder::Decimal128(mut b) => Ok(Arc::new(b.finish())), - DecimalBuilder::Decimal256(mut b) => Ok(Arc::new(b.finish())), + DecimalBuilder::Decimal128(b) => { + // Append zeroed bytes as placeholder + let placeholder = [0u8; 16]; + let value = i128::from_be_bytes(placeholder); + b.append_value(value); + } + DecimalBuilder::Decimal256(b) => { + // Append zeroed bytes as placeholder + let placeholder = [0u8; 32]; + let value = i256::from_be_bytes(placeholder); + b.append_value(value); + } + } + Ok(()) + } + + /// Finalizes the decimal array and returns it as an `ArrayRef`. + fn finish(self, nulls: Option, precision: usize, scale: usize) -> Result { + match self { + DecimalBuilder::Decimal128(mut b) => { + let array = b.finish(); + let values = array.values().clone(); + let decimal_array = Decimal128Array::new( + values, + nulls, + ).with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(decimal_array)) + } + DecimalBuilder::Decimal256(mut b) => { + let array = b.finish(); + let values = array.values().clone(); + let decimal_array = Decimal256Array::new( + values, + nulls, + ).with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(decimal_array)) + } } } } @@ -551,10 +637,12 @@ mod tests { Array, ArrayRef, Int32Array, MapArray, StringArray, StructArray, Decimal128Array, Decimal256Array, DictionaryArray, }; - use arrow_array::cast::AsArray; + use arrow_buffer::Buffer; use arrow_schema::{Field as ArrowField, DataType as ArrowDataType}; + use serde_json::json; + use arrow_array::cast::AsArray; - /// Helper functions for encoding test data + /// Helper functions for encoding test data. fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -615,22 +703,23 @@ mod tests { let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode a single map with one entry: {"hello": "world"} // Avro encoding for a map: - // - block_count: 1 (number of entries) - // - keys: "hello" (5 bytes) - // - values: "world" (5 bytes) + // - block_count: 1 (encoded as [2] due to ZigZag) + // - keys: "hello" (encoded with length prefix) + // - values: "world" (encoded with length prefix) let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 + data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 data.extend_from_slice(&encode_avro_bytes(b"hello")); // key = "hello" data.extend_from_slice(&encode_avro_bytes(b"world")); // value = "world" decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // Verify 1 map - assert_eq!(map_arr.value_length(0), 1); // Verify 1 entry in the map + assert_eq!(map_arr.len(), 1); // One map + assert_eq!(map_arr.value_length(0), 1); // One entry in the map let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); // Verify 1 entry in StructArray + assert_eq!(struct_entries.len(), 1); // One entry in StructArray let key = struct_entries .column_by_name("key") .unwrap() @@ -643,7 +732,7 @@ mod tests { .as_any() .downcast_ref::() .unwrap(); - assert_eq!(key.value(0), "hello"); // Verify Key + assert_eq!(key.value(0), "hello"); // Verify Key assert_eq!(value.value(0), "world"); // Verify Value } @@ -652,17 +741,18 @@ mod tests { let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode an empty map // Avro encoding for an empty map: - // - block_count: 0 (no entries) + // - block_count: 0 (encoded as [0] due to ZigZag) let data = encode_avro_long(0); // block_count = 0 decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // Verify 1 map - assert_eq!(map_arr.value_length(0), 0); // Verify 0 entries in the map + assert_eq!(map_arr.len(), 1); // One map + assert_eq!(map_arr.value_length(0), 0); // Zero entries in the map let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 0); // // Verify 0 entries StructArray + assert_eq!(struct_entries.len(), 0); // Zero entries in StructArray let key = struct_entries .column_by_name("key") .unwrap() @@ -710,32 +800,197 @@ mod tests { } #[test] - fn test_decimal_decoding_bytes() { + fn test_decimal_decoding_bytes_with_nulls() { let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); let mut decoder = Decoder::try_new(&dt).unwrap(); - let unscaled_row1: i128 = 1234; // 123.4 - let unscaled_row2: i128 = -1234; // -123.4 - // Note: convert unscaled values to big-endian bytes - let bytes_row1 = unscaled_row1.to_be_bytes(); - let bytes_row2 = unscaled_row2.to_be_bytes(); - // Row1: 1234 => 0x04D2 (2 bytes) - // Row2: -1234 => two's complement of 0x04D2 = 0xFB2E (2 bytes) - let row1_bytes = &bytes_row1[14..16]; // Last 2 bytes - let row2_bytes = &bytes_row2[14..16]; // Last 2 bytes + // Wrap the decimal in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + // Row1: 123.4 => unscaled: 1234 => bytes: [0x04, 0xD2] + // Row2: null + // Row3: -123.4 => unscaled: -1234 => bytes: [0xFB, 0x2E] let mut data = Vec::new(); - // Encode row1 - data.extend_from_slice(&encode_avro_long(2)); // Length=2 - data.extend_from_slice(row1_bytes); // 0x04, 0xD2 - // Encode row2 - data.extend_from_slice(&encode_avro_long(2)); // Length=2 - data.extend_from_slice(row2_bytes); // 0xFB, 0x2E + // Row1: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); // 0x04D2 = 1234 + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); // 0xFB2E = -1234 let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 2); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 123.4 + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -123.4 + let array = nullable_decoder.flush(None).unwrap(); + let dec_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(1), "-123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); } -} + + #[test] + fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + // Wrap the decimal in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + // Correct Byte Encoding: + // Row1: 1234.56 => unscaled: 123456 => bytes: [0x00; 12] + [0x00, 0x01, 0xE2, 0x40] + // Row2: null + // Row3: -1234.56 => unscaled: -123456 => bytes: [0xFF; 12] + [0xFE, 0x1D, 0xC0, 0x00] + let row1_bytes = &[ + 0x00, 0x00, 0x00, 0x00, // First 4 bytes + 0x00, 0x00, 0x00, 0x00, // Next 4 bytes + 0x00, 0x00, 0x00, 0x01, // Next 4 bytes + 0xE2, 0x40, 0x00, 0x00, // Last 4 bytes + ]; + let row3_bytes = &[ + 0xFF, 0xFF, 0xFF, 0xFF, // First 4 bytes (two's complement) + 0xFF, 0xFF, 0xFF, 0xFF, // Next 4 bytes + 0xFF, 0xFF, 0xFE, 0x1D, // Next 4 bytes + 0xC0, 0x00, 0x00, 0x00, // Last 4 bytes + ]; + + let mut data = Vec::new(); + // Row1: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(row1_bytes); // 1234.56 + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(row3_bytes); // -1234.56 + + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 1234.56 + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -1234.56 + + let array = nullable_decoder.flush(None).unwrap(); + let dec_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + + #[test] + fn test_enum_decoding_with_nulls() { + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + + // Wrap the enum in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + + // Encode the indices [1, null, 2] using ZigZag encoding + // Indices: 1 -> [2], null -> no index, 2 -> [4] + let mut data = Vec::new(); + // Row1: valid (1) + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid (2) + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: RED + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: BLUE + + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + + assert_eq!(dict_arr.len(), 3); + let keys = dict_arr.keys(); + let validity = dict_arr.is_valid(0); // Correctly access the null buffer + + assert_eq!(keys.value(0), 1); + assert_eq!(keys.value(1), 0); // Placeholder index for null + assert_eq!(keys.value(2), 2); + + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null + assert!(dict_arr.is_valid(2)); + + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_enum_with_nullable_entries() { + let symbols = vec!["APPLE".to_string(), "BANANA".to_string(), "CHERRY".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + + // Wrap the enum in a Nullable decoder + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ); + + // Encode the indices [0, null, 2, 1] using ZigZag encoding + let mut data = Vec::new(); + // Row1: valid (0) -> "APPLE" + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] + // Row2: null + data.extend_from_slice(&[0u8]); // is_valid = false + // Row3: valid (2) -> "CHERRY" + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + // Row4: valid (1) -> "BANANA" + data.extend_from_slice(&[1u8]); // is_valid = true + data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] + + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // Row1: APPLE + nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null + nullable_decoder.decode(&mut cursor).unwrap(); // Row3: CHERRY + nullable_decoder.decode(&mut cursor).unwrap(); // Row4: BANANA + + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + + assert_eq!(dict_arr.len(), 4); + let keys = dict_arr.keys(); + let validity = dict_arr.is_valid(0); // Correctly access the null buffer + + assert_eq!(keys.value(0), 0); + assert_eq!(keys.value(1), 0); // Placeholder index for null + assert_eq!(keys.value(2), 2); + assert_eq!(keys.value(3), 1); + + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null + assert!(dict_arr.is_valid(2)); + assert!(dict_arr.is_valid(3)); + + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "APPLE"); + assert_eq!(dict_values.value(1), "BANANA"); + assert_eq!(dict_values.value(2), "CHERRY"); + } +} \ No newline at end of file From 9cfda09a48a73b0c3ae8fa540bb552eaf02da8b1 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 15:05:40 -0600 Subject: [PATCH 07/38] * Reader decoder Support for nullable types. * Implemented reader decoder for Avro Lists * Cleaned up reader/record.rs and added comments for readability. Signed-off-by: Connor Sanders --- arrow-avro/src/reader/record.rs | 1240 +++++++++++++++++-------------- 1 file changed, 691 insertions(+), 549 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index a29f293107c0..500fe27fd53b 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -22,38 +22,48 @@ use crate::reader::header::Header; use crate::schema::*; use arrow_array::types::*; use arrow_array::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; use arrow_buffer::*; -use arrow_schema::{ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION}; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; use std::collections::HashMap; use std::io::Read; -use std::ptr::null; use std::sync::Arc; -use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; -/// Decodes avro encoded data into [`RecordBatch`] +/// The default capacity used for internal buffers +const DEFAULT_CAPACITY: usize = 1024; + +/// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. pub struct RecordDecoder { schema: SchemaRef, fields: Vec, } impl RecordDecoder { + /// Create a new [`RecordDecoder`] from an [`AvroDataType`] expected to be a `Record`. pub fn try_new(data_type: &AvroDataType) -> Result { match Decoder::try_new(data_type)? { Decoder::Record(fields, encodings) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), fields: encodings, }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } + /// Return the [`SchemaRef`] describing the Arrow schema of rows produced by this decoder. pub fn schema(&self) -> &SchemaRef { &self.schema } - /// Decode `count` records from `buf` + /// Decode `count` Avro records from `buf`. + /// + /// This accumulates data in internal buffers. Once done reading, call + /// [`Self::flush`] to yield an Arrow [`RecordBatch`]. pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); for _ in 0..count { @@ -64,7 +74,7 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush the decoded records into a [`RecordBatch`] + /// Flush the accumulated data into a [`RecordBatch`], clearing internal state. pub fn flush(&mut self) -> Result { let arrays = self .fields @@ -76,47 +86,78 @@ impl RecordDecoder { } } -/// Enum representing different decoders for various data types. +/// Decoder for Avro data of various shapes. +/// +/// This is the “internal” representation used by [`RecordDecoder`]. #[derive(Debug)] enum Decoder { + /// Avro `null` Null(usize), + /// Avro `boolean` Boolean(BooleanBufferBuilder), + /// Avro `int` => i32 Int32(Vec), + /// Avro `long` => i64 Int64(Vec), + /// Avro `float` => f32 Float32(Vec), + /// Avro `double` => f64 Float64(Vec), + /// Avro `date` => Date32 Date32(Vec), + /// Avro `time-millis` => Time32(Millisecond) TimeMillis(Vec), + /// Avro `time-micros` => Time64(Microsecond) TimeMicros(Vec), + /// Avro `timestamp-millis` (bool = UTC?) TimestampMillis(bool, Vec), + /// Avro `timestamp-micros` (bool = UTC?) TimestampMicros(bool, Vec), + /// Avro `bytes` => Arrow Binary Binary(OffsetBufferBuilder, Vec), + /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), + /// Avro `array` + /// * `FieldRef` is the arrow field for the list + /// * `OffsetBufferBuilder` holds offsets into the child array + /// * The boxed `Decoder` decodes T itself List(FieldRef, OffsetBufferBuilder, Box), + /// Avro `record` + /// * `Fields` is the Arrow schema of the record + /// * The `Vec` is one decoder per child field Record(Fields, Vec), + /// Avro union that includes `null` => decodes as a single arrow field + a null bit mask Nullable(Nullability, NullBufferBuilder, Box), + /// Avro `enum` => Dictionary(int32 -> string) Enum(Vec, Vec), + /// Avro `map` + /// * The `FieldRef` is the arrow field for the map + /// * `key_offsets`, `map_offsets`: offset builders + /// * `key_data` accumulates the raw UTF8 for keys + /// * `values_decoder_inner` decodes the map’s value type + /// * `current_entry_count` how many (key,value) pairs total seen so far Map( FieldRef, - OffsetBufferBuilder, // key_offsets - OffsetBufferBuilder, // map_offsets - Vec, // key_data - Box, // values_decoder_inner - usize, // current_entry_count + OffsetBufferBuilder, + OffsetBufferBuilder, + Vec, + Box, + usize, ), + /// Avro decimal => Arrow decimal + /// (precision, scale, size, builder) Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - /// Checks if the Decoder is nullable. + /// Checks if the Decoder is nullable, i.e. wrapped in [`Decoder::Nullable`]. fn is_nullable(&self) -> bool { matches!(self, Decoder::Nullable(_, _, _)) } - /// Creates a new `Decoder` based on the provided `AvroDataType`. + /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { - let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); - + let not_implemented = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); let decoder = match data_type.codec() { Codec::Null => Decoder::Null(0), Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -141,25 +182,25 @@ impl Decoder { Codec::TimestampMicros(is_utc) => { Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return nyi("decoding fixed"), - Codec::Interval => return nyi("decoding interval"), + Codec::Fixed(_) => return not_implemented("decoding Avro fixed-typed data"), + Codec::Interval => return not_implemented("decoding Avro interval"), Codec::List(item) => { - let decoder = Box::new(Self::try_new(item)?); + let item_decoder = Box::new(Self::try_new(item)?); Decoder::List( Arc::new(item.field_with_name("item")), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - decoder, + item_decoder, ) } Codec::Struct(fields) => { let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut encodings = Vec::with_capacity(fields.len()); + let mut decoders = Vec::with_capacity(fields.len()); for avro_field in fields.iter() { - let encoding = Self::try_new(avro_field.data_type())?; + let d = Self::try_new(avro_field.data_type())?; arrow_fields.push(avro_field.field()); - encodings.push(encoding); + decoders.push(d); } - Decoder::Record(arrow_fields.into(), encodings) + Decoder::Record(arrow_fields.into(), decoders) } Codec::Enum(symbols) => Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Map(value_type) => { @@ -173,11 +214,11 @@ impl Decoder { )); Decoder::Map( map_field, - OffsetBufferBuilder::new(DEFAULT_CAPACITY), // key_offsets - OffsetBufferBuilder::new(DEFAULT_CAPACITY), // map_offsets - Vec::with_capacity(DEFAULT_CAPACITY), // key_data - Box::new(Self::try_new(value_type)?), // values_decoder_inner - 0, // current_entry_count + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + Box::new(Self::try_new(value_type)?), + 0, ) } Codec::Decimal(precision, scale, size) => { @@ -185,11 +226,9 @@ impl Decoder { Decoder::Decimal(*precision, *scale, *size, builder) } }; - - // Wrap the decoder in Nullable if necessary match data_type.nullability() { - Some(nullability) => Ok(Decoder::Nullable( - nullability, + Some(nb) => Ok(Decoder::Nullable( + nb, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(decoder), )), @@ -197,304 +236,375 @@ impl Decoder { } } - /// Appends a null value to the decoder. + /// Append a null to this decoder. + /// + /// This must keep the “row counts” in sync across child buffers, etc. fn append_null(&mut self) { match self { - Decoder::Null(count) => *count += 1, - Decoder::Boolean(b) => b.append(false), - Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => v.push(0), + Decoder::Null(n) => { + *n += 1; + } + Decoder::Boolean(b) => { + b.append(false); + } + Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => { + v.push(0); + } Decoder::Int64(v) | Decoder::TimeMicros(v) | Decoder::TimestampMillis(_, v) - | Decoder::TimestampMicros(_, v) => v.push(0), - Decoder::Float32(v) => v.push(0.0), - Decoder::Float64(v) => v.push(0.0), - Decoder::Binary(offsets, _) | Decoder::String(offsets, _) => { - offsets.push_length(0); - } - Decoder::List(_, offsets, e) => { - offsets.push_length(0); - e.append_null(); - } - Decoder::Record(_, encodings) => { - for encoding in encodings.iter_mut() { - encoding.append_null(); + | Decoder::TimestampMicros(_, v) => { + v.push(0); + } + Decoder::Float32(v) => { + v.push(0.0); + } + Decoder::Float64(v) => { + v.push(0.0); + } + Decoder::Binary(off, _) | Decoder::String(off, _) => { + off.push_length(0); + } + Decoder::List(_, off, child) => { + off.push_length(0); + child.append_null(); + } + Decoder::Record(_, children) => { + for c in children.iter_mut() { + c.append_null(); } } Decoder::Enum(_, indices) => { - // Append a placeholder index for null entries indices.push(0); } - Decoder::Map( - _, - key_offsets, - map_offsets_builder, - key_data, - values_decoder_inner, - current_entry_count, - ) => { - key_offsets.push_length(0); - map_offsets_builder.push_length(*current_entry_count); + Decoder::Map(_, key_off, map_off, _, _, entry_count) => { + key_off.push_length(0); + map_off.push_length(*entry_count); } Decoder::Decimal(_, _, _, builder) => { - builder.append_null(); + let _ = builder.append_null(); } - Decoder::Nullable(_, _, _) => { /* Nulls are handled by the Nullable variant */ } + Decoder::Nullable(_, _, _) => { /* The null mask is handled by the outer decoder */ } } } - /// Decodes a single record from the provided buffer `buf`. + /// Decode a single “row” of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Decoder::Null(x) => *x += 1, - Decoder::Boolean(values) => values.append(buf.get_bool()?), - Decoder::Int32(values) => values.push(buf.get_int()?), - Decoder::Date32(values) => values.push(buf.get_int()?), - Decoder::Int64(values) => values.push(buf.get_long()?), - Decoder::TimeMillis(values) => values.push(buf.get_int()?), - Decoder::TimeMicros(values) => values.push(buf.get_long()?), - Decoder::TimestampMillis(is_utc, values) => { - values.push(buf.get_long()?); - } - Decoder::TimestampMicros(is_utc, values) => { - values.push(buf.get_long()?); - } - Decoder::Float32(values) => values.push(buf.get_float()?), - Decoder::Float64(values) => values.push(buf.get_double()?), - Decoder::Binary(offsets, values) | Decoder::String(offsets, values) => { - let data = buf.get_bytes()?; - offsets.push_length(data.len()); - values.extend_from_slice(data); - } - Decoder::List(_, _, _) => { - return Err(ArrowError::NotYetImplemented( - "Decoding ListArray".to_string(), - )); + Decoder::Null(n) => { + *n += 1; } - Decoder::Record(fields, encodings) => { - for encoding in encodings.iter_mut() { - encoding.decode(buf)?; - } + Decoder::Boolean(vals) => { + vals.append(buf.get_bool()?); } - Decoder::Nullable(_, nulls, e) => { - let is_valid = buf.get_bool()?; - nulls.append(is_valid); - match is_valid { - true => e.decode(buf)?, - false => e.append_null(), + Decoder::Int32(vals) => { + vals.push(buf.get_int()?); + } + Decoder::Date32(vals) => { + vals.push(buf.get_int()?); + } + Decoder::Int64(vals) => { + vals.push(buf.get_long()?); + } + Decoder::TimeMillis(vals) => { + vals.push(buf.get_int()?); + } + Decoder::TimeMicros(vals) => { + vals.push(buf.get_long()?); + } + Decoder::TimestampMillis(_, vals) => { + vals.push(buf.get_long()?); + } + Decoder::TimestampMicros(_, vals) => { + vals.push(buf.get_long()?); + } + Decoder::Float32(vals) => { + vals.push(buf.get_float()?); + } + Decoder::Float64(vals) => { + vals.push(buf.get_double()?); + } + Decoder::Binary(off, data) | Decoder::String(off, data) => { + let bytes = buf.get_bytes()?; + off.push_length(bytes.len()); + data.extend_from_slice(bytes); + } + Decoder::List(_, off, child) => { + let total_items = read_array_blocks(buf, |b| child.decode(b))?; + off.push_length(total_items); + } + Decoder::Record(_, children) => { + for c in children.iter_mut() { + c.decode(buf)?; } } - Decoder::Enum(symbols, indices) => { - // Enums are encoded as zero-based indices using zigzag encoding - let index = buf.get_int()?; - indices.push(index); - } - Decoder::Map( - field, - key_offsets, - map_offsets_builder, - key_data, - values_decoder_inner, - current_entry_count, - ) => { - let block_count = buf.get_long()?; - if block_count <= 0 { - // Push the current_entry_count without changes - map_offsets_builder.push_length(*current_entry_count); - } else { - let n = block_count as usize; - for _ in 0..n { - let key_bytes = buf.get_bytes()?; - key_offsets.push_length(key_bytes.len()); - key_data.extend_from_slice(key_bytes); - values_decoder_inner.decode(buf)?; + Decoder::Nullable(_, null_buf, child) => { + let branch_index = buf.get_int()?; + match branch_index { + 0 => { + // child + null_buf.append(true); + child.decode(buf)?; + } + 1 => { + // null + null_buf.append(false); + child.append_null(); + } + other => { + return Err(ArrowError::ParseError(format!( + "Unsupported union branch index {other} for Nullable" + ))); } - // Update the current_entry_count and push to map_offsets_builder - *current_entry_count += n; - map_offsets_builder.push_length(*current_entry_count); } } - Decoder::Decimal(_precision, _scale, _size, builder) => { - if let Some(size) = _size { - // Fixed-size decimal - let raw = buf.get_fixed(*size)?; + Decoder::Enum(_, indices) => { + let idx = buf.get_int()?; + indices.push(idx); + } + Decoder::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { + let newly_added = read_map_blocks(buf, |b| { + let kb = b.get_bytes()?; + key_off.push_length(kb.len()); + key_data.extend_from_slice(kb); + val_decoder.decode(b) + })?; + *entry_count += newly_added; + map_off.push_length(*entry_count); + } + Decoder::Decimal(_, _, size, builder) => { + if let Some(sz) = *size { + let raw = buf.get_fixed(sz)?; builder.append_bytes(raw)?; } else { - // Variable-size decimal - let bytes = buf.get_bytes()?; - builder.append_bytes(bytes)?; + let variable = buf.get_bytes()?; + builder.append_bytes(variable)?; } } } Ok(()) } - /// Flushes decoded records to an [`ArrayRef`]. + /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { - Decoder::Nullable(_, n, e) => e.flush(n.finish()), - Decoder::Null(size) => Ok(Arc::new(NullArray::new(std::mem::replace(size, 0)))), - Decoder::Boolean(b) => Ok(Arc::new(BooleanArray::new(b.finish(), nulls))), - Decoder::Int32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Date32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Int64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::TimeMillis(values) => { - Ok(Arc::new(flush_primitive::(values, nulls))) - } - Decoder::TimeMicros(values) => { - Ok(Arc::new(flush_primitive::(values, nulls))) - } - Decoder::TimestampMillis(is_utc, values) => Ok(Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), - )), - Decoder::TimestampMicros(is_utc, values) => Ok(Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt::>(is_utc.then(|| "+00:00".into())), - )), - Decoder::Float32(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Float64(values) => Ok(Arc::new(flush_primitive::(values, nulls))), - Decoder::Binary(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); + Decoder::Nullable(_, nb, child) => { + let mask = nb.finish(); + child.flush(mask) + } + // Null => produce NullArray + Decoder::Null(len) => { + let count = std::mem::replace(len, 0); + Ok(Arc::new(NullArray::new(count))) + } + // boolean => flush to BooleanArray + Decoder::Boolean(b) => { + let bits = b.finish(); + Ok(Arc::new(BooleanArray::new(bits, nulls))) + } + // int32 => flush to Int32Array + Decoder::Int32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // date32 => flush to Date32Array + Decoder::Date32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // int64 => flush to Int64Array + Decoder::Int64(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-millis => Time32Millisecond + Decoder::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-micros => Time64Microsecond + Decoder::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // timestamp-millis => TimestampMillisecond + Decoder::TimestampMillis(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) + } + // timestamp-micros => TimestampMicrosecond + Decoder::TimestampMicros(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) + } + // float32 => flush to Float32Array + Decoder::Float32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // float64 => flush to Float64Array + Decoder::Float64(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // Avro bytes => BinaryArray + Decoder::Binary(off, data) => { + let offsets = flush_offsets(off); + let values = flush_values(data).into(); Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } - Decoder::String(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); + // Avro string => StringArray + Decoder::String(off, data) => { + let offsets = flush_offsets(off); + let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } - Decoder::List(field, offsets, values) => { - let values = values.flush(None)?; - let offsets = flush_offsets(offsets); - Ok(Arc::new(ListArray::new(field.clone(), offsets, values, nulls))) + // Avro array => ListArray + Decoder::List(field, off, item_dec) => { + let child_arr = item_dec.flush(None)?; + let offsets = flush_offsets(off); + let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); + Ok(Arc::new(arr)) } - Decoder::Record(fields, encodings) => { - let arrays = encodings - .iter_mut() - .map(|x| x.flush(None)) - .collect::, _>>()?; + // Avro record => StructArray + Decoder::Record(fields, children) => { + let mut arrays = Vec::with_capacity(children.len()); + for c in children.iter_mut() { + let a = c.flush(None)?; + arrays.push(a); + } Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } + // Avro enum => DictionaryArray utf8> Decoder::Enum(symbols, indices) => { let dict_values = StringArray::from_iter_values(symbols.iter()); - let indices_array: Int32Array = match nulls { - Some(buf) => { - let buffer = arrow_buffer::Buffer::from_slice_ref(&indices); + let idxs: Int32Array = match nulls { + Some(b) => { + let buff = Buffer::from_slice_ref(&indices); PrimitiveArray::::try_new( - arrow_buffer::ScalarBuffer::from(buffer), - Some(buf.clone()), + arrow_buffer::ScalarBuffer::from(buff), + Some(b), )? } None => Int32Array::from_iter_values(indices.iter().cloned()), }; - let dict_array = DictionaryArray::::try_new( - indices_array, - Arc::new(dict_values), - )?; - Ok(Arc::new(dict_array)) - } - Decoder::Map( - field, - key_offsets_builder, - map_offsets_builder, - key_data, - values_decoder_inner, - current_entry_count, - ) => { - let map_offsets = flush_offsets(map_offsets_builder); - let key_offsets = flush_offsets(key_offsets_builder); - let key_data = flush_values(key_data).into(); - let key_array = StringArray::new(key_offsets, key_data, None); - let val_array = values_decoder_inner.flush(None)?; - let is_nullable = matches!(**values_decoder_inner, Decoder::Nullable(_, _, _)); + let dict = DictionaryArray::::try_new(idxs, Arc::new(dict_values))?; + indices.clear(); // reset + Ok(Arc::new(dict)) + } + // Avro map => MapArray + Decoder::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { + let moff = flush_offsets(map_off); + let koff = flush_offsets(key_off); + let kd = flush_values(key_data).into(); + let val_arr = val_dec.flush(None)?; + let is_nullable = matches!(**val_dec, Decoder::Nullable(_, _, _)); + let key_arr = StringArray::new(koff, kd, None); let struct_fields = vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), Arc::new(ArrowField::new( "value", - val_array.data_type().clone(), + val_arr.data_type().clone(), is_nullable, )), ]; - let struct_array = StructArray::new( + let entries = StructArray::new( Fields::from(struct_fields), - vec![Arc::new(key_array), val_array], + vec![Arc::new(key_arr), val_arr], None, ); - let map_array = MapArray::new( - field.clone(), - map_offsets.clone(), - struct_array.clone(), - nulls, - false, - ); - Ok(Arc::new(map_array)) - } - Decoder::Decimal(_precision, _scale, _size, builder) => { - let precision = *_precision; - let scale = _scale.unwrap_or(0); // Default scale if None - let size = _size.clone(); - let builder = std::mem::replace( - builder, - DecimalBuilder::new(precision, *_scale, *_size)?, - ); - Ok(builder.finish(nulls, precision, scale)?) // Pass precision and scale + let map_arr = MapArray::new(field.clone(), moff, entries, nulls, false); + *entry_count = 0; + Ok(Arc::new(map_arr)) + } + // Avro decimal => Arrow decimal + Decoder::Decimal(prec, sc, sz, builder) => { + let precision = *prec; + let scale = sc.unwrap_or(0); + let new_builder = DecimalBuilder::new(precision, *sc, *sz)?; + let old_builder = std::mem::replace(builder, new_builder); + let arr = old_builder.finish(nulls, precision, scale)?; + Ok(arr) } - } } } -/// Helper to build a field with a given type -fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { - Arc::new(ArrowField::new(name, dt, nullable)) +/// Helper to decode an Avro array in blocks until a 0 block_count signals end. +/// +/// Each block may be negative, in which case we read an extra “block size” `long`, +/// but typically ignore it unless we want to skip. This function invokes `decode_item` once per item. +fn read_array_blocks( + buf: &mut AvroCursor, + mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let mut total_items = 0usize; + loop { + let block_count = buf.get_long()?; + if block_count == 0 { + break; + } else if block_count < 0 { + let item_count = (-block_count) as usize; + let _block_size = buf.get_long()?; // read but ignore + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; + } else { + let item_count = block_count as usize; + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; + } + } + Ok(total_items) } -/// Extends raw bytes to the target length with sign extension. -fn sign_extend(raw: &[u8], target_len: usize) -> Vec { - if raw.is_empty() { - return vec![0; target_len]; - } - let sign_bit = raw[0] & 0x80; - let mut extended = Vec::with_capacity(target_len); - if sign_bit != 0 { - extended.resize(target_len - raw.len(), 0xFF); +/// Helper to decode an Avro map in blocks until a 0 block_count signals end. +/// +/// For each entry in a block, we decode a key (bytes) + a value (`decode_value`). +/// Returns how many map entries were decoded. +fn read_map_blocks( + buf: &mut AvroCursor, + mut decode_value: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let block_count = buf.get_long()?; + if block_count <= 0 { + Ok(0) } else { - extended.resize(target_len - raw.len(), 0x00); + let n = block_count as usize; + for _ in 0..n { + decode_value(buf)?; + } + Ok(n) } - extended.extend_from_slice(raw); - extended } -/// Extends raw bytes to 16 bytes (for Decimal128). -fn extend_to_16_bytes(raw: &[u8]) -> Result<[u8; 16], ArrowError> { - let extended = sign_extend(raw, 16); - if extended.len() != 16 { - return Err(ArrowError::ParseError(format!( - "Failed to extend bytes to 16 bytes: got {} bytes", - extended.len() - ))); - } - let mut arr = [0u8; 16]; - arr.copy_from_slice(&extended); - Ok(arr) +/// Flush a [`Vec`] of primitive values to a [`PrimitiveArray`], applying optional `nulls`. +#[inline] +fn flush_primitive( + values: &mut Vec, + nulls: Option, +) -> PrimitiveArray { + PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Extends raw bytes to 32 bytes (for Decimal256). -fn extend_to_32_bytes(raw: &[u8]) -> Result<[u8; 32], ArrowError> { - let extended = sign_extend(raw, 32); - if extended.len() != 32 { - return Err(ArrowError::ParseError(format!( - "Failed to extend bytes to 32 bytes: got {} bytes", - extended.len() - ))); - } - let mut arr = [0u8; 32]; - arr.copy_from_slice(&extended); - Ok(arr) +/// Flush an [`OffsetBufferBuilder`], returning its completed offsets. +#[inline] +fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +/// Remove and return the contents of `values`, replacing it with an empty buffer. +#[inline] +fn flush_values(values: &mut Vec) -> Vec { + std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) } -/// Enum representing the builder for Decimal arrays. +/// A builder for Avro decimal, either 128-bit or 256-bit. #[derive(Debug)] enum DecimalBuilder { Decimal128(Decimal128Builder), @@ -502,7 +612,7 @@ enum DecimalBuilder { } impl DecimalBuilder { - /// Initializes a new `DecimalBuilder` based on precision, scale, and size. + /// Create a new DecimalBuilder given precision, scale, and optional byte-size (`fixed`). fn new( precision: usize, scale: Option, @@ -510,30 +620,38 @@ impl DecimalBuilder { ) -> Result { match size { Some(s) if s > 16 && s <= 32 => { - // Decimal256 + // decimal256 Ok(Self::Decimal256( - Decimal256Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal256Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } Some(s) if s <= 16 => { - // Decimal128 + // decimal128 Ok(Self::Decimal128( - Decimal128Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal128Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } None => { - // Infer based on precision + // infer from precision when fixed size is None if precision <= DECIMAL128_MAX_PRECISION as usize { Ok(Self::Decimal128( - Decimal128Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal128Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } else if precision <= DECIMAL256_MAX_PRECISION as usize { Ok(Self::Decimal256( - Decimal256Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + Decimal256Builder::new().with_precision_and_scale( + precision as u8, + scale.unwrap_or(0) as i8, + )?, )) } else { Err(ArrowError::ParseError(format!( @@ -549,100 +667,127 @@ impl DecimalBuilder { } } - /// Appends bytes to the decimal builder. - fn append_bytes(&mut self, bytes: &[u8]) -> Result<(), ArrowError> { + /// Append sign-extended bytes to this decimal builder + fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { match self { - DecimalBuilder::Decimal128(b) => { - let padded = extend_to_16_bytes(bytes)?; - let value = i128::from_be_bytes(padded); - b.append_value(value); + Self::Decimal128(b) => { + let padded = sign_extend_to_16(raw)?; + let val = i128::from_be_bytes(padded); + b.append_value(val); } - DecimalBuilder::Decimal256(b) => { - let padded = extend_to_32_bytes(bytes)?; - let value = i256::from_be_bytes(padded); - b.append_value(value); + Self::Decimal256(b) => { + let padded = sign_extend_to_32(raw)?; + let val = i256::from_be_bytes(padded); + b.append_value(val); } } Ok(()) } - /// Appends a null value to the decimal builder by appending placeholder bytes. + /// Append a null decimal value (0) fn append_null(&mut self) -> Result<(), ArrowError> { match self { - DecimalBuilder::Decimal128(b) => { - // Append zeroed bytes as placeholder - let placeholder = [0u8; 16]; - let value = i128::from_be_bytes(placeholder); - b.append_value(value); - } - DecimalBuilder::Decimal256(b) => { - // Append zeroed bytes as placeholder - let placeholder = [0u8; 32]; - let value = i256::from_be_bytes(placeholder); - b.append_value(value); + Self::Decimal128(b) => { + let zero = [0u8; 16]; + b.append_value(i128::from_be_bytes(zero)); + } + Self::Decimal256(b) => { + let zero = [0u8; 32]; + b.append_value(i256::from_be_bytes(zero)); } } Ok(()) } - /// Finalizes the decimal array and returns it as an `ArrayRef`. - fn finish(self, nulls: Option, precision: usize, scale: usize) -> Result { + /// Finish building this decimal array, returning an [`ArrayRef`]. + fn finish( + self, + nulls: Option, + precision: usize, + scale: usize, + ) -> Result { match self { - DecimalBuilder::Decimal128(mut b) => { - let array = b.finish(); - let values = array.values().clone(); - let decimal_array = Decimal128Array::new( - values, - nulls, - ).with_precision_and_scale(precision as u8, scale as i8)?; - Ok(Arc::new(decimal_array)) - } - DecimalBuilder::Decimal256(mut b) => { - let array = b.finish(); - let values = array.values().clone(); - let decimal_array = Decimal256Array::new( - values, - nulls, - ).with_precision_and_scale(precision as u8, scale as i8)?; - Ok(Arc::new(decimal_array)) + Self::Decimal128(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal128Array::new(vals, nulls) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + Self::Decimal256(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal256Array::new(vals, nulls) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) } } } } -#[inline] -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +/// Sign-extend `raw` to 16 bytes. +fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let extended = sign_extend(raw, 16); + if extended.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 16 bytes, got {} bytes", + extended.len() + ))); + } + let mut arr = [0u8; 16]; + arr.copy_from_slice(&extended); + Ok(arr) } -#[inline] -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +/// Sign-extend `raw` to 32 bytes. +fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let extended = sign_extend(raw, 32); + if extended.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 32 bytes, got {} bytes", + extended.len() + ))); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&extended); + Ok(arr) } -#[inline] -fn flush_primitive( - values: &mut Vec, - nulls: Option, -) -> PrimitiveArray { - PrimitiveArray::new(flush_values(values).into(), nulls) +/// Sign-extend the first byte to produce `target_len` bytes total. +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut out = Vec::with_capacity(target_len); + if sign_bit != 0 { + out.resize(target_len - raw.len(), 0xFF); + } else { + out.resize(target_len - raw.len(), 0x00); + } + out.extend_from_slice(raw); + out } -const DEFAULT_CAPACITY: usize = 1024; +/// Convenience helper to build a field with `name`, `DataType` and `nullable`. +fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(ArrowField::new(name, dt, nullable)) +} #[cfg(test)] mod tests { use super::*; use arrow_array::{ - Array, ArrayRef, Int32Array, MapArray, StringArray, StructArray, - Decimal128Array, Decimal256Array, DictionaryArray, + cast::AsArray, Array, ArrayRef, Decimal128Array, Decimal256Array, DictionaryArray, + Int32Array, ListArray, MapArray, StringArray, StructArray, }; use arrow_buffer::Buffer; - use arrow_schema::{Field as ArrowField, DataType as ArrowDataType}; + use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; use serde_json::json; - use arrow_array::cast::AsArray; - /// Helper functions for encoding test data. + // ------------------- + // Zig-Zag Encoding Helper Functions + // ------------------- fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -671,20 +816,23 @@ mod tests { buf } + // ------------------- + // Tests for Enum + // ------------------- #[test] fn test_enum_decoding() { let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - // Encode the indices [1, 0, 2] using zigzag encoding + // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] - data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] - data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] + data.extend_from_slice(&encode_avro_int(1)); // => [2] + data.extend_from_slice(&encode_avro_int(0)); // => [0] + data.extend_from_slice(&encode_avro_int(2)); // => [4] let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); // => GREEN + decoder.decode(&mut cursor).unwrap(); // => RED + decoder.decode(&mut cursor).unwrap(); // => BLUE let array = decoder.flush(None).unwrap(); let dict_arr = array.as_any().downcast_ref::>().unwrap(); assert_eq!(dict_arr.len(), 3); @@ -698,187 +846,208 @@ mod tests { assert_eq!(dict_values.value(2), "BLUE"); } + #[test] + fn test_enum_decoding_with_nulls() { + // Union => [Enum(...), null] + // "child" => branch_index=0 => [0x00], "null" => 1 => [0x02] + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner_decoder), + ); + // Indices: [1, null, 2] => in Avro union + let mut data = Vec::new(); + // Row1 => union branch=0 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + // Then child's enum index=1 => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row2 => union branch=1 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => union branch=0 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + // Then child's enum index=2 => [0x04] + data.extend_from_slice(&encode_avro_int(2)); + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // => GREEN + nullable_decoder.decode(&mut cursor).unwrap(); // => null + nullable_decoder.decode(&mut cursor).unwrap(); // => BLUE + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array.as_any().downcast_ref::>().unwrap(); + assert_eq!(dict_arr.len(), 3); + // [GREEN, null, BLUE] + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); + assert!(dict_arr.is_valid(2)); + let keys = dict_arr.keys(); + // keys.value(0) => 1 => GREEN + // keys.value(2) => 2 => BLUE + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + // ------------------- + // Tests for Map + // ------------------- #[test] fn test_map_decoding_one_entry() { let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); // Encode a single map with one entry: {"hello": "world"} - // Avro encoding for a map: - // - block_count: 1 (encoded as [2] due to ZigZag) - // - keys: "hello" (encoded with length prefix) - // - values: "world" (encoded with length prefix) let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); // block_count = 1 - data.extend_from_slice(&encode_avro_bytes(b"hello")); // key = "hello" - data.extend_from_slice(&encode_avro_bytes(b"world")); // value = "world" - decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + // block_count=1 => zigzag => [0x02] + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key + data.extend_from_slice(&encode_avro_bytes(b"world")); // value + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // One map - assert_eq!(map_arr.value_length(0), 1); // One entry in the map + assert_eq!(map_arr.len(), 1); // one map + assert_eq!(map_arr.value_length(0), 1); let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); // One entry in StructArray - let key = struct_entries + assert_eq!(struct_entries.len(), 1); + let key_arr = struct_entries .column_by_name("key") .unwrap() .as_any() .downcast_ref::() .unwrap(); - let value = struct_entries + let val_arr = struct_entries .column_by_name("value") .unwrap() .as_any() .downcast_ref::() .unwrap(); - assert_eq!(key.value(0), "hello"); // Verify Key - assert_eq!(value.value(0), "world"); // Verify Value + assert_eq!(key_arr.value(0), "hello"); + assert_eq!(val_arr.value(0), "world"); } #[test] fn test_map_decoding_empty() { + // block_count=0 => empty map let value_type = AvroDataType::from_codec(Codec::Utf8); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); - // Encode an empty map - // Avro encoding for an empty map: - // - block_count: 0 (encoded as [0] due to ZigZag) - let data = encode_avro_long(0); // block_count = 0 + // Encode an empty map => block_count=0 => [0x00] + let data = encode_avro_long(0); decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // One map - assert_eq!(map_arr.value_length(0), 0); // Zero entries in the map - let entries = map_arr.value(0); - let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 0); // Zero entries in StructArray - let key = struct_entries - .column_by_name("key") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let value = struct_entries - .column_by_name("value") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(key.len(), 0); - assert_eq!(value.len(), 0); + assert_eq!(map_arr.len(), 1); + assert_eq!(map_arr.value_length(0), 0); } + // ------------------- + // Tests for Decimal + // ------------------- #[test] fn test_decimal_decoding_fixed128() { let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); let mut decoder = Decoder::try_new(&dt).unwrap(); - // Row1: 123.45 => unscaled: 12345 => i128: 0x00000000000000000000000000003039 - // Row2: -1.23 => unscaled: -123 => i128: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF85 + // Row1 => 123.45 => unscaled=12345 => i128 0x000...3039 + // Row2 => -1.23 => unscaled=-123 => i128 0xFFFF...FF85 let row1 = [ - 0x00, 0x00, 0x00, 0x00, // First 8 bytes - 0x00, 0x00, 0x00, 0x00, // Next 8 bytes - 0x00, 0x00, 0x00, 0x00, // Next 8 bytes - 0x00, 0x00, 0x30, 0x39, // Last 8 bytes: 0x3039 = 12345 + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, ]; let row2 = [ - 0xFF, 0xFF, 0xFF, 0xFF, // First 8 bytes (two's complement) - 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes - 0xFF, 0xFF, 0xFF, 0xFF, // Next 8 bytes - 0xFF, 0xFF, 0xFF, 0x85, // Last 8 bytes: 0xFFFFFF85 = -123 + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, ]; + let mut data = Vec::new(); data.extend_from_slice(&row1); data.extend_from_slice(&row2); - decoder.decode(&mut AvroCursor::new(&data)).unwrap(); - decoder.decode(&mut AvroCursor::new(&data[16..])).unwrap(); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 2); - assert_eq!(dec_arr.value_as_string(0), "123.45"); - assert_eq!(dec_arr.value_as_string(1), "-1.23"); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); } #[test] fn test_decimal_decoding_bytes_with_nulls() { + // Avro union => [ Decimal(4,1), null ] + // child => index=0 => [0x00], null => index=1 => [0x02] let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); - let mut decoder = Decoder::try_new(&dt).unwrap(); - // Wrap the decimal in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( + let mut inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( Nullability::NullFirst, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(inner), ); - // Row1: 123.4 => unscaled: 1234 => bytes: [0x04, 0xD2] - // Row2: null - // Row3: -123.4 => unscaled: -1234 => bytes: [0xFB, 0x2E] + // Decode three rows: [123.4, null, -123.4] let mut data = Vec::new(); - // Row1: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); // 0x04D2 = 1234 - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); // 0xFB2E = -1234 + // Row1 => child => [0x00], then decimal => e.g. 0x04D2 => 1234 => "123.4" + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); + // Row2 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => child => [0x00], then decimal => 0xFB2E => -1234 => "-123.4" + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 123.4 - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -123.4 - let array = nullable_decoder.flush(None).unwrap(); - let dec_arr = array.as_any().downcast_ref::().unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.is_valid(0), true); + assert_eq!(dec_arr.is_valid(1), false); + assert_eq!(dec_arr.is_valid(2), true); assert_eq!(dec_arr.value_as_string(0), "123.4"); assert_eq!(dec_arr.value_as_string(2), "-123.4"); } #[test] fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + // Avro union => [Decimal(6,2,16), null] let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); - let mut decoder = Decoder::try_new(&dt).unwrap(); - // Wrap the decimal in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( + let mut inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( Nullability::NullFirst, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(inner), ); - // Correct Byte Encoding: - // Row1: 1234.56 => unscaled: 123456 => bytes: [0x00; 12] + [0x00, 0x01, 0xE2, 0x40] - // Row2: null - // Row3: -1234.56 => unscaled: -123456 => bytes: [0xFF; 12] + [0xFE, 0x1D, 0xC0, 0x00] - let row1_bytes = &[ - 0x00, 0x00, 0x00, 0x00, // First 4 bytes - 0x00, 0x00, 0x00, 0x00, // Next 4 bytes - 0x00, 0x00, 0x00, 0x01, // Next 4 bytes - 0xE2, 0x40, 0x00, 0x00, // Last 4 bytes + // Decode [1234.56, null, -1234.56] + let row1 = [ + 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, + 0x00,0x00,0x00,0x00, 0x00,0x01,0xE2,0x40 ]; - let row3_bytes = &[ - 0xFF, 0xFF, 0xFF, 0xFF, // First 4 bytes (two's complement) - 0xFF, 0xFF, 0xFF, 0xFF, // Next 4 bytes - 0xFF, 0xFF, 0xFE, 0x1D, // Next 4 bytes - 0xC0, 0x00, 0x00, 0x00, // Last 4 bytes + let row3 = [ + 0xFF,0xFF,0xFF,0xFF, 0xFF,0xFF,0xFF,0xFF, + 0xFF,0xFF,0xFF,0xFF, 0xFF,0xFE,0x1D,0xC0 ]; - let mut data = Vec::new(); - // Row1: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(row1_bytes); // 1234.56 - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(row3_bytes); // -1234.56 - + // Row1 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + // Row2 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: 1234.56 - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: -1234.56 - - let array = nullable_decoder.flush(None).unwrap(); - let dec_arr = array.as_any().downcast_ref::().unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); assert_eq!(dec_arr.len(), 3); assert!(dec_arr.is_valid(0)); assert!(!dec_arr.is_valid(1)); @@ -887,110 +1056,83 @@ mod tests { assert_eq!(dec_arr.value_as_string(2), "-1234.56"); } + // ------------------- + // Tests for List + // ------------------- #[test] - fn test_enum_decoding_with_nulls() { - let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); - let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - - // Wrap the enum in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( - Nullability::NullFirst, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), - ); - - // Encode the indices [1, null, 2] using ZigZag encoding - // Indices: 1 -> [2], null -> no index, 2 -> [4] - let mut data = Vec::new(); - // Row1: valid (1) - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid (2) - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] - - let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: RED - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: BLUE - - let array = nullable_decoder.flush(None).unwrap(); - let dict_arr = array.as_any().downcast_ref::>().unwrap(); - - assert_eq!(dict_arr.len(), 3); - let keys = dict_arr.keys(); - let validity = dict_arr.is_valid(0); // Correctly access the null buffer - - assert_eq!(keys.value(0), 1); - assert_eq!(keys.value(1), 0); // Placeholder index for null - assert_eq!(keys.value(2), 2); - - assert!(dict_arr.is_valid(0)); - assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null - assert!(dict_arr.is_valid(2)); - - let dict_values = dict_arr.values().as_string::(); - assert_eq!(dict_values.value(0), "RED"); - assert_eq!(dict_values.value(1), "GREEN"); - assert_eq!(dict_values.value(2), "BLUE"); + fn test_list_decoding() { + // Avro array => block1(count=2), item1, item2, block2(count=0 => end) + // + // 1. Create 2 rows: + // Row1 => [10, 20] + // Row2 => [ ] + // + // 2. flush => should yield 2-element array => first row has 2 items, second row has 0 items + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt).unwrap(); + // Row1 => block_count=2 => item=10 => item=20 => block_count=0 => end + // - 2 => zigzag => [0x04] + // - item=10 => zigzag => [0x14] + // - item=20 => zigzag => [0x28] + // - 0 => [0x00] + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(2)); // block_count=2 + row1.extend_from_slice(&encode_avro_int(10)); // item=10 + row1.extend_from_slice(&encode_avro_int(20)); // item=20 + row1.extend_from_slice(&encode_avro_long(0)); // end of array + // Row2 => block_count=0 => empty array + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&row1); + decoder.decode(&mut cursor).unwrap(); + let mut cursor2 = AvroCursor::new(&row2); + decoder.decode(&mut cursor2).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 2); + // row0 => 2 items => [10, 20] + // row1 => 0 items + let offsets = list_arr.value_offsets(); + assert_eq!(offsets, &[0, 2, 2]); + let values = list_arr.values(); + let int_arr = values.as_primitive::(); + assert_eq!(int_arr.len(), 2); + assert_eq!(int_arr.value(0), 10); + assert_eq!(int_arr.value(1), 20); } #[test] - fn test_enum_with_nullable_entries() { - let symbols = vec!["APPLE".to_string(), "BANANA".to_string(), "CHERRY".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); - let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - - // Wrap the enum in a Nullable decoder - let mut nullable_decoder = Decoder::Nullable( - Nullability::NullFirst, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), - ); - - // Encode the indices [0, null, 2, 1] using ZigZag encoding - let mut data = Vec::new(); - // Row1: valid (0) -> "APPLE" - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(0)); // Encodes to [0] - // Row2: null - data.extend_from_slice(&[0u8]); // is_valid = false - // Row3: valid (2) -> "CHERRY" - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(2)); // Encodes to [4] - // Row4: valid (1) -> "BANANA" - data.extend_from_slice(&[1u8]); // is_valid = true - data.extend_from_slice(&encode_avro_int(1)); // Encodes to [2] - + fn test_list_decoding_with_negative_block_count() { + // Start with single row => [1, 2, 3] + // We'll store them in a single negative block => block_count=-3 => #items=3 + // Then read block_size => let's pretend it's 9 bytes, etc. Then the items. + // Then a block_count=0 => done + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt).unwrap(); + // block_count=-3 => zigzag => (-3 << 1) ^ (-3 >> 63) + // => -6 ^ -1 => ... + // Encode directly with `encode_avro_long(-3)`. + let mut data = encode_avro_long(-3); + // Next => block_size => let's pretend 12 => encode_avro_long(12) + data.extend_from_slice(&encode_avro_long(12)); + // Then 3 items => [1, 2, 3] + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + data.extend_from_slice(&encode_avro_int(3)); + // Then block_count=0 => done + data.extend_from_slice(&encode_avro_long(0)); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // Row1: APPLE - nullable_decoder.decode(&mut cursor).unwrap(); // Row2: null - nullable_decoder.decode(&mut cursor).unwrap(); // Row3: CHERRY - nullable_decoder.decode(&mut cursor).unwrap(); // Row4: BANANA - - let array = nullable_decoder.flush(None).unwrap(); - let dict_arr = array.as_any().downcast_ref::>().unwrap(); - - assert_eq!(dict_arr.len(), 4); - let keys = dict_arr.keys(); - let validity = dict_arr.is_valid(0); // Correctly access the null buffer - - assert_eq!(keys.value(0), 0); - assert_eq!(keys.value(1), 0); // Placeholder index for null - assert_eq!(keys.value(2), 2); - assert_eq!(keys.value(3), 1); - - assert!(dict_arr.is_valid(0)); - assert!(!dict_arr.is_valid(1)); // Ensure the second entry is null - assert!(dict_arr.is_valid(2)); - assert!(dict_arr.is_valid(3)); - - let dict_values = dict_arr.values().as_string::(); - assert_eq!(dict_values.value(0), "APPLE"); - assert_eq!(dict_values.value(1), "BANANA"); - assert_eq!(dict_values.value(2), "CHERRY"); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 1); + assert_eq!(list_arr.value_length(0), 3); + let values = list_arr.values().as_primitive::(); + assert_eq!(values.len(), 3); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + assert_eq!(values.value(2), 3); } -} \ No newline at end of file +} From 84ffb62c6333479effe2f08ad92c5a14103a24f2 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 15:23:35 -0600 Subject: [PATCH 08/38] * Minor Cleanup Signed-off-by: Connor Sanders --- arrow-avro/src/reader/cursor.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index ba1d01f72d7e..9e38a78c63ec 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -71,7 +71,6 @@ impl<'a> AvroCursor<'a> { let val: u32 = varint .try_into() .map_err(|_| ArrowError::ParseError("varint overflow".to_string()))?; - // Zig-zag decode Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } @@ -79,7 +78,6 @@ impl<'a> AvroCursor<'a> { #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; - // Zig-zag decode Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } From 8600680df4f75b5a937a5fc0ffae2f743705379b Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 31 Dec 2024 17:26:36 -0600 Subject: [PATCH 09/38] * Added record decoder support for the following types: - Fixed - Interval Signed-off-by: Connor Sanders --- arrow-avro/src/reader/record.rs | 381 +++++++++++++++++++++++--------- 1 file changed, 278 insertions(+), 103 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 500fe27fd53b..87ae7e2426a5 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -20,13 +20,13 @@ use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; use crate::reader::header::Header; use crate::schema::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; -use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, - TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, + SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use std::collections::HashMap; use std::io::Read; @@ -87,8 +87,6 @@ impl RecordDecoder { } /// Decoder for Avro data of various shapes. -/// -/// This is the “internal” representation used by [`RecordDecoder`]. #[derive(Debug)] enum Decoder { /// Avro `null` @@ -117,25 +115,19 @@ enum Decoder { Binary(OffsetBufferBuilder, Vec), /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), + /// Avro `fixed(n)` => Arrow `FixedSizeBinaryArray` + Fixed(i32, Vec), + /// Avro `interval` => Arrow `IntervalMonthDayNanoType` (12 bytes) + Interval(Vec), /// Avro `array` - /// * `FieldRef` is the arrow field for the list - /// * `OffsetBufferBuilder` holds offsets into the child array - /// * The boxed `Decoder` decodes T itself List(FieldRef, OffsetBufferBuilder, Box), /// Avro `record` - /// * `Fields` is the Arrow schema of the record - /// * The `Vec` is one decoder per child field Record(Fields, Vec), - /// Avro union that includes `null` => decodes as a single arrow field + a null bit mask + /// Avro union that includes `null` Nullable(Nullability, NullBufferBuilder, Box), /// Avro `enum` => Dictionary(int32 -> string) Enum(Vec, Vec), /// Avro `map` - /// * The `FieldRef` is the arrow field for the map - /// * `key_offsets`, `map_offsets`: offset builders - /// * `key_data` accumulates the raw UTF8 for keys - /// * `values_decoder_inner` decodes the map’s value type - /// * `current_entry_count` how many (key,value) pairs total seen so far Map( FieldRef, OffsetBufferBuilder, @@ -145,19 +137,17 @@ enum Decoder { usize, ), /// Avro decimal => Arrow decimal - /// (precision, scale, size, builder) Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - /// Checks if the Decoder is nullable, i.e. wrapped in [`Decoder::Nullable`]. + /// Checks if the Decoder is nullable, i.e. wrapped in `Nullable`. fn is_nullable(&self) -> bool { matches!(self, Decoder::Nullable(_, _, _)) } /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { - let not_implemented = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); let decoder = match data_type.codec() { Codec::Null => Decoder::Null(0), Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -182,8 +172,8 @@ impl Decoder { Codec::TimestampMicros(is_utc) => { Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return not_implemented("decoding Avro fixed-typed data"), - Codec::Interval => return not_implemented("decoding Avro interval"), + Codec::Fixed(n) => Decoder::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Interval => Decoder::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::List(item) => { let item_decoder = Box::new(Self::try_new(item)?); Decoder::List( @@ -192,17 +182,19 @@ impl Decoder { item_decoder, ) } - Codec::Struct(fields) => { - let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut decoders = Vec::with_capacity(fields.len()); - for avro_field in fields.iter() { + Codec::Struct(avro_fields) => { + let mut arrow_fields = Vec::with_capacity(avro_fields.len()); + let mut decoders = Vec::with_capacity(avro_fields.len()); + for avro_field in avro_fields.iter() { let d = Self::try_new(avro_field.data_type())?; arrow_fields.push(avro_field.field()); decoders.push(d); } Decoder::Record(arrow_fields.into(), decoders) } - Codec::Enum(symbols) => Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Enum(symbols) => { + Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)) + } Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( "entries", @@ -226,6 +218,8 @@ impl Decoder { Decoder::Decimal(*precision, *scale, *size, builder) } }; + + // Wrap in Nullable if needed match data_type.nullability() { Some(nb) => Ok(Decoder::Nullable( nb, @@ -237,8 +231,6 @@ impl Decoder { } /// Append a null to this decoder. - /// - /// This must keep the “row counts” in sync across child buffers, etc. fn append_null(&mut self) { match self { Decoder::Null(n) => { @@ -265,6 +257,19 @@ impl Decoder { Decoder::Binary(off, _) | Decoder::String(off, _) => { off.push_length(0); } + Decoder::Fixed(fsize, buf) => { + // For a null, push `fsize` zeroed bytes + let n = *fsize as usize; + buf.extend(std::iter::repeat(0u8).take(n)); + } + Decoder::Interval(intervals) => { + // null => store a 12-byte zero => months=0, days=0, nanos=0 + intervals.push(IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 0, + }); + } Decoder::List(_, off, child) => { off.push_length(0); child.append_null(); @@ -277,58 +282,82 @@ impl Decoder { Decoder::Enum(_, indices) => { indices.push(0); } - Decoder::Map(_, key_off, map_off, _, _, entry_count) => { + Decoder::Map( + _, + key_off, + map_off, + _, + _, + entry_count, + ) => { key_off.push_length(0); map_off.push_length(*entry_count); } Decoder::Decimal(_, _, _, builder) => { let _ = builder.append_null(); } - Decoder::Nullable(_, _, _) => { /* The null mask is handled by the outer decoder */ } + Decoder::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } } } - /// Decode a single “row” of data from `buf`. + /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Decoder::Null(n) => { - *n += 1; + Decoder::Null(count) => { + *count += 1; } - Decoder::Boolean(vals) => { - vals.append(buf.get_bool()?); + Decoder::Boolean(values) => { + values.append(buf.get_bool()?); } - Decoder::Int32(vals) => { - vals.push(buf.get_int()?); + Decoder::Int32(values) => { + values.push(buf.get_int()?); } - Decoder::Date32(vals) => { - vals.push(buf.get_int()?); + Decoder::Date32(values) => { + values.push(buf.get_int()?); } - Decoder::Int64(vals) => { - vals.push(buf.get_long()?); + Decoder::Int64(values) => { + values.push(buf.get_long()?); } - Decoder::TimeMillis(vals) => { - vals.push(buf.get_int()?); + Decoder::TimeMillis(values) => { + values.push(buf.get_int()?); } - Decoder::TimeMicros(vals) => { - vals.push(buf.get_long()?); + Decoder::TimeMicros(values) => { + values.push(buf.get_long()?); } - Decoder::TimestampMillis(_, vals) => { - vals.push(buf.get_long()?); + Decoder::TimestampMillis(_, values) => { + values.push(buf.get_long()?); } - Decoder::TimestampMicros(_, vals) => { - vals.push(buf.get_long()?); + Decoder::TimestampMicros(_, values) => { + values.push(buf.get_long()?); } - Decoder::Float32(vals) => { - vals.push(buf.get_float()?); + Decoder::Float32(values) => { + values.push(buf.get_float()?); } - Decoder::Float64(vals) => { - vals.push(buf.get_double()?); + Decoder::Float64(values) => { + values.push(buf.get_double()?); } Decoder::Binary(off, data) | Decoder::String(off, data) => { let bytes = buf.get_bytes()?; off.push_length(bytes.len()); data.extend_from_slice(bytes); } + Decoder::Fixed(fsize, accum) => { + let raw = buf.get_fixed(*fsize as usize)?; + accum.extend_from_slice(raw); + } + Decoder::Interval(intervals) => { + let raw = buf.get_fixed(12)?; + let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); + let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); + let nanos = millis as i64 * 1_000_000; + let val = IntervalMonthDayNano { + months, + days, + nanoseconds: nanos, + }; + intervals.push(val); + } Decoder::List(_, off, child) => { let total_items = read_array_blocks(buf, |b| child.decode(b))?; off.push_length(total_items); @@ -338,17 +367,15 @@ impl Decoder { c.decode(buf)?; } } - Decoder::Nullable(_, null_buf, child) => { + Decoder::Nullable(_, nulls, child) => { let branch_index = buf.get_int()?; match branch_index { 0 => { - // child - null_buf.append(true); + nulls.append(true); child.decode(buf)?; } 1 => { - // null - null_buf.append(false); + nulls.append(false); child.append_null(); } other => { @@ -388,6 +415,7 @@ impl Decoder { /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { + // For a nullable wrapper => flush the child with the built null buffer Decoder::Nullable(_, nb, child) => { let mask = nb.finish(); child.flush(mask) @@ -461,6 +489,32 @@ impl Decoder { let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } + // Avro fixed => FixedSizeBinaryArray + Decoder::Fixed(fsize, raw) => { + let size = *fsize; + let buf: Buffer = flush_values(raw).into(); + let total_len = buf.len() / (size as usize); + let array = FixedSizeBinaryArray::try_new(size, buf, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(array)) + } + // Avro interval => IntervalMonthDayNanoType + Decoder::Interval(vals) => { + let data_len = vals.len(); + let mut builder = PrimitiveBuilder::::with_capacity(data_len); + for v in vals.drain(..) { + builder.append_value(v); + } + let arr = builder.finish().with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + if let Some(nb) = nulls { + // "merge" the newly built array with the nulls + let arr_data = arr.into_data().into_builder().nulls(Some(nb)); + let arr_data = unsafe { arr_data.build_unchecked() }; + Ok(Arc::new(PrimitiveArray::::from(arr_data))) + } else { + Ok(Arc::new(arr)) + } + } // Avro array => ListArray Decoder::List(field, off, item_dec) => { let child_arr = item_dec.flush(None)?; @@ -532,10 +586,7 @@ impl Decoder { } } -/// Helper to decode an Avro array in blocks until a 0 block_count signals end. -/// -/// Each block may be negative, in which case we read an extra “block size” `long`, -/// but typically ignore it unless we want to skip. This function invokes `decode_item` once per item. +/// Decode an Avro array in blocks until a 0 block_count signals end. fn read_array_blocks( buf: &mut AvroCursor, mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, @@ -547,7 +598,7 @@ fn read_array_blocks( break; } else if block_count < 0 { let item_count = (-block_count) as usize; - let _block_size = buf.get_long()?; // read but ignore + let _block_size = buf.get_long()?; // “block size” is read but not used for _ in 0..item_count { decode_item(buf)?; } @@ -563,13 +614,10 @@ fn read_array_blocks( Ok(total_items) } -/// Helper to decode an Avro map in blocks until a 0 block_count signals end. -/// -/// For each entry in a block, we decode a key (bytes) + a value (`decode_value`). -/// Returns how many map entries were decoded. +/// Decode an Avro map in blocks until 0 block_count => end. fn read_map_blocks( buf: &mut AvroCursor, - mut decode_value: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { let block_count = buf.get_long()?; if block_count <= 0 { @@ -577,7 +625,7 @@ fn read_map_blocks( } else { let n = block_count as usize; for _ in 0..n { - decode_value(buf)?; + decode_entry(buf)?; } Ok(n) } @@ -592,13 +640,13 @@ fn flush_primitive( PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Flush an [`OffsetBufferBuilder`], returning its completed offsets. +/// Flush an [`OffsetBufferBuilder`]. #[inline] fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() } -/// Remove and return the contents of `values`, replacing it with an empty buffer. +/// Take ownership of `values`. #[inline] fn flush_values(values: &mut Vec) -> Vec { std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) @@ -619,39 +667,25 @@ impl DecimalBuilder { size: Option, ) -> Result { match size { - Some(s) if s > 16 && s <= 32 => { - // decimal256 - Ok(Self::Decimal256( - Decimal256Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, - )) - } - Some(s) if s <= 16 => { - // decimal128 - Ok(Self::Decimal128( - Decimal128Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, - )) - } + Some(s) if s > 16 && s <= 32 => Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )), + Some(s) if s <= 16 => Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )), None => { - // infer from precision when fixed size is None + // infer from precision if precision <= DECIMAL128_MAX_PRECISION as usize { Ok(Self::Decimal128( - Decimal128Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, )) } else if precision <= DECIMAL256_MAX_PRECISION as usize { Ok(Self::Decimal256( - Decimal256Builder::new().with_precision_and_scale( - precision as u8, - scale.unwrap_or(0) as i8, - )?, + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, )) } else { Err(ArrowError::ParseError(format!( @@ -699,7 +733,7 @@ impl DecimalBuilder { Ok(()) } - /// Finish building this decimal array, returning an [`ArrayRef`]. + /// Finish building the decimal array, returning an [`ArrayRef`]. fn finish( self, nulls: Option, @@ -779,15 +813,17 @@ mod tests { use super::*; use arrow_array::{ cast::AsArray, Array, ArrayRef, Decimal128Array, Decimal256Array, DictionaryArray, - Int32Array, ListArray, MapArray, StringArray, StructArray, + FixedSizeBinaryArray, Int32Array, IntervalMonthDayNanoArray, ListArray, MapArray, + StringArray, StructArray, }; use arrow_buffer::Buffer; use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; use serde_json::json; + use std::iter; - // ------------------- - // Zig-Zag Encoding Helper Functions - // ------------------- + // --------------- + // Zig-Zag Helpers + // --------------- fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -816,6 +852,145 @@ mod tests { buf } + // ----------------- + // Test Fixed + // ----------------- + #[test] + fn test_fixed_decoding() { + // `fixed(4)` => Arrow FixedSizeBinary(4) + let dt = AvroDataType::from_codec(Codec::Fixed(4)); + let mut dec = Decoder::try_new(&dt).unwrap(); + // 2 rows, each row => 4 bytes + let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; + let row2 = [0x01, 0x23, 0x45, 0x67]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 2); + assert_eq!(fsb.value_length(), 4); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(1), row2); + } + + #[test] + fn test_fixed_with_nulls() { + // Avro union => [ fixed(2), null] + let dt = AvroDataType::from_codec(Codec::Fixed(2)); + let child = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + // Decode 3 rows: row1 => branch=0 => [0x00], then 2 bytes + // row2 => branch=1 => null => [0x02] + // row3 => branch=0 => 2 bytes + let row1 = [0x11, 0x22]; + let row3 = [0x55, 0x66]; + let mut data = Vec::new(); + // row1 => union=0 => child => 2 bytes + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + // row2 => union=1 => null + data.extend_from_slice(&encode_avro_int(1)); + // row3 => union=0 => child => 2 bytes + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // row1 + dec.decode(&mut cursor).unwrap(); // row2 => null + dec.decode(&mut cursor).unwrap(); // row3 + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 3); + assert!(fsb.is_valid(0)); + assert!(!fsb.is_valid(1)); + assert!(fsb.is_valid(2)); + assert_eq!(fsb.value_length(), 2); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(2), row3); + } + + // ----------------- + // Test Interval + // ----------------- + #[test] + fn test_interval_decoding() { + // Avro interval => 12 bytes => [ months i32, days i32, ms i32 ] + // decode 2 rows => row1 => months=1, days=2, ms=100 => row2 => months=-1, days=10, ms=9999 + let dt = AvroDataType::from_codec(Codec::Interval); + let mut dec = Decoder::try_new(&dt).unwrap(); + // row1 => months=1 => 01,00,00,00, days=2 => 02,00,00,00, ms=100 => 64,00,00,00 + // row2 => months=-1 => 0xFF,0xFF,0xFF,0xFF, days=10 => 0x0A,0x00,0x00,0x00, ms=9999 => 0x0F,0x27,0x00,0x00 + let row1 = [0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x64, 0x00, 0x00, 0x00]; + let row2 = [0xFF, 0xFF, 0xFF, 0xFF, + 0x0A, 0x00, 0x00, 0x00, + 0x0F, 0x27, 0x00, 0x00]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(intervals.len(), 2); + // row0 => months=1, days=2, ms=100 => nanos=100_000_000 + // row1 => months=-1, days=10, ms=9999 => nanos=9999_000_000 + let val0 = intervals.value(0); + assert_eq!(val0.months, 1); + assert_eq!(val0.days, 2); + assert_eq!(val0.nanoseconds, 100_000_000); + let val1 = intervals.value(1); + assert_eq!(val1.months, -1); + assert_eq!(val1.days, 10); + assert_eq!(val1.nanoseconds, 9_999_000_000); + } + + #[test] + fn test_interval_decoding_with_nulls() { + // Avro union => [ interval, null] + let dt = AvroDataType::from_codec(Codec::Interval); + let child = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + // We'll decode 2 rows: row1 => interval => months=2, days=3, ms=500 => row2 => null + // row1 => union=0 => child => 12 bytes + // row2 => union=1 => null => no data + let row1 = [0x02, 0x00, 0x00, 0x00, // months=2 + 0x03, 0x00, 0x00, 0x00, // days=3 + 0xF4, 0x01, 0x00, 0x00]; // ms=500 => nanos=500_000_000 + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); // union=0 => child + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); // union=1 => null + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // row1 + dec.decode(&mut cursor).unwrap(); // row2 => null + let arr = dec.flush(None).unwrap(); + let intervals = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(intervals.len(), 2); + assert!(intervals.is_valid(0)); + assert!(!intervals.is_valid(1)); + let val0 = intervals.value(0); + assert_eq!(val0.months, 2); + assert_eq!(val0.days, 3); + assert_eq!(val0.nanoseconds, 500_000_000); + } + // ------------------- // Tests for Enum // ------------------- From 83310583e6c0e5039b11abddc94a4b9fbd465fc3 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 6 Jan 2025 13:28:36 -0600 Subject: [PATCH 10/38] Minor cleanup Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 12 ++++++------ arrow-avro/src/writer/schema.rs | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 4e57d4d186bc..38ef36827781 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -234,8 +234,8 @@ impl Codec { Self::Enum(_symbols) => { // Produce a Dictionary type with index = Int32, value = Utf8 Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), + Box::new(Int32), + Box::new(Utf8), ) } Self::Map(values) => { @@ -889,7 +889,7 @@ mod tests { ); let f = dt.field_with_name("bin_col"); assert_eq!(f.name(), "bin_col"); - assert_eq!(f.data_type(), &DataType::Binary); + assert_eq!(f.data_type(), &Binary); assert!(!f.is_nullable()); assert_eq!(f.metadata().get("something"), Some(&"else".to_string())); } @@ -981,7 +981,7 @@ mod tests { assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); let arrow_field = avro_field.field(); assert_eq!(arrow_field.name(), "long_col"); - assert_eq!(arrow_field.data_type(), &DataType::Int64); + assert_eq!(arrow_field.data_type(), &Int64); assert!(!arrow_field.is_nullable()); } @@ -1026,9 +1026,9 @@ mod tests { Struct(fields) => { assert_eq!(fields.len(), 2); assert_eq!(fields[0].name(), "a"); - assert_eq!(fields[0].data_type(), &DataType::Boolean); + assert_eq!(fields[0].data_type(), &Boolean); assert_eq!(fields[1].name(), "b"); - assert_eq!(fields[1].data_type(), &DataType::Float64); + assert_eq!(fields[1].data_type(), &Float64); } _ => panic!("Expected Struct data type"), } diff --git a/arrow-avro/src/writer/schema.rs b/arrow-avro/src/writer/schema.rs index c8cc5a7f9ec2..a858e1e5f3d0 100644 --- a/arrow-avro/src/writer/schema.rs +++ b/arrow-avro/src/writer/schema.rs @@ -51,7 +51,7 @@ mod tests { // Convert the batch -> Avro `Schema` let avro_schema = to_avro_json_schema(&batch, "MyTestRecord") - .expect("Failed to convert RecordBatch to Avro JSON schema");; + .expect("Failed to convert RecordBatch to Avro JSON schema"); let actual_json: Value = serde_json::from_str(&avro_schema) .expect("Invalid JSON returned by to_avro_json_schema"); From c54de48929eb8314635f6ced098cc2424a5ccbad Mon Sep 17 00:00:00 2001 From: Sven Cowart Date: Tue, 7 Jan 2025 11:33:50 -0800 Subject: [PATCH 11/38] chore: import cleanup and formatting --- arrow-avro/src/codec.rs | 124 +++++++++++++++++---------------------- arrow-avro/src/schema.rs | 22 +++---- 2 files changed, 64 insertions(+), 82 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 38ef36827781..258b3677f15f 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,20 +16,19 @@ // under the License. use crate::schema::{ - Attributes, ComplexType, PrimitiveType, Schema, TypeName, Array, Fixed, Map, Record, - Field as AvroFieldDef, - Fixed as AvroFixed, - Enum as AvroEnum, - Map as AvroMap + Array, Attributes, ComplexType, Enum, Fixed, Map, PrimitiveType, Record, RecordField, Schema, + TypeName, +}; +use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray}; +use arrow_schema::DataType::*; +use arrow_schema::{ + ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, + TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_SCALE, }; -use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, - SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE}; -use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray, RecordBatch}; use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; -use arrow_schema::DataType::*; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. @@ -227,39 +226,28 @@ impl Codec { } Self::Interval => Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => FixedSizeBinary(*size), - Self::List(f) => { - List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) - } + Self::List(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), Self::Struct(f) => Struct(f.iter().map(|x| x.field()).collect()), Self::Enum(_symbols) => { // Produce a Dictionary type with index = Int32, value = Utf8 - Dictionary( - Box::new(Int32), - Box::new(Utf8), - ) + Dictionary(Box::new(Int32), Box::new(Utf8)) } - Self::Map(values) => { - Map( - Arc::new(Field::new( - "entries", - Struct( - Fields::from(vec![ - Field::new("key", Utf8, false), - values.field_with_name("value"), - ]) - ), - false, - )), + Self::Map(values) => Map( + Arc::new(Field::new( + "entries", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + values.field_with_name("value"), + ])), false, - ) - } + )), + false, + ), Self::Decimal(precision, scale, size) => match size { Some(s) if *s > 16 && *s <= 32 => { Decimal256(*precision as u8, scale.unwrap_or(0) as i8) - }, - Some(s) if *s <= 16 => { - Decimal128(*precision as u8, scale.unwrap_or(0) as i8) - }, + } + Some(s) if *s <= 16 => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), _ => { // Note: Infer based on precision when size is None if *precision <= DECIMAL128_MAX_PRECISION as usize @@ -389,7 +377,7 @@ impl Codec { } Codec::Enum(symbols) => { // If there's a namespace in metadata, we will apply it later in maybe_add_namespace. - Schema::Complex(ComplexType::Enum(AvroEnum { + Schema::Complex(ComplexType::Enum(Enum { name, namespace: None, doc: None, @@ -409,7 +397,7 @@ impl Codec { Codec::Decimal(precision, scale, size) => { // If size is Some(n), produce Avro "fixed", else "bytes". if let Some(n) = size { - Schema::Complex(ComplexType::Fixed(AvroFixed { + Schema::Complex(ComplexType::Fixed(Fixed { name, namespace: None, aliases: vec![], @@ -597,7 +585,11 @@ fn make_data_type<'a>( } } ComplexType::Enum(e) => { - let symbols = e.symbols.iter().map(|sym| sym.to_string()).collect::>(); + let symbols = e + .symbols + .iter() + .map(|sym| sym.to_string()) + .collect::>(); let field = AvroDataType { nullability: None, metadata: e.attributes.field_metadata(), @@ -666,8 +658,12 @@ fn make_data_type<'a>( (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(false), - (Some("local-timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(false), + (Some("local-timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(false) + } + (Some("local-timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(false) + } (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, (Some(logical), _) => { // Insert unrecognized logical type into metadata @@ -729,20 +725,13 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Codec::TimestampMicros(true) } FixedSizeBinary(n) => Codec::Fixed(*n), - Decimal128(prec, scale) => Codec::Decimal( - *prec as usize, - Some(*scale as usize), - Some(16), - ), - Decimal256(prec, scale) => Codec::Decimal( - *prec as usize, - Some(*scale as usize), - Some(32), - ), + Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), + Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), Dictionary(index_type, value_type) => { if let Utf8 = **value_type { Codec::Enum(vec![]) - } else { // Fallback to Utf8 + } else { + // Fallback to Utf8 Codec::Utf8 } } @@ -774,8 +763,8 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { mod tests { use super::*; use arrow_schema::{DataType, Field}; - use std::sync::Arc; use serde_json::json; + use std::sync::Arc; #[test] fn test_decimal256_tuple_variant_fixed() { @@ -905,8 +894,12 @@ mod tests { }, AvroField { name: "label".to_string(), - data_type: AvroDataType::new(Codec::Utf8, Some(Nullability::NullFirst), Default::default()), - } + data_type: AvroDataType::new( + Codec::Utf8, + Some(Nullability::NullFirst), + Default::default(), + ), + }, ]); let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); let avro_schema = top_level.to_avro_schema("TopRecord"); @@ -987,13 +980,10 @@ mod tests { #[test] fn test_arrow_field_to_avro_field() { - let arrow_field = Field::new( - "test_meta", - Utf8, - true, - ).with_metadata(HashMap::from([ - ("namespace".to_string(), "arrow_meta_ns".to_string()) - ])); + let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( + "namespace".to_string(), + "arrow_meta_ns".to_string(), + )])); let avro_field = arrow_field_to_avro_field(&arrow_field); assert_eq!(avro_field.name(), "test_meta"); let actual_str = format!("{:?}", avro_field.data_type().codec()); @@ -1078,11 +1068,7 @@ mod tests { #[test] fn test_local_timestamp_millis() { - let arrow_field = Field::new( - "local_ts_ms", - Timestamp(TimeUnit::Millisecond, None), - false, - ); + let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); assert!( @@ -1094,11 +1080,7 @@ mod tests { #[test] fn test_local_timestamp_micros() { - let arrow_field = Field::new( - "local_ts_us", - Timestamp(TimeUnit::Microsecond, None), - false, - ); + let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = avro_field.data_type().codec(); assert!( diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 4895d24d76e4..8e3f23ffbb5e 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -130,14 +130,14 @@ pub struct Record<'a> { #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] - pub fields: Vec>, + pub fields: Vec>, #[serde(flatten)] pub attributes: Attributes<'a>, } /// A field within a [`Record`] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Field<'a> { +pub struct RecordField<'a> { #[serde(borrow)] pub name: &'a str, #[serde(borrow, default)] @@ -309,7 +309,7 @@ mod tests { namespace: None, doc: None, aliases: vec![], - fields: vec![Field { + fields: vec![RecordField { name: "value", doc: None, r#type: Schema::Union(vec![ @@ -343,13 +343,13 @@ mod tests { doc: None, aliases: vec!["LinkedLongs"], fields: vec![ - Field { + RecordField { name: "value", doc: None, r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, }, - Field { + RecordField { name: "next", doc: None, r#type: Schema::Union(vec![ @@ -402,7 +402,7 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "id", doc: None, r#type: Schema::Union(vec![ @@ -411,7 +411,7 @@ mod tests { ]), default: None, }, - Field { + RecordField { name: "timestamp_col", doc: None, r#type: Schema::Union(vec![ @@ -463,7 +463,7 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "clientHash", doc: None, r#type: Schema::Complex(ComplexType::Fixed(Fixed { @@ -475,7 +475,7 @@ mod tests { })), default: None, }, - Field { + RecordField { name: "clientProtocol", doc: None, r#type: Schema::Union(vec![ @@ -484,13 +484,13 @@ mod tests { ]), default: None, }, - Field { + RecordField { name: "serverHash", doc: None, r#type: Schema::TypeName(TypeName::Ref("MD5")), default: None, }, - Field { + RecordField { name: "meta", doc: None, r#type: Schema::Union(vec![ From fc696c8d414e796a11998df9e1eed8bfa63825d0 Mon Sep 17 00:00:00 2001 From: Sven Cowart Date: Tue, 7 Jan 2025 12:27:08 -0800 Subject: [PATCH 12/38] chore: simplifies and cleans code --- arrow-avro/src/codec.rs | 75 +++---- arrow-avro/src/reader/record.rs | 343 ++++++++++++++------------------ 2 files changed, 175 insertions(+), 243 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 258b3677f15f..1a2718be293c 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -19,14 +19,11 @@ use crate::schema::{ Array, Attributes, ComplexType, Enum, Fixed, Map, PrimitiveType, Record, RecordField, Schema, TypeName, }; -use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray}; use arrow_schema::DataType::*; use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, - TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, + ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, }; -use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -93,14 +90,15 @@ impl AvroDataType { /// (record, enum, fixed). pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { let inner_schema = self.codec.to_avro_schema(name); + let schema_with_namespace = maybe_add_namespace(inner_schema, self); // If the field is nullable in Arrow, wrap Avro schema in a union: ["null", ]. - if let Some(_) = self.nullability { + if self.nullability.is_some() { Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - maybe_add_namespace(inner_schema, self), + schema_with_namespace, ]) } else { - maybe_add_namespace(inner_schema, self) + schema_with_namespace } } } @@ -108,19 +106,12 @@ impl AvroDataType { /// If this is a named complex type (Record, Enum, Fixed), attach `namespace` /// from `dt.metadata["namespace"]` if present. Otherwise, return as-is. fn maybe_add_namespace<'a>(mut schema: Schema<'a>, dt: &'a AvroDataType) -> Schema<'a> { - let ns = dt.metadata.get("namespace"); - if let Some(ns_str) = ns { + if let Some(ns_str) = dt.metadata.get("namespace") { if let Schema::Complex(ref mut c) = schema { match c { - ComplexType::Record(r) => { - r.namespace = Some(ns_str); - } - ComplexType::Enum(e) => { - e.namespace = Some(ns_str); - } - ComplexType::Fixed(f) => { - f.namespace = Some(ns_str); - } + ComplexType::Record(r) => r.namespace = Some(ns_str), + ComplexType::Enum(e) => e.namespace = Some(ns_str), + ComplexType::Fixed(f) => f.namespace = Some(ns_str), // Arrays and Maps do not have a namespace field, so do nothing _ => {} } @@ -244,20 +235,14 @@ impl Codec { false, ), Self::Decimal(precision, scale, size) => match size { - Some(s) if *s > 16 && *s <= 32 => { - Decimal256(*precision as u8, scale.unwrap_or(0) as i8) - } - Some(s) if *s <= 16 => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), - _ => { - // Note: Infer based on precision when size is None - if *precision <= DECIMAL128_MAX_PRECISION as usize - && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize - { - Decimal128(*precision as u8, scale.unwrap_or(0) as i8) - } else { - Decimal256(*precision as u8, scale.unwrap_or(0) as i8) - } + Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), + Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), + None if *precision <= DECIMAL128_MAX_PRECISION as usize + && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize => + { + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) } + _ => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), }, } } @@ -299,30 +284,30 @@ impl Codec { }), // timestamp-millis => Avro long with logicalType=timestamp-millis or local-timestamp-millis Codec::TimestampMillis(is_utc) => { - let lt = if *is_utc { - Some("timestamp-millis") + let logical_type = Some(if *is_utc { + "timestamp-millis" } else { - Some("local-timestamp-millis") - }; + "local-timestamp-millis" + }); Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), attributes: Attributes { - logical_type: lt, + logical_type, additional: Default::default(), }, }) } // timestamp-micros => Avro long with logicalType=timestamp-micros or local-timestamp-micros Codec::TimestampMicros(is_utc) => { - let lt = if *is_utc { - Some("timestamp-micros") + let logical_type = Some(if *is_utc { + "timestamp-micros" } else { - Some("local-timestamp-micros") - }; + "local-timestamp-micros" + }); Schema::Type(crate::schema::Type { r#type: TypeName::Primitive(PrimitiveType::Long), attributes: Attributes { - logical_type: lt, + logical_type, additional: Default::default(), }, }) @@ -358,7 +343,7 @@ impl Codec { .iter() .map(|f| { let child_schema = f.data_type().to_avro_schema(f.name()); - AvroFieldDef { + RecordField { name: f.name(), doc: None, r#type: child_schema, @@ -389,7 +374,7 @@ impl Codec { } Codec::Map(values) => { let val_schema = values.to_avro_schema("values"); - Schema::Complex(ComplexType::Map(AvroMap { + Schema::Complex(ComplexType::Map(Map { values: Box::new(val_schema), attributes: Attributes::default(), })) @@ -762,7 +747,7 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { #[cfg(test)] mod tests { use super::*; - use arrow_schema::{DataType, Field}; + use arrow_schema::Field; use serde_json::json; use std::sync::Arc; diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 87ae7e2426a5..e5eb01df3322 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -16,19 +16,15 @@ // under the License. use crate::codec::{AvroDataType, Codec, Nullability}; -use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::schema::*; use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, - SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, + Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use std::collections::HashMap; use std::io::Read; use std::sync::Arc; @@ -143,40 +139,40 @@ enum Decoder { impl Decoder { /// Checks if the Decoder is nullable, i.e. wrapped in `Nullable`. fn is_nullable(&self) -> bool { - matches!(self, Decoder::Nullable(_, _, _)) + matches!(self, Self::Nullable(_, _, _)) } /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { let decoder = match data_type.codec() { - Codec::Null => Decoder::Null(0), - Codec::Boolean => Decoder::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), - Codec::Int32 => Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Int64 => Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float32 => Decoder::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float64 => Decoder::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Binary => Decoder::Binary( + Codec::Null => Self::Null(0), + Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Binary => Self::Binary( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Decoder::String( + Codec::Utf8 => Self::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Decoder::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Decoder::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Decoder::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimestampMillis(is_utc) => { - Decoder::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::TimestampMicros(is_utc) => { - Decoder::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(n) => Decoder::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Interval => Decoder::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Interval => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::List(item) => { let item_decoder = Box::new(Self::try_new(item)?); - Decoder::List( + Self::List( Arc::new(item.field_with_name("item")), OffsetBufferBuilder::new(DEFAULT_CAPACITY), item_decoder, @@ -190,10 +186,10 @@ impl Decoder { arrow_fields.push(avro_field.field()); decoders.push(d); } - Decoder::Record(arrow_fields.into(), decoders) + Self::Record(arrow_fields.into(), decoders) } Codec::Enum(symbols) => { - Decoder::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)) + Self::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( @@ -204,7 +200,7 @@ impl Decoder { ])), false, )); - Decoder::Map( + Self::Map( map_field, OffsetBufferBuilder::new(DEFAULT_CAPACITY), OffsetBufferBuilder::new(DEFAULT_CAPACITY), @@ -215,13 +211,13 @@ impl Decoder { } Codec::Decimal(precision, scale, size) => { let builder = DecimalBuilder::new(*precision, *scale, *size)?; - Decoder::Decimal(*precision, *scale, *size, builder) + Self::Decimal(*precision, *scale, *size, builder) } }; // Wrap in Nullable if needed match data_type.nullability() { - Some(nb) => Ok(Decoder::Nullable( + Some(nb) => Ok(Self::Nullable( nb, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(decoder), @@ -233,36 +229,21 @@ impl Decoder { /// Append a null to this decoder. fn append_null(&mut self) { match self { - Decoder::Null(n) => { - *n += 1; - } - Decoder::Boolean(b) => { - b.append(false); - } - Decoder::Int32(v) | Decoder::Date32(v) | Decoder::TimeMillis(v) => { - v.push(0); - } - Decoder::Int64(v) - | Decoder::TimeMicros(v) - | Decoder::TimestampMillis(_, v) - | Decoder::TimestampMicros(_, v) => { - v.push(0); - } - Decoder::Float32(v) => { - v.push(0.0); - } - Decoder::Float64(v) => { - v.push(0.0); - } - Decoder::Binary(off, _) | Decoder::String(off, _) => { - off.push_length(0); - } - Decoder::Fixed(fsize, buf) => { + Self::Null(n) => *n += 1, + Self::Boolean(b) => b.append(false), + Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), + Self::Int64(v) + | Self::TimeMicros(v) + | Self::TimestampMillis(_, v) + | Self::TimestampMicros(_, v) => v.push(0), + Self::Float32(v) => v.push(0.0), + Self::Float64(v) => v.push(0.0), + Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), + Self::Fixed(fsize, buf) => { // For a null, push `fsize` zeroed bytes - let n = *fsize as usize; - buf.extend(std::iter::repeat(0u8).take(n)); + buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); } - Decoder::Interval(intervals) => { + Self::Interval(intervals) => { // null => store a 12-byte zero => months=0, days=0, nanos=0 intervals.push(IntervalMonthDayNano { months: 0, @@ -270,82 +251,48 @@ impl Decoder { nanoseconds: 0, }); } - Decoder::List(_, off, child) => { + Self::List(_, off, child) => { off.push_length(0); child.append_null(); } - Decoder::Record(_, children) => { + Self::Record(_, children) => { for c in children.iter_mut() { c.append_null(); } } - Decoder::Enum(_, indices) => { - indices.push(0); - } - Decoder::Map( - _, - key_off, - map_off, - _, - _, - entry_count, - ) => { + Self::Enum(_, indices) => indices.push(0), + Self::Map(_, key_off, map_off, _, _, entry_count) => { key_off.push_length(0); map_off.push_length(*entry_count); } - Decoder::Decimal(_, _, _, builder) => { + Self::Decimal(_, _, _, builder) => { let _ = builder.append_null(); } - Decoder::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } + Self::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } } } /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Decoder::Null(count) => { - *count += 1; - } - Decoder::Boolean(values) => { - values.append(buf.get_bool()?); - } - Decoder::Int32(values) => { - values.push(buf.get_int()?); - } - Decoder::Date32(values) => { - values.push(buf.get_int()?); - } - Decoder::Int64(values) => { - values.push(buf.get_long()?); - } - Decoder::TimeMillis(values) => { - values.push(buf.get_int()?); - } - Decoder::TimeMicros(values) => { - values.push(buf.get_long()?); - } - Decoder::TimestampMillis(_, values) => { - values.push(buf.get_long()?); - } - Decoder::TimestampMicros(_, values) => { - values.push(buf.get_long()?); - } - Decoder::Float32(values) => { - values.push(buf.get_float()?); - } - Decoder::Float64(values) => { - values.push(buf.get_double()?); - } - Decoder::Binary(off, data) | Decoder::String(off, data) => { + Self::Null(count) => *count += 1, + Self::Boolean(values) => values.append(buf.get_bool()?), + Self::Int32(values) => values.push(buf.get_int()?), + Self::Date32(values) => values.push(buf.get_int()?), + Self::Int64(values) => values.push(buf.get_long()?), + Self::TimeMillis(values) => values.push(buf.get_int()?), + Self::TimeMicros(values) => values.push(buf.get_long()?), + Self::TimestampMillis(_, values) => values.push(buf.get_long()?), + Self::TimestampMicros(_, values) => values.push(buf.get_long()?), + Self::Float32(values) => values.push(buf.get_float()?), + Self::Float64(values) => values.push(buf.get_double()?), + Self::Binary(off, data) | Self::String(off, data) => { let bytes = buf.get_bytes()?; off.push_length(bytes.len()); data.extend_from_slice(bytes); } - Decoder::Fixed(fsize, accum) => { - let raw = buf.get_fixed(*fsize as usize)?; - accum.extend_from_slice(raw); - } - Decoder::Interval(intervals) => { + Self::Fixed(fsize, accum) => accum.extend_from_slice(buf.get_fixed(*fsize as usize)?), + Self::Interval(intervals) => { let raw = buf.get_fixed(12)?; let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); @@ -358,38 +305,32 @@ impl Decoder { }; intervals.push(val); } - Decoder::List(_, off, child) => { + Self::List(_, off, child) => { let total_items = read_array_blocks(buf, |b| child.decode(b))?; off.push_length(total_items); } - Decoder::Record(_, children) => { + Self::Record(_, children) => { for c in children.iter_mut() { c.decode(buf)?; } } - Decoder::Nullable(_, nulls, child) => { - let branch_index = buf.get_int()?; - match branch_index { - 0 => { - nulls.append(true); - child.decode(buf)?; - } - 1 => { - nulls.append(false); - child.append_null(); - } - other => { - return Err(ArrowError::ParseError(format!( - "Unsupported union branch index {other} for Nullable" - ))); - } + Self::Nullable(_, nulls, child) => match buf.get_int()? { + 0 => { + nulls.append(true); + child.decode(buf)?; } - } - Decoder::Enum(_, indices) => { - let idx = buf.get_int()?; - indices.push(idx); - } - Decoder::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { + 1 => { + nulls.append(false); + child.append_null(); + } + other => { + return Err(ArrowError::ParseError(format!( + "Unsupported union branch index {other} for Nullable" + ))); + } + }, + Self::Enum(_, indices) => indices.push(buf.get_int()?), + Self::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { let newly_added = read_map_blocks(buf, |b| { let kb = b.get_bytes()?; key_off.push_length(kb.len()); @@ -399,14 +340,12 @@ impl Decoder { *entry_count += newly_added; map_off.push_length(*entry_count); } - Decoder::Decimal(_, _, size, builder) => { - if let Some(sz) = *size { - let raw = buf.get_fixed(sz)?; - builder.append_bytes(raw)?; - } else { - let variable = buf.get_bytes()?; - builder.append_bytes(variable)?; - } + Self::Decimal(_, _, size, builder) => { + let bytes = match *size { + Some(sz) => buf.get_fixed(sz)?, + None => buf.get_bytes()?, + }; + builder.append_bytes(bytes)?; } } Ok(()) @@ -416,81 +355,81 @@ impl Decoder { fn flush(&mut self, nulls: Option) -> Result { match self { // For a nullable wrapper => flush the child with the built null buffer - Decoder::Nullable(_, nb, child) => { + Self::Nullable(_, nb, child) => { let mask = nb.finish(); child.flush(mask) } // Null => produce NullArray - Decoder::Null(len) => { + Self::Null(len) => { let count = std::mem::replace(len, 0); Ok(Arc::new(NullArray::new(count))) } // boolean => flush to BooleanArray - Decoder::Boolean(b) => { + Self::Boolean(b) => { let bits = b.finish(); Ok(Arc::new(BooleanArray::new(bits, nulls))) } // int32 => flush to Int32Array - Decoder::Int32(vals) => { + Self::Int32(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // date32 => flush to Date32Array - Decoder::Date32(vals) => { + Self::Date32(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // int64 => flush to Int64Array - Decoder::Int64(vals) => { + Self::Int64(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // time-millis => Time32Millisecond - Decoder::TimeMillis(vals) => { + Self::TimeMillis(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // time-micros => Time64Microsecond - Decoder::TimeMicros(vals) => { + Self::TimeMicros(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // timestamp-millis => TimestampMillisecond - Decoder::TimestampMillis(is_utc, vals) => { + Self::TimestampMillis(is_utc, vals) => { let arr = flush_primitive::(vals, nulls) .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); Ok(Arc::new(arr)) } // timestamp-micros => TimestampMicrosecond - Decoder::TimestampMicros(is_utc, vals) => { + Self::TimestampMicros(is_utc, vals) => { let arr = flush_primitive::(vals, nulls) .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); Ok(Arc::new(arr)) } // float32 => flush to Float32Array - Decoder::Float32(vals) => { + Self::Float32(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // float64 => flush to Float64Array - Decoder::Float64(vals) => { + Self::Float64(vals) => { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } // Avro bytes => BinaryArray - Decoder::Binary(off, data) => { + Self::Binary(off, data) => { let offsets = flush_offsets(off); let values = flush_values(data).into(); Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } // Avro string => StringArray - Decoder::String(off, data) => { + Self::String(off, data) => { let offsets = flush_offsets(off); let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } // Avro fixed => FixedSizeBinaryArray - Decoder::Fixed(fsize, raw) => { + Self::Fixed(fsize, raw) => { let size = *fsize; let buf: Buffer = flush_values(raw).into(); let total_len = buf.len() / (size as usize); @@ -499,31 +438,36 @@ impl Decoder { Ok(Arc::new(array)) } // Avro interval => IntervalMonthDayNanoType - Decoder::Interval(vals) => { + Self::Interval(vals) => { let data_len = vals.len(); - let mut builder = PrimitiveBuilder::::with_capacity(data_len); + let mut builder = + PrimitiveBuilder::::with_capacity(data_len); for v in vals.drain(..) { builder.append_value(v); } - let arr = builder.finish().with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + let arr = builder + .finish() + .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); if let Some(nb) = nulls { // "merge" the newly built array with the nulls let arr_data = arr.into_data().into_builder().nulls(Some(nb)); let arr_data = unsafe { arr_data.build_unchecked() }; - Ok(Arc::new(PrimitiveArray::::from(arr_data))) + Ok(Arc::new(PrimitiveArray::::from( + arr_data, + ))) } else { Ok(Arc::new(arr)) } } // Avro array => ListArray - Decoder::List(field, off, item_dec) => { + Self::List(field, off, item_dec) => { let child_arr = item_dec.flush(None)?; let offsets = flush_offsets(off); let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); Ok(Arc::new(arr)) } // Avro record => StructArray - Decoder::Record(fields, children) => { + Self::Record(fields, children) => { let mut arrays = Vec::with_capacity(children.len()); for c in children.iter_mut() { let a = c.flush(None)?; @@ -532,7 +476,7 @@ impl Decoder { Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } // Avro enum => DictionaryArray utf8> - Decoder::Enum(symbols, indices) => { + Self::Enum(symbols, indices) => { let dict_values = StringArray::from_iter_values(symbols.iter()); let idxs: Int32Array = match nulls { Some(b) => { @@ -549,12 +493,12 @@ impl Decoder { Ok(Arc::new(dict)) } // Avro map => MapArray - Decoder::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { + Self::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { let moff = flush_offsets(map_off); let koff = flush_offsets(key_off); let kd = flush_values(key_data).into(); let val_arr = val_dec.flush(None)?; - let is_nullable = matches!(**val_dec, Decoder::Nullable(_, _, _)); + let is_nullable = matches!(**val_dec, Self::Nullable(_, _, _)); let key_arr = StringArray::new(koff, kd, None); let struct_fields = vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), @@ -574,7 +518,7 @@ impl Decoder { Ok(Arc::new(map_arr)) } // Avro decimal => Arrow decimal - Decoder::Decimal(prec, sc, sz, builder) => { + Self::Decimal(prec, sc, sz, builder) => { let precision = *prec; let scale = sc.unwrap_or(0); let new_builder = DecimalBuilder::new(precision, *sc, *sz)?; @@ -812,14 +756,9 @@ fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { mod tests { use super::*; use arrow_array::{ - cast::AsArray, Array, ArrayRef, Decimal128Array, Decimal256Array, DictionaryArray, - FixedSizeBinaryArray, Int32Array, IntervalMonthDayNanoArray, ListArray, MapArray, - StringArray, StructArray, + cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, + IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, }; - use arrow_buffer::Buffer; - use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; - use serde_json::json; - use std::iter; // --------------- // Zig-Zag Helpers @@ -927,12 +866,12 @@ mod tests { let mut dec = Decoder::try_new(&dt).unwrap(); // row1 => months=1 => 01,00,00,00, days=2 => 02,00,00,00, ms=100 => 64,00,00,00 // row2 => months=-1 => 0xFF,0xFF,0xFF,0xFF, days=10 => 0x0A,0x00,0x00,0x00, ms=9999 => 0x0F,0x27,0x00,0x00 - let row1 = [0x01, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, - 0x64, 0x00, 0x00, 0x00]; - let row2 = [0xFF, 0xFF, 0xFF, 0xFF, - 0x0A, 0x00, 0x00, 0x00, - 0x0F, 0x27, 0x00, 0x00]; + let row1 = [ + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0x0A, 0x00, 0x00, 0x00, 0x0F, 0x27, 0x00, 0x00, + ]; let mut data = Vec::new(); data.extend_from_slice(&row1); data.extend_from_slice(&row2); @@ -970,9 +909,11 @@ mod tests { // We'll decode 2 rows: row1 => interval => months=2, days=3, ms=500 => row2 => null // row1 => union=0 => child => 12 bytes // row2 => union=1 => null => no data - let row1 = [0x02, 0x00, 0x00, 0x00, // months=2 - 0x03, 0x00, 0x00, 0x00, // days=3 - 0xF4, 0x01, 0x00, 0x00]; // ms=500 => nanos=500_000_000 + let row1 = [ + 0x02, 0x00, 0x00, 0x00, // months=2 + 0x03, 0x00, 0x00, 0x00, // days=3 + 0xF4, 0x01, 0x00, 0x00, + ]; // ms=500 => nanos=500_000_000 let mut data = Vec::new(); data.extend_from_slice(&encode_avro_int(0)); // union=0 => child data.extend_from_slice(&row1); @@ -981,7 +922,10 @@ mod tests { dec.decode(&mut cursor).unwrap(); // row1 dec.decode(&mut cursor).unwrap(); // row2 => null let arr = dec.flush(None).unwrap(); - let intervals = arr.as_any().downcast_ref::().unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(intervals.len(), 2); assert!(intervals.is_valid(0)); assert!(!intervals.is_valid(1)); @@ -1009,7 +953,10 @@ mod tests { decoder.decode(&mut cursor).unwrap(); // => RED decoder.decode(&mut cursor).unwrap(); // => BLUE let array = decoder.flush(None).unwrap(); - let dict_arr = array.as_any().downcast_ref::>().unwrap(); + let dict_arr = array + .as_any() + .downcast_ref::>() + .unwrap(); assert_eq!(dict_arr.len(), 3); let keys = dict_arr.keys(); assert_eq!(keys.value(0), 1); @@ -1050,7 +997,10 @@ mod tests { nullable_decoder.decode(&mut cursor).unwrap(); // => null nullable_decoder.decode(&mut cursor).unwrap(); // => BLUE let array = nullable_decoder.flush(None).unwrap(); - let dict_arr = array.as_any().downcast_ref::>().unwrap(); + let dict_arr = array + .as_any() + .downcast_ref::>() + .unwrap(); assert_eq!(dict_arr.len(), 3); // [GREEN, null, BLUE] assert!(dict_arr.is_valid(0)); @@ -1129,16 +1079,12 @@ mod tests { // Row1 => 123.45 => unscaled=12345 => i128 0x000...3039 // Row2 => -1.23 => unscaled=-123 => i128 0xFFFF...FF85 let row1 = [ - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x30, 0x39, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, ]; let row2 = [ - 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0x85, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, ]; let mut data = Vec::new(); @@ -1201,12 +1147,12 @@ mod tests { ); // Decode [1234.56, null, -1234.56] let row1 = [ - 0x00,0x00,0x00,0x00, 0x00,0x00,0x00,0x00, - 0x00,0x00,0x00,0x00, 0x00,0x01,0xE2,0x40 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0xE2, 0x40, ]; let row3 = [ - 0xFF,0xFF,0xFF,0xFF, 0xFF,0xFF,0xFF,0xFF, - 0xFF,0xFF,0xFF,0xFF, 0xFF,0xFE,0x1D,0xC0 + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, + 0x1D, 0xC0, ]; let mut data = Vec::new(); // Row1 => child => [0x00] @@ -1256,6 +1202,7 @@ mod tests { row1.extend_from_slice(&encode_avro_int(10)); // item=10 row1.extend_from_slice(&encode_avro_int(20)); // item=20 row1.extend_from_slice(&encode_avro_long(0)); // end of array + // Row2 => block_count=0 => empty array let mut row2 = Vec::new(); row2.extend_from_slice(&encode_avro_long(0)); From 81d7bbabfac0b5d21259437b99f228cf81130d89 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Fri, 10 Jan 2025 13:02:34 -0600 Subject: [PATCH 13/38] ran linter Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 4 +- arrow-avro/src/reader/record.rs | 36 ++++++------ arrow-avro/src/writer/mod.rs | 22 +++++++- arrow-avro/src/writer/schema.rs | 99 +++++++++++++++------------------ arrow-avro/src/writer/vlq.rs | 20 ++++++- 5 files changed, 103 insertions(+), 78 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 1a2718be293c..681691ec8c22 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -547,7 +547,7 @@ fn make_data_type<'a>( .additional .get("scale") .and_then(|v| v.as_u64()) - .or_else(|| Some(0)); + .or(Some(0)); let field = AvroDataType { nullability: None, metadata: f.attributes.field_metadata(), @@ -725,7 +725,7 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { let value_field = &child_fields[1]; let sub_codec = arrow_type_to_codec(value_field.data_type()); Codec::Map(Arc::new(AvroDataType { - nullability: value_field.is_nullable().then(|| Nullability::NullFirst), + nullability: value_field.is_nullable().then_some(Nullability::NullFirst), metadata: value_field.metadata().clone(), codec: sub_codec, })) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index e5eb01df3322..6fe4ae87bef3 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -538,21 +538,25 @@ fn read_array_blocks( let mut total_items = 0usize; loop { let block_count = buf.get_long()?; - if block_count == 0 { - break; - } else if block_count < 0 { - let item_count = (-block_count) as usize; - let _block_size = buf.get_long()?; // “block size” is read but not used - for _ in 0..item_count { - decode_item(buf)?; + match block_count { + 0 => break, // If block_count is 0, exit the loop + n if n < 0 => { + // If block_count is negative + let item_count = (-n) as usize; + let _block_size = buf.get_long()?; // Read but ignore block size + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; } - total_items += item_count; - } else { - let item_count = block_count as usize; - for _ in 0..item_count { - decode_item(buf)?; + n => { + // If block_count is positive + let item_count = n as usize; + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; } - total_items += item_count; } } Ok(total_items) @@ -1128,9 +1132,9 @@ mod tests { let arr = decoder.flush(None).unwrap(); let dec_arr = arr.as_any().downcast_ref::().unwrap(); assert_eq!(dec_arr.len(), 3); - assert_eq!(dec_arr.is_valid(0), true); - assert_eq!(dec_arr.is_valid(1), false); - assert_eq!(dec_arr.is_valid(2), true); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); assert_eq!(dec_arr.value_as_string(0), "123.4"); assert_eq!(dec_arr.value_as_string(2), "-123.4"); } diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs index afb623162d9a..635333718ac7 100644 --- a/arrow-avro/src/writer/mod.rs +++ b/arrow-avro/src/writer/mod.rs @@ -1,15 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + mod schema; mod vlq; #[cfg(test)] mod test { + use arrow_array::RecordBatch; use std::fs::File; use std::io::BufWriter; - use arrow_array::RecordBatch; fn write_file(file: &str, batch: &RecordBatch) { let file = File::open(file).unwrap(); let mut writer = BufWriter::new(file); - } -} \ No newline at end of file +} diff --git a/arrow-avro/src/writer/schema.rs b/arrow-avro/src/writer/schema.rs index a858e1e5f3d0..521ea9e6b107 100644 --- a/arrow-avro/src/writer/schema.rs +++ b/arrow-avro/src/writer/schema.rs @@ -1,8 +1,24 @@ -use std::collections::HashMap; -use std::sync::Arc; -use arrow_array::RecordBatch; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::codec::{AvroDataType, AvroField, Codec}; use crate::schema::Schema; +use arrow_array::RecordBatch; +use std::sync::Arc; fn record_batch_to_avro_schema<'a>( batch: &'a RecordBatch, @@ -22,9 +38,7 @@ pub fn to_avro_json_schema( .iter() .map(|arrow_field| crate::codec::arrow_field_to_avro_field(arrow_field)) .collect(); - let top_level_data_type = AvroDataType::from_codec( - Codec::Struct(Arc::from(avro_fields)), - ); + let top_level_data_type = AvroDataType::from_codec(Codec::Struct(Arc::from(avro_fields))); let avro_schema = record_batch_to_avro_schema(batch, record_name, &top_level_data_type); serde_json::to_string_pretty(&avro_schema) } @@ -32,7 +46,7 @@ pub fn to_avro_json_schema( #[cfg(test)] mod tests { use super::*; - use arrow_array::{Int32Array, StringArray, RecordBatch, ArrayRef, StructArray}; + use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; use serde_json::{json, Value}; use std::sync::Arc; @@ -74,8 +88,6 @@ mod tests { } ] }); - - // Compare the two JSON objects assert_eq!( actual_json, expected_json, "Avro Schema JSON does not match expected" @@ -88,18 +100,14 @@ mod tests { Field::new("id", DataType::Int32, false), Field::new("desc", DataType::Utf8, true), ])); - let col_id = Arc::new(Int32Array::from(vec![10, 20, 30])); let col_desc = Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])); let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_desc]) .expect("Failed to create RecordBatch"); - let json_schema_string = to_avro_json_schema(&batch, "AnotherTestRecord") .expect("Failed to convert RecordBatch to Avro JSON schema"); - let actual_json: Value = serde_json::from_str(&json_schema_string) .expect("Invalid JSON returned by to_avro_json_schema"); - let expected_json = json!({ "type": "record", "name": "AnotherTestRecord", @@ -119,7 +127,6 @@ mod tests { } ] }); - assert_eq!( actual_json, expected_json, "JSON schema mismatch for to_avro_json_schema" @@ -128,18 +135,18 @@ mod tests { #[test] fn test_to_avro_json_schema_single_nonnull_int() { - let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new("id", DataType::Int32, false)])); - + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); - let batch = RecordBatch::try_new(arrow_schema, vec![col_id]) - .expect("Failed to create RecordBatch"); - + let batch = + RecordBatch::try_new(arrow_schema, vec![col_id]).expect("Failed to create RecordBatch"); let avro_json_string = to_avro_json_schema(&batch, "MySingleIntRecord") .expect("Failed to generate Avro JSON schema"); - - let actual_json: Value = serde_json::from_str(&avro_json_string) - .expect("Failed to parse Avro JSON schema"); - + let actual_json: Value = + serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); let expected_json = json!({ "type": "record", "name": "MySingleIntRecord", @@ -154,8 +161,6 @@ mod tests { } ] }); - - // Compare assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); } @@ -165,18 +170,14 @@ mod tests { Field::new("id", DataType::Int32, false), Field::new("name", DataType::Utf8, true), ])); - let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) .expect("Failed to create RecordBatch"); - - let avro_json_string = to_avro_json_schema(&batch, "MyRecord") - .expect("Failed to generate Avro JSON schema"); - - let actual_json: Value = serde_json::from_str(&avro_json_string) - .expect("Failed to parse Avro JSON schema"); - + let avro_json_string = + to_avro_json_schema(&batch, "MyRecord").expect("Failed to generate Avro JSON schema"); + let actual_json: Value = + serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); let expected_json = json!({ "type": "record", "name": "MyRecord", @@ -199,8 +200,6 @@ mod tests { } ] }); - - // Compare assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); } @@ -210,14 +209,14 @@ mod tests { Field::new("inner_int", DataType::Int32, false), Field::new("inner_str", DataType::Utf8, true), ]); - - let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![ - Field::new("my_struct", DataType::Struct(inner_fields), true) - ])); - + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "my_struct", + DataType::Struct(inner_fields), + true, + )])); let inner_int_col = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef; - let inner_str_col = Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef; - + let inner_str_col = + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef; let fields_arrays = vec![ ( Arc::new(Field::new("inner_int", DataType::Int32, false)), @@ -228,21 +227,13 @@ mod tests { inner_str_col, ), ]; - let struct_array = StructArray::from(fields_arrays); - - let batch = RecordBatch::try_new( - arrow_schema, - vec![Arc::new(struct_array)], - ) + let batch = RecordBatch::try_new(arrow_schema, vec![Arc::new(struct_array)]) .expect("Failed to create RecordBatch"); - let avro_json_string = to_avro_json_schema(&batch, "NestedRecord") .expect("Failed to generate Avro JSON schema"); - - let actual_json: Value = serde_json::from_str(&avro_json_string) - .expect("Failed to parse Avro JSON schema"); - + let actual_json: Value = + serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); let expected_json = json!({ "type": "record", "name": "NestedRecord", @@ -281,8 +272,6 @@ mod tests { } ] }); - - // Compare assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); } } diff --git a/arrow-avro/src/writer/vlq.rs b/arrow-avro/src/writer/vlq.rs index 765e6687abaa..4cf26e23856d 100644 --- a/arrow-avro/src/writer/vlq.rs +++ b/arrow-avro/src/writer/vlq.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + /// Encoder for zig-zag encoded variable length integers /// /// This complements the VLQ decoding logic used by Avro. Zig-zag encoding maps signed integers @@ -65,10 +82,9 @@ mod tests { } fn round_trip(value: i64) { - let mut encoder = VLQEncoder::default(); + let mut encoder = VLQEncoder; let mut buf = Vec::new(); encoder.long(value, &mut buf); - let mut slice = buf.as_slice(); let decoded = decode_long(&mut slice).expect("Failed to decode value"); assert_eq!(decoded, value, "Round-trip mismatch for value {}", value); From 1bc9a51bc447e3a33c12ab44f5bdb15a4c8f1780 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Mon, 13 Jan 2025 15:55:08 -0600 Subject: [PATCH 14/38] Removed Avro writer module and Avro writer related logic from codec.rs Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 386 +------------------------------- arrow-avro/src/lib.rs | 1 - arrow-avro/src/writer/mod.rs | 31 --- arrow-avro/src/writer/schema.rs | 277 ----------------------- arrow-avro/src/writer/vlq.rs | 114 ---------- 5 files changed, 1 insertion(+), 808 deletions(-) delete mode 100644 arrow-avro/src/writer/mod.rs delete mode 100644 arrow-avro/src/writer/schema.rs delete mode 100644 arrow-avro/src/writer/vlq.rs diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 681691ec8c22..91bb413b36ba 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,10 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{ - Array, Attributes, ComplexType, Enum, Fixed, Map, PrimitiveType, Record, RecordField, Schema, - TypeName, -}; +use crate::schema::{ComplexType, PrimitiveType, Schema, TypeName}; use arrow_schema::DataType::*; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, @@ -82,25 +79,6 @@ impl AvroDataType { pub fn nullability(&self) -> Option { self.nullability } - - /// Convert this `AvroDataType`, which encapsulates an Arrow data type (`codec`) - /// plus nullability and metadata, back into an Avro `Schema<'a>`. - /// - /// - If `metadata["namespace"]` is present, we'll store it in the resulting schema for named types - /// (record, enum, fixed). - pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { - let inner_schema = self.codec.to_avro_schema(name); - let schema_with_namespace = maybe_add_namespace(inner_schema, self); - // If the field is nullable in Arrow, wrap Avro schema in a union: ["null", ]. - if self.nullability.is_some() { - Schema::Union(vec![ - Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - schema_with_namespace, - ]) - } else { - schema_with_namespace - } - } } /// If this is a named complex type (Record, Enum, Fixed), attach `namespace` @@ -246,172 +224,6 @@ impl Codec { }, } } - - /// Convert this `Codec` variant to an Avro `Schema<'a>`. - pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { - match self { - Codec::Null => Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - Codec::Boolean => Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)), - Codec::Int32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), - Codec::Int64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), - Codec::Float32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)), - Codec::Float64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), - Codec::Binary => Schema::TypeName(TypeName::Primitive(PrimitiveType::Bytes)), - Codec::Utf8 => Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), - // date32 => Avro int + logicalType=date - Codec::Date32 => Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Int), - attributes: Attributes { - logical_type: Some("date"), - additional: Default::default(), - }, - }), - // time-millis => Avro int with logicalType=time-millis - Codec::TimeMillis => Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Int), - attributes: Attributes { - logical_type: Some("time-millis"), - additional: Default::default(), - }, - }), - // time-micros => Avro long with logicalType=time-micros - Codec::TimeMicros => Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Long), - attributes: Attributes { - logical_type: Some("time-micros"), - additional: Default::default(), - }, - }), - // timestamp-millis => Avro long with logicalType=timestamp-millis or local-timestamp-millis - Codec::TimestampMillis(is_utc) => { - let logical_type = Some(if *is_utc { - "timestamp-millis" - } else { - "local-timestamp-millis" - }); - Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Long), - attributes: Attributes { - logical_type, - additional: Default::default(), - }, - }) - } - // timestamp-micros => Avro long with logicalType=timestamp-micros or local-timestamp-micros - Codec::TimestampMicros(is_utc) => { - let logical_type = Some(if *is_utc { - "timestamp-micros" - } else { - "local-timestamp-micros" - }); - Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Long), - attributes: Attributes { - logical_type, - additional: Default::default(), - }, - }) - } - Codec::Interval => Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Bytes), - attributes: Attributes { - logical_type: Some("duration"), - additional: Default::default(), - }, - }), - Codec::Fixed(size) => { - // Convert Arrow FixedSizeBinary => Avro fixed with name & size - Schema::Complex(ComplexType::Fixed(Fixed { - name, - namespace: None, - aliases: vec![], - size: *size as usize, - attributes: Attributes::default(), - })) - } - Codec::List(item_type) => { - // Avro array with "items" recursively derived - let items_schema = item_type.to_avro_schema("items"); - Schema::Complex(ComplexType::Array(Array { - items: Box::new(items_schema), - attributes: Attributes::default(), - })) - } - Codec::Struct(fields) => { - // Avro record with nested fields - let record_fields = fields - .iter() - .map(|f| { - let child_schema = f.data_type().to_avro_schema(f.name()); - RecordField { - name: f.name(), - doc: None, - r#type: child_schema, - default: None, - } - }) - .collect(); - Schema::Complex(ComplexType::Record(Record { - name, - namespace: None, - doc: None, - aliases: vec![], - fields: record_fields, - attributes: Attributes::default(), - })) - } - Codec::Enum(symbols) => { - // If there's a namespace in metadata, we will apply it later in maybe_add_namespace. - Schema::Complex(ComplexType::Enum(Enum { - name, - namespace: None, - doc: None, - aliases: vec![], - symbols: symbols.iter().map(|s| s.as_str()).collect(), - default: None, - attributes: Attributes::default(), - })) - } - Codec::Map(values) => { - let val_schema = values.to_avro_schema("values"); - Schema::Complex(ComplexType::Map(Map { - values: Box::new(val_schema), - attributes: Attributes::default(), - })) - } - Codec::Decimal(precision, scale, size) => { - // If size is Some(n), produce Avro "fixed", else "bytes". - if let Some(n) = size { - Schema::Complex(ComplexType::Fixed(Fixed { - name, - namespace: None, - aliases: vec![], - size: *n, - attributes: Attributes { - logical_type: Some("decimal"), - additional: HashMap::from([ - ("precision", serde_json::json!(*precision)), - ("scale", serde_json::json!(scale.unwrap_or(0))), - ("size", serde_json::json!(*n)), - ]), - }, - })) - } else { - // "type":"bytes", "logicalType":"decimal" - Schema::Type(crate::schema::Type { - r#type: TypeName::Primitive(PrimitiveType::Bytes), - attributes: Attributes { - logical_type: Some("decimal"), - additional: HashMap::from([ - ("precision", serde_json::json!(*precision)), - ("scale", serde_json::json!(scale.unwrap_or(0))), - ]), - }, - }) - } - } - } - } } impl From for Codec { @@ -748,204 +560,8 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { mod tests { use super::*; use arrow_schema::Field; - use serde_json::json; use std::sync::Arc; - #[test] - fn test_decimal256_tuple_variant_fixed() { - let c = arrow_type_to_codec(&Decimal256(60, 3)); - match c { - Codec::Decimal(p, s, Some(32)) => { - assert_eq!(p, 60); - assert_eq!(s, Some(3)); - } - _ => panic!("Expected decimal(60,3,Some(32))"), - } - let avro_dt = AvroDataType::from_codec(c); - let avro_schema = avro_dt.to_avro_schema("FixedDec"); - let j = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ - "type": "fixed", - "name": "FixedDec", - "aliases": [], - "size": 32, - "logicalType": "decimal", - "precision": 60, - "scale": 3 - }); - assert_eq!(j, expected); - } - - #[test] - fn test_decimal128_tuple_variant_fixed() { - let c = Codec::Decimal(6, Some(2), Some(4)); - let dt = c.data_type(); - match dt { - Decimal128(p, s) => { - assert_eq!(p, 6); - assert_eq!(s, 2); - } - _ => panic!("Expected decimal(6,2) arrow type"), - } - let avro_dt = AvroDataType::from_codec(c); - let schema = avro_dt.to_avro_schema("FixedDec"); - let j = serde_json::to_value(&schema).unwrap(); - let expected = json!({ - "type": "fixed", - "name": "FixedDec", - "aliases": [], - "size": 4, - "logicalType": "decimal", - "precision": 6, - "scale": 2, - }); - assert_eq!(j, expected); - } - - #[test] - fn test_decimal_size_decision() { - let codec = Codec::Decimal(10, Some(3), Some(16)); - let dt = codec.data_type(); - match dt { - Decimal128(precision, scale) => { - assert_eq!(precision, 10); - assert_eq!(scale, 3); - } - _ => panic!("Expected Decimal128"), - } - let codec = Codec::Decimal(18, Some(4), Some(32)); - let dt = codec.data_type(); - match dt { - Decimal256(precision, scale) => { - assert_eq!(precision, 18); - assert_eq!(scale, 4); - } - _ => panic!("Expected Decimal256"), - } - let codec = Codec::Decimal(8, Some(2), None); - let dt = codec.data_type(); - match dt { - Decimal128(precision, scale) => { - assert_eq!(precision, 8); - assert_eq!(scale, 2); - } - _ => panic!("Expected Decimal128"), - } - } - - #[test] - fn test_avro_data_type_new_and_from_codec() { - let dt1 = AvroDataType::new( - Codec::Int32, - Some(Nullability::NullFirst), - HashMap::from([("namespace".into(), "my.ns".into())]), - ); - let actual_str = format!("{:?}", dt1.nullability()); - let expected_str = format!("{:?}", Some(Nullability::NullFirst)); - assert_eq!(actual_str, expected_str); - let actual_str2 = format!("{:?}", dt1.codec()); - let expected_str2 = format!("{:?}", &Codec::Int32); - assert_eq!(actual_str2, expected_str2); - assert_eq!(dt1.metadata.get("namespace"), Some(&"my.ns".to_string())); - let dt2 = AvroDataType::from_codec(Codec::Float64); - let actual_str4 = format!("{:?}", dt2.codec()); - let expected_str4 = format!("{:?}", &Codec::Float64); - assert_eq!(actual_str4, expected_str4); - assert!(dt2.metadata.is_empty()); - } - - #[test] - fn test_avro_data_type_field_with_name() { - let dt = AvroDataType::new( - Codec::Binary, - None, - HashMap::from([("something".into(), "else".into())]), - ); - let f = dt.field_with_name("bin_col"); - assert_eq!(f.name(), "bin_col"); - assert_eq!(f.data_type(), &Binary); - assert!(!f.is_nullable()); - assert_eq!(f.metadata().get("something"), Some(&"else".to_string())); - } - - #[test] - fn test_avro_data_type_to_avro_schema_with_namespace_record() { - let mut meta = HashMap::new(); - meta.insert("namespace".to_string(), "com.example".to_string()); - let fields = Arc::from(vec![ - AvroField { - name: "id".to_string(), - data_type: AvroDataType::from_codec(Codec::Int32), - }, - AvroField { - name: "label".to_string(), - data_type: AvroDataType::new( - Codec::Utf8, - Some(Nullability::NullFirst), - Default::default(), - ), - }, - ]); - let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); - let avro_schema = top_level.to_avro_schema("TopRecord"); - let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ - "type": "record", - "name": "TopRecord", - "namespace": "com.example", - "doc": null, - "logicalType": null, - "aliases": [], - "fields": [ - { "name": "id", "doc": null, "type": "int" }, - { "name": "label", "doc": null, "type": ["null","string"] } - ], - }); - assert_eq!(json_val, expected); - } - - #[test] - fn test_avro_data_type_to_avro_schema_with_namespace_enum() { - let mut meta = HashMap::new(); - meta.insert("namespace".to_string(), "com.example.enum".to_string()); - - let enum_dt = AvroDataType::new( - Codec::Enum(vec!["A".to_string(), "B".to_string(), "C".to_string()]), - None, - meta, - ); - let avro_schema = enum_dt.to_avro_schema("MyEnum"); - let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ - "type": "enum", - "name": "MyEnum", - "logicalType": null, - "namespace": "com.example.enum", - "doc": null, - "aliases": [], - "symbols": ["A","B","C"] - }); - assert_eq!(json_val, expected); - } - - #[test] - fn test_avro_data_type_to_avro_schema_with_namespace_fixed() { - let mut meta = HashMap::new(); - meta.insert("namespace".to_string(), "com.example.fixed".to_string()); - let fixed_dt = AvroDataType::new(Codec::Fixed(8), None, meta); - let avro_schema = fixed_dt.to_avro_schema("MyFixed"); - let json_val = serde_json::to_value(&avro_schema).unwrap(); - let expected = json!({ - "type": "fixed", - "name": "MyFixed", - "logicalType": null, - "namespace": "com.example.fixed", - "aliases": [], - "size": 8 - }); - assert_eq!(json_val, expected); - } - #[test] fn test_avro_field() { let field_codec = AvroDataType::from_codec(Codec::Int64); diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index ef3bd082d0e8..d01d681b7af0 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -29,7 +29,6 @@ mod schema; mod compression; mod codec; -mod writer; #[cfg(test)] mod test_util { diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs deleted file mode 100644 index 635333718ac7..000000000000 --- a/arrow-avro/src/writer/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod schema; -mod vlq; - -#[cfg(test)] -mod test { - use arrow_array::RecordBatch; - use std::fs::File; - use std::io::BufWriter; - - fn write_file(file: &str, batch: &RecordBatch) { - let file = File::open(file).unwrap(); - let mut writer = BufWriter::new(file); - } -} diff --git a/arrow-avro/src/writer/schema.rs b/arrow-avro/src/writer/schema.rs deleted file mode 100644 index 521ea9e6b107..000000000000 --- a/arrow-avro/src/writer/schema.rs +++ /dev/null @@ -1,277 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::codec::{AvroDataType, AvroField, Codec}; -use crate::schema::Schema; -use arrow_array::RecordBatch; -use std::sync::Arc; - -fn record_batch_to_avro_schema<'a>( - batch: &'a RecordBatch, - record_name: &'a str, - top_level_data_type: &'a AvroDataType, -) -> Schema<'a> { - top_level_data_type.to_avro_schema(record_name) -} - -pub fn to_avro_json_schema( - batch: &RecordBatch, - record_name: &str, -) -> Result { - let avro_fields: Vec = batch - .schema() - .fields() - .iter() - .map(|arrow_field| crate::codec::arrow_field_to_avro_field(arrow_field)) - .collect(); - let top_level_data_type = AvroDataType::from_codec(Codec::Struct(Arc::from(avro_fields))); - let avro_schema = record_batch_to_avro_schema(batch, record_name, &top_level_data_type); - serde_json::to_string_pretty(&avro_schema) -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray}; - use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; - use serde_json::{json, Value}; - use std::sync::Arc; - - #[test] - fn test_record_batch_to_avro_schema_basic() { - let arrow_schema = Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - - let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); - let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); - let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) - .expect("Failed to create RecordBatch"); - - // Convert the batch -> Avro `Schema` - let avro_schema = to_avro_json_schema(&batch, "MyTestRecord") - .expect("Failed to convert RecordBatch to Avro JSON schema"); - let actual_json: Value = serde_json::from_str(&avro_schema) - .expect("Invalid JSON returned by to_avro_json_schema"); - - let expected_json = json!({ - "type": "record", - "name": "MyTestRecord", - "aliases": [], - "doc": null, - "logicalType": null, - "fields": [ - { - "name": "id", - "doc": null, - "type": "int" - }, - { - "name": "name", - "doc": null, - "type": ["null", "string"] - } - ] - }); - assert_eq!( - actual_json, expected_json, - "Avro Schema JSON does not match expected" - ); - } - - #[test] - fn test_to_avro_json_schema_basic() { - let arrow_schema = Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("desc", DataType::Utf8, true), - ])); - let col_id = Arc::new(Int32Array::from(vec![10, 20, 30])); - let col_desc = Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])); - let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_desc]) - .expect("Failed to create RecordBatch"); - let json_schema_string = to_avro_json_schema(&batch, "AnotherTestRecord") - .expect("Failed to convert RecordBatch to Avro JSON schema"); - let actual_json: Value = serde_json::from_str(&json_schema_string) - .expect("Invalid JSON returned by to_avro_json_schema"); - let expected_json = json!({ - "type": "record", - "name": "AnotherTestRecord", - "aliases": [], - "doc": null, - "logicalType": null, - "fields": [ - { - "name": "id", - "type": "int", - "doc": null, - }, - { - "name": "desc", - "type": ["null", "string"], - "doc": null, - } - ] - }); - assert_eq!( - actual_json, expected_json, - "JSON schema mismatch for to_avro_json_schema" - ); - } - - #[test] - fn test_to_avro_json_schema_single_nonnull_int() { - let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "id", - DataType::Int32, - false, - )])); - let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); - let batch = - RecordBatch::try_new(arrow_schema, vec![col_id]).expect("Failed to create RecordBatch"); - let avro_json_string = to_avro_json_schema(&batch, "MySingleIntRecord") - .expect("Failed to generate Avro JSON schema"); - let actual_json: Value = - serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); - let expected_json = json!({ - "type": "record", - "name": "MySingleIntRecord", - "aliases": [], - "doc": null, - "logicalType": null, - "fields": [ - { - "name": "id", - "type": "int", - "doc": null, - } - ] - }); - assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); - } - - #[test] - fn test_to_avro_json_schema_two_fields_nullable_string() { - let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); - let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); - let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) - .expect("Failed to create RecordBatch"); - let avro_json_string = - to_avro_json_schema(&batch, "MyRecord").expect("Failed to generate Avro JSON schema"); - let actual_json: Value = - serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); - let expected_json = json!({ - "type": "record", - "name": "MyRecord", - "aliases": [], - "doc": null, - "logicalType": null, - "fields": [ - { - "name": "id", - "type": "int", - "doc": null, - }, - { - "name": "name", - "doc": null, - "type": [ - "null", - "string", - ] - } - ] - }); - assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); - } - - #[test] - fn test_to_avro_json_schema_nested_struct() { - let inner_fields = Fields::from(vec![ - Field::new("inner_int", DataType::Int32, false), - Field::new("inner_str", DataType::Utf8, true), - ]); - let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( - "my_struct", - DataType::Struct(inner_fields), - true, - )])); - let inner_int_col = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef; - let inner_str_col = - Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef; - let fields_arrays = vec![ - ( - Arc::new(Field::new("inner_int", DataType::Int32, false)), - inner_int_col, - ), - ( - Arc::new(Field::new("inner_str", DataType::Utf8, true)), - inner_str_col, - ), - ]; - let struct_array = StructArray::from(fields_arrays); - let batch = RecordBatch::try_new(arrow_schema, vec![Arc::new(struct_array)]) - .expect("Failed to create RecordBatch"); - let avro_json_string = to_avro_json_schema(&batch, "NestedRecord") - .expect("Failed to generate Avro JSON schema"); - let actual_json: Value = - serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); - let expected_json = json!({ - "type": "record", - "name": "NestedRecord", - "aliases": [], - "doc": null, - "logicalType": null, - "fields": [ - { - "name": "my_struct", - "doc": null, - "type": [ - "null", - { - "type": "record", - "name": "my_struct", - "aliases": [], - "doc": null, - "logicalType": null, - "fields": [ - { - "name": "inner_int", - "type": "int", - "doc": null, - }, - { - "name": "inner_str", - "doc": null, - "type": [ - "null", - "string", - ] - } - ] - } - ] - } - ] - }); - assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); - } -} diff --git a/arrow-avro/src/writer/vlq.rs b/arrow-avro/src/writer/vlq.rs deleted file mode 100644 index 4cf26e23856d..000000000000 --- a/arrow-avro/src/writer/vlq.rs +++ /dev/null @@ -1,114 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -/// Encoder for zig-zag encoded variable length integers -/// -/// This complements the VLQ decoding logic used by Avro. Zig-zag encoding maps signed integers -/// to unsigned integers so that small magnitudes (both positive and negative) produce smaller varints. -/// After zig-zag encoding, values are encoded as a series of bytes where the lower 7 bits are data -/// and the high bit indicates if another byte follows. -/// -/// See also: -/// -/// -#[derive(Debug, Default)] -pub struct VLQEncoder; - -impl VLQEncoder { - /// Encode a signed 64-bit integer `value` into `output` using Avro's zig-zag varint encoding. - /// - /// Zig-zag encoding: - /// ```text - /// encoded = (value << 1) ^ (value >> 63) - /// ``` - /// - /// Then `encoded` is written as a variable-length integer (varint): - /// - Extract 7 bits at a time - /// - If more bits remain, set the MSB of the current byte to 1 and continue - /// - Otherwise, MSB is 0 and this is the last byte - pub fn long(&mut self, value: i64, output: &mut Vec) { - let zigzag = ((value << 1) ^ (value >> 63)) as u64; - self.encode_varint(zigzag, output); - } - - fn encode_varint(&self, mut val: u64, output: &mut Vec) { - while (val & !0x7F) != 0 { - output.push(((val & 0x7F) as u8) | 0x80); - val >>= 7; - } - output.push(val as u8); - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn decode_varint(buf: &mut &[u8]) -> Option { - let mut value = 0_u64; - for i in 0..10 { - let b = buf.get(i).copied()?; - let lower_7 = (b & 0x7F) as u64; - value |= lower_7 << (7 * i); - if b & 0x80 == 0 { - *buf = &buf[i + 1..]; - return Some(value); - } - } - None // more than 10 bytes or not terminated properly - } - - fn decode_zigzag(val: u64) -> i64 { - ((val >> 1) as i64) ^ -((val & 1) as i64) - } - - fn decode_long(buf: &mut &[u8]) -> Option { - let val = decode_varint(buf)?; - Some(decode_zigzag(val)) - } - - fn round_trip(value: i64) { - let mut encoder = VLQEncoder; - let mut buf = Vec::new(); - encoder.long(value, &mut buf); - let mut slice = buf.as_slice(); - let decoded = decode_long(&mut slice).expect("Failed to decode value"); - assert_eq!(decoded, value, "Round-trip mismatch for value {}", value); - assert!(slice.is_empty(), "Not all bytes consumed"); - } - - #[test] - fn test_round_trip() { - round_trip(0); - round_trip(1); - round_trip(-1); - round_trip(12345678); - round_trip(-12345678); - round_trip(i64::MAX); - round_trip(i64::MIN); - } - - #[test] - fn test_random_values() { - use rand::Rng; - let mut rng = rand::thread_rng(); - for _ in 0..1000 { - let val: i64 = rng.gen(); - round_trip(val); - } - } -} From b1575847a95b88bf12025eea9762822db7a2c7cf Mon Sep 17 00:00:00 2001 From: Sven Cowart Date: Tue, 14 Jan 2025 10:18:52 -0800 Subject: [PATCH 15/38] chore: clean up UUIDs and Durations --- arrow-avro/src/codec.rs | 35 ++++++++++++++++++++------------- arrow-avro/src/reader/record.rs | 6 +++--- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 91bb413b36ba..fb8c765794c1 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -155,6 +155,8 @@ pub enum Codec { Float64, Binary, Utf8, + Decimal(usize, Option, Option), + Uuid, Date32, TimeMillis, TimeMicros, @@ -165,11 +167,10 @@ pub enum Codec { Fixed(i32), List(Arc), Struct(Arc<[AvroField]>), - Interval, + Duration, /// In Arrow, use Dictionary(Int32, Utf8) for Enum. Enum(Vec), Map(Arc), - Decimal(usize, Option, Option), } impl Codec { @@ -184,6 +185,18 @@ impl Codec { Self::Float64 => Float64, Self::Binary => Binary, Self::Utf8 => Utf8, + Self::Decimal(precision, scale, size) => match size { + Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), + Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), + None if *precision <= DECIMAL128_MAX_PRECISION as usize + && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize => + { + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + } + _ => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), + }, + // arrow-rs does not support the UUID Canonical Extension Type yet, so this is a temporary workaround. + Self::Uuid => FixedSizeBinary(16), Self::Date32 => Date32, Self::TimeMillis => Time32(TimeUnit::Millisecond), Self::TimeMicros => Time64(TimeUnit::Microsecond), @@ -193,7 +206,7 @@ impl Codec { Self::TimestampMicros(is_utc) => { Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } - Self::Interval => Interval(IntervalUnit::MonthDayNano), + Self::Duration => Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => FixedSizeBinary(*size), Self::List(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), Self::Struct(f) => Struct(f.iter().map(|x| x.field()).collect()), @@ -212,16 +225,6 @@ impl Codec { )), false, ), - Self::Decimal(precision, scale, size) => match size { - Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), - Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), - None if *precision <= DECIMAL128_MAX_PRECISION as usize - && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize => - { - Decimal128(*precision as u8, scale.unwrap_or(0) as i8) - } - _ => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), - }, } } } @@ -450,6 +453,7 @@ fn make_data_type<'a>( None, ); } + (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid, (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, @@ -461,7 +465,7 @@ fn make_data_type<'a>( (Some("local-timestamp-micros"), c @ Codec::Int64) => { *c = Codec::TimestampMicros(false) } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, + (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, (Some(logical), _) => { // Insert unrecognized logical type into metadata field.metadata.insert("logicalType".into(), logical.into()); @@ -510,6 +514,9 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Float64 => Codec::Float64, Utf8 => Codec::Utf8, Binary | LargeBinary => Codec::Binary, + // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. + // It is unsafe to assume all FixedSizeBinary(16) are UUIDs. + // Uuid => Codec::Uuid, Date32 => Codec::Date32, Time32(TimeUnit::Millisecond) => Codec::TimeMillis, Time64(TimeUnit::Microsecond) => Codec::TimeMicros, diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 6fe4ae87bef3..ab673577537f 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -169,7 +169,7 @@ impl Decoder { Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Interval => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::List(item) => { let item_decoder = Box::new(Self::try_new(item)?); Self::List( @@ -866,7 +866,7 @@ mod tests { fn test_interval_decoding() { // Avro interval => 12 bytes => [ months i32, days i32, ms i32 ] // decode 2 rows => row1 => months=1, days=2, ms=100 => row2 => months=-1, days=10, ms=9999 - let dt = AvroDataType::from_codec(Codec::Interval); + let dt = AvroDataType::from_codec(Codec::Duration); let mut dec = Decoder::try_new(&dt).unwrap(); // row1 => months=1 => 01,00,00,00, days=2 => 02,00,00,00, ms=100 => 64,00,00,00 // row2 => months=-1 => 0xFF,0xFF,0xFF,0xFF, days=10 => 0x0A,0x00,0x00,0x00, ms=9999 => 0x0F,0x27,0x00,0x00 @@ -903,7 +903,7 @@ mod tests { #[test] fn test_interval_decoding_with_nulls() { // Avro union => [ interval, null] - let dt = AvroDataType::from_codec(Codec::Interval); + let dt = AvroDataType::from_codec(Codec::Duration); let child = Decoder::try_new(&dt).unwrap(); let mut dec = Decoder::Nullable( Nullability::NullFirst, From 8d0fe771406fb8e4f1e06d10ae08bf368d104644 Mon Sep 17 00:00:00 2001 From: Sven Cowart Date: Tue, 14 Jan 2025 10:18:52 -0800 Subject: [PATCH 16/38] chore: clean up UUIDs Durations --- arrow-avro/src/reader/record.rs | 9 +++++---- arrow-avro/src/schema.rs | 32 +++++++++++++++++++++++++------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index ab673577537f..e4c2f372d25e 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -159,6 +159,11 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), + Codec::Decimal(precision, scale, size) => { + let builder = DecimalBuilder::new(*precision, *scale, *size)?; + Self::Decimal(*precision, *scale, *size, builder) + } + Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), @@ -209,10 +214,6 @@ impl Decoder { 0, ) } - Codec::Decimal(precision, scale, size) => { - let builder = DecimalBuilder::new(*precision, *scale, *size)?; - Self::Decimal(*precision, *scale, *size, builder) - } }; // Wrap in Nullable if needed diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 8e3f23ffbb5e..c675c33ccf76 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -210,7 +210,7 @@ pub struct Fixed<'a> { #[cfg(test)] mod tests { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use arrow_schema::{DataType, Fields, TimeUnit}; use serde_json::json; @@ -237,7 +237,7 @@ mod tests { "logicalType":"timestamp-micros" }"#, ) - .unwrap(); + .unwrap(); let timestamp = Type { r#type: TypeName::Primitive(PrimitiveType::Long), @@ -260,7 +260,7 @@ mod tests { "scale":2 }"#, ) - .unwrap(); + .unwrap(); let decimal = ComplexType::Fixed(Fixed { name: "fixed", @@ -300,7 +300,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -333,7 +333,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -392,7 +392,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -453,7 +453,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -508,5 +508,23 @@ mod tests { attributes: Default::default(), })) ); + + let t: Type = serde_json::from_str( + r#"{ + "type":"string", + "logicalType":"uuid" + }"#, + ) + .unwrap(); + + let uuid = Type { + r#type: TypeName::Primitive(PrimitiveType::String), + attributes: Attributes { + logical_type: Some("uuid"), + additional: Default::default(), + }, + }; + + assert_eq!(t, uuid); } } From 0496dcfa8a22ad81abb2fa0fadfc5336d20caea9 Mon Sep 17 00:00:00 2001 From: Sven Cowart Date: Tue, 14 Jan 2025 11:00:43 -0800 Subject: [PATCH 17/38] add aliases to record fields --- arrow-avro/src/schema.rs | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index c675c33ccf76..d8722048a463 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -142,6 +142,8 @@ pub struct RecordField<'a> { pub name: &'a str, #[serde(borrow, default)] pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, #[serde(borrow, default, skip_serializing_if = "Option::is_none")] @@ -254,6 +256,7 @@ mod tests { "type":"fixed", "name":"fixed", "namespace":"topLevelRecord.value", + "aliases":[], "size":11, "logicalType":"decimal", "precision":25, @@ -312,6 +315,7 @@ mod tests { fields: vec![RecordField { name: "value", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::Complex(decimal), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -346,12 +350,14 @@ mod tests { RecordField { name: "value", doc: None, + aliases: vec![], r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, }, RecordField { name: "next", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Ref("LongList")), @@ -405,6 +411,7 @@ mod tests { RecordField { name: "id", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -414,6 +421,7 @@ mod tests { RecordField { name: "timestamp_col", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::Type(timestamp), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -466,6 +474,7 @@ mod tests { RecordField { name: "clientHash", doc: None, + aliases: vec![], r#type: Schema::Complex(ComplexType::Fixed(Fixed { name: "MD5", namespace: None, @@ -478,6 +487,7 @@ mod tests { RecordField { name: "clientProtocol", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), @@ -487,12 +497,14 @@ mod tests { RecordField { name: "serverHash", doc: None, + aliases: vec![], r#type: Schema::TypeName(TypeName::Ref("MD5")), default: None, }, RecordField { name: "meta", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::Complex(ComplexType::Map(Map { @@ -526,5 +538,35 @@ mod tests { }; assert_eq!(t, uuid); + + // Ensure aliases are parsed + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "Foo", + "aliases": ["Bar"], + "fields" : [ + {"name":"id","aliases":["uid"],"type":"int"} + ] + }"#, + ) + .unwrap(); + + let with_aliases = Schema::Complex(ComplexType::Record(Record { + name: "Foo", + namespace: None, + doc: None, + aliases: vec!["Bar"], + fields: vec![RecordField { + name: "id", + aliases: vec!["uid"], + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + default: None, + }], + attributes: Default::default(), + })); + + assert_eq!(schema, with_aliases); } } From 92e105ffda634c3732c27b6a529169847e209881 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Wed, 15 Jan 2025 14:32:44 -0600 Subject: [PATCH 18/38] * Fixed size issue in codec.rs fixed type decimal * Add test_fixed_length_decimal to ensure fix works Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 10 +--------- arrow-avro/src/reader/mod.rs | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index fb8c765794c1..bcd11bb56c31 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -349,14 +349,6 @@ fn make_data_type<'a>( .ok_or_else(|| { ArrowError::ParseError("Decimal requires precision".to_string()) })?; - let size_val = f - .attributes - .additional - .get("size") - .and_then(|v| v.as_u64()) - .ok_or_else(|| { - ArrowError::ParseError("Decimal requires size".to_string()) - })?; let scale = f .attributes .additional @@ -369,7 +361,7 @@ fn make_data_type<'a>( codec: Codec::Decimal( precision as usize, Some(scale.unwrap_or(0) as usize), - Some(size_val as usize), + Some(size as usize), ), }; resolver.register(f.name, namespace, field.clone()); diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 12fa67d9c8e3..20ab0ad88a29 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -76,11 +76,12 @@ fn read_blocks(mut reader: R) -> impl Iterator = (1..=24).map(|n| n as i128 * 100).collect(); + let array = Decimal128Array::from_iter_values(decimal_values) + .with_precision_and_scale(25, 2) + .unwrap(); + let mut meta = HashMap::new(); + meta.insert("precision".to_string(), "25".to_string()); + meta.insert("scale".to_string(), "2".to_string()); + let field_with_meta = + Field::new("value", DataType::Decimal128(25, 2), true).with_metadata(meta); + let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); + let expected_batch = RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(array)]) + .expect("Failed to build expected RecordBatch"); + assert_eq!( + actual_batch, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data" + ); + } } From 51586d8f39e751b9e696b4654c507aad875c628f Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Thu, 16 Jan 2025 18:36:13 -0600 Subject: [PATCH 19/38] * Removed Avro writer specific `maybe_add_namespace` function from codec.rs Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index bcd11bb56c31..8e9f74494d73 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -81,23 +81,6 @@ impl AvroDataType { } } -/// If this is a named complex type (Record, Enum, Fixed), attach `namespace` -/// from `dt.metadata["namespace"]` if present. Otherwise, return as-is. -fn maybe_add_namespace<'a>(mut schema: Schema<'a>, dt: &'a AvroDataType) -> Schema<'a> { - if let Some(ns_str) = dt.metadata.get("namespace") { - if let Schema::Complex(ref mut c) = schema { - match c { - ComplexType::Record(r) => r.namespace = Some(ns_str), - ComplexType::Enum(e) => e.namespace = Some(ns_str), - ComplexType::Fixed(f) => f.namespace = Some(ns_str), - // Arrays and Maps do not have a namespace field, so do nothing - _ => {} - } - } - } - schema -} - /// A named [`AvroDataType`] #[derive(Debug, Clone)] pub struct AvroField { From a51b2025401343eaa3ad5b3317fa8771950b0751 Mon Sep 17 00:00:00 2001 From: Sven Cowart Date: Thu, 16 Jan 2025 21:33:26 -0500 Subject: [PATCH 20/38] Aligns mapping between avro and codec and tests all paths of arrow_field_to_avro_field --- arrow-avro/src/codec.rs | 385 ++++++++++++++++++++++---------- arrow-avro/src/reader/record.rs | 26 +-- 2 files changed, 281 insertions(+), 130 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 8e9f74494d73..d42840e323fd 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,6 +16,7 @@ // under the License. use crate::schema::{ComplexType, PrimitiveType, Schema, TypeName}; +use arrow_array::Array; use arrow_schema::DataType::*; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, @@ -137,7 +138,7 @@ pub enum Codec { Float32, Float64, Binary, - Utf8, + String, Decimal(usize, Option, Option), Uuid, Date32, @@ -147,13 +148,13 @@ pub enum Codec { TimestampMillis(bool), /// TimestampMicros(is_utc) TimestampMicros(bool), - Fixed(i32), - List(Arc), - Struct(Arc<[AvroField]>), Duration, - /// In Arrow, use Dictionary(Int32, Utf8) for Enum. - Enum(Vec), + Record(Arc<[AvroField]>), + /// In Arrow, use Dictionary(Utf8, Int32) for Enum. + Enum(Arc<[String]>, Arc<[i32]>), + Array(Arc), Map(Arc), + Fixed(i32), } impl Codec { @@ -167,7 +168,7 @@ impl Codec { Self::Float32 => Float32, Self::Float64 => Float64, Self::Binary => Binary, - Self::Utf8 => Utf8, + Self::String => Utf8, Self::Decimal(precision, scale, size) => match size { Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), @@ -190,13 +191,12 @@ impl Codec { Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } Self::Duration => Interval(IntervalUnit::MonthDayNano), - Self::Fixed(size) => FixedSizeBinary(*size), - Self::List(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), - Self::Struct(f) => Struct(f.iter().map(|x| x.field()).collect()), - Self::Enum(_symbols) => { - // Produce a Dictionary type with index = Int32, value = Utf8 - Dictionary(Box::new(Int32), Box::new(Utf8)) + Self::Record(f) => Struct(f.iter().map(|x| x.field()).collect()), + Self::Enum(symbols, values) => { + // Produce a Dictionary type with index = Utf8, value = Int32 + Dictionary(Box::new(Utf8), Box::new(Int32)) } + Self::Array(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), Self::Map(values) => Map( Arc::new(Field::new( "entries", @@ -208,6 +208,7 @@ impl Codec { )), false, ), + Self::Fixed(size) => FixedSizeBinary(*size), } } } @@ -222,7 +223,7 @@ impl From for Codec { PrimitiveType::Float => Self::Float32, PrimitiveType::Double => Self::Float64, PrimitiveType::Bytes => Self::Binary, - PrimitiveType::String => Self::Utf8, + PrimitiveType::String => Self::String, } } } @@ -304,7 +305,7 @@ fn make_data_type<'a>( .collect::>()?; let field = AvroDataType { nullability: None, - codec: Codec::Struct(fields), + codec: Codec::Record(fields), metadata: r.attributes.field_metadata(), }; resolver.register(r.name, namespace, field.clone()); @@ -315,7 +316,7 @@ fn make_data_type<'a>( Ok(AvroDataType { nullability: None, metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), + codec: Codec::Array(Arc::new(field)), }) } ComplexType::Fixed(f) => { @@ -360,15 +361,13 @@ fn make_data_type<'a>( } } ComplexType::Enum(e) => { - let symbols = e - .symbols - .iter() - .map(|sym| sym.to_string()) - .collect::>(); let field = AvroDataType { nullability: None, metadata: e.attributes.field_metadata(), - codec: Codec::Enum(symbols), + codec: Codec::Enum( + Arc::from(e.symbols.iter().map(|s| s.to_string()).collect::>()), + Arc::from(vec![]), + ), }; resolver.register(e.name, namespace, field.clone()); Ok(field) @@ -428,7 +427,7 @@ fn make_data_type<'a>( None, ); } - (Some("uuid"), c @ Codec::Utf8) => *c = Codec::Uuid, + (Some("uuid"), c @ Codec::String) => *c = Codec::Uuid, (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, @@ -487,8 +486,10 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Int64 => Codec::Int64, Float32 => Codec::Float32, Float64 => Codec::Float64, - Utf8 => Codec::Utf8, Binary | LargeBinary => Codec::Binary, + Utf8 => Codec::String, + Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), + Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. // It is unsafe to assume all FixedSizeBinary(16) are UUIDs. // Uuid => Codec::Uuid, @@ -503,17 +504,33 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { Codec::TimestampMicros(true) } - FixedSizeBinary(n) => Codec::Fixed(*n), - Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), - Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), - Dictionary(index_type, value_type) => { - if let Utf8 = **value_type { - Codec::Enum(vec![]) + Interval(IntervalUnit::MonthDayNano) => Codec::Duration, + Struct(child_fields) => { + let avro_fields: Vec = child_fields + .iter() + .map(|f_ref| arrow_field_to_avro_field(f_ref.as_ref())) + .collect(); + Codec::Record(Arc::from(avro_fields)) + } + Dictionary(symbol_type, value_type) => { + if let Utf8 = **symbol_type { + Codec::Enum( + Arc::from(Vec::::new()), + Arc::from(Vec::::new()), + ) } else { // Fallback to Utf8 - Codec::Utf8 + Codec::String } } + List(field) => { + let sub_codec = arrow_type_to_codec(field.data_type()); + Codec::Array(Arc::new(AvroDataType { + nullability: field.is_nullable().then_some(Nullability::NullFirst), + metadata: field.metadata().clone(), + codec: sub_codec, + })) + } Map(field, _keys_sorted) => { if let Struct(child_fields) = field.data_type() { let value_field = &child_fields[1]; @@ -524,17 +541,11 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { codec: sub_codec, })) } else { - Codec::Map(Arc::new(AvroDataType::from_codec(Codec::Utf8))) + Codec::Map(Arc::new(AvroDataType::from_codec(Codec::String))) } } - Struct(child_fields) => { - let avro_fields: Vec = child_fields - .iter() - .map(|f_ref| arrow_field_to_avro_field(f_ref.as_ref())) - .collect(); - Codec::Struct(Arc::from(avro_fields)) - } - _ => Codec::Utf8, + FixedSizeBinary(n) => Codec::Fixed(*n), + _ => Codec::String, } } @@ -561,52 +572,6 @@ mod tests { assert!(!arrow_field.is_nullable()); } - #[test] - fn test_arrow_field_to_avro_field() { - let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( - "namespace".to_string(), - "arrow_meta_ns".to_string(), - )])); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert_eq!(avro_field.name(), "test_meta"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); - let expected_str = format!("{:?}", &Codec::Utf8); - assert_eq!(actual_str, expected_str); - let actual_str = format!("{:?}", avro_field.data_type().nullability()); - let expected_str = format!("{:?}", Some(Nullability::NullFirst)); - assert_eq!(actual_str, expected_str); - assert_eq!( - avro_field.data_type().metadata.get("namespace"), - Some(&"arrow_meta_ns".to_string()) - ); - } - - #[test] - fn test_codec_struct() { - let fields = Arc::from(vec![ - AvroField { - name: "a".to_string(), - data_type: AvroDataType::from_codec(Codec::Boolean), - }, - AvroField { - name: "b".to_string(), - data_type: AvroDataType::from_codec(Codec::Float64), - }, - ]); - let codec = Codec::Struct(fields); - let dt = codec.data_type(); - match dt { - Struct(fields) => { - assert_eq!(fields.len(), 2); - assert_eq!(fields[0].name(), "a"); - assert_eq!(fields[0].data_type(), &Boolean); - assert_eq!(fields[1].name(), "b"); - assert_eq!(fields[1].data_type(), &Float64); - } - _ => panic!("Expected Struct data type"), - } - } - #[test] fn test_codec_fixedsizebinary() { let codec = Codec::Fixed(12); @@ -618,58 +583,244 @@ mod tests { } #[test] - fn test_utc_timestamp_millis() { + fn test_arrow_field_to_avro_field() { + let arrow_field = Field::new("Null", Null, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Null)); + + let arrow_field = Field::new("Boolean", Boolean, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Boolean)); + + let arrow_field = Field::new("Int32", Int32, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Int32)); + + let arrow_field = Field::new("Int64", Int64, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Int64)); + + let arrow_field = Field::new("Float32", Float32, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Float32)); + + let arrow_field = Field::new("Float64", Float64, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Float64)); + + let arrow_field = Field::new("Binary", Binary, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Binary)); + + let arrow_field = Field::new("Utf8", Utf8, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::String)); + + let arrow_field = Field::new("Decimal128", Decimal128(1, 2), true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!( + avro_field.data_type().codec(), + Codec::Decimal(1, Some(2), Some(16)) + )); + + let arrow_field = Field::new("Decimal256", Decimal256(1, 2), true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!( + avro_field.data_type().codec(), + Codec::Decimal(1, Some(2), Some(32)) + )); + + // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. + // let arrow_field = Field::new("Uuid", FixedSizeBinary(16), true); + // let avro_field = arrow_field_to_avro_field(&arrow_field); + // let codec = avro_field.data_type().codec(); + // assert!( + // matches!(codec, Codec::Uuid), + // ); + + let arrow_field = Field::new("Date32", Date32, true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Date32)); + + let arrow_field = Field::new("Time32", Time32(TimeUnit::Millisecond), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::TimeMillis)); + + let arrow_field = Field::new("Time32", Time64(TimeUnit::Microsecond), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::TimeMicros)); + let arrow_field = Field::new( "utc_ts_ms", Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - let codec = avro_field.data_type().codec(); - assert!( - matches!(codec, Codec::TimestampMillis(true)), - "Expected Codec::TimestampMillis(true), got: {:?}", - codec - ); - } + assert!(matches!( + avro_field.data_type().codec(), + Codec::TimestampMillis(true) + )); - #[test] - fn test_utc_timestamp_micros() { let arrow_field = Field::new( "utc_ts_us", Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - let codec = avro_field.data_type().codec(); - assert!( - matches!(codec, Codec::TimestampMicros(true)), - "Expected Codec::TimestampMicros(true), got: {:?}", - codec - ); - } + assert!(matches!( + avro_field.data_type().codec(), + Codec::TimestampMicros(true) + )); - #[test] - fn test_local_timestamp_millis() { let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); - let codec = avro_field.data_type().codec(); - assert!( - matches!(codec, Codec::TimestampMillis(false)), - "Expected Codec::TimestampMillis(false), got: {:?}", - codec + assert!(matches!( + avro_field.data_type().codec(), + Codec::TimestampMillis(false) + )); + + let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!( + avro_field.data_type().codec(), + Codec::TimestampMicros(false) + )); + + let arrow_field = Field::new("Interval", Interval(IntervalUnit::MonthDayNano), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Duration)); + + let arrow_field = Field::new( + "Struct", + Struct(Fields::from(vec![ + Field::new("a", Boolean, false), + Field::new("b", Float64, false), + ])), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + match avro_field.data_type().codec() { + Codec::Record(fields) => { + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), "a"); + assert!(matches!(fields[0].data_type().codec(), Codec::Boolean)); + assert_eq!(fields[1].name(), "b"); + assert!(matches!(fields[1].data_type().codec(), Codec::Float64)); + } + _ => panic!("Expected Record data type"), + } + + let arrow_field = Field::new( + "DictionaryEnum", + Dictionary(Box::new(Utf8), Box::new(Int32)), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::Enum(_, _))); + + let arrow_field = Field::new( + "DictionaryString", + Dictionary(Box::new(Int32), Box::new(Boolean)), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert!(matches!(avro_field.data_type().codec(), Codec::String)); + + let field = Field::new("Utf8", Utf8, true); + let arrow_field = Field::new("Array with nullable items", List(Arc::new(field)), true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + if let Codec::Array(avro_data_type) = avro_field.data_type().codec() { + assert!(matches!( + avro_data_type.nullability(), + Some(Nullability::NullFirst) + )); + assert_eq!(avro_data_type.metadata.len(), 0); + assert!(matches!(avro_data_type.codec(), Codec::String)); + } else { + panic!("Expected Codec::Array"); + } + + let field = Field::new("Utf8", Utf8, false); + let arrow_field = Field::new( + "Array with non-nullable items", + List(Arc::new(field)), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + if let Codec::Array(avro_data_type) = avro_field.data_type().codec() { + assert!(matches!(avro_data_type.nullability(), None)); + assert_eq!(avro_data_type.metadata.len(), 0); + assert!(matches!(avro_data_type.codec(), Codec::String)); + } else { + panic!("Expected Codec::Array"); + } + + let field = Field::new( + "Utf8", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + Field::new("value", Utf8, true), + ])), + true, + ); + let arrow_field = Field::new("Map with nullable items", Map(Arc::new(field), true), true); + let avro_field = arrow_field_to_avro_field(&arrow_field); + if let Codec::Map(avro_data_type) = avro_field.data_type().codec() { + assert!(matches!( + avro_data_type.nullability(), + Some(Nullability::NullFirst) + )); + assert_eq!(avro_data_type.metadata.len(), 0); + assert!(matches!(avro_data_type.codec(), Codec::String)); + } else { + panic!("Expected Codec::Map"); + } + + let field = Field::new( + "Utf8", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + Field::new("value", Utf8, false), + ])), + false, ); + let arrow_field = Field::new( + "Map with non-nullable items", + Map(Arc::new(field), false), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + if let Codec::Map(avro_data_type) = avro_field.data_type().codec() { + assert!(matches!(avro_data_type.nullability(), None,)); + assert_eq!(avro_data_type.metadata.len(), 0); + assert!(matches!(avro_data_type.codec(), Codec::String)); + } else { + panic!("Expected Codec::Map"); + } + + let arrow_field = Field::new("FixedSizeBinary", FixedSizeBinary(8), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + assert!(matches!(codec, Codec::Fixed(8))); } #[test] - fn test_local_timestamp_micros() { - let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); + fn test_arrow_field_to_avro_field_meta_namespace() { + let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( + "namespace".to_string(), + "arrow_meta_ns".to_string(), + )])); let avro_field = arrow_field_to_avro_field(&arrow_field); - let codec = avro_field.data_type().codec(); - assert!( - matches!(codec, Codec::TimestampMicros(false)), - "Expected Codec::TimestampMicros(false), got: {:?}", - codec + assert_eq!(avro_field.name(), "test_meta"); + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::String); + assert_eq!(actual_str, expected_str); + let actual_str = format!("{:?}", avro_field.data_type().nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + assert_eq!( + avro_field.data_type().metadata.get("namespace"), + Some(&"arrow_meta_ns".to_string()) ); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index e4c2f372d25e..4154ed75ddc0 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -122,7 +122,7 @@ enum Decoder { /// Avro union that includes `null` Nullable(Nullability, NullBufferBuilder, Box), /// Avro `enum` => Dictionary(int32 -> string) - Enum(Vec, Vec), + Enum(Arc<[String]>, Vec), /// Avro `map` Map( FieldRef, @@ -155,7 +155,7 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + Codec::String => Self::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), @@ -175,7 +175,7 @@ impl Decoder { } Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::List(item) => { + Codec::Array(item) => { let item_decoder = Box::new(Self::try_new(item)?); Self::List( Arc::new(item.field_with_name("item")), @@ -183,7 +183,7 @@ impl Decoder { item_decoder, ) } - Codec::Struct(avro_fields) => { + Codec::Record(avro_fields) => { let mut arrow_fields = Vec::with_capacity(avro_fields.len()); let mut decoders = Vec::with_capacity(avro_fields.len()); for avro_field in avro_fields.iter() { @@ -193,8 +193,8 @@ impl Decoder { } Self::Record(arrow_fields.into(), decoders) } - Codec::Enum(symbols) => { - Self::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::Enum(keys, values) => { + Self::Enum(Arc::clone(keys), Vec::with_capacity(values.len())) } Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( @@ -946,7 +946,7 @@ mod tests { #[test] fn test_enum_decoding() { let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new([]), Arc::new([]))); let mut decoder = Decoder::try_new(&enum_dt).unwrap(); // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 let mut data = Vec::new(); @@ -977,8 +977,8 @@ mod tests { fn test_enum_decoding_with_nulls() { // Union => [Enum(...), null] // "child" => branch_index=0 => [0x00], "null" => 1 => [0x02] - let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); let mut nullable_decoder = Decoder::Nullable( Nullability::NullFirst, @@ -1025,7 +1025,7 @@ mod tests { // ------------------- #[test] fn test_map_decoding_one_entry() { - let value_type = AvroDataType::from_codec(Codec::Utf8); + let value_type = AvroDataType::from_codec(Codec::String); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); // Encode a single map with one entry: {"hello": "world"} @@ -1062,7 +1062,7 @@ mod tests { #[test] fn test_map_decoding_empty() { // block_count=0 => empty map - let value_type = AvroDataType::from_codec(Codec::Utf8); + let value_type = AvroDataType::from_codec(Codec::String); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); // Encode an empty map => block_count=0 => [0x00] @@ -1195,7 +1195,7 @@ mod tests { // // 2. flush => should yield 2-element array => first row has 2 items, second row has 0 items let item_dt = AvroDataType::from_codec(Codec::Int32); - let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); let mut decoder = Decoder::try_new(&list_dt).unwrap(); // Row1 => block_count=2 => item=10 => item=20 => block_count=0 => end // - 2 => zigzag => [0x04] @@ -1236,7 +1236,7 @@ mod tests { // Then read block_size => let's pretend it's 9 bytes, etc. Then the items. // Then a block_count=0 => done let item_dt = AvroDataType::from_codec(Codec::Int32); - let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); let mut decoder = Decoder::try_new(&list_dt).unwrap(); // block_count=-3 => zigzag => (-3 << 1) ^ (-3 >> 63) // => -6 ^ -1 => ... From 5585a87a7860522296018e6359b2e899d9ca8c26 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Fri, 17 Jan 2025 19:22:18 -0600 Subject: [PATCH 21/38] * Organized Types to match Avro spec ordering * Fixed lint issues Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 112 +++++----- arrow-avro/src/reader/record.rs | 348 ++++++++++++++++++-------------- 2 files changed, 256 insertions(+), 204 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index d42840e323fd..06ca4649c54f 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -131,6 +131,8 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { /// #[derive(Debug, Clone)] pub enum Codec { + /// Primitive Types + /// https://avro.apache.org/docs/1.11.1/specification/#primitive-types Null, Boolean, Int32, @@ -139,28 +141,31 @@ pub enum Codec { Float64, Binary, String, + /// Complex Types + /// https://avro.apache.org/docs/1.11.1/specification/#complex-types + Record(Arc<[AvroField]>), + Enum(Arc<[String]>, Arc<[i32]>), + Array(Arc), + Map(Arc), + Fixed(i32), + /// Logical Types + /// https://avro.apache.org/docs/1.11.1/specification/#logical-types Decimal(usize, Option, Option), Uuid, Date32, TimeMillis, TimeMicros, - /// TimestampMillis(is_utc) TimestampMillis(bool), - /// TimestampMicros(is_utc) TimestampMicros(bool), Duration, - Record(Arc<[AvroField]>), - /// In Arrow, use Dictionary(Utf8, Int32) for Enum. - Enum(Arc<[String]>, Arc<[i32]>), - Array(Arc), - Map(Arc), - Fixed(i32), } impl Codec { /// Convert this to an Arrow `DataType` fn data_type(&self) -> DataType { match self { + /// Primitive Types + /// Self::Null => Null, Self::Boolean => Boolean, Self::Int32 => Int32, @@ -169,6 +174,25 @@ impl Codec { Self::Float64 => Float64, Self::Binary => Binary, Self::String => Utf8, + /// Complex Types + /// + Self::Record(f) => Struct(f.iter().map(|x| x.field()).collect()), + Self::Enum(symbols, values) => Dictionary(Box::new(Utf8), Box::new(Int32)), + Self::Array(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), + Self::Map(values) => Map( + Arc::new(Field::new( + "entries", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + values.field_with_name("value"), + ])), + false, + )), + false, + ), + Self::Fixed(size) => FixedSizeBinary(*size), + /// Logical Types + /// Self::Decimal(precision, scale, size) => match size { Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), @@ -179,7 +203,8 @@ impl Codec { } _ => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), }, - // arrow-rs does not support the UUID Canonical Extension Type yet, so this is a temporary workaround. + // arrow-rs does not support the UUID Canonical Extension Type yet, + // so this is a temporary workaround. Self::Uuid => FixedSizeBinary(16), Self::Date32 => Date32, Self::TimeMillis => Time32(TimeUnit::Millisecond), @@ -191,24 +216,6 @@ impl Codec { Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } Self::Duration => Interval(IntervalUnit::MonthDayNano), - Self::Record(f) => Struct(f.iter().map(|x| x.field()).collect()), - Self::Enum(symbols, values) => { - // Produce a Dictionary type with index = Utf8, value = Int32 - Dictionary(Box::new(Utf8), Box::new(Int32)) - } - Self::Array(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), - Self::Map(values) => Map( - Arc::new(Field::new( - "entries", - Struct(Fields::from(vec![ - Field::new("key", Utf8, false), - values.field_with_name("value"), - ])), - false, - )), - false, - ), - Self::Fixed(size) => FixedSizeBinary(*size), } } } @@ -263,12 +270,16 @@ fn make_data_type<'a>( resolver: &mut Resolver<'a>, ) -> Result { match schema { + /// Primitive Types + /// Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { nullability: None, metadata: Default::default(), codec: (*p).into(), }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), + /// Complex Types + /// Schema::Union(f) => { // Special case the common case of nullable primitives or single-type let null = f @@ -382,8 +393,9 @@ fn make_data_type<'a>( Ok(field) } }, + /// Logical Types + /// Schema::Type(t) => { - // Possibly decimal, or other logical types let mut field = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; match (t.attributes.logical_type, &mut field.codec) { @@ -480,6 +492,8 @@ pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { /// Maps an Arrow `DataType` to a `Codec`. fn arrow_type_to_codec(dt: &DataType) -> Codec { match dt { + /// Primitive Types + /// Null => Codec::Null, Boolean => Codec::Boolean, Int8 | Int16 | Int32 => Codec::Int32, @@ -488,23 +502,8 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Float64 => Codec::Float64, Binary | LargeBinary => Codec::Binary, Utf8 => Codec::String, - Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), - Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), - // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. - // It is unsafe to assume all FixedSizeBinary(16) are UUIDs. - // Uuid => Codec::Uuid, - Date32 => Codec::Date32, - Time32(TimeUnit::Millisecond) => Codec::TimeMillis, - Time64(TimeUnit::Microsecond) => Codec::TimeMicros, - Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), - Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), - Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { - Codec::TimestampMillis(true) - } - Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { - Codec::TimestampMicros(true) - } - Interval(IntervalUnit::MonthDayNano) => Codec::Duration, + /// Complex Types + /// Struct(child_fields) => { let avro_fields: Vec = child_fields .iter() @@ -545,6 +544,25 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { } } FixedSizeBinary(n) => Codec::Fixed(*n), + /// Logical Types + /// + Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), + Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), + // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. + // It is unsafe to assume all FixedSizeBinary(16) are UUIDs. + // Uuid => Codec::Uuid, + Date32 => Codec::Date32, + Time32(TimeUnit::Millisecond) => Codec::TimeMillis, + Time64(TimeUnit::Microsecond) => Codec::TimeMicros, + Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), + Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), + Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMillis(true) + } + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMicros(true) + } + Interval(IntervalUnit::MonthDayNano) => Codec::Duration, _ => Codec::String, } } @@ -748,7 +766,7 @@ mod tests { ); let avro_field = arrow_field_to_avro_field(&arrow_field); if let Codec::Array(avro_data_type) = avro_field.data_type().codec() { - assert!(matches!(avro_data_type.nullability(), None)); + assert!(avro_data_type.nullability().is_none()); assert_eq!(avro_data_type.metadata.len(), 0); assert!(matches!(avro_data_type.codec(), Codec::String)); } else { @@ -791,7 +809,7 @@ mod tests { ); let avro_field = arrow_field_to_avro_field(&arrow_field); if let Codec::Map(avro_data_type) = avro_field.data_type().codec() { - assert!(matches!(avro_data_type.nullability(), None,)); + assert!(avro_data_type.nullability().is_none()); assert_eq!(avro_data_type.metadata.len(), 0); assert!(matches!(avro_data_type.codec(), Codec::String)); } else { diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 4154ed75ddc0..85809e0dbc5e 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -85,6 +85,8 @@ impl RecordDecoder { /// Decoder for Avro data of various shapes. #[derive(Debug)] enum Decoder { + /// Primitive Types + /// /// Avro `null` Null(usize), /// Avro `boolean` @@ -97,32 +99,18 @@ enum Decoder { Float32(Vec), /// Avro `double` => f64 Float64(Vec), - /// Avro `date` => Date32 - Date32(Vec), - /// Avro `time-millis` => Time32(Millisecond) - TimeMillis(Vec), - /// Avro `time-micros` => Time64(Microsecond) - TimeMicros(Vec), - /// Avro `timestamp-millis` (bool = UTC?) - TimestampMillis(bool, Vec), - /// Avro `timestamp-micros` (bool = UTC?) - TimestampMicros(bool, Vec), /// Avro `bytes` => Arrow Binary Binary(OffsetBufferBuilder, Vec), /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), - /// Avro `fixed(n)` => Arrow `FixedSizeBinaryArray` - Fixed(i32, Vec), - /// Avro `interval` => Arrow `IntervalMonthDayNanoType` (12 bytes) - Interval(Vec), - /// Avro `array` - List(FieldRef, OffsetBufferBuilder, Box), + /// Complex Types + /// /// Avro `record` Record(Fields, Vec), - /// Avro union that includes `null` - Nullable(Nullability, NullBufferBuilder, Box), /// Avro `enum` => Dictionary(int32 -> string) Enum(Arc<[String]>, Vec), + /// Avro `array` + List(FieldRef, OffsetBufferBuilder, Box), /// Avro `map` Map( FieldRef, @@ -132,8 +120,26 @@ enum Decoder { Box, usize, ), + /// Avro union that includes `null` + Nullable(Nullability, NullBufferBuilder, Box), + /// Avro `fixed(n)` => Arrow `FixedSizeBinaryArray` + Fixed(i32, Vec), + /// Logical Types + /// /// Avro decimal => Arrow decimal Decimal(usize, Option, Option, DecimalBuilder), + /// Avro `date` => Date32 + Date32(Vec), + /// Avro `time-millis` => Time32(Millisecond) + TimeMillis(Vec), + /// Avro `time-micros` => Time64(Microsecond) + TimeMicros(Vec), + /// Avro `timestamp-millis` (bool = UTC?) + TimestampMillis(bool, Vec), + /// Avro `timestamp-micros` (bool = UTC?) + TimestampMicros(bool, Vec), + /// Avro `interval` => Arrow `IntervalMonthDayNanoType` (12 bytes) + Interval(Vec), } impl Decoder { @@ -145,6 +151,8 @@ impl Decoder { /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { let decoder = match data_type.codec() { + /// Primitive Types + /// Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), @@ -159,30 +167,9 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Decimal(precision, scale, size) => { - let builder = DecimalBuilder::new(*precision, *scale, *size)?; - Self::Decimal(*precision, *scale, *size, builder) - } - Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { - Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) - } - Codec::TimestampMicros(is_utc) => { - Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) - } - Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Array(item) => { - let item_decoder = Box::new(Self::try_new(item)?); - Self::List( - Arc::new(item.field_with_name("item")), - OffsetBufferBuilder::new(DEFAULT_CAPACITY), - item_decoder, - ) - } + + /// Complex Types + /// Codec::Record(avro_fields) => { let mut arrow_fields = Vec::with_capacity(avro_fields.len()); let mut decoders = Vec::with_capacity(avro_fields.len()); @@ -196,6 +183,14 @@ impl Decoder { Codec::Enum(keys, values) => { Self::Enum(Arc::clone(keys), Vec::with_capacity(values.len())) } + Codec::Array(item) => { + let item_decoder = Box::new(Self::try_new(item)?); + Self::List( + Arc::new(item.field_with_name("item")), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + item_decoder, + ) + } Codec::Map(value_type) => { let map_field = Arc::new(ArrowField::new( "entries", @@ -214,6 +209,25 @@ impl Decoder { 0, ) } + Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), + + /// Logical Types + /// + Codec::Decimal(precision, scale, size) => { + let builder = DecimalBuilder::new(*precision, *scale, *size)?; + Self::Decimal(*precision, *scale, *size, builder) + } + Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimestampMillis(is_utc) => { + Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::TimestampMicros(is_utc) => { + Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), }; // Wrap in Nullable if needed @@ -230,6 +244,8 @@ impl Decoder { /// Append a null to this decoder. fn append_null(&mut self) { match self { + /// Primitive & Date Logical Types + /// Self::Null(n) => *n += 1, Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), @@ -240,35 +256,39 @@ impl Decoder { Self::Float32(v) => v.push(0.0), Self::Float64(v) => v.push(0.0), Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), - Self::Fixed(fsize, buf) => { - // For a null, push `fsize` zeroed bytes - buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); - } - Self::Interval(intervals) => { - // null => store a 12-byte zero => months=0, days=0, nanos=0 - intervals.push(IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 0, - }); - } - Self::List(_, off, child) => { - off.push_length(0); - child.append_null(); - } + /// Complex Types + /// Self::Record(_, children) => { for c in children.iter_mut() { c.append_null(); } } Self::Enum(_, indices) => indices.push(0), + Self::List(_, off, child) => { + off.push_length(0); + child.append_null(); + } Self::Map(_, key_off, map_off, _, _, entry_count) => { key_off.push_length(0); map_off.push_length(*entry_count); } + Self::Fixed(fsize, buf) => { + // For a null, push `fsize` zeroed bytes + buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); + } + /// Non-Date Logical Types + /// Self::Decimal(_, _, _, builder) => { let _ = builder.append_null(); } + Self::Interval(intervals) => { + // null => store a 12-byte zero => months=0, days=0, nanos=0 + intervals.push(IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 0, + }); + } Self::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } } } @@ -276,15 +296,12 @@ impl Decoder { /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { + /// Primitive Types + /// Self::Null(count) => *count += 1, Self::Boolean(values) => values.append(buf.get_bool()?), Self::Int32(values) => values.push(buf.get_int()?), - Self::Date32(values) => values.push(buf.get_int()?), Self::Int64(values) => values.push(buf.get_long()?), - Self::TimeMillis(values) => values.push(buf.get_int()?), - Self::TimeMicros(values) => values.push(buf.get_long()?), - Self::TimestampMillis(_, values) => values.push(buf.get_long()?), - Self::TimestampMicros(_, values) => values.push(buf.get_long()?), Self::Float32(values) => values.push(buf.get_float()?), Self::Float64(values) => values.push(buf.get_double()?), Self::Binary(off, data) | Self::String(off, data) => { @@ -292,28 +309,27 @@ impl Decoder { off.push_length(bytes.len()); data.extend_from_slice(bytes); } - Self::Fixed(fsize, accum) => accum.extend_from_slice(buf.get_fixed(*fsize as usize)?), - Self::Interval(intervals) => { - let raw = buf.get_fixed(12)?; - let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); - let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); - let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); - let nanos = millis as i64 * 1_000_000; - let val = IntervalMonthDayNano { - months, - days, - nanoseconds: nanos, - }; - intervals.push(val); + /// Complex Types + /// + Self::Record(_, children) => { + for c in children.iter_mut() { + c.decode(buf)?; + } } + Self::Enum(_, indices) => indices.push(buf.get_int()?), Self::List(_, off, child) => { let total_items = read_array_blocks(buf, |b| child.decode(b))?; off.push_length(total_items); } - Self::Record(_, children) => { - for c in children.iter_mut() { - c.decode(buf)?; - } + Self::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { + let newly_added = read_map_blocks(buf, |b| { + let kb = b.get_bytes()?; + key_off.push_length(kb.len()); + key_data.extend_from_slice(kb); + val_decoder.decode(b) + })?; + *entry_count += newly_added; + map_off.push_length(*entry_count); } Self::Nullable(_, nulls, child) => match buf.get_int()? { 0 => { @@ -330,17 +346,9 @@ impl Decoder { ))); } }, - Self::Enum(_, indices) => indices.push(buf.get_int()?), - Self::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { - let newly_added = read_map_blocks(buf, |b| { - let kb = b.get_bytes()?; - key_off.push_length(kb.len()); - key_data.extend_from_slice(kb); - val_decoder.decode(b) - })?; - *entry_count += newly_added; - map_off.push_length(*entry_count); - } + Self::Fixed(fsize, accum) => accum.extend_from_slice(buf.get_fixed(*fsize as usize)?), + /// Logical Types + /// Self::Decimal(_, _, size, builder) => { let bytes = match *size { Some(sz) => buf.get_fixed(sz)?, @@ -348,6 +356,24 @@ impl Decoder { }; builder.append_bytes(bytes)?; } + Self::Date32(values) => values.push(buf.get_int()?), + Self::TimeMillis(values) => values.push(buf.get_int()?), + Self::TimeMicros(values) => values.push(buf.get_long()?), + Self::TimestampMillis(_, values) => values.push(buf.get_long()?), + Self::TimestampMicros(_, values) => values.push(buf.get_long()?), + Self::Interval(intervals) => { + let raw = buf.get_fixed(12)?; + let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); + let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); + let nanos = millis as i64 * 1_000_000; + let val = IntervalMonthDayNano { + months, + days, + nanoseconds: nanos, + }; + intervals.push(val); + } } Ok(()) } @@ -355,11 +381,8 @@ impl Decoder { /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { - // For a nullable wrapper => flush the child with the built null buffer - Self::Nullable(_, nb, child) => { - let mask = nb.finish(); - child.flush(mask) - } + /// Primitive Types + /// // Null => produce NullArray Self::Null(len) => { let count = std::mem::replace(len, 0); @@ -385,28 +408,6 @@ impl Decoder { let arr = flush_primitive::(vals, nulls); Ok(Arc::new(arr)) } - // time-millis => Time32Millisecond - Self::TimeMillis(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // time-micros => Time64Microsecond - Self::TimeMicros(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // timestamp-millis => TimestampMillisecond - Self::TimestampMillis(is_utc, vals) => { - let arr = flush_primitive::(vals, nulls) - .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); - Ok(Arc::new(arr)) - } - // timestamp-micros => TimestampMicrosecond - Self::TimestampMicros(is_utc, vals) => { - let arr = flush_primitive::(vals, nulls) - .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); - Ok(Arc::new(arr)) - } // float32 => flush to Float32Array Self::Float32(vals) => { let arr = flush_primitive::(vals, nulls); @@ -429,44 +430,9 @@ impl Decoder { let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } - // Avro fixed => FixedSizeBinaryArray - Self::Fixed(fsize, raw) => { - let size = *fsize; - let buf: Buffer = flush_values(raw).into(); - let total_len = buf.len() / (size as usize); - let array = FixedSizeBinaryArray::try_new(size, buf, nulls) - .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Ok(Arc::new(array)) - } - // Avro interval => IntervalMonthDayNanoType - Self::Interval(vals) => { - let data_len = vals.len(); - let mut builder = - PrimitiveBuilder::::with_capacity(data_len); - for v in vals.drain(..) { - builder.append_value(v); - } - let arr = builder - .finish() - .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); - if let Some(nb) = nulls { - // "merge" the newly built array with the nulls - let arr_data = arr.into_data().into_builder().nulls(Some(nb)); - let arr_data = unsafe { arr_data.build_unchecked() }; - Ok(Arc::new(PrimitiveArray::::from( - arr_data, - ))) - } else { - Ok(Arc::new(arr)) - } - } - // Avro array => ListArray - Self::List(field, off, item_dec) => { - let child_arr = item_dec.flush(None)?; - let offsets = flush_offsets(off); - let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); - Ok(Arc::new(arr)) - } + + /// Complex Types + /// // Avro record => StructArray Self::Record(fields, children) => { let mut arrays = Vec::with_capacity(children.len()); @@ -493,6 +459,13 @@ impl Decoder { indices.clear(); // reset Ok(Arc::new(dict)) } + // Avro array => ListArray + Self::List(field, off, item_dec) => { + let child_arr = item_dec.flush(None)?; + let offsets = flush_offsets(off); + let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); + Ok(Arc::new(arr)) + } // Avro map => MapArray Self::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { let moff = flush_offsets(map_off); @@ -518,6 +491,18 @@ impl Decoder { *entry_count = 0; Ok(Arc::new(map_arr)) } + + // Avro fixed => FixedSizeBinaryArray + Self::Fixed(fsize, raw) => { + let size = *fsize; + let buf: Buffer = flush_values(raw).into(); + let total_len = buf.len() / (size as usize); + let array = FixedSizeBinaryArray::try_new(size, buf, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(array)) + } + /// Logical Types + /// // Avro decimal => Arrow decimal Self::Decimal(prec, sc, sz, builder) => { let precision = *prec; @@ -527,6 +512,55 @@ impl Decoder { let arr = old_builder.finish(nulls, precision, scale)?; Ok(arr) } + // time-millis => Time32Millisecond + Self::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-micros => Time64Microsecond + Self::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // timestamp-millis => TimestampMillisecond + Self::TimestampMillis(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) + } + // timestamp-micros => TimestampMicrosecond + Self::TimestampMicros(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) + } + // Avro interval => IntervalMonthDayNanoType + Self::Interval(vals) => { + let data_len = vals.len(); + let mut builder = + PrimitiveBuilder::::with_capacity(data_len); + for v in vals.drain(..) { + builder.append_value(v); + } + let arr = builder + .finish() + .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + if let Some(nb) = nulls { + // "merge" the newly built array with the nulls + let arr_data = arr.into_data().into_builder().nulls(Some(nb)); + let arr_data = unsafe { arr_data.build_unchecked() }; + Ok(Arc::new(PrimitiveArray::::from( + arr_data, + ))) + } else { + Ok(Arc::new(arr)) + } + } + // For a nullable wrapper => flush the child with the built null buffer + Self::Nullable(_, nb, child) => { + let mask = nb.finish(); + child.flush(mask) + } } } } @@ -945,7 +979,7 @@ mod tests { // ------------------- #[test] fn test_enum_decoding() { - let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new([]), Arc::new([]))); let mut decoder = Decoder::try_new(&enum_dt).unwrap(); // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 From 26faf5b11bcc41b57c98dcee96e518ca2a6e99b6 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Fri, 17 Jan 2025 19:30:14 -0600 Subject: [PATCH 22/38] * Fixed failing `test_enum_decoding` test due to constructing Codec::Enum with an empty list of symbols. Signed-off-by: Connor Sanders --- arrow-avro/src/reader/record.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 85809e0dbc5e..b6a60f01a6c7 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -979,8 +979,8 @@ mod tests { // ------------------- #[test] fn test_enum_decoding() { - let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new([]), Arc::new([]))); + let symbols = Arc::new(["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]); + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols, Arc::new([]))); let mut decoder = Decoder::try_new(&enum_dt).unwrap(); // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 let mut data = Vec::new(); From cf45fcf7e0a90011bb725659b1e867360f5979b3 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Fri, 17 Jan 2025 19:52:12 -0600 Subject: [PATCH 23/38] * Optimized avro codec.rs data_type Self::Decimal case: - Reduced repeated casting to improve efficiency. - Reduced cyclomatic complexity Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 06ca4649c54f..83121ad3a3bb 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -193,16 +193,22 @@ impl Codec { Self::Fixed(size) => FixedSizeBinary(*size), /// Logical Types /// - Self::Decimal(precision, scale, size) => match size { - Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), - Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), - None if *precision <= DECIMAL128_MAX_PRECISION as usize - && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize => - { - Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + Self::Decimal(precision, scale, size) => { + let scale = scale.unwrap_or(0) as i8; + let precision = *precision as u8; + let is_256 = match *size { + Some(s) => s > 16, + None => { + (precision as usize) > DECIMAL128_MAX_PRECISION as usize + || (scale as usize) > DECIMAL128_MAX_SCALE as usize + } + }; + if is_256 { + Decimal256(precision, scale) + } else { + Decimal128(precision, scale) } - _ => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), - }, + } // arrow-rs does not support the UUID Canonical Extension Type yet, // so this is a temporary workaround. Self::Uuid => FixedSizeBinary(16), @@ -518,26 +524,21 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Arc::from(Vec::::new()), ) } else { - // Fallback to Utf8 Codec::String } } - List(field) => { - let sub_codec = arrow_type_to_codec(field.data_type()); - Codec::Array(Arc::new(AvroDataType { - nullability: field.is_nullable().then_some(Nullability::NullFirst), - metadata: field.metadata().clone(), - codec: sub_codec, - })) - } + List(field) => Codec::Array(Arc::new(AvroDataType { + nullability: field.is_nullable().then_some(Nullability::NullFirst), + metadata: field.metadata().clone(), + codec: arrow_type_to_codec(field.data_type()), + })), Map(field, _keys_sorted) => { if let Struct(child_fields) = field.data_type() { let value_field = &child_fields[1]; - let sub_codec = arrow_type_to_codec(value_field.data_type()); Codec::Map(Arc::new(AvroDataType { nullability: value_field.is_nullable().then_some(Nullability::NullFirst), metadata: value_field.metadata().clone(), - codec: sub_codec, + codec: arrow_type_to_codec(value_field.data_type()), })) } else { Codec::Map(Arc::new(AvroDataType::from_codec(Codec::String))) From 75e0ad1bc0f09fc328b10cb63e591ecdafd5b447 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Wed, 29 Jan 2025 17:21:03 -0600 Subject: [PATCH 24/38] Fixed hyperlink linter errors Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 83121ad3a3bb..62529a04d3ce 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -132,7 +132,8 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { #[derive(Debug, Clone)] pub enum Codec { /// Primitive Types - /// https://avro.apache.org/docs/1.11.1/specification/#primitive-types + /// + /// Null, Boolean, Int32, @@ -142,14 +143,16 @@ pub enum Codec { Binary, String, /// Complex Types - /// https://avro.apache.org/docs/1.11.1/specification/#complex-types + /// + /// Record(Arc<[AvroField]>), Enum(Arc<[String]>, Arc<[i32]>), Array(Arc), Map(Arc), Fixed(i32), /// Logical Types - /// https://avro.apache.org/docs/1.11.1/specification/#logical-types + /// + /// Decimal(usize, Option, Option), Uuid, Date32, From f93331dcb26f8288faa8ef7e08bcab4d3d042677 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 2 Feb 2025 10:03:53 -0600 Subject: [PATCH 25/38] Implemented improved defaults Added bigzip2 and xz compression Improved nullability handling Added additional avro file tests to reader/mod.rs. Signed-off-by: Connor Sanders --- arrow-avro/Cargo.toml | 5 +- arrow-avro/src/codec.rs | 633 ++++++++++++++++---------------- arrow-avro/src/compression.rs | 26 +- arrow-avro/src/reader/header.rs | 46 ++- arrow-avro/src/reader/mod.rs | 586 +++++++++++++++++++++++++++-- arrow-avro/src/reader/record.rs | 405 +++++++------------- arrow-avro/src/schema.rs | 62 +++- 7 files changed, 1126 insertions(+), 637 deletions(-) diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index c103c2ecc0f3..433b16c3aa89 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -34,7 +34,7 @@ path = "src/lib.rs" bench = false [features] -default = ["deflate", "snappy", "zstd"] +default = ["deflate", "snappy", "zstd", "bzip2", "xz"] deflate = ["flate2"] snappy = ["snap", "crc"] @@ -42,11 +42,14 @@ snappy = ["snap", "crc"] arrow-schema = { workspace = true } arrow-buffer = { workspace = true } arrow-array = { workspace = true } +arrow-data = { workspace = true } serde_json = { version = "1.0", default-features = false, features = ["std"] } serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } snap = { version = "1.0", default-features = false, optional = true } zstd = { version = "0.13", default-features = false, optional = true } +bzip2 = { version = "0.4.4", default-features = false, optional = true } +xz = { version = "0.1.0", default-features = false, optional = true } crc = { version = "3.0", optional = true } diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 62529a04d3ce..5ed4d58dd09c 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -27,9 +27,6 @@ use std::sync::Arc; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. -/// -/// To accommodate this we special case two-variant unions where one of the -/// variants is the null type, and use this to derive arrow's notion of nullability #[derive(Debug, Copy, Clone)] pub enum Nullability { /// The nulls are encoded as the first union variant @@ -41,9 +38,9 @@ pub enum Nullability { /// An Avro datatype mapped to the arrow data model #[derive(Debug, Clone)] pub struct AvroDataType { - nullability: Option, - metadata: HashMap, - codec: Codec, + pub nullability: Option, + pub metadata: HashMap, + pub codec: Codec, } impl AvroDataType { @@ -65,20 +62,10 @@ impl AvroDataType { Self::new(codec, None, Default::default()) } - /// Returns an arrow [`Field`] with the given name + /// Returns an arrow [`Field`] with the given name, applying `nullability` if present. pub fn field_with_name(&self, name: &str) -> Field { - let d = self.codec.data_type(); - Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) - } - - /// Return a reference to the inner `Codec`. - pub fn codec(&self) -> &Codec { - &self.codec - } - - /// Return the nullability for this Avro type, if any. - pub fn nullability(&self) -> Option { - self.nullability + let is_nullable = self.nullability.is_some(); + Field::new(name, self.codec.data_type(), is_nullable).with_metadata(self.metadata.clone()) } } @@ -87,12 +74,19 @@ impl AvroDataType { pub struct AvroField { name: String, data_type: AvroDataType, + default: Option, } impl AvroField { /// Returns the arrow [`Field`] pub fn field(&self) -> Field { - self.data_type.field_with_name(&self.name) + let mut fld = self.data_type.field_with_name(&self.name); + if let Some(def_val) = &self.default { + let mut md = fld.metadata().clone(); + md.insert("avro.default".to_string(), def_val.to_string()); + fld = fld.with_metadata(md); + } + fld } /// Returns the [`AvroDataType`] @@ -114,9 +108,10 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { Schema::Complex(ComplexType::Record(r)) => { let mut resolver = Resolver::default(); let data_type = make_data_type(schema, None, &mut resolver)?; - Ok(AvroField { + Ok(Self { data_type, name: r.name.to_string(), + default: None, }) } _ => Err(ArrowError::ParseError(format!( @@ -127,13 +122,9 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { } /// An Avro encoding -/// -/// #[derive(Debug, Clone)] pub enum Codec { - /// Primitive Types - /// - /// + /// Primitive Null, Boolean, Int32, @@ -142,17 +133,13 @@ pub enum Codec { Float64, Binary, String, - /// Complex Types - /// - /// + /// Complex Record(Arc<[AvroField]>), Enum(Arc<[String]>, Arc<[i32]>), Array(Arc), Map(Arc), Fixed(i32), - /// Logical Types - /// - /// + /// Logical Decimal(usize, Option, Option), Uuid, Date32, @@ -165,10 +152,9 @@ pub enum Codec { impl Codec { /// Convert this to an Arrow `DataType` - fn data_type(&self) -> DataType { + pub(crate) fn data_type(&self) -> DataType { match self { - /// Primitive Types - /// + // Primitives Self::Null => Null, Self::Boolean => Boolean, Self::Int32 => Int32, @@ -177,43 +163,51 @@ impl Codec { Self::Float64 => Float64, Self::Binary => Binary, Self::String => Utf8, - /// Complex Types - /// - Self::Record(f) => Struct(f.iter().map(|x| x.field()).collect()), - Self::Enum(symbols, values) => Dictionary(Box::new(Utf8), Box::new(Int32)), - Self::Array(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), - Self::Map(values) => Map( - Arc::new(Field::new( - "entries", - Struct(Fields::from(vec![ - Field::new("key", Utf8, false), - values.field_with_name("value"), - ])), + Self::Record(fields) => { + let arrow_fields: Vec = fields.iter().map(|f| f.field()).collect(); + Struct(arrow_fields.into()) + } + Self::Enum(_, _) => Dictionary(Box::new(Utf8), Box::new(Int32)), + Self::Array(child_type) => { + let child_dt = child_type.codec.data_type(); + let child_md = child_type.metadata.clone(); + let child_field = Field::new(Field::LIST_FIELD_DEFAULT_NAME, child_dt, true) + .with_metadata(child_md); + List(Arc::new(child_field)) + } + Self::Map(value_type) => { + let val_dt = value_type.codec.data_type(); + let val_md = value_type.metadata.clone(); + let val_field = Field::new("value", val_dt, true).with_metadata(val_md); + Map( + Arc::new(Field::new( + "entries", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + val_field, + ])), + false, + )), false, - )), - false, - ), - Self::Fixed(size) => FixedSizeBinary(*size), - /// Logical Types - /// + ) + } + Self::Fixed(sz) => FixedSizeBinary(*sz), Self::Decimal(precision, scale, size) => { - let scale = scale.unwrap_or(0) as i8; - let precision = *precision as u8; - let is_256 = match *size { - Some(s) => s > 16, + let p = *precision as u8; + let s = scale.unwrap_or(0) as i8; + let too_large_for_128 = match *size { + Some(sz) => sz > 16, None => { - (precision as usize) > DECIMAL128_MAX_PRECISION as usize - || (scale as usize) > DECIMAL128_MAX_SCALE as usize + (p as usize) > DECIMAL128_MAX_PRECISION as usize + || (s as usize) > DECIMAL128_MAX_SCALE as usize } }; - if is_256 { - Decimal256(precision, scale) + if too_large_for_128 { + Decimal256(p, s) } else { - Decimal128(precision, scale) + Decimal128(p, s) } } - // arrow-rs does not support the UUID Canonical Extension Type yet, - // so this is a temporary workaround. Self::Uuid => FixedSizeBinary(16), Self::Date32 => Date32, Self::TimeMillis => Time32(TimeUnit::Millisecond), @@ -245,107 +239,122 @@ impl From for Codec { } /// Resolves Avro type names to [`AvroDataType`] -/// -/// See -#[derive(Debug, Default)] +#[derive(Default, Debug)] struct Resolver<'a> { map: HashMap<(&'a str, &'a str), AvroDataType>, } impl<'a> Resolver<'a> { - fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { - self.map.insert((name, namespace.unwrap_or("")), schema); + fn register(&mut self, name: &'a str, namespace: Option<&'a str>, dt: AvroDataType) { + let ns = namespace.unwrap_or(""); + self.map.insert((name, ns), dt); } - fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { - let (namespace, name) = name - .rsplit_once('.') - .unwrap_or_else(|| (namespace.unwrap_or(""), name)); - + fn resolve( + &self, + full_name: &str, + namespace: Option<&'a str>, + ) -> Result { + let (ns, nm) = match full_name.rsplit_once('.') { + Some((a, b)) => (a, b), + None => (namespace.unwrap_or(""), full_name), + }; self.map - .get(&(namespace, name)) - .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) + .get(&(nm, ns)) .cloned() + .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {ns}.{nm}"))) } } -/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` -/// -/// `name`: is name used to refer to `schema` in its parent -/// `namespace`: an optional qualifier used as part of a type hierarchy +/// Parses a [`AvroDataType`] from the provided [`Schema`], plus optional `namespace`. fn make_data_type<'a>( schema: &Schema<'a>, namespace: Option<&'a str>, resolver: &mut Resolver<'a>, ) -> Result { match schema { - /// Primitive Types - /// Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { nullability: None, metadata: Default::default(), codec: (*p).into(), }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), - /// Complex Types - /// - Schema::Union(f) => { - // Special case the common case of nullable primitives or single-type - let null = f + Schema::Union(u) => { + let null_idx = u .iter() .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (f.len() == 2, null) { + match (u.len() == 2, null_idx) { (true, Some(0)) => { - let mut field = make_data_type(&f[1], namespace, resolver)?; - field.nullability = Some(Nullability::NullFirst); - Ok(field) + let mut dt = make_data_type(&u[1], namespace, resolver)?; + dt.nullability = Some(Nullability::NullFirst); + Ok(dt) } (true, Some(1)) => { - let mut field = make_data_type(&f[0], namespace, resolver)?; - field.nullability = Some(Nullability::NullSecond); - Ok(field) + let mut dt = make_data_type(&u[0], namespace, resolver)?; + dt.nullability = Some(Nullability::NullSecond); + Ok(dt) } _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" + "Union of {u:?} not currently supported" ))), } } + // complex Schema::Complex(c) => match c { ComplexType::Record(r) => { - let namespace = r.namespace.or(namespace); + let ns = r.namespace.or(namespace); let fields = r .fields .iter() - .map(|field| { - Ok(AvroField { - name: field.name.to_string(), - data_type: make_data_type(&field.r#type, namespace, resolver)?, + .map(|f| { + let data_type = make_data_type(&f.r#type, ns, resolver)?; + Ok::(AvroField { + name: f.name.to_string(), + data_type, + default: f.default.clone(), }) }) - .collect::>()?; - let field = AvroDataType { + .collect::, ArrowError>>()?; + let rec = AvroDataType { nullability: None, - codec: Codec::Record(fields), metadata: r.attributes.field_metadata(), + codec: Codec::Record(Arc::from(fields)), }; - resolver.register(r.name, namespace, field.clone()); - Ok(field) + resolver.register(r.name, ns, rec.clone()); + Ok(rec) + } + ComplexType::Enum(e) => { + let en = AvroDataType { + nullability: None, + metadata: e.attributes.field_metadata(), + codec: Codec::Enum( + Arc::from(e.symbols.iter().map(|s| s.to_string()).collect::>()), + Arc::from(vec![]), + ), + }; + resolver.register(e.name, namespace, en.clone()); + Ok(en) } ComplexType::Array(a) => { - let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; + let child = make_data_type(&a.items, namespace, resolver)?; Ok(AvroDataType { nullability: None, metadata: a.attributes.field_metadata(), - codec: Codec::Array(Arc::new(field)), + codec: Codec::Array(Arc::new(child)), + }) + } + ComplexType::Map(m) => { + let val = make_data_type(&m.values, namespace, resolver)?; + Ok(AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(val)), }) } - ComplexType::Fixed(f) => { - // Possibly decimal with logicalType=decimal - let size = f.size.try_into().map_err(|e| { - ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) - })?; - if let Some("decimal") = f.attributes.logical_type { - let precision = f + ComplexType::Fixed(fx) => { + let size = fx.size as i32; + if let Some("decimal") = fx.attributes.logical_type { + let precision = fx .attributes .additional .get("precision") @@ -353,156 +362,130 @@ fn make_data_type<'a>( .ok_or_else(|| { ArrowError::ParseError("Decimal requires precision".to_string()) })?; - let scale = f + let scale = fx .attributes .additional .get("scale") .and_then(|v| v.as_u64()) - .or(Some(0)); - let field = AvroDataType { + .unwrap_or(0); + let dec = AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), + metadata: fx.attributes.field_metadata(), codec: Codec::Decimal( precision as usize, - Some(scale.unwrap_or(0) as usize), + Some(scale as usize), Some(size as usize), ), }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) + resolver.register(fx.name, namespace, dec.clone()); + Ok(dec) } else { - let field = AvroDataType { + let fixed_dt = AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), + metadata: fx.attributes.field_metadata(), codec: Codec::Fixed(size), }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) + resolver.register(fx.name, namespace, fixed_dt.clone()); + Ok(fixed_dt) } } - ComplexType::Enum(e) => { - let field = AvroDataType { - nullability: None, - metadata: e.attributes.field_metadata(), - codec: Codec::Enum( - Arc::from(e.symbols.iter().map(|s| s.to_string()).collect::>()), - Arc::from(vec![]), - ), - }; - resolver.register(e.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Map(m) => { - let values_data_type = make_data_type(m.values.as_ref(), namespace, resolver)?; - let field = AvroDataType { - nullability: None, - metadata: m.attributes.field_metadata(), - codec: Codec::Map(Arc::new(values_data_type)), - }; - Ok(field) - } }, - /// Logical Types - /// Schema::Type(t) => { - let mut field = - make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - match (t.attributes.logical_type, &mut field.codec) { - (Some("decimal"), c @ Codec::Fixed(_)) => { - *c = Codec::Decimal( - t.attributes - .additional - .get("precision") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize, - Some( - t.attributes - .additional - .get("scale") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize, - ), - Some( - t.attributes - .additional - .get("size") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize, - ), - ); + let mut dt = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; + match (t.attributes.logical_type, &mut dt.codec) { + (Some("decimal"), Codec::Fixed(sz)) => { + let prec = t + .attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize; + let sc = t + .attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + *sz = t + .attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .unwrap_or(*sz as u64) as i32; + dt.codec = Codec::Decimal(prec, Some(sc), Some(*sz as usize)); } - (Some("decimal"), c @ Codec::Binary) => { - *c = Codec::Decimal( - t.attributes - .additional - .get("precision") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize, - Some( - t.attributes - .additional - .get("scale") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize, - ), - None, - ); + (Some("decimal"), Codec::Binary) => { + let prec = t + .attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize; + let sc = t + .attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + dt.codec = Codec::Decimal(prec, Some(sc), None); } - (Some("uuid"), c @ Codec::String) => *c = Codec::Uuid, - (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, - (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, - (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, - (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), - (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) + (Some("uuid"), Codec::String) => { + dt.codec = Codec::Uuid; } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) + (Some("date"), Codec::Int32) => { + dt.codec = Codec::Date32; } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, - (Some(logical), _) => { - // Insert unrecognized logical type into metadata - field.metadata.insert("logicalType".into(), logical.into()); + (Some("time-millis"), Codec::Int32) => { + dt.codec = Codec::TimeMillis; + } + (Some("time-micros"), Codec::Int64) => { + dt.codec = Codec::TimeMicros; + } + (Some("timestamp-millis"), Codec::Int64) => { + dt.codec = Codec::TimestampMillis(true); + } + (Some("timestamp-micros"), Codec::Int64) => { + dt.codec = Codec::TimestampMicros(true); + } + (Some("local-timestamp-millis"), Codec::Int64) => { + dt.codec = Codec::TimestampMillis(false); + } + (Some("local-timestamp-micros"), Codec::Int64) => { + dt.codec = Codec::TimestampMicros(false); + } + (Some("duration"), Codec::Fixed(12)) => { + dt.codec = Codec::Duration; + } + (Some(other), _) => { + dt.metadata.insert("logicalType".into(), other.into()); } (None, _) => {} } - - if !t.attributes.additional.is_empty() { - for (k, v) in &t.attributes.additional { - field.metadata.insert(k.to_string(), v.to_string()); - } + for (k, v) in &t.attributes.additional { + dt.metadata.insert(k.to_string(), v.to_string()); } - Ok(field) + Ok(dt) } } } -/// Convert an Arrow `Field` into an `AvroField`. -pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { - let codec = arrow_type_to_codec(arrow_field.data_type()); - let nullability = if arrow_field.is_nullable() { - Some(Nullability::NullFirst) - } else { - None - }; - let mut metadata = arrow_field.metadata().clone(); - let avro_data_type = AvroDataType { - nullability, - metadata, +pub fn arrow_field_to_avro_field(field: &Field) -> AvroField { + let codec = arrow_type_to_codec(field.data_type()); + let top_null = field.is_nullable().then_some(Nullability::NullFirst); + let data_type = AvroDataType { + nullability: top_null, + metadata: field.metadata().clone(), codec, }; AvroField { - name: arrow_field.name().clone(), - data_type: avro_data_type, + name: field.name().to_string(), + data_type, + default: None, } } -/// Maps an Arrow `DataType` to a `Codec`. fn arrow_type_to_codec(dt: &DataType) -> Codec { match dt { - /// Primitive Types - /// Null => Codec::Null, Boolean => Codec::Boolean, Int8 | Int16 | Int32 => Codec::Int32, @@ -511,63 +494,64 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { Float64 => Codec::Float64, Binary | LargeBinary => Codec::Binary, Utf8 => Codec::String, - /// Complex Types - /// - Struct(child_fields) => { - let avro_fields: Vec = child_fields + Struct(fields) => { + let avro_fields: Vec = fields .iter() - .map(|f_ref| arrow_field_to_avro_field(f_ref.as_ref())) + .map(|fref| arrow_field_to_avro_field(fref.as_ref())) .collect(); Codec::Record(Arc::from(avro_fields)) } - Dictionary(symbol_type, value_type) => { - if let Utf8 = **symbol_type { - Codec::Enum( - Arc::from(Vec::::new()), - Arc::from(Vec::::new()), - ) + Dictionary(dict_ty, _val_ty) => { + if let Utf8 = &**dict_ty { + Codec::Enum(Arc::from(Vec::new()), Arc::from(Vec::new())) } else { Codec::String } } - List(field) => Codec::Array(Arc::new(AvroDataType { - nullability: field.is_nullable().then_some(Nullability::NullFirst), - metadata: field.metadata().clone(), - codec: arrow_type_to_codec(field.data_type()), - })), - Map(field, _keys_sorted) => { - if let Struct(child_fields) = field.data_type() { - let value_field = &child_fields[1]; - Codec::Map(Arc::new(AvroDataType { - nullability: value_field.is_nullable().then_some(Nullability::NullFirst), - metadata: value_field.metadata().clone(), - codec: arrow_type_to_codec(value_field.data_type()), - })) + List(item_field) => { + let item_codec = arrow_type_to_codec(item_field.data_type()); + let child_nullability = item_field.is_nullable().then_some(Nullability::NullFirst); + let child_dt = AvroDataType { + codec: item_codec, + nullability: child_nullability, + metadata: item_field.metadata().clone(), + }; + Codec::Array(Arc::new(child_dt)) + } + Map(entries_field, _keys_sorted) => { + if let Struct(struct_fields) = entries_field.data_type() { + let val_field = &struct_fields[1]; + let val_codec = arrow_type_to_codec(val_field.data_type()); + let val_nullability = val_field.is_nullable().then_some(Nullability::NullFirst); + let val_dt = AvroDataType { + codec: val_codec, + nullability: val_nullability, + metadata: val_field.metadata().clone(), + }; + Codec::Map(Arc::new(val_dt)) } else { Codec::Map(Arc::new(AvroDataType::from_codec(Codec::String))) } } FixedSizeBinary(n) => Codec::Fixed(*n), - /// Logical Types - /// - Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), - Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), - // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. - // It is unsafe to assume all FixedSizeBinary(16) are UUIDs. - // Uuid => Codec::Uuid, + Decimal128(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(16)), + Decimal256(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(32)), Date32 => Codec::Date32, Time32(TimeUnit::Millisecond) => Codec::TimeMillis, Time64(TimeUnit::Microsecond) => Codec::TimeMicros, - Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), - Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { Codec::TimestampMillis(true) } + Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { Codec::TimestampMicros(true) } + Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), Interval(IntervalUnit::MonthDayNano) => Codec::Duration, - _ => Codec::String, + other => { + let _ = other; + Codec::String + } } } @@ -583,9 +567,10 @@ mod tests { let avro_field = AvroField { name: "long_col".to_string(), data_type: field_codec.clone(), + default: None, }; assert_eq!(avro_field.name(), "long_col"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); + let actual_str = format!("{:?}", avro_field.data_type().codec); let expected_str = format!("{:?}", &Codec::Int64); assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); let arrow_field = avro_field.field(); @@ -594,6 +579,23 @@ mod tests { assert!(!arrow_field.is_nullable()); } + #[test] + fn test_avro_field_with_default() { + let field_codec = AvroDataType::from_codec(Codec::Int32); + let default_value = serde_json::json!(123); + let avro_field = AvroField { + name: "int_col".to_string(), + data_type: field_codec.clone(), + default: Some(default_value.clone()), + }; + let arrow_field = avro_field.field(); + let metadata = arrow_field.metadata(); + assert_eq!( + metadata.get("avro.default").unwrap(), + &default_value.to_string() + ); + } + #[test] fn test_codec_fixedsizebinary() { let codec = Codec::Fixed(12); @@ -608,69 +610,61 @@ mod tests { fn test_arrow_field_to_avro_field() { let arrow_field = Field::new("Null", Null, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Null)); + assert!(matches!(avro_field.data_type().codec, Codec::Null)); let arrow_field = Field::new("Boolean", Boolean, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Boolean)); + assert!(matches!(avro_field.data_type().codec, Codec::Boolean)); let arrow_field = Field::new("Int32", Int32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Int32)); + assert!(matches!(avro_field.data_type().codec, Codec::Int32)); let arrow_field = Field::new("Int64", Int64, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Int64)); + assert!(matches!(avro_field.data_type().codec, Codec::Int64)); let arrow_field = Field::new("Float32", Float32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Float32)); + assert!(matches!(avro_field.data_type().codec, Codec::Float32)); let arrow_field = Field::new("Float64", Float64, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Float64)); + assert!(matches!(avro_field.data_type().codec, Codec::Float64)); let arrow_field = Field::new("Binary", Binary, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Binary)); + assert!(matches!(avro_field.data_type().codec, Codec::Binary)); let arrow_field = Field::new("Utf8", Utf8, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::String)); + assert!(matches!(avro_field.data_type().codec, Codec::String)); let arrow_field = Field::new("Decimal128", Decimal128(1, 2), true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( - avro_field.data_type().codec(), + avro_field.data_type().codec, Codec::Decimal(1, Some(2), Some(16)) )); let arrow_field = Field::new("Decimal256", Decimal256(1, 2), true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( - avro_field.data_type().codec(), + avro_field.data_type().codec, Codec::Decimal(1, Some(2), Some(32)) )); - // arrow-rs does not support the UUID Canonical Extension Type yet, so this mapping is not possible. - // let arrow_field = Field::new("Uuid", FixedSizeBinary(16), true); - // let avro_field = arrow_field_to_avro_field(&arrow_field); - // let codec = avro_field.data_type().codec(); - // assert!( - // matches!(codec, Codec::Uuid), - // ); - let arrow_field = Field::new("Date32", Date32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Date32)); + assert!(matches!(avro_field.data_type().codec, Codec::Date32)); let arrow_field = Field::new("Time32", Time32(TimeUnit::Millisecond), false); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::TimeMillis)); + assert!(matches!(avro_field.data_type().codec, Codec::TimeMillis)); let arrow_field = Field::new("Time32", Time64(TimeUnit::Microsecond), false); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::TimeMicros)); + assert!(matches!(avro_field.data_type().codec, Codec::TimeMicros)); let arrow_field = Field::new( "utc_ts_ms", @@ -679,7 +673,7 @@ mod tests { ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( - avro_field.data_type().codec(), + avro_field.data_type().codec, Codec::TimestampMillis(true) )); @@ -690,27 +684,27 @@ mod tests { ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( - avro_field.data_type().codec(), + avro_field.data_type().codec, Codec::TimestampMicros(true) )); let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( - avro_field.data_type().codec(), + avro_field.data_type().codec, Codec::TimestampMillis(false) )); let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( - avro_field.data_type().codec(), + avro_field.data_type().codec, Codec::TimestampMicros(false) )); let arrow_field = Field::new("Interval", Interval(IntervalUnit::MonthDayNano), false); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Duration)); + assert!(matches!(avro_field.data_type().codec, Codec::Duration)); let arrow_field = Field::new( "Struct", @@ -721,13 +715,13 @@ mod tests { false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - match avro_field.data_type().codec() { + match &avro_field.data_type().codec { Codec::Record(fields) => { assert_eq!(fields.len(), 2); assert_eq!(fields[0].name(), "a"); - assert!(matches!(fields[0].data_type().codec(), Codec::Boolean)); + assert!(matches!(fields[0].data_type().codec, Codec::Boolean)); assert_eq!(fields[1].name(), "b"); - assert!(matches!(fields[1].data_type().codec(), Codec::Float64)); + assert!(matches!(fields[1].data_type().codec, Codec::Float64)); } _ => panic!("Expected Record data type"), } @@ -738,7 +732,7 @@ mod tests { false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::Enum(_, _))); + assert!(matches!(avro_field.data_type().codec, Codec::Enum(_, _))); let arrow_field = Field::new( "DictionaryString", @@ -746,18 +740,18 @@ mod tests { false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec(), Codec::String)); + assert!(matches!(avro_field.data_type().codec, Codec::String)); let field = Field::new("Utf8", Utf8, true); let arrow_field = Field::new("Array with nullable items", List(Arc::new(field)), true); let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Array(avro_data_type) = avro_field.data_type().codec() { + if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { assert!(matches!( - avro_data_type.nullability(), + avro_data_type.nullability, Some(Nullability::NullFirst) )); assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec(), Codec::String)); + assert!(matches!(avro_data_type.codec, Codec::String)); } else { panic!("Expected Codec::Array"); } @@ -769,36 +763,43 @@ mod tests { false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Array(avro_data_type) = avro_field.data_type().codec() { - assert!(avro_data_type.nullability().is_none()); + if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { + assert!(avro_data_type.nullability.is_none()); assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec(), Codec::String)); + assert!(matches!(avro_data_type.codec, Codec::String)); } else { panic!("Expected Codec::Array"); } - let field = Field::new( - "Utf8", - Struct(Fields::from(vec![ - Field::new("key", Utf8, false), - Field::new("value", Utf8, true), - ])), + let entries_field = Field::new( + "entries", + Struct( + vec![ + Field::new("key", Utf8, false), + Field::new("value", Utf8, true), + ] + .into(), + ), + false, + ); + let arrow_field = Field::new( + "Map with nullable items", + Map(Arc::new(entries_field), true), true, ); - let arrow_field = Field::new("Map with nullable items", Map(Arc::new(field), true), true); let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Map(avro_data_type) = avro_field.data_type().codec() { + if let Codec::Map(avro_data_type) = &avro_field.data_type().codec { assert!(matches!( - avro_data_type.nullability(), + avro_data_type.nullability, Some(Nullability::NullFirst) )); assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec(), Codec::String)); + assert!(matches!(avro_data_type.codec, Codec::String)); } else { panic!("Expected Codec::Map"); } - let field = Field::new( + let arrow_field = Field::new( "Utf8", Struct(Fields::from(vec![ Field::new("key", Utf8, false), @@ -808,21 +809,21 @@ mod tests { ); let arrow_field = Field::new( "Map with non-nullable items", - Map(Arc::new(field), false), + Map(Arc::new(arrow_field), false), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Map(avro_data_type) = avro_field.data_type().codec() { - assert!(avro_data_type.nullability().is_none()); + if let Codec::Map(avro_data_type) = &avro_field.data_type().codec { + assert!(avro_data_type.nullability.is_none()); assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec(), Codec::String)); + assert!(matches!(avro_data_type.codec, Codec::String)); } else { panic!("Expected Codec::Map"); } let arrow_field = Field::new("FixedSizeBinary", FixedSizeBinary(8), false); let avro_field = arrow_field_to_avro_field(&arrow_field); - let codec = avro_field.data_type().codec(); + let codec = &avro_field.data_type().codec; assert!(matches!(codec, Codec::Fixed(8))); } @@ -834,10 +835,10 @@ mod tests { )])); let avro_field = arrow_field_to_avro_field(&arrow_field); assert_eq!(avro_field.name(), "test_meta"); - let actual_str = format!("{:?}", avro_field.data_type().codec()); + let actual_str = format!("{:?}", avro_field.data_type().codec); let expected_str = format!("{:?}", &Codec::String); assert_eq!(actual_str, expected_str); - let actual_str = format!("{:?}", avro_field.data_type().nullability()); + let actual_str = format!("{:?}", avro_field.data_type().nullability); let expected_str = format!("{:?}", Some(Nullability::NullFirst)); assert_eq!(actual_str, expected_str); assert_eq!( diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs index f29b8dd07606..5c4c988c899e 100644 --- a/arrow-avro/src/compression.rs +++ b/arrow-avro/src/compression.rs @@ -16,7 +16,6 @@ // under the License. use arrow_schema::ArrowError; -use std::io; use std::io::Read; /// The metadata key used for storing the JSON encoded [`CompressionCodec`] @@ -27,6 +26,8 @@ pub enum CompressionCodec { Deflate, Snappy, ZStandard, + Bzip2, + Xz, } impl CompressionCodec { @@ -65,7 +66,6 @@ impl CompressionCodec { CompressionCodec::Snappy => Err(ArrowError::ParseError( "Snappy codec requires snappy feature".to_string(), )), - #[cfg(feature = "zstd")] CompressionCodec::ZStandard => { let mut decoder = zstd::Decoder::new(block)?; @@ -77,6 +77,28 @@ impl CompressionCodec { CompressionCodec::ZStandard => Err(ArrowError::ParseError( "ZStandard codec requires zstd feature".to_string(), )), + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + let mut decoder = bzip2::read::BzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + let mut decoder = xz::read::XzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), } } } diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 98c285171bf3..f62b01922814 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -14,13 +14,13 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - //! Decoder for [`Header`] use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; use crate::reader::vlq::VLQDecoder; use crate::schema::{Schema, SCHEMA_METADATA_KEY}; use arrow_schema::ArrowError; +use std::io::BufRead; #[derive(Debug)] enum HeaderDecoderState { @@ -74,17 +74,18 @@ impl Header { self.sync } - /// Returns the [`CompressionCodec`] if any + /// Returns the [`CompressionCodec`] if any. pub fn compression(&self) -> Result, ArrowError> { let v = self.get(CODEC_METADATA_KEY); - match v { None | Some(b"null") => Ok(None), Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)), Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)), Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)), + Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)), + Some(b"xz") => Ok(Some(CompressionCodec::Xz)), Some(v) => Err(ArrowError::ParseError(format!( - "Unrecognized compression codec \'{}\'", + "Unrecognized compression codec '{}'", String::from_utf8_lossy(v) ))), } @@ -147,8 +148,6 @@ impl HeaderDecoder { /// This method can be called multiple times with consecutive chunks of data, allowing /// integration with chunked IO systems like [`BufRead::fill_buf`] /// - /// All errors should be considered fatal, and decoding aborted - /// /// Once the entire [`Header`] has been decoded this method will not read any further /// input bytes, and the header can be obtained with [`Self::flush`] /// @@ -264,13 +263,13 @@ impl HeaderDecoder { #[cfg(test)] mod test { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use crate::reader::read_header; use crate::schema::SCHEMA_METADATA_KEY; use crate::test_util::arrow_test_data; use arrow_schema::{DataType, Field, Fields, TimeUnit}; use std::fs::File; - use std::io::{BufRead, BufReader}; + use std::io::BufReader; #[test] fn test_header_decode() { @@ -353,4 +352,35 @@ mod test { 325166208089902833952788552656412487328 ); } + #[test] + fn test_header_schema_default() { + let json_schema = r#" + { + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "a", "type": "int", "default": 10} + ] + } + "#; + let key = "avro.schema"; + let key_bytes = key.as_bytes(); + let value_bytes = json_schema.as_bytes(); + let mut meta_buf = Vec::new(); + meta_buf.extend_from_slice(key_bytes); + meta_buf.extend_from_slice(value_bytes); + let meta_offsets = vec![key_bytes.len(), key_bytes.len() + value_bytes.len()]; + let header = Header { + meta_offsets, + meta_buf, + sync: [0; 16], + }; + let schema = header.schema().unwrap().unwrap(); + if let crate::schema::Schema::Complex(crate::schema::ComplexType::Record(record)) = schema { + assert_eq!(record.fields.len(), 1); + assert_eq!(record.fields[0].default, Some(serde_json::json!(10))); + } else { + panic!("Expected record schema"); + } + } } diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 20ab0ad88a29..276b17475fb9 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -45,7 +45,6 @@ fn read_header(mut reader: R) -> Result { break; } } - decoder .flush() .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) @@ -54,7 +53,6 @@ fn read_header(mut reader: R) -> Result { /// Return an iterator of [`Block`] from the provided [`BufRead`] fn read_blocks(mut reader: R) -> impl Iterator> { let mut decoder = BlockDecoder::default(); - let mut try_next = move || { loop { let buf = reader.fill_buf()?; @@ -79,8 +77,14 @@ mod test { use crate::reader::record::RecordDecoder; use crate::reader::{read_blocks, read_header}; use crate::test_util::arrow_test_data; + use arrow_array::builder::{ + BooleanBuilder, Float32Builder, Int64Builder, ListBuilder, StringBuilder, StructBuilder, + }; use arrow_array::*; - use arrow_schema::{DataType, Field, Schema}; + use arrow_array::{Array, Float64Array, Int32Array, RecordBatch, StringArray, StructArray}; + use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields, Schema}; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; @@ -94,25 +98,22 @@ mod test { let schema = header.schema().unwrap().unwrap(); let root = AvroField::try_from(&schema).unwrap(); let mut decoder = RecordDecoder::try_new(root.data_type()).unwrap(); - for result in read_blocks(reader) { let block = result.unwrap(); assert_eq!(block.sync, header.sync()); - if let Some(c) = compression { - let decompressed = c.decompress(&block.data).unwrap(); - - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = remaining.max(batch_size); - offset += decoder - .decode(&decompressed[offset..], block.count) - .unwrap(); - - remaining -= to_read; - } - assert_eq!(offset, decompressed.len()); + let block_data = if let Some(c) = compression { + c.decompress(&block.data).unwrap() + } else { + block.data + }; + let mut offset = 0; + let mut remaining = block.count; + while remaining > 0 { + let to_read = remaining.min(batch_size); + offset += decoder.decode(&block_data[offset..], to_read).unwrap(); + remaining -= to_read; } + assert_eq!(offset, block_data.len()); } decoder.flush().unwrap() } @@ -123,6 +124,8 @@ mod test { "avro/alltypes_plain.avro", "avro/alltypes_plain.snappy.avro", "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", ]; let expected = RecordBatch::try_from_iter_with_nullable([ @@ -208,34 +211,547 @@ mod test { ), ]) .unwrap(); - for file in files { let file = arrow_test_data(file); - assert_eq!(read_file(&file, 8), expected); assert_eq!(read_file(&file, 3), expected); } } #[test] - fn test_fixed_length_decimal() { - let file_path = arrow_test_data("avro/fixed_length_decimal.avro"); - let actual_batch = read_file(&file_path, 8); + fn test_alltypes_dictionary() { + let file = "avro/alltypes_dictionary.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![Some(true), Some(false)])) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![0, 10])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![0.0, 1.1])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![0.0, 10.1])) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([b"01/01/09", b"01/01/09"])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values([b"0", b"1"])) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let batch_large = read_file(&file_path, 8); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + let batch_small = read_file(&file_path, 3); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 3) does not match for file {}", + file + ); + } + + #[test] + fn test_alltypes_nulls_plain() { + let file = "avro/alltypes_nulls_plain.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "string_col", + Arc::new(StringArray::from(vec![None::<&str>])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![None])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![None])) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![None])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![None])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![None])) as _, + true, + ), + ( + "bytes_col", + Arc::new(BinaryArray::from(vec![None::<&[u8]>])) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let batch_large = read_file(&file_path, 8); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + let batch_small = read_file(&file_path, 3); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 3) does not match for file {}", + file + ); + } + + #[test] + fn test_binary() { + let file = arrow_test_data("avro/binary.avro"); + let batch = read_file(&file, 8); + let expected = RecordBatch::try_from_iter_with_nullable([( + "foo", + Arc::new(BinaryArray::from_iter_values(vec![ + b"\x00".as_ref(), + b"\x01".as_ref(), + b"\x02".as_ref(), + b"\x03".as_ref(), + b"\x04".as_ref(), + b"\x05".as_ref(), + b"\x06".as_ref(), + b"\x07".as_ref(), + b"\x08".as_ref(), + b"\t".as_ref(), + b"\n".as_ref(), + b"\x0b".as_ref(), + ])) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_decimal() { + let files = [ + ("avro/fixed_length_decimal.avro", 25, 2), + ("avro/fixed_length_decimal_legacy.avro", 13, 2), + ("avro/int32_decimal.avro", 4, 2), + ("avro/int64_decimal.avro", 10, 2), + ]; let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); - let array = Decimal128Array::from_iter_values(decimal_values) - .with_precision_and_scale(25, 2) + for (file, precision, scale) in files { + let file_path = arrow_test_data(file); + let actual_batch = read_file(&file_path, 8); + let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) + .with_precision_and_scale(precision, scale) + .unwrap(); + let mut meta = HashMap::new(); + meta.insert("precision".to_string(), precision.to_string()); + meta.insert("scale".to_string(), scale.to_string()); + let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) + .with_metadata(meta); + let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); + let expected_batch = + RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) + .expect("Failed to build expected RecordBatch"); + assert_eq!( + actual_batch, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {}", + file + ); + let actual_batch_small = read_file(&file_path, 3); + assert_eq!( + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", + file + ); + } + } + + #[test] + fn test_datapage_v2() { + let file = arrow_test_data("avro/datapage_v2.snappy.avro"); + let batch = read_file(&file, 8); + let a = StringArray::from(vec![ + Some("abc"), + Some("abc"), + Some("abc"), + None, + Some("abc"), + ]); + let b = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let c = Float64Array::from(vec![Some(2.0), Some(3.0), Some(4.0), Some(5.0), Some(2.0)]); + let d = BooleanArray::from(vec![ + Some(true), + Some(true), + Some(true), + Some(false), + Some(true), + ]); + let e_values = Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + ]); + let e_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, 3, 3, 3, 6, 8])); + let e_validity = Some(NullBuffer::from(vec![true, false, false, true, true])); + let field_e = Arc::new(Field::new("item", DataType::Int32, true)); + let e = ListArray::new(field_e, e_offsets, Arc::new(e_values), e_validity); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a) as Arc, true), + ("b", Arc::new(b) as Arc, true), + ("c", Arc::new(c) as Arc, true), + ("d", Arc::new(d) as Arc, true), + ("e", Arc::new(e) as Arc, true), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_dict_pages_offset_zero() { + let file = arrow_test_data("avro/dict-page-offset-zero.avro"); + let batch = read_file(&file, 32); + let num_rows = batch.num_rows(); + let expected_field = Int32Array::from(vec![Some(1552); num_rows]); + let expected = RecordBatch::try_from_iter_with_nullable([( + "l_partkey", + Arc::new(expected_field) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_list_columns() { + let file = arrow_test_data("avro/list_columns.avro"); + let mut int64_list_builder = ListBuilder::new(Int64Builder::new()); + { + { + let values = int64_list_builder.values(); + values.append_value(1); + values.append_value(2); + values.append_value(3); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_null(); + values.append_value(1); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_value(4); + } + int64_list_builder.append(true); + } + let int64_list = int64_list_builder.finish(); + let mut utf8_list_builder = ListBuilder::new(StringBuilder::new()); + { + { + let values = utf8_list_builder.values(); + values.append_value("abc"); + values.append_value("efg"); + values.append_value("hij"); + } + utf8_list_builder.append(true); + } + { + utf8_list_builder.append(false); + } + { + { + let values = utf8_list_builder.values(); + values.append_value("efg"); + values.append_null(); + values.append_value("hij"); + values.append_value("xyz"); + } + utf8_list_builder.append(true); + } + let utf8_list = utf8_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("int64_list", Arc::new(int64_list) as Arc, true), + ("utf8_list", Arc::new(utf8_list) as Arc, true), + ]) + .unwrap(); + let batch = read_file(&file, 8); + assert_eq!(batch, expected); + } + + #[test] + fn test_nested_lists() { + let file = arrow_test_data("avro/nested_lists.snappy.avro"); + let inner_values = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("f"), + ]); + let inner_offsets = Buffer::from_slice_ref([0, 2, 3, 3, 4, 6, 8, 8, 9, 11, 13, 14, 14, 15]); + let inner_validity = [ + true, true, false, true, true, true, false, true, true, true, true, false, true, + ]; + let inner_null_buffer = Buffer::from_iter(inner_validity.iter().copied()); + let inner_field = Field::new("item", DataType::Utf8, true); + let inner_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(inner_field))) + .len(13) + .add_buffer(inner_offsets) + .add_child_data(inner_values.to_data()) + .null_bit_buffer(Some(inner_null_buffer)) + .build() .unwrap(); - let mut meta = HashMap::new(); - meta.insert("precision".to_string(), "25".to_string()); - meta.insert("scale".to_string(), "2".to_string()); - let field_with_meta = - Field::new("value", DataType::Decimal128(25, 2), true).with_metadata(meta); - let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); - let expected_batch = RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(array)]) - .expect("Failed to build expected RecordBatch"); + let inner_list_array = ListArray::from(inner_list_data); + let middle_offsets = Buffer::from_slice_ref([0, 2, 4, 6, 8, 11, 13]); + let middle_validity = [true; 6]; + let middle_null_buffer = Buffer::from_iter(middle_validity.iter().copied()); + let middle_field = Field::new("item", inner_list_array.data_type().clone(), true); + let middle_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(middle_field))) + .len(6) + .add_buffer(middle_offsets) + .add_child_data(inner_list_array.to_data()) + .null_bit_buffer(Some(middle_null_buffer)) + .build() + .unwrap(); + let middle_list_array = ListArray::from(middle_list_data); + let outer_offsets = Buffer::from_slice_ref([0, 2, 4, 6]); + let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all 3 rows valid + let outer_field = Field::new("item", middle_list_array.data_type().clone(), true); + let outer_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(outer_field))) + .len(3) + .add_buffer(outer_offsets) + .add_child_data(middle_list_array.to_data()) + .null_bit_buffer(Some(outer_null_buffer)) + .build() + .unwrap(); + let a_expected = ListArray::from(outer_list_data); + let b_expected = Int32Array::from(vec![1, 1, 1]); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a_expected) as Arc, true), + ("b", Arc::new(b_expected) as Arc, true), + ]) + .unwrap(); + let left = read_file(&file, 8); + assert_eq!(left, expected, "Mismatch for batch size=8"); + let left_small = read_file(&file, 3); + assert_eq!(left_small, expected, "Mismatch for batch size=3"); + } + + #[test] + fn test_nested_records() { + let f1_f1_1 = StringArray::from(vec!["aaa", "bbb"]); + let f1_f1_2 = Int32Array::from(vec![10, 20]); + let rounded_pi = (std::f64::consts::PI * 100.0).round() / 100.0; + let f1_f1_3_1 = Float64Array::from(vec![rounded_pi, rounded_pi]); + let f1_f1_3 = StructArray::from(vec![( + Arc::new(Field::new("f1_3_1", DataType::Float64, false)), + Arc::new(f1_f1_3_1) as Arc, + )]); + let f1_expected = StructArray::from(vec![ + ( + Arc::new(Field::new("f1_1", DataType::Utf8, false)), + Arc::new(f1_f1_1) as Arc, + ), + ( + Arc::new(Field::new("f1_2", DataType::Int32, false)), + Arc::new(f1_f1_2) as Arc, + ), + ( + Arc::new(Field::new( + "f1_3", + DataType::Struct(Fields::from(vec![Field::new( + "f1_3_1", + DataType::Float64, + false, + )])), + false, + )), + Arc::new(f1_f1_3) as Arc, + ), + ]); + let f2_fields = vec![ + Field::new("f2_1", DataType::Boolean, false), + Field::new("f2_2", DataType::Float32, false), + ]; + let f2_struct_builder = StructBuilder::new( + f2_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![ + Box::new(BooleanBuilder::new()) as Box, + Box::new(Float32Builder::new()) as Box, + ], + ); + let mut f2_list_builder = ListBuilder::new(f2_struct_builder); + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(1.2_f32); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(2.2_f32); + } + f2_list_builder.append(true); + } + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(false); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(10.2_f32); + } + f2_list_builder.append(true); + } + let f2_expected = f2_list_builder.finish(); + let mut f3_struct_builder = StructBuilder::new( + vec![Arc::new(Field::new("f3_1", DataType::Utf8, false))], + vec![Box::new(StringBuilder::new()) as Box], + ); + f3_struct_builder.append(true); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_value("xyz"); + } + f3_struct_builder.append(false); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + let f3_expected = f3_struct_builder.finish(); + let f4_fields = [Field::new("f4_1", DataType::Int64, false)]; + let f4_struct_builder = StructBuilder::new( + f4_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![Box::new(Int64Builder::new()) as Box], + ); + let mut f4_list_builder = ListBuilder::new(f4_struct_builder); + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(200); + } + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + f4_list_builder.append(true); + } + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(300); + } + f4_list_builder.append(true); + } + let f4_expected = f4_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("f1", Arc::new(f1_expected) as Arc, false), + ("f2", Arc::new(f2_expected) as Arc, false), + ("f3", Arc::new(f3_expected) as Arc, true), + ("f4", Arc::new(f4_expected) as Arc, false), + ]) + .unwrap(); + let file = arrow_test_data("avro/nested_records.avro"); + let batch_large = read_file(&file, 8); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 8)" + ); + let batch_small = read_file(&file, 3); assert_eq!( - actual_batch, expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data" + batch_small, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 3)" ); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index b6a60f01a6c7..cd2ba24c759d 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -150,9 +150,7 @@ impl Decoder { /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { - let decoder = match data_type.codec() { - /// Primitive Types - /// + let decoder = match &data_type.codec { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), @@ -167,9 +165,6 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - - /// Complex Types - /// Codec::Record(avro_fields) => { let mut arrow_fields = Vec::with_capacity(avro_fields.len()); let mut decoders = Vec::with_capacity(avro_fields.len()); @@ -185,18 +180,20 @@ impl Decoder { } Codec::Array(item) => { let item_decoder = Box::new(Self::try_new(item)?); + let item_field = item.field_with_name("item").with_nullable(true); Self::List( - Arc::new(item.field_with_name("item")), + Arc::new(item_field), OffsetBufferBuilder::new(DEFAULT_CAPACITY), item_decoder, ) } Codec::Map(value_type) => { + let val_field = value_type.field_with_name("value").with_nullable(true); let map_field = Arc::new(ArrowField::new( "entries", DataType::Struct(Fields::from(vec![ - Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(value_type.field_with_name("value")), + ArrowField::new("key", DataType::Utf8, false), + val_field, ])), false, )); @@ -210,9 +207,6 @@ impl Decoder { ) } Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), - - /// Logical Types - /// Codec::Decimal(precision, scale, size) => { let builder = DecimalBuilder::new(*precision, *scale, *size)?; Self::Decimal(*precision, *scale, *size, builder) @@ -230,8 +224,7 @@ impl Decoder { Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), }; - // Wrap in Nullable if needed - match data_type.nullability() { + match data_type.nullability { Some(nb) => Ok(Self::Nullable( nb, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -244,8 +237,6 @@ impl Decoder { /// Append a null to this decoder. fn append_null(&mut self) { match self { - /// Primitive & Date Logical Types - /// Self::Null(n) => *n += 1, Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), @@ -256,8 +247,6 @@ impl Decoder { Self::Float32(v) => v.push(0.0), Self::Float64(v) => v.push(0.0), Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), - /// Complex Types - /// Self::Record(_, children) => { for c in children.iter_mut() { c.append_null(); @@ -273,31 +262,25 @@ impl Decoder { map_off.push_length(*entry_count); } Self::Fixed(fsize, buf) => { - // For a null, push `fsize` zeroed bytes buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); } - /// Non-Date Logical Types - /// Self::Decimal(_, _, _, builder) => { let _ = builder.append_null(); } Self::Interval(intervals) => { - // null => store a 12-byte zero => months=0, days=0, nanos=0 intervals.push(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 0, }); } - Self::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } + Self::Nullable(_, _, _) => {} } } /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - /// Primitive Types - /// Self::Null(count) => *count += 1, Self::Boolean(values) => values.append(buf.get_bool()?), Self::Int32(values) => values.push(buf.get_int()?), @@ -309,8 +292,6 @@ impl Decoder { off.push_length(bytes.len()); data.extend_from_slice(bytes); } - /// Complex Types - /// Self::Record(_, children) => { for c in children.iter_mut() { c.decode(buf)?; @@ -331,24 +312,38 @@ impl Decoder { *entry_count += newly_added; map_off.push_length(*entry_count); } - Self::Nullable(_, nulls, child) => match buf.get_int()? { - 0 => { - nulls.append(true); - child.decode(buf)?; - } - 1 => { - nulls.append(false); - child.append_null(); - } - other => { - return Err(ArrowError::ParseError(format!( - "Unsupported union branch index {other} for Nullable" - ))); + Self::Nullable(nb, nulls, child) => { + let branch = buf.get_int()?; + match nb { + Nullability::NullFirst => { + if branch == 0 { + nulls.append(false); + child.append_null(); + } else if branch == 1 { + nulls.append(true); + child.decode(buf)?; + } else { + return Err(ArrowError::ParseError(format!( + "Unsupported union branch index {branch} for Nullable (NullFirst)" + ))); + } + } + Nullability::NullSecond => { + if branch == 0 { + nulls.append(true); + child.decode(buf)?; + } else if branch == 1 { + nulls.append(false); + child.append_null(); + } else { + return Err(ArrowError::ParseError(format!( + "Unsupported union branch index {branch} for Nullable (NullSecond)" + ))); + } + } } - }, + } Self::Fixed(fsize, accum) => accum.extend_from_slice(buf.get_fixed(*fsize as usize)?), - /// Logical Types - /// Self::Decimal(_, _, size, builder) => { let bytes = match *size { Some(sz) => buf.get_fixed(sz)?, @@ -367,12 +362,11 @@ impl Decoder { let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); let nanos = millis as i64 * 1_000_000; - let val = IntervalMonthDayNano { + intervals.push(IntervalMonthDayNano { months, days, nanoseconds: nanos, - }; - intervals.push(val); + }); } } Ok(()) @@ -381,68 +375,37 @@ impl Decoder { /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { - /// Primitive Types - /// - // Null => produce NullArray Self::Null(len) => { let count = std::mem::replace(len, 0); Ok(Arc::new(NullArray::new(count))) } - // boolean => flush to BooleanArray Self::Boolean(b) => { let bits = b.finish(); Ok(Arc::new(BooleanArray::new(bits, nulls))) } - // int32 => flush to Int32Array - Self::Int32(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // date32 => flush to Date32Array - Self::Date32(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // int64 => flush to Int64Array - Self::Int64(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // float32 => flush to Float32Array - Self::Float32(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // float64 => flush to Float64Array - Self::Float64(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // Avro bytes => BinaryArray + Self::Int32(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), + Self::Date32(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), + Self::Int64(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), + Self::Float32(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), + Self::Float64(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), Self::Binary(off, data) => { let offsets = flush_offsets(off); let values = flush_values(data).into(); Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) } - // Avro string => StringArray Self::String(off, data) => { let offsets = flush_offsets(off); let values = flush_values(data).into(); Ok(Arc::new(StringArray::new(offsets, values, nulls))) } - - /// Complex Types - /// - // Avro record => StructArray Self::Record(fields, children) => { let mut arrays = Vec::with_capacity(children.len()); for c in children.iter_mut() { - let a = c.flush(None)?; + let a = c.flush(nulls.clone())?; arrays.push(a); } Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) } - // Avro enum => DictionaryArray utf8> Self::Enum(symbols, indices) => { let dict_values = StringArray::from_iter_values(symbols.iter()); let idxs: Int32Array = match nulls { @@ -456,31 +419,24 @@ impl Decoder { None => Int32Array::from_iter_values(indices.iter().cloned()), }; let dict = DictionaryArray::::try_new(idxs, Arc::new(dict_values))?; - indices.clear(); // reset + indices.clear(); Ok(Arc::new(dict)) } - // Avro array => ListArray Self::List(field, off, item_dec) => { let child_arr = item_dec.flush(None)?; let offsets = flush_offsets(off); let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); Ok(Arc::new(arr)) } - // Avro map => MapArray Self::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { let moff = flush_offsets(map_off); let koff = flush_offsets(key_off); let kd = flush_values(key_data).into(); let val_arr = val_dec.flush(None)?; - let is_nullable = matches!(**val_dec, Self::Nullable(_, _, _)); let key_arr = StringArray::new(koff, kd, None); let struct_fields = vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new( - "value", - val_arr.data_type().clone(), - is_nullable, - )), + Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), ]; let entries = StructArray::new( Fields::from(struct_fields), @@ -491,19 +447,13 @@ impl Decoder { *entry_count = 0; Ok(Arc::new(map_arr)) } - - // Avro fixed => FixedSizeBinaryArray Self::Fixed(fsize, raw) => { let size = *fsize; let buf: Buffer = flush_values(raw).into(); - let total_len = buf.len() / (size as usize); let array = FixedSizeBinaryArray::try_new(size, buf, nulls) .map_err(|e| ArrowError::ParseError(e.to_string()))?; Ok(Arc::new(array)) } - /// Logical Types - /// - // Avro decimal => Arrow decimal Self::Decimal(prec, sc, sz, builder) => { let precision = *prec; let scale = sc.unwrap_or(0); @@ -512,29 +462,22 @@ impl Decoder { let arr = old_builder.finish(nulls, precision, scale)?; Ok(arr) } - // time-millis => Time32Millisecond - Self::TimeMillis(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // time-micros => Time64Microsecond - Self::TimeMicros(vals) => { - let arr = flush_primitive::(vals, nulls); - Ok(Arc::new(arr)) - } - // timestamp-millis => TimestampMillisecond + Self::TimeMillis(vals) => Ok(Arc::new(flush_primitive::( + vals, nulls, + ))), + Self::TimeMicros(vals) => Ok(Arc::new(flush_primitive::( + vals, nulls, + ))), Self::TimestampMillis(is_utc, vals) => { let arr = flush_primitive::(vals, nulls) .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); Ok(Arc::new(arr)) } - // timestamp-micros => TimestampMicrosecond Self::TimestampMicros(is_utc, vals) => { let arr = flush_primitive::(vals, nulls) .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); Ok(Arc::new(arr)) } - // Avro interval => IntervalMonthDayNanoType Self::Interval(vals) => { let data_len = vals.len(); let mut builder = @@ -546,7 +489,6 @@ impl Decoder { .finish() .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); if let Some(nb) = nulls { - // "merge" the newly built array with the nulls let arr_data = arr.into_data().into_builder().nulls(Some(nb)); let arr_data = unsafe { arr_data.build_unchecked() }; Ok(Arc::new(PrimitiveArray::::from( @@ -556,8 +498,7 @@ impl Decoder { Ok(Arc::new(arr)) } } - // For a nullable wrapper => flush the child with the built null buffer - Self::Nullable(_, nb, child) => { + Self::Nullable(_, ref mut nb, ref mut child) => { let mask = nb.finish(); child.flush(mask) } @@ -574,18 +515,16 @@ fn read_array_blocks( loop { let block_count = buf.get_long()?; match block_count { - 0 => break, // If block_count is 0, exit the loop + 0 => break, n if n < 0 => { - // If block_count is negative let item_count = (-n) as usize; - let _block_size = buf.get_long()?; // Read but ignore block size + let _block_size = buf.get_long()?; // size (ignored) for _ in 0..item_count { decode_item(buf)?; } total_items += item_count; } n => { - // If block_count is positive let item_count = n as usize; for _ in 0..item_count { decode_item(buf)?; @@ -597,7 +536,7 @@ fn read_array_blocks( Ok(total_items) } -/// Decode an Avro map in blocks until 0 block_count => end. +/// Decode an Avro map in blocks until 0 block_count signals end. fn read_map_blocks( buf: &mut AvroCursor, mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, @@ -659,7 +598,6 @@ impl DecimalBuilder { .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, )), None => { - // infer from precision if precision <= DECIMAL128_MAX_PRECISION as usize { Ok(Self::Decimal128( Decimal128Builder::new() @@ -786,11 +724,6 @@ fn sign_extend(raw: &[u8], target_len: usize) -> Vec { out } -/// Convenience helper to build a field with `name`, `DataType` and `nullable`. -fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { - Arc::new(ArrowField::new(name, dt, nullable)) -} - #[cfg(test)] mod tests { use super::*; @@ -798,10 +731,8 @@ mod tests { cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, }; + use std::sync::Arc; - // --------------- - // Zig-Zag Helpers - // --------------- fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); let mut v = (value << 1) ^ (value >> 31); @@ -830,9 +761,29 @@ mod tests { buf } - // ----------------- - // Test Fixed - // ----------------- + #[test] + fn test_record_decoder_default_metadata() { + use crate::codec::AvroField; + use crate::schema::Schema; + let json_schema = r#" + { + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "default_int", "type": "int", "default": 42} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let arrow_schema = record_decoder.schema(); + assert_eq!(arrow_schema.fields().len(), 1); + let field = arrow_schema.field(0); + let metadata = field.metadata(); + assert_eq!(metadata.get("avro.default").unwrap(), "42"); + } + #[test] fn test_fixed_decoding() { // `fixed(4)` => Arrow FixedSizeBinary(4) @@ -857,32 +808,25 @@ mod tests { #[test] fn test_fixed_with_nulls() { - // Avro union => [ fixed(2), null] let dt = AvroDataType::from_codec(Codec::Fixed(2)); let child = Decoder::try_new(&dt).unwrap(); let mut dec = Decoder::Nullable( - Nullability::NullFirst, + Nullability::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(child), ); - // Decode 3 rows: row1 => branch=0 => [0x00], then 2 bytes - // row2 => branch=1 => null => [0x02] - // row3 => branch=0 => 2 bytes let row1 = [0x11, 0x22]; let row3 = [0x55, 0x66]; let mut data = Vec::new(); - // row1 => union=0 => child => 2 bytes data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&row1); - // row2 => union=1 => null data.extend_from_slice(&encode_avro_int(1)); - // row3 => union=0 => child => 2 bytes data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&row3); let mut cursor = AvroCursor::new(&data); - dec.decode(&mut cursor).unwrap(); // row1 - dec.decode(&mut cursor).unwrap(); // row2 => null - dec.decode(&mut cursor).unwrap(); // row3 + dec.decode(&mut cursor).unwrap(); // Row1 + dec.decode(&mut cursor).unwrap(); // Row2 (null) + dec.decode(&mut cursor).unwrap(); // Row3 let arr = dec.flush(None).unwrap(); let fsb = arr.as_any().downcast_ref::().unwrap(); assert_eq!(fsb.len(), 3); @@ -894,22 +838,19 @@ mod tests { assert_eq!(fsb.value(2), row3); } - // ----------------- - // Test Interval - // ----------------- #[test] fn test_interval_decoding() { - // Avro interval => 12 bytes => [ months i32, days i32, ms i32 ] - // decode 2 rows => row1 => months=1, days=2, ms=100 => row2 => months=-1, days=10, ms=9999 let dt = AvroDataType::from_codec(Codec::Duration); let mut dec = Decoder::try_new(&dt).unwrap(); - // row1 => months=1 => 01,00,00,00, days=2 => 02,00,00,00, ms=100 => 64,00,00,00 - // row2 => months=-1 => 0xFF,0xFF,0xFF,0xFF, days=10 => 0x0A,0x00,0x00,0x00, ms=9999 => 0x0F,0x27,0x00,0x00 let row1 = [ - 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, // months=1 + 0x02, 0x00, 0x00, 0x00, // days=2 + 0x64, 0x00, 0x00, 0x00, // ms=100 ]; let row2 = [ - 0xFF, 0xFF, 0xFF, 0xFF, 0x0A, 0x00, 0x00, 0x00, 0x0F, 0x27, 0x00, 0x00, + 0xFF, 0xFF, 0xFF, 0xFF, // months=-1 + 0x0A, 0x00, 0x00, 0x00, // days=10 + 0x0F, 0x27, 0x00, 0x00, // ms=9999 ]; let mut data = Vec::new(); data.extend_from_slice(&row1); @@ -923,8 +864,6 @@ mod tests { .downcast_ref::() .unwrap(); assert_eq!(intervals.len(), 2); - // row0 => months=1, days=2, ms=100 => nanos=100_000_000 - // row1 => months=-1, days=10, ms=9999 => nanos=9999_000_000 let val0 = intervals.value(0); assert_eq!(val0.months, 1); assert_eq!(val0.days, 2); @@ -937,29 +876,26 @@ mod tests { #[test] fn test_interval_decoding_with_nulls() { - // Avro union => [ interval, null] + // Avro union => [ interval, null ] let dt = AvroDataType::from_codec(Codec::Duration); let child = Decoder::try_new(&dt).unwrap(); let mut dec = Decoder::Nullable( - Nullability::NullFirst, + Nullability::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(child), ); - // We'll decode 2 rows: row1 => interval => months=2, days=3, ms=500 => row2 => null - // row1 => union=0 => child => 12 bytes - // row2 => union=1 => null => no data let row1 = [ 0x02, 0x00, 0x00, 0x00, // months=2 0x03, 0x00, 0x00, 0x00, // days=3 - 0xF4, 0x01, 0x00, 0x00, - ]; // ms=500 => nanos=500_000_000 + 0xF4, 0x01, 0x00, 0x00, // ms=500 + ]; let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); // union=0 => child + data.extend_from_slice(&encode_avro_int(0)); // branch=0: non-null data.extend_from_slice(&row1); - data.extend_from_slice(&encode_avro_int(1)); // union=1 => null + data.extend_from_slice(&encode_avro_int(1)); // branch=1: null let mut cursor = AvroCursor::new(&data); - dec.decode(&mut cursor).unwrap(); // row1 - dec.decode(&mut cursor).unwrap(); // row2 => null + dec.decode(&mut cursor).unwrap(); // Row1 + dec.decode(&mut cursor).unwrap(); // Row2 (null) let arr = dec.flush(None).unwrap(); let intervals = arr .as_any() @@ -974,23 +910,19 @@ mod tests { assert_eq!(val0.nanoseconds, 500_000_000); } - // ------------------- - // Tests for Enum - // ------------------- #[test] fn test_enum_decoding() { let symbols = Arc::new(["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]); let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols, Arc::new([]))); let mut decoder = Decoder::try_new(&enum_dt).unwrap(); - // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(1)); // => [2] - data.extend_from_slice(&encode_avro_int(0)); // => [0] - data.extend_from_slice(&encode_avro_int(2)); // => [4] + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(2)); let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); // => GREEN - decoder.decode(&mut cursor).unwrap(); // => RED - decoder.decode(&mut cursor).unwrap(); // => BLUE + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); let array = decoder.flush(None).unwrap(); let dict_arr = array .as_any() @@ -1010,69 +942,54 @@ mod tests { #[test] fn test_enum_decoding_with_nulls() { // Union => [Enum(...), null] - // "child" => branch_index=0 => [0x00], "null" => 1 => [0x02] let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); let mut nullable_decoder = Decoder::Nullable( - Nullability::NullFirst, + Nullability::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(inner_decoder), ); - // Indices: [1, null, 2] => in Avro union let mut data = Vec::new(); - // Row1 => union branch=0 => child => [0x00] data.extend_from_slice(&encode_avro_int(0)); - // Then child's enum index=1 => [0x02] data.extend_from_slice(&encode_avro_int(1)); - // Row2 => union branch=1 => null => [0x02] data.extend_from_slice(&encode_avro_int(1)); - // Row3 => union branch=0 => child => [0x00] data.extend_from_slice(&encode_avro_int(0)); - // Then child's enum index=2 => [0x04] - data.extend_from_slice(&encode_avro_int(2)); + data.extend_from_slice(&encode_avro_int(0)); let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); // => GREEN - nullable_decoder.decode(&mut cursor).unwrap(); // => null - nullable_decoder.decode(&mut cursor).unwrap(); // => BLUE + nullable_decoder.decode(&mut cursor).unwrap(); + nullable_decoder.decode(&mut cursor).unwrap(); + nullable_decoder.decode(&mut cursor).unwrap(); let array = nullable_decoder.flush(None).unwrap(); let dict_arr = array .as_any() .downcast_ref::>() .unwrap(); assert_eq!(dict_arr.len(), 3); - // [GREEN, null, BLUE] assert!(dict_arr.is_valid(0)); assert!(!dict_arr.is_valid(1)); assert!(dict_arr.is_valid(2)); let keys = dict_arr.keys(); - // keys.value(0) => 1 => GREEN - // keys.value(2) => 2 => BLUE let dict_values = dict_arr.values().as_string::(); assert_eq!(dict_values.value(0), "RED"); assert_eq!(dict_values.value(1), "GREEN"); assert_eq!(dict_values.value(2), "BLUE"); } - // ------------------- - // Tests for Map - // ------------------- #[test] fn test_map_decoding_one_entry() { let value_type = AvroDataType::from_codec(Codec::String); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); - // Encode a single map with one entry: {"hello": "world"} let mut data = Vec::new(); - // block_count=1 => zigzag => [0x02] - data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_long(1)); // block_count=1 data.extend_from_slice(&encode_avro_bytes(b"hello")); // key data.extend_from_slice(&encode_avro_bytes(b"world")); // value let mut cursor = AvroCursor::new(&data); decoder.decode(&mut cursor).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // one map + assert_eq!(map_arr.len(), 1); assert_eq!(map_arr.value_length(0), 1); let entries = map_arr.value(0); let struct_entries = entries.as_any().downcast_ref::().unwrap(); @@ -1095,11 +1012,9 @@ mod tests { #[test] fn test_map_decoding_empty() { - // block_count=0 => empty map let value_type = AvroDataType::from_codec(Codec::String); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); - // Encode an empty map => block_count=0 => [0x00] let data = encode_avro_long(0); decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); @@ -1108,15 +1023,10 @@ mod tests { assert_eq!(map_arr.value_length(0), 0); } - // ------------------- - // Tests for Decimal - // ------------------- #[test] fn test_decimal_decoding_fixed128() { let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); let mut decoder = Decoder::try_new(&dt).unwrap(); - // Row1 => 123.45 => unscaled=12345 => i128 0x000...3039 - // Row2 => -1.23 => unscaled=-123 => i128 0xFFFF...FF85 let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x39, @@ -1125,7 +1035,6 @@ mod tests { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x85, ]; - let mut data = Vec::new(); data.extend_from_slice(&row1); data.extend_from_slice(&row2); @@ -1142,36 +1051,33 @@ mod tests { #[test] fn test_decimal_decoding_bytes_with_nulls() { // Avro union => [ Decimal(4,1), null ] - // child => index=0 => [0x00], null => index=1 => [0x02] let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); let mut inner = Decoder::try_new(&dt).unwrap(); let mut decoder = Decoder::Nullable( - Nullability::NullFirst, + Nullability::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(inner), ); - // Decode three rows: [123.4, null, -123.4] - let mut data = Vec::new(); - // Row1 => child => [0x00], then decimal => e.g. 0x04D2 => 1234 => "123.4" - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); - // Row2 => null => [0x02] - data.extend_from_slice(&encode_avro_int(1)); - // Row3 => child => [0x00], then decimal => 0xFB2E => -1234 => "-123.4" - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(2), "-123.4"); + 'data_clear: { + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); // branch=0 => non-null + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); // child's value: 1234 => "123.4" + data.extend_from_slice(&encode_avro_int(1)); // branch=1 => null + data.extend_from_slice(&encode_avro_int(0)); // branch=0 => non-null + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); // child's value: -1234 => "-123.4" + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } } #[test] @@ -1180,11 +1086,10 @@ mod tests { let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); let mut inner = Decoder::try_new(&dt).unwrap(); let mut decoder = Decoder::Nullable( - Nullability::NullFirst, + Nullability::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(inner), ); - // Decode [1234.56, null, -1234.56] let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xE2, 0x40, @@ -1194,12 +1099,9 @@ mod tests { 0x1D, 0xC0, ]; let mut data = Vec::new(); - // Row1 => child => [0x00] data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&row1); - // Row2 => null => [0x02] data.extend_from_slice(&encode_avro_int(1)); - // Row3 => child => [0x00] data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&row3); let mut cursor = AvroCursor::new(&data); @@ -1216,33 +1118,16 @@ mod tests { assert_eq!(dec_arr.value_as_string(2), "-1234.56"); } - // ------------------- - // Tests for List - // ------------------- #[test] fn test_list_decoding() { - // Avro array => block1(count=2), item1, item2, block2(count=0 => end) - // - // 1. Create 2 rows: - // Row1 => [10, 20] - // Row2 => [ ] - // - // 2. flush => should yield 2-element array => first row has 2 items, second row has 0 items let item_dt = AvroDataType::from_codec(Codec::Int32); let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); let mut decoder = Decoder::try_new(&list_dt).unwrap(); - // Row1 => block_count=2 => item=10 => item=20 => block_count=0 => end - // - 2 => zigzag => [0x04] - // - item=10 => zigzag => [0x14] - // - item=20 => zigzag => [0x28] - // - 0 => [0x00] let mut row1 = Vec::new(); - row1.extend_from_slice(&encode_avro_long(2)); // block_count=2 - row1.extend_from_slice(&encode_avro_int(10)); // item=10 - row1.extend_from_slice(&encode_avro_int(20)); // item=20 - row1.extend_from_slice(&encode_avro_long(0)); // end of array - - // Row2 => block_count=0 => empty array + row1.extend_from_slice(&encode_avro_long(2)); + row1.extend_from_slice(&encode_avro_int(10)); + row1.extend_from_slice(&encode_avro_int(20)); + row1.extend_from_slice(&encode_avro_long(0)); let mut row2 = Vec::new(); row2.extend_from_slice(&encode_avro_long(0)); let mut cursor = AvroCursor::new(&row1); @@ -1252,8 +1137,6 @@ mod tests { let array = decoder.flush(None).unwrap(); let list_arr = array.as_any().downcast_ref::().unwrap(); assert_eq!(list_arr.len(), 2); - // row0 => 2 items => [10, 20] - // row1 => 0 items let offsets = list_arr.value_offsets(); assert_eq!(offsets, &[0, 2, 2]); let values = list_arr.values(); @@ -1265,24 +1148,14 @@ mod tests { #[test] fn test_list_decoding_with_negative_block_count() { - // Start with single row => [1, 2, 3] - // We'll store them in a single negative block => block_count=-3 => #items=3 - // Then read block_size => let's pretend it's 9 bytes, etc. Then the items. - // Then a block_count=0 => done let item_dt = AvroDataType::from_codec(Codec::Int32); let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); let mut decoder = Decoder::try_new(&list_dt).unwrap(); - // block_count=-3 => zigzag => (-3 << 1) ^ (-3 >> 63) - // => -6 ^ -1 => ... - // Encode directly with `encode_avro_long(-3)`. let mut data = encode_avro_long(-3); - // Next => block_size => let's pretend 12 => encode_avro_long(12) data.extend_from_slice(&encode_avro_long(12)); - // Then 3 items => [1, 2, 3] data.extend_from_slice(&encode_avro_int(1)); data.extend_from_slice(&encode_avro_int(2)); data.extend_from_slice(&encode_avro_int(3)); - // Then block_count=0 => done data.extend_from_slice(&encode_avro_long(0)); let mut cursor = AvroCursor::new(&data); decoder.decode(&mut cursor).unwrap(); diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index d8722048a463..174a28fba62d 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::codec::Nullability; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -125,7 +126,7 @@ pub struct Record<'a> { pub name: &'a str, #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, @@ -140,14 +141,14 @@ pub struct Record<'a> { pub struct RecordField<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, - #[serde(borrow, default, skip_serializing_if = "Option::is_none")] - pub default: Option<&'a str>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, } /// An enumeration @@ -159,14 +160,14 @@ pub struct Enum<'a> { pub name: &'a str, #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] pub symbols: Vec<&'a str>, - #[serde(borrow, default, skip_serializing_if = "Option::is_none")] - pub default: Option<&'a str>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, #[serde(flatten)] pub attributes: Attributes<'a>, } @@ -209,6 +210,24 @@ pub struct Fixed<'a> { pub attributes: Attributes<'a>, } +/// An Avro data type (not an Avro schema) +#[derive(Debug, Clone)] +pub struct AvroDataType { + pub nullability: Option, + pub metadata: HashMap, + pub codec: crate::codec::Codec, +} + +impl AvroDataType { + /// Returns an Arrow [`Field`] with the given name, + /// respecting this type’s `nullability` (instead of forcing `true`). + pub fn field_with_name(&self, name: &str) -> arrow_schema::Field { + let d = self.codec.data_type(); + let is_nullable = self.nullability.is_some(); + arrow_schema::Field::new(name, d, is_nullable).with_metadata(self.metadata.clone()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -365,7 +384,7 @@ mod tests { default: None, } ], - attributes: Attributes::default(), + attributes: Default::default(), })) ); @@ -507,7 +526,7 @@ mod tests { aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - Schema::Complex(ComplexType::Map(Map { + Schema::Complex(ComplexType::Map(crate::schema::Map { values: Box::new(Schema::TypeName(TypeName::Primitive( PrimitiveType::Bytes ))), @@ -569,4 +588,29 @@ mod tests { assert_eq!(schema, with_aliases); } + + #[test] + fn test_default_parsing() { + // Test that a default value is correctly parsed for a record field. + let json_schema = r#" + { + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "a", "type": "int", "default": 10}, + {"name": "b", "type": "string", "default": "default_str"}, + {"name": "c", "type": "boolean"} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + if let Schema::Complex(ComplexType::Record(rec)) = schema { + assert_eq!(rec.fields.len(), 3); + assert_eq!(rec.fields[0].default, Some(json!(10))); + assert_eq!(rec.fields[1].default, Some(json!("default_str"))); + assert_eq!(rec.fields[2].default, None); + } else { + panic!("Expected record schema"); + } + } } From eee39485051f3908da6c92f59529cdb9c5382ead Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 2 Feb 2025 22:58:03 -0600 Subject: [PATCH 26/38] Implemented `test_nonnullable_impala` in `arrow-avro/src/reader/mod.rs` Signed-off-by: Connor Sanders --- arrow-avro/src/reader/mod.rs | 283 +++++++++++++++++++++++++++++++- arrow-avro/src/reader/record.rs | 124 +++++--------- 2 files changed, 323 insertions(+), 84 deletions(-) diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 276b17475fb9..ca1ebd02c738 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -78,7 +78,8 @@ mod test { use crate::reader::{read_blocks, read_header}; use crate::test_util::arrow_test_data; use arrow_array::builder::{ - BooleanBuilder, Float32Builder, Int64Builder, ListBuilder, StringBuilder, StructBuilder, + ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, + ListBuilder, MapBuilder, StringBuilder, StructBuilder, }; use arrow_array::*; use arrow_array::{Array, Float64Array, Int32Array, RecordBatch, StringArray, StructArray}; @@ -754,4 +755,284 @@ mod test { "Decoded RecordBatch does not match expected data for nested records (batch size 3)" ); } + + #[test] + fn test_nonnullable_impala() { + let file = arrow_test_data("avro/nonnullable.impala.avro"); + let id = Int64Array::from(vec![Some(8)]); + let mut int_array_builder = ListBuilder::new(Int32Builder::new()); + { + let vb = int_array_builder.values(); + vb.append_value(-1); + } + int_array_builder.append(true); // finalize one sub-list + let int_array = int_array_builder.finish(); + let mut iaa_builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + { + let inner_list_builder = iaa_builder.values(); + { + let vb = inner_list_builder.values(); + vb.append_value(-1); + vb.append_value(-2); + } + inner_list_builder.append(true); + inner_list_builder.append(true); + } + iaa_builder.append(true); + let int_array_array = iaa_builder.finish(); + use arrow_array::builder::MapFieldNames; + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut int_map_builder = + MapBuilder::new(Some(field_names), StringBuilder::new(), Int32Builder::new()); + { + let (keys, vals) = int_map_builder.entries(); + keys.append_value("k1"); + vals.append_value(-1); + } + int_map_builder.append(true).unwrap(); // finalize map for row 0 + let int_map = int_map_builder.finish(); + let field_names2 = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut ima_builder = ListBuilder::new(MapBuilder::new( + Some(field_names2), + StringBuilder::new(), + Int32Builder::new(), + )); + { + let map_builder = ima_builder.values(); + map_builder.append(true).unwrap(); + { + let (keys, vals) = map_builder.entries(); + keys.append_value("k1"); + vals.append_value(1); + } + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + } + ima_builder.append(true); + let int_map_array_ = ima_builder.finish(); + let mut nested_sb = StructBuilder::new( + vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new( + "B", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + Arc::new(Field::new( + "c", + DataType::Struct( + vec![Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), + true, + ))), + true, + ))), + true, + )] + .into(), + ), + true, + )), + Arc::new(Field::new( + "G", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct( + vec![Field::new( + "h", + DataType::Struct( + vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )] + .into(), + ), + true, + )] + .into(), + ), + true, + ), + ] + .into(), + ), + false, + )), + false, + ), + true, + )), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(ListBuilder::new(Int32Builder::new())), + { + let d_field = Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), + true, + ))), + true, + ))), + true, + ); + Box::new(StructBuilder::new( + vec![Arc::new(d_field)], + vec![Box::new({ + let ef_struct_builder = StructBuilder::new( + vec![ + Arc::new(Field::new("e", DataType::Int32, true)), + Arc::new(Field::new("f", DataType::Utf8, true)), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ); + let list_of_ef = ListBuilder::new(ef_struct_builder); + ListBuilder::new(list_of_ef) + })], + )) + }, + { + let map_field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let i_list_builder = ListBuilder::new(Float64Builder::new()); + let h_struct = StructBuilder::new( + vec![Arc::new(Field::new( + "i", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + true, + ))], + vec![Box::new(i_list_builder)], + ); + let g_value_builder = StructBuilder::new( + vec![Arc::new(Field::new( + "h", + DataType::Struct( + vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )] + .into(), + ), + true, + ))], + vec![Box::new(h_struct)], + ); + Box::new(MapBuilder::new( + Some(map_field_names), + StringBuilder::new(), + g_value_builder, + )) + }, + ], + ); + nested_sb.append(true); + { + let a_builder = nested_sb.field_builder::(0).unwrap(); + a_builder.append_value(-1); + } + { + let b_builder = nested_sb + .field_builder::>(1) + .unwrap(); + { + let vb = b_builder.values(); + vb.append_value(-1); + } + b_builder.append(true); + } + { + let c_struct_builder = nested_sb.field_builder::(2).unwrap(); + c_struct_builder.append(true); + let d_list_builder = c_struct_builder + .field_builder::>>(0) + .unwrap(); + { + let sub_list_builder = d_list_builder.values(); + { + let ef_struct = sub_list_builder.values(); + ef_struct.append(true); + { + let e_b = ef_struct.field_builder::(0).unwrap(); + e_b.append_value(-1); + let f_b = ef_struct.field_builder::(1).unwrap(); + f_b.append_value("nonnullable"); + } + sub_list_builder.append(true); + } + d_list_builder.append(true); + } + } + { + let g_map_builder = nested_sb + .field_builder::>(3) + .unwrap(); + g_map_builder.append(true).unwrap(); + } + let nested_struct = nested_sb.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("ID", Arc::new(id) as Arc, true), + ("Int_Array", Arc::new(int_array), true), + ("int_array_array", Arc::new(int_array_array), true), + ("Int_Map", Arc::new(int_map), true), + ("int_map_array", Arc::new(int_map_array_), true), + ("nested_Struct", Arc::new(nested_struct), true), + ]) + .unwrap(); + let batch_large = read_file(&file, 8); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index cd2ba24c759d..85887fcc6cbf 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -77,7 +77,6 @@ impl RecordDecoder { .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) } } @@ -86,69 +85,42 @@ impl RecordDecoder { #[derive(Debug)] enum Decoder { /// Primitive Types - /// - /// Avro `null` Null(usize), - /// Avro `boolean` Boolean(BooleanBufferBuilder), - /// Avro `int` => i32 Int32(Vec), - /// Avro `long` => i64 Int64(Vec), - /// Avro `float` => f32 Float32(Vec), - /// Avro `double` => f64 Float64(Vec), - /// Avro `bytes` => Arrow Binary Binary(OffsetBufferBuilder, Vec), - /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), - /// Complex Types - /// - /// Avro `record` + /// Complex Record(Fields, Vec), - /// Avro `enum` => Dictionary(int32 -> string) Enum(Arc<[String]>, Vec), - /// Avro `array` List(FieldRef, OffsetBufferBuilder, Box), - /// Avro `map` Map( FieldRef, OffsetBufferBuilder, OffsetBufferBuilder, Vec, Box, - usize, ), - /// Avro union that includes `null` Nullable(Nullability, NullBufferBuilder, Box), - /// Avro `fixed(n)` => Arrow `FixedSizeBinaryArray` Fixed(i32, Vec), /// Logical Types - /// - /// Avro decimal => Arrow decimal Decimal(usize, Option, Option, DecimalBuilder), - /// Avro `date` => Date32 Date32(Vec), - /// Avro `time-millis` => Time32(Millisecond) TimeMillis(Vec), - /// Avro `time-micros` => Time64(Microsecond) TimeMicros(Vec), - /// Avro `timestamp-millis` (bool = UTC?) TimestampMillis(bool, Vec), - /// Avro `timestamp-micros` (bool = UTC?) TimestampMicros(bool, Vec), - /// Avro `interval` => Arrow `IntervalMonthDayNanoType` (12 bytes) Interval(Vec), } impl Decoder { - /// Checks if the Decoder is nullable, i.e. wrapped in `Nullable`. fn is_nullable(&self) -> bool { matches!(self, Self::Nullable(_, _, _)) } - /// Create a `Decoder` from an [`AvroDataType`]. fn try_new(data_type: &AvroDataType) -> Result { let decoder = match &data_type.codec { Codec::Null => Self::Null(0), @@ -203,7 +175,6 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), Box::new(Self::try_new(value_type)?), - 0, ) } Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), @@ -234,7 +205,6 @@ impl Decoder { } } - /// Append a null to this decoder. fn append_null(&mut self) { match self { Self::Null(n) => *n += 1, @@ -257,9 +227,10 @@ impl Decoder { off.push_length(0); child.append_null(); } - Self::Map(_, key_off, map_off, _, _, entry_count) => { + Self::Map(_, key_off, map_off, _, child) => { key_off.push_length(0); - map_off.push_length(*entry_count); + map_off.push_length(0); + child.append_null(); } Self::Fixed(fsize, buf) => { buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); @@ -278,7 +249,6 @@ impl Decoder { } } - /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { Self::Null(count) => *count += 1, @@ -302,15 +272,15 @@ impl Decoder { let total_items = read_array_blocks(buf, |b| child.decode(b))?; off.push_length(total_items); } - Self::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { + Self::Map(_, key_off, map_off, key_data, val_decoder) => { let newly_added = read_map_blocks(buf, |b| { let kb = b.get_bytes()?; key_off.push_length(kb.len()); key_data.extend_from_slice(kb); val_decoder.decode(b) })?; - *entry_count += newly_added; - map_off.push_length(*entry_count); + + map_off.push_length(newly_added); } Self::Nullable(nb, nulls, child) => { let branch = buf.get_int()?; @@ -372,7 +342,6 @@ impl Decoder { Ok(()) } - /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { match self { Self::Null(len) => { @@ -428,7 +397,7 @@ impl Decoder { let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); Ok(Arc::new(arr)) } - Self::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { + Self::Map(field, key_off, map_off, key_data, val_dec) => { let moff = flush_offsets(map_off); let koff = flush_offsets(key_off); let kd = flush_values(key_data).into(); @@ -444,7 +413,6 @@ impl Decoder { None, ); let map_arr = MapArray::new(field.clone(), moff, entries, nulls, false); - *entry_count = 0; Ok(Arc::new(map_arr)) } Self::Fixed(fsize, raw) => { @@ -518,7 +486,7 @@ fn read_array_blocks( 0 => break, n if n < 0 => { let item_count = (-n) as usize; - let _block_size = buf.get_long()?; // size (ignored) + let _block_size = buf.get_long()?; for _ in 0..item_count { decode_item(buf)?; } @@ -541,20 +509,31 @@ fn read_map_blocks( buf: &mut AvroCursor, mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { - let block_count = buf.get_long()?; - if block_count <= 0 { - Ok(0) - } else { - let n = block_count as usize; - for _ in 0..n { - decode_entry(buf)?; + let mut total_entries = 0usize; + loop { + let block_count = buf.get_long()?; + match block_count { + 0 => break, + n if n < 0 => { + let item_count = (-n) as usize; + let _block_size = buf.get_long()?; + for _ in 0..item_count { + decode_entry(buf)?; + } + total_entries += item_count; + } + n => { + let item_count = n as usize; + for _ in 0..item_count { + decode_entry(buf)?; + } + total_entries += item_count; + } } - Ok(n) } + Ok(total_entries) } -/// Flush a [`Vec`] of primitive values to a [`PrimitiveArray`], applying optional `nulls`. -#[inline] fn flush_primitive( values: &mut Vec, nulls: Option, @@ -562,14 +541,10 @@ fn flush_primitive( PrimitiveArray::new(flush_values(values).into(), nulls) } -/// Flush an [`OffsetBufferBuilder`]. -#[inline] fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() } -/// Take ownership of `values`. -#[inline] fn flush_values(values: &mut Vec) -> Vec { std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) } @@ -582,7 +557,6 @@ enum DecimalBuilder { } impl DecimalBuilder { - /// Create a new DecimalBuilder given precision, scale, and optional byte-size (`fixed`). fn new( precision: usize, scale: Option, @@ -622,7 +596,6 @@ impl DecimalBuilder { } } - /// Append sign-extended bytes to this decimal builder fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { match self { Self::Decimal128(b) => { @@ -639,7 +612,6 @@ impl DecimalBuilder { Ok(()) } - /// Append a null decimal value (0) fn append_null(&mut self) -> Result<(), ArrowError> { match self { Self::Decimal128(b) => { @@ -654,7 +626,6 @@ impl DecimalBuilder { Ok(()) } - /// Finish building the decimal array, returning an [`ArrayRef`]. fn finish( self, nulls: Option, @@ -680,7 +651,6 @@ impl DecimalBuilder { } } -/// Sign-extend `raw` to 16 bytes. fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { let extended = sign_extend(raw, 16); if extended.len() != 16 { @@ -694,7 +664,6 @@ fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { Ok(arr) } -/// Sign-extend `raw` to 32 bytes. fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { let extended = sign_extend(raw, 32); if extended.len() != 32 { @@ -708,7 +677,6 @@ fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { Ok(arr) } -/// Sign-extend the first byte to produce `target_len` bytes total. fn sign_extend(raw: &[u8], target_len: usize) -> Vec { if raw.is_empty() { return vec![0; target_len]; @@ -729,7 +697,7 @@ mod tests { use super::*; use arrow_array::{ cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, - IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, + IntervalMonthDayNanoArray, ListArray, MapArray, }; use std::sync::Arc; @@ -982,32 +950,22 @@ mod tests { let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); let mut decoder = Decoder::try_new(&map_type).unwrap(); let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); // block_count=1 - data.extend_from_slice(&encode_avro_bytes(b"hello")); // key - data.extend_from_slice(&encode_avro_bytes(b"world")); // value + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); + data.extend_from_slice(&encode_avro_bytes(b"world")); + data.extend_from_slice(&encode_avro_long(0)); let mut cursor = AvroCursor::new(&data); decoder.decode(&mut cursor).unwrap(); let array = decoder.flush(None).unwrap(); let map_arr = array.as_any().downcast_ref::().unwrap(); assert_eq!(map_arr.len(), 1); assert_eq!(map_arr.value_length(0), 1); - let entries = map_arr.value(0); - let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); - let key_arr = struct_entries - .column_by_name("key") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let val_arr = struct_entries - .column_by_name("value") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(key_arr.value(0), "hello"); - assert_eq!(val_arr.value(0), "world"); + let struct_arr = map_arr.value(0); + assert_eq!(struct_arr.len(), 1); + let keys = struct_arr.column(0).as_string::(); + let vals = struct_arr.column(1).as_string::(); + assert_eq!(keys.value(0), "hello"); + assert_eq!(vals.value(0), "world"); } #[test] From 6d1ae4901075c68c54de1b85b6513fed4c971544 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Thu, 6 Feb 2025 13:41:19 -0600 Subject: [PATCH 27/38] Implemented `test_nullable_impala` testcase in `arrow-avro/src/reader/mod.rs` Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 600 ++++++++++++-- arrow-avro/src/reader/cursor.rs | 1 + arrow-avro/src/reader/mod.rs | 72 +- arrow-avro/src/reader/record.rs | 1350 +++++++++++++++++++++++-------- arrow-avro/src/schema.rs | 174 +++- 5 files changed, 1803 insertions(+), 394 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 5ed4d58dd09c..d173362c3a91 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -27,11 +27,16 @@ use std::sync::Arc; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Nullability { - /// The nulls are encoded as the first union variant + /// The nulls are encoded as the first union variant => `[ "null", T ]` NullFirst, - /// The nulls are encoded as the second union variant + /// The nulls are encoded as the second union variant => `[ T, "null" ]` + /// + /// **Important**: In Impala’s out-of-spec approach, branch=0 => null, branch=1 => decode T. + /// This is reversed from the typical “standard” Avro interpretation for `[T,"null"]`. + /// + /// NullSecond, } @@ -82,9 +87,11 @@ impl AvroField { pub fn field(&self) -> Field { let mut fld = self.data_type.field_with_name(&self.name); if let Some(def_val) = &self.default { - let mut md = fld.metadata().clone(); - md.insert("avro.default".to_string(), def_val.to_string()); - fld = fld.with_metadata(md); + if !def_val.is_null() { + let mut md = fld.metadata().clone(); + md.insert("avro.default".to_string(), def_val.to_string()); + fld = fld.with_metadata(md); + } } fld } @@ -299,7 +306,6 @@ fn make_data_type<'a>( ))), } } - // complex Schema::Complex(c) => match c { ComplexType::Record(r) => { let ns = r.namespace.or(namespace); @@ -548,19 +554,56 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { } Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), Interval(IntervalUnit::MonthDayNano) => Codec::Duration, - other => { - let _ = other; - Codec::String - } + _ => Codec::String, } } #[cfg(test)] mod tests { use super::*; - use arrow_schema::Field; + use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; + use serde_json::json; + use std::collections::HashMap; use std::sync::Arc; + #[test] + fn test_skip_avro_default_null_in_metadata() { + let dt = AvroDataType::from_codec(Codec::Int32); + let field = AvroField { + name: "test_col".into(), + data_type: dt, + default: Some(json!(null)), + }; + let arrow_field = field.field(); + assert!(arrow_field.metadata().get("avro.default").is_none()); + } + + #[test] + fn test_store_avro_default_nonnull_in_metadata() { + let dt = AvroDataType::from_codec(Codec::Int32); + let field = AvroField { + name: "test_col".into(), + data_type: dt, + default: Some(json!(42)), + }; + let arrow_field = field.field(); + let md = arrow_field.metadata(); + let got = md.get("avro.default").cloned(); + assert_eq!(got, Some("42".to_string())); + } + + #[test] + fn test_no_default_metadata_if_none() { + let dt = AvroDataType::from_codec(Codec::String); + let field = AvroField { + name: "col".to_string(), + data_type: dt, + default: None, + }; + let arrow_field = field.field(); + assert!(arrow_field.metadata().get("avro.default").is_none()); + } + #[test] fn test_avro_field() { let field_codec = AvroDataType::from_codec(Codec::Int64); @@ -575,7 +618,7 @@ mod tests { assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); let arrow_field = avro_field.field(); assert_eq!(arrow_field.name(), "long_col"); - assert_eq!(arrow_field.data_type(), &Int64); + assert_eq!(arrow_field.data_type(), &DataType::Int64); assert!(!arrow_field.is_nullable()); } @@ -601,74 +644,74 @@ mod tests { let codec = Codec::Fixed(12); let dt = codec.data_type(); match dt { - FixedSizeBinary(n) => assert_eq!(n, 12), + DataType::FixedSizeBinary(n) => assert_eq!(n, 12), _ => panic!("Expected FixedSizeBinary(12)"), } } #[test] fn test_arrow_field_to_avro_field() { - let arrow_field = Field::new("Null", Null, true); + let arrow_field = Field::new("Null", DataType::Null, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Null)); - let arrow_field = Field::new("Boolean", Boolean, true); + let arrow_field = Field::new("Boolean", DataType::Boolean, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Boolean)); - let arrow_field = Field::new("Int32", Int32, true); + let arrow_field = Field::new("Int32", DataType::Int32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Int32)); - let arrow_field = Field::new("Int64", Int64, true); + let arrow_field = Field::new("Int64", DataType::Int64, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Int64)); - let arrow_field = Field::new("Float32", Float32, true); + let arrow_field = Field::new("Float32", DataType::Float32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Float32)); - let arrow_field = Field::new("Float64", Float64, true); + let arrow_field = Field::new("Float64", DataType::Float64, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Float64)); - let arrow_field = Field::new("Binary", Binary, true); + let arrow_field = Field::new("Binary", DataType::Binary, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Binary)); - let arrow_field = Field::new("Utf8", Utf8, true); + let arrow_field = Field::new("Utf8", DataType::Utf8, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::String)); - let arrow_field = Field::new("Decimal128", Decimal128(1, 2), true); + let arrow_field = Field::new("Decimal128", DataType::Decimal128(1, 2), true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::Decimal(1, Some(2), Some(16)) )); - let arrow_field = Field::new("Decimal256", Decimal256(1, 2), true); + let arrow_field = Field::new("Decimal256", DataType::Decimal256(1, 2), true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::Decimal(1, Some(2), Some(32)) )); - let arrow_field = Field::new("Date32", Date32, true); + let arrow_field = Field::new("Date32", DataType::Date32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Date32)); - let arrow_field = Field::new("Time32", Time32(TimeUnit::Millisecond), false); + let arrow_field = Field::new("Time32", DataType::Time32(TimeUnit::Millisecond), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::TimeMillis)); - let arrow_field = Field::new("Time32", Time64(TimeUnit::Microsecond), false); + let arrow_field = Field::new("Time32", DataType::Time64(TimeUnit::Microsecond), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::TimeMicros)); let arrow_field = Field::new( "utc_ts_ms", - Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -679,7 +722,7 @@ mod tests { let arrow_field = Field::new( "utc_ts_us", - Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -688,30 +731,45 @@ mod tests { Codec::TimestampMicros(true) )); - let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); + let arrow_field = Field::new( + "local_ts_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::TimestampMillis(false) )); - let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); + let arrow_field = Field::new( + "local_ts_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::TimestampMicros(false) )); - let arrow_field = Field::new("Interval", Interval(IntervalUnit::MonthDayNano), false); + let arrow_field = Field::new( + "Interval", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Duration)); let arrow_field = Field::new( "Struct", - Struct(Fields::from(vec![ - Field::new("a", Boolean, false), - Field::new("b", Float64, false), - ])), + DataType::Struct( + vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Float64, false), + ] + .into(), + ), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -728,7 +786,7 @@ mod tests { let arrow_field = Field::new( "DictionaryEnum", - Dictionary(Box::new(Utf8), Box::new(Int32)), + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -736,30 +794,32 @@ mod tests { let arrow_field = Field::new( "DictionaryString", - Dictionary(Box::new(Int32), Box::new(Boolean)), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Boolean)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::String)); - let field = Field::new("Utf8", Utf8, true); - let arrow_field = Field::new("Array with nullable items", List(Arc::new(field)), true); + // Array with nullable items + let field = Field::new("Utf8", DataType::Utf8, true); + let arrow_field = Field::new( + "Array with nullable items", + DataType::List(Arc::new(field)), + true, + ); let avro_field = arrow_field_to_avro_field(&arrow_field); if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { - assert!(matches!( - avro_data_type.nullability, - Some(Nullability::NullFirst) - )); + assert_eq!(avro_data_type.nullability, Some(Nullability::NullFirst)); assert_eq!(avro_data_type.metadata.len(), 0); assert!(matches!(avro_data_type.codec, Codec::String)); } else { panic!("Expected Codec::Array"); } - let field = Field::new("Utf8", Utf8, false); + let field = Field::new("Utf8", DataType::Utf8, false); let arrow_field = Field::new( "Array with non-nullable items", - List(Arc::new(field)), + DataType::List(Arc::new(field)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -773,10 +833,10 @@ mod tests { let entries_field = Field::new( "entries", - Struct( + DataType::Struct( vec![ - Field::new("key", Utf8, false), - Field::new("value", Utf8, true), + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, true), ] .into(), ), @@ -784,15 +844,12 @@ mod tests { ); let arrow_field = Field::new( "Map with nullable items", - Map(Arc::new(entries_field), true), + DataType::Map(Arc::new(entries_field), true), true, ); let avro_field = arrow_field_to_avro_field(&arrow_field); if let Codec::Map(avro_data_type) = &avro_field.data_type().codec { - assert!(matches!( - avro_data_type.nullability, - Some(Nullability::NullFirst) - )); + assert_eq!(avro_data_type.nullability, Some(Nullability::NullFirst)); assert_eq!(avro_data_type.metadata.len(), 0); assert!(matches!(avro_data_type.codec, Codec::String)); } else { @@ -801,15 +858,18 @@ mod tests { let arrow_field = Field::new( "Utf8", - Struct(Fields::from(vec![ - Field::new("key", Utf8, false), - Field::new("value", Utf8, false), - ])), + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ] + .into(), + ), false, ); let arrow_field = Field::new( "Map with non-nullable items", - Map(Arc::new(arrow_field), false), + DataType::Map(Arc::new(arrow_field), false), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -820,8 +880,7 @@ mod tests { } else { panic!("Expected Codec::Map"); } - - let arrow_field = Field::new("FixedSizeBinary", FixedSizeBinary(8), false); + let arrow_field = Field::new("FixedSizeBinary", DataType::FixedSizeBinary(8), false); let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = &avro_field.data_type().codec; assert!(matches!(codec, Codec::Fixed(8))); @@ -829,10 +888,9 @@ mod tests { #[test] fn test_arrow_field_to_avro_field_meta_namespace() { - let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( - "namespace".to_string(), - "arrow_meta_ns".to_string(), - )])); + let arrow_field = Field::new("test_meta", DataType::Utf8, true).with_metadata( + HashMap::from([("namespace".to_string(), "arrow_meta_ns".to_string())]), + ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert_eq!(avro_field.name(), "test_meta"); let actual_str = format!("{:?}", avro_field.data_type().codec); @@ -846,4 +904,416 @@ mod tests { Some(&"arrow_meta_ns".to_string()) ); } + + #[test] + fn test_union_long_null() { + let json_schema = r#" + { + "type": "record", + "name": "test_long_null", + "fields": [ + {"name": "f0", "type": ["long", "null"]} + ] + } + "#; + let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema).unwrap(); + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "f0"); + let child_dt = fields[0].data_type(); + // "long" + "null" => NullSecond + assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(child_dt.codec, Codec::Int64)); + } + _ => panic!("Expected a record with a single [long,null] field"), + } + let mut resolver = Resolver::default(); + let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + if let Codec::Record(fields) = &top_dt.codec { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "f0"); + let child_dt = fields[0].data_type(); + assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(child_dt.codec, Codec::Int64)); + } else { + panic!("Expected a record with a single [long,null] field (make_data_type)"); + } + } + + #[test] + fn test_union_array_of_int_null() { + let json_schema = r#" + { + "type":"record", + "name":"test_array_int_null", + "fields":[ + {"name":"arr","type":[{"type":"array","items":["int","null"]},"null"]} + ] + } + "#; + let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema).unwrap(); + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "arr"); + let child_dt = fields[0].data_type(); + assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(item_type) = &child_dt.codec { + assert_eq!(item_type.nullability, Some(Nullability::NullSecond)); + assert!(matches!(item_type.codec, Codec::Int32)); + } else { + panic!("Expected Codec::Array for 'arr' field"); + } + } + _ => panic!("Expected a record with a single union array field"), + } + let mut resolver = Resolver::default(); + let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + if let Codec::Record(fields) = &top_dt.codec { + assert_eq!(fields.len(), 1); + let arr_dt = fields[0].data_type(); + assert_eq!(arr_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(item_type) = &arr_dt.codec { + assert_eq!(item_type.nullability, Some(Nullability::NullSecond)); + assert!(matches!(item_type.codec, Codec::Int32)); + } else { + panic!("Expected Codec::Array (make_data_type)"); + } + } else { + panic!("Expected record (make_data_type)"); + } + } + + #[test] + fn test_union_nested_array_of_int_null() { + let json_schema = r#" + { + "type":"record", + "name":"test_nested_array_int_null", + "fields":[ + { + "name":"nested_arr", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":["int","null"] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema).unwrap(); + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "nested_arr"); + let outer_dt = fields[0].data_type(); + assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(mid_dt) = &outer_dt.codec { + assert_eq!(mid_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(inner_dt) = &mid_dt.codec { + assert_eq!(inner_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(inner_dt.codec, Codec::Int32)); + } else { + panic!("Expected inner Codec::Array for nested_arr"); + } + } else { + panic!("Expected outer Codec::Array for nested_arr"); + } + } + _ => panic!("Expected a record with a single nested union array field"), + } + let mut resolver = Resolver::default(); + let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + if let Codec::Record(fields) = &top_dt.codec { + assert_eq!(fields.len(), 1); + let outer_dt = fields[0].data_type(); + assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(mid_dt) = &outer_dt.codec { + assert_eq!(mid_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(inner_dt) = &mid_dt.codec { + assert_eq!(inner_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(inner_dt.codec, Codec::Int32)); + } else { + panic!("Expected inner array (make_data_type)"); + } + } else { + panic!("Expected outer array (make_data_type)"); + } + } else { + panic!("Expected record (make_data_type)"); + } + } + + #[test] + fn test_union_map_of_int_null() { + let json_schema = r#" + { + "type":"record", + "name":"test_map_int_null", + "fields":[ + {"name":"map_field","type":[{"type":"map","values":["int","null"]},"null"]} + ] + } + "#; + let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + + let avro_field = AvroField::try_from(&schema).unwrap(); + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "map_field"); + let map_dt = fields[0].data_type(); + assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Map(value_type) = &map_dt.codec { + assert_eq!(value_type.nullability, Some(Nullability::NullSecond)); + assert!(matches!(value_type.codec, Codec::Int32)); + } else { + panic!("Expected Codec::Map for map_field"); + } + } + _ => panic!("Expected a record with a single union map field"), + } + let mut resolver = Resolver::default(); + let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + if let Codec::Record(fields) = &top_dt.codec { + assert_eq!(fields.len(), 1); + let map_dt = fields[0].data_type(); + assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Map(val_dt) = &map_dt.codec { + assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(val_dt.codec, Codec::Int32)); + } else { + panic!("Expected map in make_data_type"); + } + } else { + panic!("Expected record in make_data_type"); + } + } + + #[test] + fn test_union_map_array_of_int_null() { + let json_schema = r#" + { + "type":"record", + "name":"test_map_array_int_null", + "fields":[ + { + "name":"map_arr", + "type":[ + { + "type":"array", + "items":[ + { + "type":"map", + "values":["int","null"] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema).unwrap(); + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "map_arr"); + let outer_dt = fields[0].data_type(); + assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(map_dt) = &outer_dt.codec { + assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Map(val_dt) = &map_dt.codec { + assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(val_dt.codec, Codec::Int32)); + } else { + panic!("Expected Codec::Map for map_arr items"); + } + } else { + panic!("Expected Codec::Array for map_arr"); + } + } + _ => panic!("Expected a record with a single union array-of-map field"), + } + let mut resolver = Resolver::default(); + let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + if let Codec::Record(fields) = &top_dt.codec { + assert_eq!(fields.len(), 1); + let outer_dt = fields[0].data_type(); + assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Array(map_dt) = &outer_dt.codec { + assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Map(val_dt) = &map_dt.codec { + assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(val_dt.codec, Codec::Int32)); + } else { + panic!("Expected Codec::Map in make_data_type"); + } + } else { + panic!("Expected Codec::Array in make_data_type"); + } + } else { + panic!("Expected record in make_data_type"); + } + } + + #[test] + fn test_union_nested_struct_out_of_spec() { + let json_schema = r#" + { + "type":"record","name":"topLevelRecord","fields":[ + {"name":"nested_struct","type":[ + { + "type":"record", + "name":"nested_struct", + "namespace":"topLevelRecord", + "fields":[ + {"name":"A","type":["int","null"]}, + { + "name":"b", + "type":[{"type":"array","items":["int","null"]},"null"] + }, + { + "name":"C", + "type":[ + { + "type":"record", + "name":"C", + "namespace":"topLevelRecord.nested_struct", + "fields":[ + { + "name":"d", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":[ + { + "type":"record", + "name":"d", + "namespace":"topLevelRecord.nested_struct.C", + "fields":[ + {"name":"E","type":["int","null"]}, + {"name":"F","type":["string","null"]} + ] + }, + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + }, + { + "name":"g", + "type":[ + { + "type":"map", + "values":[ + { + "type":"record", + "name":"g", + "namespace":"topLevelRecord.nested_struct", + "fields":[ + { + "name":"H", + "type":[ + { + "type":"record", + "name":"H", + "namespace":"topLevelRecord.nested_struct.g", + "fields":[ + { + "name":"i", + "type":[ + { + "type":"array", + "items":["double","null"] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + }, + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ]} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema).unwrap(); + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "nested_struct"); + let ns_dt = fields[0].data_type(); + assert_eq!(ns_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Record(nested_fields) = &ns_dt.codec { + assert_eq!(nested_fields.len(), 4); + let field_a_dt = nested_fields[0].data_type(); + assert_eq!(field_a_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(field_a_dt.codec, Codec::Int32)); + } else { + panic!("Expected nested_struct to be a Record"); + } + } + _ => panic!("Expected top-level record with a single union-based nested_struct"), + } + let mut resolver = Resolver::default(); + let dt = make_data_type(&schema, None, &mut resolver).unwrap(); + if let Codec::Record(fields) = &dt.codec { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "nested_struct"); + let ns_dt = fields[0].data_type(); + assert_eq!(ns_dt.nullability, Some(Nullability::NullSecond)); + if let Codec::Record(nested_fields) = &ns_dt.codec { + assert_eq!(nested_fields.len(), 4); + let field_a_dt = nested_fields[0].data_type(); + assert_eq!(field_a_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(field_a_dt.codec, Codec::Int32)); + } else { + panic!("Expected nested_struct to be a Record (make_data_type)"); + } + } else { + panic!("Expected top-level record (make_data_type)"); + } + } } diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 9e38a78c63ec..65c93dab42fe 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -14,6 +14,7 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index ca1ebd02c738..ec4d260a706c 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -81,8 +81,10 @@ mod test { ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, ListBuilder, MapBuilder, StringBuilder, StructBuilder, }; - use arrow_array::*; - use arrow_array::{Array, Float64Array, Int32Array, RecordBatch, StringArray, StructArray}; + use arrow_array::{ + Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int32Array, + Int64Array, ListArray, RecordBatch, StringArray, StructArray, TimestampMicrosecondArray, + }; use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field, Fields, Schema}; @@ -1035,4 +1037,70 @@ mod test { let batch_small = read_file(&file, 3); assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); } + + #[test] + fn test_nullable_impala() { + let file = arrow_test_data("avro/nullable.impala.avro"); + let batch1 = read_file(&file, 3); + let batch2 = read_file(&file, 8); + assert_eq!(batch1, batch2); + let batch = batch1; + assert_eq!(batch.num_rows(), 7); + let id_array = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column should be an Int64Array"); + let expected_ids = [1, 2, 3, 4, 5, 6, 7]; + for (i, &expected_id) in expected_ids.iter().enumerate() { + assert_eq!( + id_array.value(i), + expected_id, + "Mismatch in id at row {}", + i + ); + } + let int_array = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("int_array column should be a ListArray"); + { + let offsets = int_array.value_offsets(); + let start = offsets[0] as usize; + let end = offsets[1] as usize; + let values = int_array + .values() + .as_any() + .downcast_ref::() + .expect("Values of int_array should be an Int32Array"); + let row0: Vec> = (start..end).map(|i| Some(values.value(i))).collect(); + assert_eq!( + row0, + vec![Some(1), Some(2), Some(3)], + "Mismatch in int_array row 0" + ); + } + let nested_struct = batch + .column(5) + .as_any() + .downcast_ref::() + .expect("nested_struct column should be a StructArray"); + let a_array = nested_struct + .column_by_name("A") + .expect("Field A should exist in nested_struct") + .as_any() + .downcast_ref::() + .expect("Field A should be an Int32Array"); + assert_eq!(a_array.value(0), 1, "Mismatch in nested_struct.A at row 0"); + assert!( + !a_array.is_valid(1), + "Expected null in nested_struct.A at row 1" + ); + assert!( + !a_array.is_valid(3), + "Expected null in nested_struct.A at row 3" + ); + assert_eq!(a_array.value(6), 7, "Mismatch in nested_struct.A at row 6"); + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 85887fcc6cbf..cb9855493d18 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -21,6 +21,7 @@ use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilde use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; +use arrow_data::ArrayData; use arrow_schema::{ ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, @@ -28,7 +29,6 @@ use arrow_schema::{ use std::io::Read; use std::sync::Arc; -/// The default capacity used for internal buffers const DEFAULT_CAPACITY: usize = 1024; /// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. @@ -41,9 +41,9 @@ impl RecordDecoder { /// Create a new [`RecordDecoder`] from an [`AvroDataType`] expected to be a `Record`. pub fn try_new(data_type: &AvroDataType) -> Result { match Decoder::try_new(data_type)? { - Decoder::Record(fields, encodings) => Ok(Self { + Decoder::Record(fields, decoders) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), - fields: encodings, + fields: decoders, }), other => Err(ArrowError::ParseError(format!( "Expected record got {other:?}" @@ -70,18 +70,34 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush the accumulated data into a [`RecordBatch`], clearing internal state. + /// Flush into a [`RecordBatch`]. + /// + /// - Flush each `Decoder` => `Arc` + /// - Sanitize offsets in each final array => `sanitize_array_offsets(...)` pub fn flush(&mut self) -> Result { let arrays = self .fields .iter_mut() - .map(|x| x.flush(None)) + .map(|d| d.flush(None)) + .collect::, _>>()?; + let sanitized_cols = arrays + .into_iter() + .map(sanitize_array_offsets) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) + RecordBatch::try_new(self.schema.clone(), sanitized_cols) } } -/// Decoder for Avro data of various shapes. +/// For 2-branch unions we store either `[null, T]` or `[T, null]`. +/// +/// - `NullFirst`: `[null, T]` => branch=0 => null, branch=1 => decode T +/// - `NullSecond`: `[T, null]` => branch=0 => decode T, branch=1 => null +#[derive(Debug, Copy, Clone)] +enum UnionOrder { + NullFirst, + NullSecond, +} + #[derive(Debug)] enum Decoder { /// Primitive Types @@ -93,7 +109,7 @@ enum Decoder { Float64(Vec), Binary(OffsetBufferBuilder, Vec), String(OffsetBufferBuilder, Vec), - /// Complex + /// Complex Types Record(Fields, Vec), Enum(Arc<[String]>, Vec), List(FieldRef, OffsetBufferBuilder, Box), @@ -104,7 +120,7 @@ enum Decoder { Vec, Box, ), - Nullable(Nullability, NullBufferBuilder, Box), + Nullable(UnionOrder, NullBufferBuilder, Box), Fixed(i32, Vec), /// Logical Types Decimal(usize, Option, Option, DecimalBuilder), @@ -117,12 +133,8 @@ enum Decoder { } impl Decoder { - fn is_nullable(&self) -> bool { - matches!(self, Self::Nullable(_, _, _)) - } - fn try_new(data_type: &AvroDataType) -> Result { - let decoder = match &data_type.codec { + let base = match &data_type.codec { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), @@ -138,29 +150,29 @@ impl Decoder { Vec::with_capacity(DEFAULT_CAPACITY), ), Codec::Record(avro_fields) => { - let mut arrow_fields = Vec::with_capacity(avro_fields.len()); - let mut decoders = Vec::with_capacity(avro_fields.len()); - for avro_field in avro_fields.iter() { - let d = Self::try_new(avro_field.data_type())?; - arrow_fields.push(avro_field.field()); - decoders.push(d); + let mut fields = Vec::with_capacity(avro_fields.len()); + let mut children = Vec::with_capacity(avro_fields.len()); + for f in avro_fields.iter() { + let child = Self::try_new(f.data_type())?; + fields.push(f.field()); + children.push(child); } - Self::Record(arrow_fields.into(), decoders) + Self::Record(fields.into(), children) } - Codec::Enum(keys, values) => { - Self::Enum(Arc::clone(keys), Vec::with_capacity(values.len())) + Codec::Enum(syms, _) => { + Self::Enum(Arc::clone(syms), Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Array(item) => { - let item_decoder = Box::new(Self::try_new(item)?); - let item_field = item.field_with_name("item").with_nullable(true); + Codec::Array(child) => { + let child_dec = Self::try_new(child)?; + let item_field = child.field_with_name("item").with_nullable(true); Self::List( Arc::new(item_field), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - item_decoder, + Box::new(child_dec), ) } - Codec::Map(value_type) => { - let val_field = value_type.field_with_name("value").with_nullable(true); + Codec::Map(child) => { + let val_field = child.field_with_name("value").with_nullable(true); let map_field = Arc::new(ArrowField::new( "entries", DataType::Struct(Fields::from(vec![ @@ -169,39 +181,46 @@ impl Decoder { ])), false, )); + let valdec = Self::try_new(child)?; Self::Map( map_field, OffsetBufferBuilder::new(DEFAULT_CAPACITY), OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), - Box::new(Self::try_new(value_type)?), + Box::new(valdec), ) } - Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Decimal(precision, scale, size) => { - let builder = DecimalBuilder::new(*precision, *scale, *size)?; - Self::Decimal(*precision, *scale, *size, builder) + Codec::Fixed(sz) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Decimal(p, s, size) => { + let b = DecimalBuilder::new(*p, *s, *size)?; + Self::Decimal(*p, *s, *size, b) } Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)), Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { - Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::TimestampMillis(utc) => { + Self::TimestampMillis(*utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::TimestampMicros(is_utc) => { - Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::TimestampMicros(utc) => { + Self::TimestampMicros(*utc, Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), }; - match data_type.nullability { - Some(nb) => Ok(Self::Nullable( - nb, + let union_order = match data_type.nullability { + None => None, + Some(Nullability::NullFirst) => Some(UnionOrder::NullFirst), + Some(Nullability::NullSecond) => Some(UnionOrder::NullSecond), + }; + + match union_order { + Some(order) => Ok(Self::Nullable( + order, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(base), )), - None => Ok(decoder), + None => Ok(base), } } @@ -218,121 +237,129 @@ impl Decoder { Self::Float64(v) => v.push(0.0), Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), Self::Record(_, children) => { - for c in children.iter_mut() { + for c in children { c.append_null(); } } - Self::Enum(_, indices) => indices.push(0), - Self::List(_, off, child) => { + Self::Enum(_, idxs) => idxs.push(0), + Self::List(_, off, _) => { off.push_length(0); - child.append_null(); } - Self::Map(_, key_off, map_off, _, child) => { - key_off.push_length(0); - map_off.push_length(0); + Self::Map(_, _koff, moff, _kdata, _valdec) => { + moff.push_length(0); + } + Self::Nullable(_, nb, child) => { + nb.append(false); child.append_null(); } - Self::Fixed(fsize, buf) => { - buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); + Self::Fixed(sz, accum) => { + accum.extend(std::iter::repeat(0u8).take(*sz as usize)); } - Self::Decimal(_, _, _, builder) => { - let _ = builder.append_null(); + Self::Decimal(_, _, _, db) => { + let _ = db.append_null(); } - Self::Interval(intervals) => { - intervals.push(IntervalMonthDayNano { + Self::Interval(ivals) => { + ivals.push(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 0, }); } - Self::Nullable(_, _, _) => {} } } - fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + fn decode(&mut self, buf: &mut AvroCursor) -> Result<(), ArrowError> { match self { - Self::Null(count) => *count += 1, - Self::Boolean(values) => values.append(buf.get_bool()?), - Self::Int32(values) => values.push(buf.get_int()?), - Self::Int64(values) => values.push(buf.get_long()?), - Self::Float32(values) => values.push(buf.get_float()?), - Self::Float64(values) => values.push(buf.get_double()?), + Self::Null(n) => { + *n += 1; + } + Self::Boolean(b) => { + b.append(buf.get_bool()?); + } + Self::Int32(v) => { + v.push(buf.get_int()?); + } + Self::Int64(v) => { + v.push(buf.get_long()?); + } + Self::Float32(vals) => { + vals.push(buf.get_float()?); + } + Self::Float64(vals) => { + vals.push(buf.get_double()?); + } Self::Binary(off, data) | Self::String(off, data) => { let bytes = buf.get_bytes()?; off.push_length(bytes.len()); data.extend_from_slice(bytes); } Self::Record(_, children) => { - for c in children.iter_mut() { + for c in children { c.decode(buf)?; } } - Self::Enum(_, indices) => indices.push(buf.get_int()?), + Self::Enum(_, idxs) => { + idxs.push(buf.get_int()?); + } Self::List(_, off, child) => { - let total_items = read_array_blocks(buf, |b| child.decode(b))?; + let total_items = read_array_blocks(buf, |cursor| child.decode(cursor))?; off.push_length(total_items); } - Self::Map(_, key_off, map_off, key_data, val_decoder) => { - let newly_added = read_map_blocks(buf, |b| { - let kb = b.get_bytes()?; - key_off.push_length(kb.len()); - key_data.extend_from_slice(kb); - val_decoder.decode(b) + Self::Map(_, koff, moff, kdata, valdec) => { + let newly_added = read_map_blocks(buf, |cur| { + let kb = cur.get_bytes()?; + koff.push_length(kb.len()); + kdata.extend_from_slice(kb); + valdec.decode(cur) })?; - - map_off.push_length(newly_added); + moff.push_length(newly_added); } - Self::Nullable(nb, nulls, child) => { + Self::Nullable(order, nb, child) => { let branch = buf.get_int()?; - match nb { - Nullability::NullFirst => { + match order { + UnionOrder::NullFirst => { if branch == 0 { - nulls.append(false); + nb.append(false); child.append_null(); - } else if branch == 1 { - nulls.append(true); - child.decode(buf)?; } else { - return Err(ArrowError::ParseError(format!( - "Unsupported union branch index {branch} for Nullable (NullFirst)" - ))); + nb.append(true); + child.decode(buf)?; } } - Nullability::NullSecond => { + UnionOrder::NullSecond => { if branch == 0 { - nulls.append(true); + nb.append(true); child.decode(buf)?; - } else if branch == 1 { - nulls.append(false); - child.append_null(); } else { - return Err(ArrowError::ParseError(format!( - "Unsupported union branch index {branch} for Nullable (NullSecond)" - ))); + nb.append(false); + child.append_null(); } } } } - Self::Fixed(fsize, accum) => accum.extend_from_slice(buf.get_fixed(*fsize as usize)?), - Self::Decimal(_, _, size, builder) => { - let bytes = match *size { - Some(sz) => buf.get_fixed(sz)?, + Self::Fixed(sz, accum) => { + let fx = buf.get_fixed(*sz as usize)?; + accum.extend_from_slice(fx); + } + Self::Decimal(_, _, fsz, db) => { + let raw = match *fsz { + Some(n) => buf.get_fixed(n)?, None => buf.get_bytes()?, }; - builder.append_bytes(bytes)?; - } - Self::Date32(values) => values.push(buf.get_int()?), - Self::TimeMillis(values) => values.push(buf.get_int()?), - Self::TimeMicros(values) => values.push(buf.get_long()?), - Self::TimestampMillis(_, values) => values.push(buf.get_long()?), - Self::TimestampMicros(_, values) => values.push(buf.get_long()?), - Self::Interval(intervals) => { - let raw = buf.get_fixed(12)?; - let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); - let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); - let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); - let nanos = millis as i64 * 1_000_000; - intervals.push(IntervalMonthDayNano { + db.append_bytes(raw)?; + } + Self::Date32(vals) => vals.push(buf.get_int()?), + Self::TimeMillis(vals) => vals.push(buf.get_int()?), + Self::TimeMicros(vals) => vals.push(buf.get_long()?), + Self::TimestampMillis(_, vals) => vals.push(buf.get_long()?), + Self::TimestampMicros(_, vals) => vals.push(buf.get_long()?), + Self::Interval(ivals) => { + let x = buf.get_fixed(12)?; + let months = i32::from_le_bytes(x[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(x[4..8].try_into().unwrap()); + let ms = i32::from_le_bytes(x[8..12].try_into().unwrap()); + let nanos = ms as i64 * 1_000_000; + ivals.push(IntervalMonthDayNano { months, days, nanoseconds: nanos, @@ -342,136 +369,225 @@ impl Decoder { Ok(()) } - fn flush(&mut self, nulls: Option) -> Result { + fn flush(&mut self, nulls: Option) -> Result, ArrowError> { match self { - Self::Null(len) => { - let count = std::mem::replace(len, 0); - Ok(Arc::new(NullArray::new(count))) + Self::Null(count) => { + let c = std::mem::replace(count, 0); + Ok(Arc::new(NullArray::new(c)) as Arc) } Self::Boolean(b) => { let bits = b.finish(); - Ok(Arc::new(BooleanArray::new(bits, nulls))) + Ok(Arc::new(BooleanArray::new(bits, nulls)) as Arc) + } + Self::Int32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Date32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Int64(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Float32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Float64(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) } - Self::Int32(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), - Self::Date32(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), - Self::Int64(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), - Self::Float32(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), - Self::Float64(vals) => Ok(Arc::new(flush_primitive::(vals, nulls))), Self::Binary(off, data) => { let offsets = flush_offsets(off); - let values = flush_values(data).into(); - Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) + let vals = flush_values(data).into(); + let arr = BinaryArray::new(offsets, vals, nulls); + Ok(Arc::new(arr) as Arc) } Self::String(off, data) => { let offsets = flush_offsets(off); - let values = flush_values(data).into(); - Ok(Arc::new(StringArray::new(offsets, values, nulls))) + let vals = flush_values(data).into(); + let arr = StringArray::new(offsets, vals, nulls); + Ok(Arc::new(arr) as Arc) } Self::Record(fields, children) => { - let mut arrays = Vec::with_capacity(children.len()); - for c in children.iter_mut() { - let a = c.flush(nulls.clone())?; - arrays.push(a); + let mut child_arrays = Vec::with_capacity(children.len()); + for c in children { + child_arrays.push(c.flush(None)?); } - Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) + let (fixed, final_nulls) = flush_record_children(child_arrays, nulls)?; + let sarr = StructArray::new(fields.clone(), fixed, final_nulls); + Ok(Arc::new(sarr) as Arc) } - Self::Enum(symbols, indices) => { - let dict_values = StringArray::from_iter_values(symbols.iter()); - let idxs: Int32Array = match nulls { - Some(b) => { - let buff = Buffer::from_slice_ref(&indices); + Self::Enum(symbols, idxs) => { + let dict_vals = StringArray::from_iter_values(symbols.iter()); + let i32arr = match nulls { + Some(nb) => { + let buff = Buffer::from_slice_ref(&idxs); PrimitiveArray::::try_new( arrow_buffer::ScalarBuffer::from(buff), - Some(b), + Some(nb), )? } - None => Int32Array::from_iter_values(indices.iter().cloned()), + None => Int32Array::from_iter_values(idxs.iter().cloned()), }; - let dict = DictionaryArray::::try_new(idxs, Arc::new(dict_values))?; - indices.clear(); - Ok(Arc::new(dict)) + idxs.clear(); + let d = DictionaryArray::::try_new(i32arr, Arc::new(dict_vals))?; + Ok(Arc::new(d) as Arc) } - Self::List(field, off, item_dec) => { - let child_arr = item_dec.flush(None)?; + Self::List(item_field, off, child) => { + let c = child.flush(None)?; let offsets = flush_offsets(off); - let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); - Ok(Arc::new(arr)) - } - Self::Map(field, key_off, map_off, key_data, val_dec) => { - let moff = flush_offsets(map_off); - let koff = flush_offsets(key_off); - let kd = flush_values(key_data).into(); - let val_arr = val_dec.flush(None)?; + let larr = ListArray::new(item_field.clone(), offsets, c, nulls); + Ok(Arc::new(larr) as Arc) + } + Self::Map(map_field, k_off, m_off, kdata, valdec) => { + let moff = flush_offsets(m_off); + let koff = flush_offsets(k_off); + let kd = flush_values(kdata).into(); + let val_arr = valdec.flush(None)?; let key_arr = StringArray::new(koff, kd, None); - let struct_fields = vec![ - Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), - ]; - let entries = StructArray::new( - Fields::from(struct_fields), - vec![Arc::new(key_arr), val_arr], + let (fixed_keys, fixed_vals) = flush_map_children(&key_arr, &val_arr)?; + let entries_struct = StructArray::new( + Fields::from(vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(ArrowField::new( + "value", + fixed_vals.data_type().clone(), + true, + )), + ]), + vec![Arc::new(fixed_keys), fixed_vals], None, ); - let map_arr = MapArray::new(field.clone(), moff, entries, nulls, false); - Ok(Arc::new(map_arr)) + let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); + Ok(Arc::new(map_arr) as Arc) } - Self::Fixed(fsize, raw) => { - let size = *fsize; - let buf: Buffer = flush_values(raw).into(); - let array = FixedSizeBinaryArray::try_new(size, buf, nulls) + Self::Nullable(_, nb_builder, child) => { + let mask = nb_builder.finish(); + child.flush(mask) + } + Self::Fixed(sz, accum) => { + let b: Buffer = flush_values(accum).into(); + let arr = FixedSizeBinaryArray::try_new(*sz, b, nulls) .map_err(|e| ArrowError::ParseError(e.to_string()))?; - Ok(Arc::new(array)) - } - Self::Decimal(prec, sc, sz, builder) => { - let precision = *prec; - let scale = sc.unwrap_or(0); - let new_builder = DecimalBuilder::new(precision, *sc, *sz)?; - let old_builder = std::mem::replace(builder, new_builder); - let arr = old_builder.finish(nulls, precision, scale)?; + Ok(Arc::new(arr) as Arc) + } + Self::Decimal(precision, scale, sz, builder) => { + let p = *precision; + let s = scale.unwrap_or(0); + let new_b = DecimalBuilder::new(p, *scale, *sz)?; + let old = std::mem::replace(builder, new_b); + let arr = old.finish(nulls, p, s)?; Ok(arr) } - Self::TimeMillis(vals) => Ok(Arc::new(flush_primitive::( - vals, nulls, - ))), - Self::TimeMicros(vals) => Ok(Arc::new(flush_primitive::( - vals, nulls, - ))), + Self::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr) as Arc) + } Self::TimestampMillis(is_utc, vals) => { let arr = flush_primitive::(vals, nulls) .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); - Ok(Arc::new(arr)) + Ok(Arc::new(arr) as Arc) } Self::TimestampMicros(is_utc, vals) => { let arr = flush_primitive::(vals, nulls) .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); - Ok(Arc::new(arr)) - } - Self::Interval(vals) => { - let data_len = vals.len(); - let mut builder = - PrimitiveBuilder::::with_capacity(data_len); - for v in vals.drain(..) { - builder.append_value(v); + Ok(Arc::new(arr) as Arc) + } + Self::Interval(ivals) => { + let len = ivals.len(); + let mut b = PrimitiveBuilder::::with_capacity(len); + for v in ivals.drain(..) { + b.append_value(v); } - let arr = builder + let arr = b .finish() .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); if let Some(nb) = nulls { let arr_data = arr.into_data().into_builder().nulls(Some(nb)); - let arr_data = unsafe { arr_data.build_unchecked() }; - Ok(Arc::new(PrimitiveArray::::from( - arr_data, - ))) + let arr_data = arr_data.build()?; + Ok( + Arc::new(PrimitiveArray::::from(arr_data)) + as Arc, + ) } else { - Ok(Arc::new(arr)) + Ok(Arc::new(arr) as Arc) } } - Self::Nullable(_, ref mut nb, ref mut child) => { - let mask = nb.finish(); - child.flush(mask) + } + } +} + +fn flush_record_children( + mut kids: Vec>, + parent_nulls: Option, +) -> Result<(Vec>, Option), ArrowError> { + let max_len = kids.iter().map(|c| c.len()).max().unwrap_or(0); + let fixed_parent_nulls = match parent_nulls { + None => None, + Some(nb) => { + let old_len = nb.len(); + if old_len == max_len { + Some(nb) + } else if old_len < max_len { + let mut b = NullBufferBuilder::new(max_len); + for i in 0..old_len { + b.append(nb.is_valid(i)); + } + for _ in 0..(max_len - old_len) { + b.append(false); + } + b.finish() + } else { + // truncate + let mut b = NullBufferBuilder::new(max_len); + for i in 0..max_len { + b.append(nb.is_valid(i)); + } + b.finish() } } + }; + let mut out = Vec::with_capacity(kids.len()); + for arr in kids { + let cur_len = arr.len(); + if cur_len == max_len { + out.push(arr); + } else if cur_len < max_len { + let to_add = max_len - cur_len; + let appended = append_nulls(&arr, to_add)?; + out.push(appended); + } else { + // slice + let sliced = arr.slice(0, max_len); + out.push(sliced); + } } + Ok((out, fixed_parent_nulls)) +} + +fn flush_map_children( + key_arr: &StringArray, + val_arr: &Arc, +) -> Result<(StringArray, Arc), ArrowError> { + let kl = key_arr.len(); + let vl = val_arr.len(); + if kl == vl { + return Ok((key_arr.clone(), val_arr.clone())); + } + if kl < vl { + let truncated = val_arr.slice(0, kl); + return Ok((key_arr.clone(), truncated)); + } + let to_add = kl - vl; + let appended = append_nulls(val_arr, to_add)?; + Ok((key_arr.clone(), appended)) } /// Decode an Avro array in blocks until a 0 block_count signals end. @@ -479,29 +595,27 @@ fn read_array_blocks( buf: &mut AvroCursor, mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { - let mut total_items = 0usize; + let mut total = 0usize; loop { - let block_count = buf.get_long()?; - match block_count { - 0 => break, - n if n < 0 => { - let item_count = (-n) as usize; - let _block_size = buf.get_long()?; - for _ in 0..item_count { - decode_item(buf)?; - } - total_items += item_count; + let blk = buf.get_long()?; + if blk == 0 { + break; + } else if blk < 0 { + let cnt = (-blk) as usize; + let _sz = buf.get_long()?; + for _i in 0..cnt { + decode_item(buf)?; } - n => { - let item_count = n as usize; - for _ in 0..item_count { - decode_item(buf)?; - } - total_items += item_count; + total += cnt; + } else { + let cnt = blk as usize; + for _i in 0..cnt { + decode_item(buf)?; } + total += cnt; } } - Ok(total_items) + Ok(total) } /// Decode an Avro map in blocks until 0 block_count signals end. @@ -509,44 +623,158 @@ fn read_map_blocks( buf: &mut AvroCursor, mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { - let mut total_entries = 0usize; + let mut total = 0usize; loop { - let block_count = buf.get_long()?; - match block_count { - 0 => break, - n if n < 0 => { - let item_count = (-n) as usize; - let _block_size = buf.get_long()?; - for _ in 0..item_count { - decode_entry(buf)?; - } - total_entries += item_count; + let blk = buf.get_long()?; + if blk == 0 { + break; + } else if blk < 0 { + let cnt = (-blk) as usize; + let _sz = buf.get_long()?; + for _i in 0..cnt { + decode_entry(buf)?; } - n => { - let item_count = n as usize; - for _ in 0..item_count { - decode_entry(buf)?; - } - total_entries += item_count; + total += cnt; + } else { + let cnt = blk as usize; + for _i in 0..cnt { + decode_entry(buf)?; } + total += cnt; } } - Ok(total_entries) + Ok(total) } fn flush_primitive( - values: &mut Vec, - nulls: Option, + vals: &mut Vec, + nb: Option, ) -> PrimitiveArray { - PrimitiveArray::new(flush_values(values).into(), nulls) + let arr = PrimitiveArray::new(std::mem::replace(vals, Vec::new()).into(), nb); + arr +} + +fn flush_offsets(ob: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(ob, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +fn flush_values(vec: &mut Vec) -> Vec { + std::mem::replace(vec, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +fn append_nulls(arr: &Arc, count: usize) -> Result, ArrowError> { + use arrow_data::transform::MutableArrayData; + + let d = arr.to_data(); + let mut mad = MutableArrayData::new(vec![&d], false, 0); + mad.extend(0, 0, arr.len()); + mad.extend_nulls(count); + let out = mad.freeze(); + let arr2 = make_array(out); + sanitize_array_offsets(arr2) +} + +fn sanitize_offsets_vec(offsets: &[i32], child_len: i32) -> Vec { + let mut new_offsets = Vec::with_capacity(offsets.len()); + let mut prev = 0; + for &offset in offsets { + // Clamp each offset between the previous value and the child length. + let clamped = offset.clamp(prev, child_len); + new_offsets.push(clamped); + if clamped > prev { + prev = clamped; + } + } + new_offsets +} + +fn sanitize_offsets_array( + original_data: &ArrayData, + child: Arc, + offsets: &[i32], +) -> Result { + let child_san = sanitize_array_offsets(child)?; + let child_len = child_san.len() as i32; + let new_offsets = sanitize_offsets_vec(offsets, child_len); + let final_len = new_offsets.len() - 1; + let mut new_data = original_data.clone(); + let mut bufs = new_data.buffers().to_vec(); + bufs[0] = Buffer::from_slice_ref(&new_offsets); + new_data = new_data + .into_builder() + .len(final_len) + .buffers(bufs) + .child_data(vec![child_san.to_data()]) + .build()?; + Ok(new_data) } -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +fn sanitize_struct_child( + array: Arc, + target_len: usize, +) -> Result { + let sanitized = sanitize_array_offsets(array)?; + let sanitized_len = sanitized.len(); + if sanitized_len == target_len { + Ok(sanitized.to_data()) + } else if sanitized_len < target_len { + let to_add = target_len - sanitized_len; + let appended = append_nulls(&sanitized, to_add)?; + Ok(appended.to_data()) + } else { + let sliced = sanitized.slice(0, target_len); + Ok(sliced.to_data()) + } } -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +/// Recursively sanitizes the offsets for arrays of List, Map, and Struct types. +fn sanitize_array_offsets(array: Arc) -> Result, ArrowError> { + match array.data_type() { + DataType::List(_item) => { + let list_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ParseError("Downcast to ListArray".into()))?; + let child = Arc::new(list_arr.values().clone()) as Arc; + let new_data = + sanitize_offsets_array(&list_arr.to_data(), child, list_arr.value_offsets())?; + Ok(make_array(new_data)) + } + DataType::Map(_field, _keys_sorted) => { + let map_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ParseError("Downcast to MapArray".into()))?; + let child = Arc::new(map_arr.entries().clone()) as Arc; + let new_data = + sanitize_offsets_array(&map_arr.to_data(), child, map_arr.value_offsets())?; + Ok(make_array(new_data)) + } + DataType::Struct(_fs) => { + let struct_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ParseError("Downcast to StructArray".into()))?; + let length = struct_arr.len(); + + let new_child_data = struct_arr + .columns() + .iter() + .map(|col| { + let col_arc = Arc::new(col.clone()) as Arc; + sanitize_struct_child(col_arc, length) + }) + .collect::, _>>()?; + let new_data = struct_arr + .to_data() + .clone() + .into_builder() + .child_data(new_child_data) + .build()?; + Ok(make_array(new_data)) + } + _ => Ok(array), + } } /// A builder for Avro decimal, either 128-bit or 256-bit. @@ -562,50 +790,49 @@ impl DecimalBuilder { scale: Option, size: Option, ) -> Result { - match size { - Some(s) if s > 16 && s <= 32 => Ok(Self::Decimal256( - Decimal256Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, - )), - Some(s) if s <= 16 => Ok(Self::Decimal128( - Decimal128Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, - )), - None => { - if precision <= DECIMAL128_MAX_PRECISION as usize { - Ok(Self::Decimal128( - Decimal128Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, - )) - } else if precision <= DECIMAL256_MAX_PRECISION as usize { - Ok(Self::Decimal256( - Decimal256Builder::new() - .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, - )) - } else { - Err(ArrowError::ParseError(format!( - "Decimal precision {} exceeds maximum supported", - precision - ))) - } + let prec = precision as u8; + let scl = scale.unwrap_or(0) as i8; + if let Some(s) = size { + if s <= 16 { + return Ok(Self::Decimal128( + Decimal128Builder::new().with_precision_and_scale(prec, scl)?, + )); } - _ => Err(ArrowError::ParseError(format!( - "Unsupported decimal size: {:?}", - size - ))), + if s <= 32 { + return Ok(Self::Decimal256( + Decimal256Builder::new().with_precision_and_scale(prec, scl)?, + )); + } + return Err(ArrowError::ParseError(format!( + "Unsupported decimal size: {s:?}" + ))); + } + if precision <= DECIMAL128_MAX_PRECISION as usize { + Ok(Self::Decimal128( + Decimal128Builder::new().with_precision_and_scale(prec, scl)?, + )) + } else if precision <= DECIMAL256_MAX_PRECISION as usize { + Ok(Self::Decimal256( + Decimal256Builder::new().with_precision_and_scale(prec, scl)?, + )) + } else { + Err(ArrowError::ParseError(format!( + "Decimal precision {} exceeds maximum supported", + precision + ))) } } fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { match self { Self::Decimal128(b) => { - let padded = sign_extend_to_16(raw)?; - let val = i128::from_be_bytes(padded); + let ext = sign_extend_to_16(raw)?; + let val = i128::from_be_bytes(ext); b.append_value(val); } Self::Decimal256(b) => { - let padded = sign_extend_to_32(raw)?; - let val = i256::from_be_bytes(padded); + let ext = sign_extend_to_32(raw)?; + let val = i256::from_be_bytes(ext); b.append_value(val); } } @@ -628,22 +855,22 @@ impl DecimalBuilder { fn finish( self, - nulls: Option, + nb: Option, precision: usize, scale: usize, - ) -> Result { + ) -> Result, ArrowError> { match self { Self::Decimal128(mut b) => { let arr = b.finish(); let vals = arr.values().clone(); - let dec = Decimal128Array::new(vals, nulls) + let dec = Decimal128Array::new(vals, nb) .with_precision_and_scale(precision as u8, scale as i8)?; Ok(Arc::new(dec)) } Self::Decimal256(mut b) => { let arr = b.finish(); let vals = arr.values().clone(); - let dec = Decimal256Array::new(vals, nulls) + let dec = Decimal256Array::new(vals, nb) .with_precision_and_scale(precision as u8, scale as i8)?; Ok(Arc::new(dec)) } @@ -652,28 +879,28 @@ impl DecimalBuilder { } fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { - let extended = sign_extend(raw, 16); - if extended.len() != 16 { + let ext = sign_extend(raw, 16); + if ext.len() != 16 { return Err(ArrowError::ParseError(format!( "Failed to extend to 16 bytes, got {} bytes", - extended.len() + ext.len() ))); } let mut arr = [0u8; 16]; - arr.copy_from_slice(&extended); + arr.copy_from_slice(&ext); Ok(arr) } fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { - let extended = sign_extend(raw, 32); - if extended.len() != 32 { + let ext = sign_extend(raw, 32); + if ext.len() != 32 { return Err(ArrowError::ParseError(format!( "Failed to extend to 32 bytes, got {} bytes", - extended.len() + ext.len() ))); } let mut arr = [0u8; 32]; - arr.copy_from_slice(&extended); + arr.copy_from_slice(&ext); Ok(arr) } @@ -695,10 +922,9 @@ fn sign_extend(raw: &[u8], target_len: usize) -> Vec { #[cfg(test)] mod tests { use super::*; - use arrow_array::{ - cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, - IntervalMonthDayNanoArray, ListArray, MapArray, - }; + use crate::codec::AvroField; + use crate::schema::Schema; + use arrow_array::{cast::AsArray, Array, ListArray, MapArray, StructArray}; use std::sync::Arc; fn encode_avro_int(value: i32) -> Vec { @@ -724,9 +950,491 @@ mod tests { } fn encode_avro_bytes(bytes: &[u8]) -> Vec { - let mut buf = encode_avro_long(bytes.len() as i64); - buf.extend_from_slice(bytes); - buf + let mut out = encode_avro_long(bytes.len() as i64); + out.extend_from_slice(bytes); + out + } + + fn encode_union_branch(branch_idx: i32) -> Vec { + encode_avro_int(branch_idx) + } + + fn encode_array(items: &[T], mut encode_item: impl FnMut(&T) -> Vec) -> Vec { + let mut out = Vec::new(); + if !items.is_empty() { + out.extend_from_slice(&encode_avro_long(items.len() as i64)); + for it in items { + out.extend_from_slice(&encode_item(it)); + } + } + out.extend_from_slice(&encode_avro_long(0)); + out + } + + fn encode_map(entries: &[(&str, Vec)]) -> Vec { + let mut out = Vec::new(); + if !entries.is_empty() { + out.extend_from_slice(&encode_avro_long(entries.len() as i64)); + for (k, val) in entries { + out.extend_from_slice(&encode_avro_bytes(k.as_bytes())); + out.extend_from_slice(val); + } + } + out.extend_from_slice(&encode_avro_long(0)); + out + } + + #[test] + fn test_union_primitive_long_null_record_decoder() { + let json_schema = r#" + { + "type": "record", + "name": "topLevelRecord", + "fields": [ + { + "name": "id", + "type": ["long","null"] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_union_branch(1)); + let used = record_decoder.decode(&data, 2).unwrap(); + assert_eq!(used, data.len()); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 2); + let arr = batch.column(0).as_primitive::(); + assert_eq!(arr.value(0), 1); + assert!(arr.is_null(1)); + } + + #[test] + fn test_union_array_of_int_null_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"int_array", + "type":[ + { + "type":"array", + "items":[ "int", "null" ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut data = Vec::new(); + + fn encode_int_or_null(opt_val: &Option) -> Vec { + match opt_val { + Some(v) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*v)); + out + } + None => encode_union_branch(1), + } + } + + data.extend_from_slice(&encode_union_branch(0)); + let row1_values = vec![Some(1), Some(2), Some(3)]; + data.extend_from_slice(&encode_array(&row1_values, encode_int_or_null)); + data.extend_from_slice(&encode_union_branch(0)); + let row2_values = vec![None, Some(1), Some(2), None, Some(3), None]; + data.extend_from_slice(&encode_array(&row2_values, encode_int_or_null)); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_avro_long(0)); // block_count=0 => end immediately + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 4).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 4); + let list_arr = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(list_arr.is_null(3)); + { + let start = list_arr.value_offsets()[0] as usize; + let end = list_arr.value_offsets()[1] as usize; + let child = list_arr.values().as_primitive::(); + assert_eq!(end - start, 3); + assert_eq!(child.value(start), 1); + assert_eq!(child.value(start + 1), 2); + assert_eq!(child.value(start + 2), 3); + } + { + let start = list_arr.value_offsets()[1] as usize; + let end = list_arr.value_offsets()[2] as usize; + let child = list_arr.values().as_primitive::(); + assert_eq!(end - start, 6); + // index-by-index + assert!(child.is_null(start)); // None + assert_eq!(child.value(start + 1), 1); // Some(1) + assert_eq!(child.value(start + 2), 2); + assert!(child.is_null(start + 3)); + assert_eq!(child.value(start + 4), 3); + assert!(child.is_null(start + 5)); + } + { + let start = list_arr.value_offsets()[2] as usize; + let end = list_arr.value_offsets()[3] as usize; + assert_eq!(end - start, 0); + } + } + + #[test] + fn test_union_nested_array_of_int_null_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"int_array_Array", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":[ + "int", + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut data = Vec::new(); + + fn encode_inner(vals: &[Option]) -> Vec { + encode_array(vals, |o| match o { + Some(v) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*v)); + out + } + None => encode_union_branch(1), + }) + } + + data.extend_from_slice(&encode_union_branch(0)); + { + let outer_vals: Vec>>> = + vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), None])]; + data.extend_from_slice(&encode_array(&outer_vals, |maybe_arr| match maybe_arr { + Some(vlist) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_inner(vlist)); + out + } + None => encode_union_branch(1), + })); + } + data.extend_from_slice(&encode_union_branch(0)); + { + let outer_vals: Vec>>> = vec![None]; + data.extend_from_slice(&encode_array(&outer_vals, |maybe_arr| match maybe_arr { + Some(vlist) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_inner(vlist)); + out + } + None => encode_union_branch(1), + })); + } + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let outer_list = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(outer_list.is_null(2)); + assert!(!outer_list.is_null(0)); + let start = outer_list.value_offsets()[0] as usize; + let end = outer_list.value_offsets()[1] as usize; + assert_eq!(end - start, 2); + let start2 = outer_list.value_offsets()[1] as usize; + let end2 = outer_list.value_offsets()[2] as usize; + assert_eq!(end2 - start2, 1); + let subitem_arr = outer_list.value(1); + let sub_list = subitem_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(sub_list.len(), 1); + assert!(sub_list.is_null(0)); + } + + #[test] + fn test_union_map_of_int_null_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"int_map", + "type":[ + { + "type":"map", + "values":[ + "int", + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_union_branch(0)); + let row1_map = vec![ + ("k1", { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(1)); + out + }), + ("k2", { encode_union_branch(1) }), + ]; + data.extend_from_slice(&encode_map(&row1_map)); + data.extend_from_slice(&encode_union_branch(0)); + let empty: [(&str, Vec); 0] = []; + data.extend_from_slice(&encode_map(&empty)); + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let map_arr = batch.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 3); + assert!(map_arr.is_null(2)); + assert_eq!(map_arr.value_length(0), 2); + let binding = map_arr.value(0); + let struct_arr = binding.as_any().downcast_ref::().unwrap(); + let keys = struct_arr.column(0).as_string::(); + let vals = struct_arr.column(1).as_primitive::(); + assert_eq!(keys.value(0), "k1"); + assert_eq!(vals.value(0), 1); + assert_eq!(keys.value(1), "k2"); + assert!(vals.is_null(1)); + assert_eq!(map_arr.value_length(1), 0); + } + + #[test] + fn test_union_map_array_of_int_null_record_decoder() { + let json_schema = r#" + { + "type": "record", + "name": "topLevelRecord", + "fields": [ + { + "name": "int_Map_Array", + "type": [ + { + "type": "array", + "items": [ + { + "type": "map", + "values": [ + "int", + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut data = Vec::new(); + fn encode_map_int_null(entries: &[(&str, Option)]) -> Vec { + let items: Vec<(&str, Vec)> = entries + .iter() + .map(|(k, v)| { + let val = match v { + Some(x) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*x)); + out + } + None => encode_union_branch(1), + }; + (*k, val) + }) + .collect(); + encode_map(&items) + } + data.extend_from_slice(&encode_union_branch(0)); + { + let mut arr_buf = encode_avro_long(1); + { + let mut item_buf = encode_union_branch(0); + item_buf.extend_from_slice(&encode_map_int_null(&[("k1", Some(1))])); + arr_buf.extend_from_slice(&item_buf); + } + arr_buf.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&arr_buf); + } + data.extend_from_slice(&encode_union_branch(0)); + { + let mut arr_buf = encode_avro_long(2); // 2 items + arr_buf.extend_from_slice(&encode_union_branch(1)); + { + let mut item1 = encode_union_branch(0); + item1.extend_from_slice(&encode_map_int_null(&[("k2", None)])); + arr_buf.extend_from_slice(&item1); + } + arr_buf.extend_from_slice(&encode_avro_long(0)); // end + data.extend_from_slice(&arr_buf); + } + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let outer_list = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(outer_list.is_null(2)); + { + let start = outer_list.value_offsets()[0] as usize; + let end = outer_list.value_offsets()[1] as usize; + assert_eq!(end - start, 1); + let subarr = outer_list.value(0); + let sublist = subarr.as_any().downcast_ref::().unwrap(); + assert_eq!(sublist.len(), 1); + assert!(!sublist.is_null(0)); + let sub_value_0 = sublist.value(0); + let struct_arr = sub_value_0.as_any().downcast_ref::().unwrap(); + let keys = struct_arr.column(0).as_string::(); + let vals = struct_arr.column(1).as_primitive::(); + assert_eq!(keys.value(0), "k1"); + assert_eq!(vals.value(0), 1); + } + } + + #[test] + fn test_union_nested_struct_out_of_spec_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"nested_struct", + "type":[ + { + "type":"record", + "name":"nested_struct", + "namespace":"topLevelRecord", + "fields":[ + { + "name":"A", + "type":[ + "int", + "null" + ] + }, + { + "name":"b", + "type":[ + { + "type":"array", + "items":[ + "int", + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_avro_int(7)); + data.extend_from_slice(&encode_union_branch(0)); + let row1_b = [Some(1), Some(2)]; + data.extend_from_slice(&encode_array(&row1_b, |val| match val { + Some(x) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*x)); + out + } + None => encode_union_branch(1), + })); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_union_branch(1)); + data.extend_from_slice(&encode_union_branch(1)); + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col.is_null(2)); + let field_a = col.column(0).as_primitive::(); + let field_b = col.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(field_a.value(0), 7); + { + let start = field_b.value_offsets()[0] as usize; + let end = field_b.value_offsets()[1] as usize; + let values = field_b.values().as_primitive::(); + assert_eq!(end - start, 2); + assert_eq!(values.value(start), 1); + assert_eq!(values.value(start + 1), 2); + } + assert!(field_a.is_null(1)); + assert!(field_b.is_null(1)); } #[test] @@ -754,10 +1462,8 @@ mod tests { #[test] fn test_fixed_decoding() { - // `fixed(4)` => Arrow FixedSizeBinary(4) let dt = AvroDataType::from_codec(Codec::Fixed(4)); let mut dec = Decoder::try_new(&dt).unwrap(); - // 2 rows, each row => 4 bytes let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; let row2 = [0x01, 0x23, 0x45, 0x67]; let mut data = Vec::new(); @@ -779,7 +1485,7 @@ mod tests { let dt = AvroDataType::from_codec(Codec::Fixed(2)); let child = Decoder::try_new(&dt).unwrap(); let mut dec = Decoder::Nullable( - Nullability::NullSecond, + UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(child), ); @@ -844,11 +1550,10 @@ mod tests { #[test] fn test_interval_decoding_with_nulls() { - // Avro union => [ interval, null ] let dt = AvroDataType::from_codec(Codec::Duration); let child = Decoder::try_new(&dt).unwrap(); let mut dec = Decoder::Nullable( - Nullability::NullSecond, + UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(child), ); @@ -858,9 +1563,9 @@ mod tests { 0xF4, 0x01, 0x00, 0x00, // ms=500 ]; let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); // branch=0: non-null + data.extend_from_slice(&encode_avro_int(0)); data.extend_from_slice(&row1); - data.extend_from_slice(&encode_avro_int(1)); // branch=1: null + data.extend_from_slice(&encode_avro_int(1)); let mut cursor = AvroCursor::new(&data); dec.decode(&mut cursor).unwrap(); // Row1 dec.decode(&mut cursor).unwrap(); // Row2 (null) @@ -914,7 +1619,7 @@ mod tests { let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); let mut nullable_decoder = Decoder::Nullable( - Nullability::NullSecond, + UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(inner_decoder), ); @@ -1008,43 +1713,39 @@ mod tests { #[test] fn test_decimal_decoding_bytes_with_nulls() { - // Avro union => [ Decimal(4,1), null ] let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); let mut inner = Decoder::try_new(&dt).unwrap(); let mut decoder = Decoder::Nullable( - Nullability::NullSecond, + UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(inner), ); - 'data_clear: { - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); // branch=0 => non-null - data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); // child's value: 1234 => "123.4" - data.extend_from_slice(&encode_avro_int(1)); // branch=1 => null - data.extend_from_slice(&encode_avro_int(0)); // branch=0 => non-null - data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); // child's value: -1234 => "-123.4" - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(2), "-123.4"); - } + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); // row1 + decoder.decode(&mut cursor).unwrap(); // row2 + decoder.decode(&mut cursor).unwrap(); // row3 + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); } #[test] fn test_decimal_decoding_bytes_with_nulls_fixed_size() { - // Avro union => [Decimal(6,2,16), null] let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); let mut inner = Decoder::try_new(&dt).unwrap(); let mut decoder = Decoder::Nullable( - Nullability::NullSecond, + UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(inner), ); @@ -1086,8 +1787,7 @@ mod tests { row1.extend_from_slice(&encode_avro_int(10)); row1.extend_from_slice(&encode_avro_int(20)); row1.extend_from_slice(&encode_avro_long(0)); - let mut row2 = Vec::new(); - row2.extend_from_slice(&encode_avro_long(0)); + let mut row2 = encode_avro_long(0); let mut cursor = AvroCursor::new(&row1); decoder.decode(&mut cursor).unwrap(); let mut cursor2 = AvroCursor::new(&row2); diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 174a28fba62d..b05743887792 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -16,7 +16,7 @@ // under the License. use crate::codec::Nullability; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; /// The metadata key used for storing the JSON encoded [`Schema`] @@ -137,6 +137,8 @@ pub struct Record<'a> { } /// A field within a [`Record`] +/// +/// **Modified** to preserve any `"default": null` even in out-of-spec union ordering. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RecordField<'a> { #[serde(borrow)] @@ -147,10 +149,26 @@ pub struct RecordField<'a> { pub aliases: Vec<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, - #[serde(default, skip_serializing_if = "Option::is_none")] + #[serde( + default, + skip_serializing_if = "Option::is_none", + deserialize_with = "allow_out_of_spec_default" + )] pub default: Option, } +/// Custom parse logic that stores *any* default as raw JSON +/// (including "null" for non-null-first unions). +fn allow_out_of_spec_default<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + match serde_json::Value::deserialize(deserializer) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } +} + /// An enumeration /// /// @@ -613,4 +631,156 @@ mod tests { panic!("Expected record schema"); } } + + #[test] + fn test_union_int_null_with_default_null() { + let json_schema = r#" + { + "type": "record", + "name": "ImpalaNullableRecord", + "fields": [ + {"name": "i", "type": ["int","null"], "default": null} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + if let Schema::Complex(ComplexType::Record(rec)) = schema { + assert_eq!(rec.fields.len(), 1); + assert_eq!(rec.fields[0].name, "i"); + assert_eq!(rec.fields[0].default, Some(json!(null))); + let field_codec = + AvroField::try_from(&Schema::Complex(ComplexType::Record(rec))).unwrap(); + use arrow_schema::{DataType, Field, Fields}; + assert_eq!( + field_codec.field(), + Field::new( + "ImpalaNullableRecord", + DataType::Struct(Fields::from(vec![Field::new("i", DataType::Int32, true),])), + false + ) + ); + } else { + panic!("Expected record schema with union int|null, default null"); + } + } + + #[test] + fn test_union_impala_null_with_default_null() { + let json_schema = r#" + { + "type":"record","name":"topLevelRecord","fields":[ + {"name":"id","type":["long","null"]}, + {"name":"int_array","type":[{"type":"array","items":["int","null"]},"null"]}, + {"name":"int_array_Array","type":[{"type":"array","items":[{"type":"array","items":["int","null"]},"null"]},"null"]}, + {"name":"int_map","type":[{"type":"map","values":["int","null"]},"null"]}, + {"name":"int_Map_Array","type":[{"type":"array","items":[{"type":"map","values":["int","null"]},"null"]},"null"]}, + { + "name":"nested_struct", + "type":[ + { + "type":"record", + "name":"nested_struct", + "namespace":"topLevelRecord", + "fields":[ + {"name":"A","type":["int","null"]}, + {"name":"b","type":[{"type":"array","items":["int","null"]},"null"]}, + { + "name":"C", + "type":[ + { + "type":"record", + "name":"C", + "namespace":"topLevelRecord.nested_struct", + "fields":[ + { + "name":"d", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":[ + { + "type":"record", + "name":"d", + "namespace":"topLevelRecord.nested_struct.C", + "fields":[ + {"name":"E","type":["int","null"]}, + {"name":"F","type":["string","null"]} + ] + }, + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + }, + { + "name":"g", + "type":[ + { + "type":"map", + "values":[ + { + "type":"record", + "name":"g", + "namespace":"topLevelRecord.nested_struct", + "fields":[ + { + "name":"H", + "type":[ + { + "type":"record", + "name":"H", + "namespace":"topLevelRecord.nested_struct.g", + "fields":[ + { + "name":"i", + "type":[ + { + "type":"array", + "items":["double","null"] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + }, + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + if let Schema::Complex(ComplexType::Record(rec)) = &schema { + assert_eq!(rec.name, "topLevelRecord"); + assert_eq!(rec.fields.len(), 6); + let _field_codec = AvroField::try_from(&schema).unwrap(); + } else { + panic!("Expected top-level record schema"); + } + } } From 332b3fb302702953436109b12b3b180c8686bc6a Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Thu, 6 Feb 2025 19:11:22 -0600 Subject: [PATCH 28/38] linter Signed-off-by: Connor Sanders --- arrow-avro/src/reader/record.rs | 171 +++++++++++++++++--------------- 1 file changed, 93 insertions(+), 78 deletions(-) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index cb9855493d18..db7b1ad2a7db 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -26,6 +26,7 @@ use arrow_schema::{ ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; +use std::cmp::Ordering; use std::io::Read; use std::sync::Arc; @@ -424,7 +425,7 @@ impl Decoder { let dict_vals = StringArray::from_iter_values(symbols.iter()); let i32arr = match nulls { Some(nb) => { - let buff = Buffer::from_slice_ref(&idxs); + let buff = Buffer::from_slice_ref(&*idxs); PrimitiveArray::::try_new( arrow_buffer::ScalarBuffer::from(buff), Some(nb), @@ -524,51 +525,58 @@ impl Decoder { } } +type FlushResult = (Vec>, Option); + fn flush_record_children( mut kids: Vec>, parent_nulls: Option, -) -> Result<(Vec>, Option), ArrowError> { +) -> Result { let max_len = kids.iter().map(|c| c.len()).max().unwrap_or(0); + let fixed_parent_nulls = match parent_nulls { None => None, Some(nb) => { let old_len = nb.len(); - if old_len == max_len { - Some(nb) - } else if old_len < max_len { - let mut b = NullBufferBuilder::new(max_len); - for i in 0..old_len { - b.append(nb.is_valid(i)); - } - for _ in 0..(max_len - old_len) { - b.append(false); + match old_len.cmp(&max_len) { + Ordering::Equal => Some(nb), + Ordering::Less => { + let mut b = NullBufferBuilder::new(max_len); + for i in 0..old_len { + b.append(nb.is_valid(i)); + } + for _ in old_len..max_len { + b.append(false); + } + b.finish() } - b.finish() - } else { - // truncate - let mut b = NullBufferBuilder::new(max_len); - for i in 0..max_len { - b.append(nb.is_valid(i)); + Ordering::Greater => { + let mut b = NullBufferBuilder::new(max_len); + for i in 0..max_len { + b.append(nb.is_valid(i)); + } + b.finish() } - b.finish() } } }; + let mut out = Vec::with_capacity(kids.len()); for arr in kids { let cur_len = arr.len(); - if cur_len == max_len { - out.push(arr); - } else if cur_len < max_len { - let to_add = max_len - cur_len; - let appended = append_nulls(&arr, to_add)?; - out.push(appended); - } else { - // slice - let sliced = arr.slice(0, max_len); - out.push(sliced); + match cur_len.cmp(&max_len) { + Ordering::Equal => out.push(arr), + Ordering::Less => { + let to_add = max_len - cur_len; + let appended = append_nulls(&arr, to_add)?; + out.push(appended); + } + Ordering::Greater => { + let sliced = arr.slice(0, max_len); + out.push(sliced); + } } } + Ok((out, fixed_parent_nulls)) } @@ -578,16 +586,18 @@ fn flush_map_children( ) -> Result<(StringArray, Arc), ArrowError> { let kl = key_arr.len(); let vl = val_arr.len(); - if kl == vl { - return Ok((key_arr.clone(), val_arr.clone())); - } - if kl < vl { - let truncated = val_arr.slice(0, kl); - return Ok((key_arr.clone(), truncated)); + match kl.cmp(&vl) { + Ordering::Equal => Ok((key_arr.clone(), val_arr.clone())), + Ordering::Less => { + let truncated = val_arr.slice(0, kl); + Ok((key_arr.clone(), truncated)) + } + Ordering::Greater => { + let to_add = kl - vl; + let appended = append_nulls(val_arr, to_add)?; + Ok((key_arr.clone(), appended)) + } } - let to_add = kl - vl; - let appended = append_nulls(val_arr, to_add)?; - Ok((key_arr.clone(), appended)) } /// Decode an Avro array in blocks until a 0 block_count signals end. @@ -598,21 +608,23 @@ fn read_array_blocks( let mut total = 0usize; loop { let blk = buf.get_long()?; - if blk == 0 { - break; - } else if blk < 0 { - let cnt = (-blk) as usize; - let _sz = buf.get_long()?; - for _i in 0..cnt { - decode_item(buf)?; - } - total += cnt; - } else { - let cnt = blk as usize; - for _i in 0..cnt { - decode_item(buf)?; + match blk.cmp(&0) { + Ordering::Equal => break, + Ordering::Less => { + let cnt = (-blk) as usize; + let _sz = buf.get_long()?; + for _i in 0..cnt { + decode_item(buf)?; + } + total += cnt; + } + Ordering::Greater => { + let cnt = blk as usize; + for _i in 0..cnt { + decode_item(buf)?; + } + total += cnt; } - total += cnt; } } Ok(total) @@ -626,21 +638,23 @@ fn read_map_blocks( let mut total = 0usize; loop { let blk = buf.get_long()?; - if blk == 0 { - break; - } else if blk < 0 { - let cnt = (-blk) as usize; - let _sz = buf.get_long()?; - for _i in 0..cnt { - decode_entry(buf)?; - } - total += cnt; - } else { - let cnt = blk as usize; - for _i in 0..cnt { - decode_entry(buf)?; + match blk.cmp(&0) { + Ordering::Equal => break, + Ordering::Less => { + let cnt = (-blk) as usize; + let _sz = buf.get_long()?; + for _i in 0..cnt { + decode_entry(buf)?; + } + total += cnt; + } + Ordering::Greater => { + let cnt = blk as usize; + for _i in 0..cnt { + decode_entry(buf)?; + } + total += cnt; } - total += cnt; } } Ok(total) @@ -650,8 +664,7 @@ fn flush_primitive( vals: &mut Vec, nb: Option, ) -> PrimitiveArray { - let arr = PrimitiveArray::new(std::mem::replace(vals, Vec::new()).into(), nb); - arr + PrimitiveArray::new(std::mem::take(vals).into(), nb) } fn flush_offsets(ob: &mut OffsetBufferBuilder) -> OffsetBuffer { @@ -715,15 +728,17 @@ fn sanitize_struct_child( ) -> Result { let sanitized = sanitize_array_offsets(array)?; let sanitized_len = sanitized.len(); - if sanitized_len == target_len { - Ok(sanitized.to_data()) - } else if sanitized_len < target_len { - let to_add = target_len - sanitized_len; - let appended = append_nulls(&sanitized, to_add)?; - Ok(appended.to_data()) - } else { - let sliced = sanitized.slice(0, target_len); - Ok(sliced.to_data()) + match sanitized_len.cmp(&target_len) { + Ordering::Equal => Ok(sanitized.to_data()), + Ordering::Less => { + let to_add = target_len - sanitized_len; + let appended = append_nulls(&sanitized, to_add)?; + Ok(appended.to_data()) + } + Ordering::Greater => { + let sliced = sanitized.slice(0, target_len); + Ok(sliced.to_data()) + } } } @@ -1224,7 +1239,7 @@ mod tests { out.extend_from_slice(&encode_avro_int(1)); out }), - ("k2", { encode_union_branch(1) }), + ("k2", encode_union_branch(1)), ]; data.extend_from_slice(&encode_map(&row1_map)); data.extend_from_slice(&encode_union_branch(0)); @@ -1787,7 +1802,7 @@ mod tests { row1.extend_from_slice(&encode_avro_int(10)); row1.extend_from_slice(&encode_avro_int(20)); row1.extend_from_slice(&encode_avro_long(0)); - let mut row2 = encode_avro_long(0); + let row2 = encode_avro_long(0); let mut cursor = AvroCursor::new(&row1); decoder.decode(&mut cursor).unwrap(); let mut cursor2 = AvroCursor::new(&row2); From fadbf5ebd3f7eb07af38c4c11c1dc2f606349bf7 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Fri, 7 Feb 2025 01:08:31 -0600 Subject: [PATCH 29/38] Added a `strict_model` parameter to the Avro `read_file` method for enforcement strictness of avro schema, i.e. Impala usecase. Signed-off-by: Connor Sanders --- arrow-avro/src/reader/mod.rs | 49 ++++++++++--------- arrow-avro/src/reader/record.rs | 85 +++++++++++++++++++-------------- 2 files changed, 75 insertions(+), 59 deletions(-) diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index ec4d260a706c..dad56535f8b6 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - //! Read Avro data to Arrow use crate::reader::block::{Block, BlockDecoder}; @@ -93,14 +92,17 @@ mod test { use std::io::BufReader; use std::sync::Arc; - fn read_file(file: &str, batch_size: usize) -> RecordBatch { + /// Helper to read an Avro file into a `RecordBatch`. + /// + /// - `strict_mode`: if `true`, we reject unions of the form `[T,"null"]`. + fn read_file(file: &str, batch_size: usize, strict_mode: bool) -> RecordBatch { let file = File::open(file).unwrap(); let mut reader = BufReader::new(file); let header = read_header(&mut reader).unwrap(); let compression = header.compression().unwrap(); let schema = header.schema().unwrap().unwrap(); let root = AvroField::try_from(&schema).unwrap(); - let mut decoder = RecordDecoder::try_new(root.data_type()).unwrap(); + let mut decoder = RecordDecoder::try_new(root.data_type(), strict_mode).unwrap(); for result in read_blocks(reader) { let block = result.unwrap(); assert_eq!(block.sync, header.sync()); @@ -216,8 +218,9 @@ mod test { .unwrap(); for file in files { let file = arrow_test_data(file); - assert_eq!(read_file(&file, 8), expected); - assert_eq!(read_file(&file, 3), expected); + // Pass `false` for strict_mode so we don't fail on out-of-spec unions + assert_eq!(read_file(&file, 8, false), expected); + assert_eq!(read_file(&file, 3, false), expected); } } @@ -281,13 +284,13 @@ mod test { ]) .unwrap(); let file_path = arrow_test_data(file); - let batch_large = read_file(&file_path, 8); + let batch_large = read_file(&file_path, 8, false); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match for file {}", file ); - let batch_small = read_file(&file_path, 3); + let batch_small = read_file(&file_path, 3, false); assert_eq!( batch_small, expected, "Decoded RecordBatch (batch size 3) does not match for file {}", @@ -333,13 +336,13 @@ mod test { ]) .unwrap(); let file_path = arrow_test_data(file); - let batch_large = read_file(&file_path, 8); + let batch_large = read_file(&file_path, 8, false); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match for file {}", file ); - let batch_small = read_file(&file_path, 3); + let batch_small = read_file(&file_path, 3, false); assert_eq!( batch_small, expected, "Decoded RecordBatch (batch size 3) does not match for file {}", @@ -350,7 +353,7 @@ mod test { #[test] fn test_binary() { let file = arrow_test_data("avro/binary.avro"); - let batch = read_file(&file, 8); + let batch = read_file(&file, 8, false); let expected = RecordBatch::try_from_iter_with_nullable([( "foo", Arc::new(BinaryArray::from_iter_values(vec![ @@ -384,7 +387,7 @@ mod test { let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); for (file, precision, scale) in files { let file_path = arrow_test_data(file); - let actual_batch = read_file(&file_path, 8); + let actual_batch = read_file(&file_path, 8, false); let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) .with_precision_and_scale(precision, scale) .unwrap(); @@ -402,7 +405,7 @@ mod test { "Decoded RecordBatch does not match the expected Decimal128 data for file {}", file ); - let actual_batch_small = read_file(&file_path, 3); + let actual_batch_small = read_file(&file_path, 3, false); assert_eq!( actual_batch_small, expected_batch, "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", @@ -414,7 +417,7 @@ mod test { #[test] fn test_datapage_v2() { let file = arrow_test_data("avro/datapage_v2.snappy.avro"); - let batch = read_file(&file, 8); + let batch = read_file(&file, 8, false); let a = StringArray::from(vec![ Some("abc"), Some("abc"), @@ -459,7 +462,7 @@ mod test { #[test] fn test_dict_pages_offset_zero() { let file = arrow_test_data("avro/dict-page-offset-zero.avro"); - let batch = read_file(&file, 32); + let batch = read_file(&file, 32, false); let num_rows = batch.num_rows(); let expected_field = Int32Array::from(vec![Some(1552); num_rows]); let expected = RecordBatch::try_from_iter_with_nullable([( @@ -529,7 +532,7 @@ mod test { ("utf8_list", Arc::new(utf8_list) as Arc, true), ]) .unwrap(); - let batch = read_file(&file, 8); + let batch = read_file(&file, 8, false); assert_eq!(batch, expected); } @@ -596,9 +599,9 @@ mod test { ("b", Arc::new(b_expected) as Arc, true), ]) .unwrap(); - let left = read_file(&file, 8); + let left = read_file(&file, 8, false); assert_eq!(left, expected, "Mismatch for batch size=8"); - let left_small = read_file(&file, 3); + let left_small = read_file(&file, 3, false); assert_eq!(left_small, expected, "Mismatch for batch size=3"); } @@ -746,12 +749,12 @@ mod test { ]) .unwrap(); let file = arrow_test_data("avro/nested_records.avro"); - let batch_large = read_file(&file, 8); + let batch_large = read_file(&file, 8, false); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match expected data for nested records (batch size 8)" ); - let batch_small = read_file(&file, 3); + let batch_small = read_file(&file, 3, false); assert_eq!( batch_small, expected, "Decoded RecordBatch does not match expected data for nested records (batch size 3)" @@ -1032,17 +1035,17 @@ mod test { ("nested_Struct", Arc::new(nested_struct), true), ]) .unwrap(); - let batch_large = read_file(&file, 8); + let batch_large = read_file(&file, 8, false); assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); - let batch_small = read_file(&file, 3); + let batch_small = read_file(&file, 3, false); assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); } #[test] fn test_nullable_impala() { let file = arrow_test_data("avro/nullable.impala.avro"); - let batch1 = read_file(&file, 3); - let batch2 = read_file(&file, 8); + let batch1 = read_file(&file, 3, false); + let batch2 = read_file(&file, 8, false); assert_eq!(batch1, batch2); let batch = batch1; assert_eq!(batch.num_rows(), 7); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index db7b1ad2a7db..f497571f2d3e 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -39,9 +39,12 @@ pub struct RecordDecoder { } impl RecordDecoder { - /// Create a new [`RecordDecoder`] from an [`AvroDataType`] expected to be a `Record`. - pub fn try_new(data_type: &AvroDataType) -> Result { - match Decoder::try_new(data_type)? { + /// Create a new [`RecordDecoder`] from an [`AvroDataType`] that must be a `Record`. + /// + /// - `strict_mode`: if `true`, we will throw an error if we encounter + /// a union of the form `[T, "null"]` (i.e. `Nullability::NullSecond`). + pub fn try_new(data_type: &AvroDataType, strict_mode: bool) -> Result { + match Decoder::try_new(data_type, strict_mode)? { Decoder::Record(fields, decoders) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), fields: decoders, @@ -134,7 +137,8 @@ enum Decoder { } impl Decoder { - fn try_new(data_type: &AvroDataType) -> Result { + fn try_new(data_type: &AvroDataType, strict_mode: bool) -> Result { + // 1) Create the "base" decoder for the underlying Avro codec let base = match &data_type.codec { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -154,7 +158,8 @@ impl Decoder { let mut fields = Vec::with_capacity(avro_fields.len()); let mut children = Vec::with_capacity(avro_fields.len()); for f in avro_fields.iter() { - let child = Self::try_new(f.data_type())?; + // Recursively build a Decoder for each child + let child = Self::try_new(f.data_type(), strict_mode)?; fields.push(f.field()); children.push(child); } @@ -164,7 +169,7 @@ impl Decoder { Self::Enum(Arc::clone(syms), Vec::with_capacity(DEFAULT_CAPACITY)) } Codec::Array(child) => { - let child_dec = Self::try_new(child)?; + let child_dec = Self::try_new(child, strict_mode)?; let item_field = child.field_with_name("item").with_nullable(true); Self::List( Arc::new(item_field), @@ -182,7 +187,7 @@ impl Decoder { ])), false, )); - let valdec = Self::try_new(child)?; + let valdec = Self::try_new(child, strict_mode)?; Self::Map( map_field, OffsetBufferBuilder::new(DEFAULT_CAPACITY), @@ -208,21 +213,28 @@ impl Decoder { } Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), }; - let union_order = match data_type.nullability { None => None, Some(Nullability::NullFirst) => Some(UnionOrder::NullFirst), - Some(Nullability::NullSecond) => Some(UnionOrder::NullSecond), + Some(Nullability::NullSecond) => { + if strict_mode { + return Err(ArrowError::ParseError( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode mode" + .to_string(), + )); + } + Some(UnionOrder::NullSecond) + } }; - - match union_order { - Some(order) => Ok(Self::Nullable( + let decoder = match union_order { + Some(order) => Decoder::Nullable( order, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(base), - )), - None => Ok(base), - } + ), + None => base, + }; + Ok(decoder) } fn append_null(&mut self) { @@ -328,6 +340,7 @@ impl Decoder { } } UnionOrder::NullSecond => { + // In out-of-spec files: branch=0 => decode T, branch=1 => null if branch == 0 { nb.append(true); child.decode(buf)?; @@ -691,7 +704,7 @@ fn sanitize_offsets_vec(offsets: &[i32], child_len: i32) -> Vec { let mut new_offsets = Vec::with_capacity(offsets.len()); let mut prev = 0; for &offset in offsets { - // Clamp each offset between the previous value and the child length. + // clamp each offset between the previous value and the child length let clamped = offset.clamp(prev, child_len); new_offsets.push(clamped); if clamped > prev { @@ -1015,7 +1028,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); let mut data = Vec::new(); data.extend_from_slice(&encode_union_branch(0)); data.extend_from_slice(&encode_avro_long(1)); @@ -1051,7 +1064,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); let mut data = Vec::new(); fn encode_int_or_null(opt_val: &Option) -> Vec { @@ -1143,7 +1156,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); let mut data = Vec::new(); fn encode_inner(vals: &[Option]) -> Vec { @@ -1230,7 +1243,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); let mut data = Vec::new(); data.extend_from_slice(&encode_union_branch(0)); let row1_map = vec![ @@ -1295,7 +1308,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); let mut data = Vec::new(); fn encode_map_int_null(entries: &[(&str, Option)]) -> Vec { let items: Vec<(&str, Vec)> = entries @@ -1409,7 +1422,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); let mut data = Vec::new(); data.extend_from_slice(&encode_union_branch(0)); data.extend_from_slice(&encode_union_branch(0)); @@ -1467,7 +1480,7 @@ mod tests { "#; let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_record = AvroField::try_from(&schema).unwrap(); - let record_decoder = RecordDecoder::try_new(avro_record.data_type()).unwrap(); + let record_decoder = RecordDecoder::try_new(avro_record.data_type(), true).unwrap(); let arrow_schema = record_decoder.schema(); assert_eq!(arrow_schema.fields().len(), 1); let field = arrow_schema.field(0); @@ -1478,7 +1491,7 @@ mod tests { #[test] fn test_fixed_decoding() { let dt = AvroDataType::from_codec(Codec::Fixed(4)); - let mut dec = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::try_new(&dt, true).unwrap(); let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; let row2 = [0x01, 0x23, 0x45, 0x67]; let mut data = Vec::new(); @@ -1498,7 +1511,7 @@ mod tests { #[test] fn test_fixed_with_nulls() { let dt = AvroDataType::from_codec(Codec::Fixed(2)); - let child = Decoder::try_new(&dt).unwrap(); + let child = Decoder::try_new(&dt, true).unwrap(); let mut dec = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1530,7 +1543,7 @@ mod tests { #[test] fn test_interval_decoding() { let dt = AvroDataType::from_codec(Codec::Duration); - let mut dec = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::try_new(&dt, true).unwrap(); let row1 = [ 0x01, 0x00, 0x00, 0x00, // months=1 0x02, 0x00, 0x00, 0x00, // days=2 @@ -1566,7 +1579,7 @@ mod tests { #[test] fn test_interval_decoding_with_nulls() { let dt = AvroDataType::from_codec(Codec::Duration); - let child = Decoder::try_new(&dt).unwrap(); + let child = Decoder::try_new(&dt, true).unwrap(); let mut dec = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1602,7 +1615,7 @@ mod tests { fn test_enum_decoding() { let symbols = Arc::new(["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]); let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols, Arc::new([]))); - let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + let mut decoder = Decoder::try_new(&enum_dt, true).unwrap(); let mut data = Vec::new(); data.extend_from_slice(&encode_avro_int(1)); data.extend_from_slice(&encode_avro_int(0)); @@ -1632,7 +1645,7 @@ mod tests { // Union => [Enum(...), null] let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); - let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); + let mut inner_decoder = Decoder::try_new(&enum_dt, true).unwrap(); let mut nullable_decoder = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1668,7 +1681,7 @@ mod tests { fn test_map_decoding_one_entry() { let value_type = AvroDataType::from_codec(Codec::String); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); - let mut decoder = Decoder::try_new(&map_type).unwrap(); + let mut decoder = Decoder::try_new(&map_type, true).unwrap(); let mut data = Vec::new(); data.extend_from_slice(&encode_avro_long(1)); data.extend_from_slice(&encode_avro_bytes(b"hello")); @@ -1692,7 +1705,7 @@ mod tests { fn test_map_decoding_empty() { let value_type = AvroDataType::from_codec(Codec::String); let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); - let mut decoder = Decoder::try_new(&map_type).unwrap(); + let mut decoder = Decoder::try_new(&map_type, true).unwrap(); let data = encode_avro_long(0); decoder.decode(&mut AvroCursor::new(&data)).unwrap(); let array = decoder.flush(None).unwrap(); @@ -1704,7 +1717,7 @@ mod tests { #[test] fn test_decimal_decoding_fixed128() { let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); - let mut decoder = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::try_new(&dt, true).unwrap(); let row1 = [ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x39, @@ -1729,7 +1742,7 @@ mod tests { #[test] fn test_decimal_decoding_bytes_with_nulls() { let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); - let mut inner = Decoder::try_new(&dt).unwrap(); + let mut inner = Decoder::try_new(&dt, true).unwrap(); let mut decoder = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1758,7 +1771,7 @@ mod tests { #[test] fn test_decimal_decoding_bytes_with_nulls_fixed_size() { let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); - let mut inner = Decoder::try_new(&dt).unwrap(); + let mut inner = Decoder::try_new(&dt, true).unwrap(); let mut decoder = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1796,7 +1809,7 @@ mod tests { fn test_list_decoding() { let item_dt = AvroDataType::from_codec(Codec::Int32); let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); - let mut decoder = Decoder::try_new(&list_dt).unwrap(); + let mut decoder = Decoder::try_new(&list_dt, true).unwrap(); let mut row1 = Vec::new(); row1.extend_from_slice(&encode_avro_long(2)); row1.extend_from_slice(&encode_avro_int(10)); @@ -1823,7 +1836,7 @@ mod tests { fn test_list_decoding_with_negative_block_count() { let item_dt = AvroDataType::from_codec(Codec::Int32); let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); - let mut decoder = Decoder::try_new(&list_dt).unwrap(); + let mut decoder = Decoder::try_new(&list_dt, true).unwrap(); let mut data = encode_avro_long(-3); data.extend_from_slice(&encode_avro_long(12)); data.extend_from_slice(&encode_avro_int(1)); From dbe8555d98e5e673a5883a56116857fe24286b93 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Fri, 7 Feb 2025 20:04:51 -0600 Subject: [PATCH 30/38] Added `test_nulls_snappy` and ``test_repeated_no_annotation` to `arrow-avro/src/reader/mod.rs` Signed-off-by: Connor Sanders --- arrow-avro/src/reader/mod.rs | 94 ++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index dad56535f8b6..28e87ceb3cca 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -1106,4 +1106,98 @@ mod test { ); assert_eq!(a_array.value(6), 7, "Mismatch in nested_struct.A at row 6"); } + + #[test] + fn test_nulls_snappy() { + let file = arrow_test_data("avro/nulls.snappy.avro"); + let batch_large = read_file(&file, 8, false); + use arrow_array::{Int32Array, StructArray}; + use arrow_buffer::Buffer; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields}; + let b_c_int = Int32Array::from(vec![None; 8]); + let b_c_int_data = b_c_int.into_data(); + let b_struct_field = Field::new("b_c_int", DataType::Int32, true); + let b_struct_type = DataType::Struct(Fields::from(vec![b_struct_field])); + let struct_validity = Buffer::from_iter((0..8).map(|_| true)); + let b_struct_data = ArrayDataBuilder::new(b_struct_type) + .len(8) + .null_bit_buffer(Some(struct_validity)) + .child_data(vec![b_c_int_data]) + .build() + .unwrap(); + let b_struct_array = StructArray::from(b_struct_data); + let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([( + "b_struct", + Arc::new(b_struct_array) as _, + true, + )]) + .unwrap(); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } + + #[test] + fn test_repeated_no_annotation() { + let file = arrow_test_data("avro/repeated_no_annotation.avro"); + let batch_large = read_file(&file, 8, false); + use arrow_array::{Int32Array, Int64Array, ListArray, StringArray, StructArray}; + use arrow_buffer::Buffer; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields}; + let id_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let number_array = Int64Array::from(vec![ + Some(5555555555), + Some(1111111111), + Some(1111111111), + Some(2222222222), + Some(3333333333), + ]); + let kind_array = + StringArray::from(vec![None, Some("home"), Some("home"), None, Some("mobile")]); + let phone_fields = Fields::from(vec![ + Field::new("number", DataType::Int64, true), + Field::new("kind", DataType::Utf8, true), + ]); + let phone_struct_data = ArrayDataBuilder::new(DataType::Struct(phone_fields)) + .len(5) // 5 phone entries total + .child_data(vec![number_array.into_data(), kind_array.into_data()]) + .build() + .unwrap(); + let phone_struct_array = StructArray::from(phone_struct_data); + let phone_list_offsets = Buffer::from_slice_ref([0, 0, 0, 0, 1, 2, 5]); + let phone_list_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_item_field = Field::new("item", phone_struct_array.data_type().clone(), true); + let phone_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(phone_item_field))) + .len(6) + .add_buffer(phone_list_offsets) + .null_bit_buffer(Some(phone_list_validity)) + .child_data(vec![phone_struct_array.into_data()]) + .build() + .unwrap(); + let phone_list_array = ListArray::from(phone_list_data); + let phone_numbers_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_numbers_field = Field::new("phone", phone_list_array.data_type().clone(), true); + let phone_numbers_struct_data = + ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![phone_numbers_field]))) + .len(6) + .null_bit_buffer(Some(phone_numbers_validity)) + .child_data(vec![phone_list_array.into_data()]) + .build() + .unwrap(); + let phone_numbers_struct_array = StructArray::from(phone_numbers_struct_data); + let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(id_array) as _, true), + ( + "phoneNumbers", + Arc::new(phone_numbers_struct_array) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } } From ec6d0ca3a47a621fe372fd3f73c17c1475f3c0e2 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 8 Feb 2025 00:52:02 -0600 Subject: [PATCH 31/38] Added remaining Avro `read_file` tests to `arrow-avro/src/reader/mod.rs` Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 109 +++++------------------------------ arrow-avro/src/reader/mod.rs | 106 +++++++++++++++++++++++++++++++--- 2 files changed, 112 insertions(+), 103 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index d173362c3a91..a2475087b8a3 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -142,6 +142,7 @@ pub enum Codec { String, /// Complex Record(Arc<[AvroField]>), + /// Changed from `Dictionary(Utf8, Int32)` to `Dictionary(Int32, Utf8)` Enum(Arc<[String]>, Arc<[i32]>), Array(Arc), Map(Arc), @@ -174,7 +175,7 @@ impl Codec { let arrow_fields: Vec = fields.iter().map(|f| f.field()).collect(); Struct(arrow_fields.into()) } - Self::Enum(_, _) => Dictionary(Box::new(Utf8), Box::new(Int32)), + Self::Enum(_, _) => Dictionary(Box::new(Int32), Box::new(Utf8)), Self::Array(child_type) => { let child_dt = child_type.codec.data_type(); let child_md = child_type.metadata.clone(); @@ -507,12 +508,13 @@ fn arrow_type_to_codec(dt: &DataType) -> Codec { .collect(); Codec::Record(Arc::from(avro_fields)) } - Dictionary(dict_ty, _val_ty) => { - if let Utf8 = &**dict_ty { - Codec::Enum(Arc::from(Vec::new()), Arc::from(Vec::new())) - } else { - Codec::String + Dictionary(dict_ty, val_ty) => { + if let Int32 = &**dict_ty { + if let Utf8 = &**val_ty { + return Codec::Enum(Arc::from(Vec::new()), Arc::from(Vec::new())); + } } + Codec::String } List(item_field) => { let item_codec = arrow_type_to_codec(item_field.data_type()); @@ -786,7 +788,7 @@ mod tests { let arrow_field = Field::new( "DictionaryEnum", - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)), + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -794,7 +796,7 @@ mod tests { let arrow_field = Field::new( "DictionaryString", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Boolean)), + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Boolean)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -1185,92 +1187,7 @@ mod tests { "namespace":"topLevelRecord", "fields":[ {"name":"A","type":["int","null"]}, - { - "name":"b", - "type":[{"type":"array","items":["int","null"]},"null"] - }, - { - "name":"C", - "type":[ - { - "type":"record", - "name":"C", - "namespace":"topLevelRecord.nested_struct", - "fields":[ - { - "name":"d", - "type":[ - { - "type":"array", - "items":[ - { - "type":"array", - "items":[ - { - "type":"record", - "name":"d", - "namespace":"topLevelRecord.nested_struct.C", - "fields":[ - {"name":"E","type":["int","null"]}, - {"name":"F","type":["string","null"]} - ] - }, - "null" - ] - }, - "null" - ] - }, - "null" - ] - } - ] - }, - "null" - ] - }, - { - "name":"g", - "type":[ - { - "type":"map", - "values":[ - { - "type":"record", - "name":"g", - "namespace":"topLevelRecord.nested_struct", - "fields":[ - { - "name":"H", - "type":[ - { - "type":"record", - "name":"H", - "namespace":"topLevelRecord.nested_struct.g", - "fields":[ - { - "name":"i", - "type":[ - { - "type":"array", - "items":["double","null"] - }, - "null" - ] - } - ] - }, - "null" - ] - } - ] - }, - "null" - ] - }, - "null" - ] - } + {"name":"b","type":[{"type":"array","items":["int","null"]},"null"]} ] }, "null" @@ -1287,7 +1204,7 @@ mod tests { let ns_dt = fields[0].data_type(); assert_eq!(ns_dt.nullability, Some(Nullability::NullSecond)); if let Codec::Record(nested_fields) = &ns_dt.codec { - assert_eq!(nested_fields.len(), 4); + assert_eq!(nested_fields.len(), 2); let field_a_dt = nested_fields[0].data_type(); assert_eq!(field_a_dt.nullability, Some(Nullability::NullSecond)); assert!(matches!(field_a_dt.codec, Codec::Int32)); @@ -1305,7 +1222,7 @@ mod tests { let ns_dt = fields[0].data_type(); assert_eq!(ns_dt.nullability, Some(Nullability::NullSecond)); if let Codec::Record(nested_fields) = &ns_dt.codec { - assert_eq!(nested_fields.len(), 4); + assert_eq!(nested_fields.len(), 2); let field_a_dt = nested_fields[0].data_type(); assert_eq!(field_a_dt.nullability, Some(Nullability::NullSecond)); assert!(matches!(field_a_dt.codec, Codec::Int32)); diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 28e87ceb3cca..f6c6ba0ab5d4 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -80,9 +80,11 @@ mod test { ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, ListBuilder, MapBuilder, StringBuilder, StructBuilder, }; + use arrow_array::types::Int32Type; use arrow_array::{ - Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int32Array, - Int64Array, ListArray, RecordBatch, StringArray, StructArray, TimestampMicrosecondArray, + Array, BinaryArray, BooleanArray, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, + Float32Array, Float64Array, Int32Array, Int64Array, ListArray, RecordBatch, StringArray, + StructArray, TimestampMicrosecondArray, }; use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_data::ArrayDataBuilder; @@ -218,7 +220,6 @@ mod test { .unwrap(); for file in files { let file = arrow_test_data(file); - // Pass `false` for strict_mode so we don't fail on out-of-spec unions assert_eq!(read_file(&file, 8, false), expected); assert_eq!(read_file(&file, 3, false), expected); } @@ -407,10 +408,10 @@ mod test { ); let actual_batch_small = read_file(&file_path, 3, false); assert_eq!( - actual_batch_small, expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", - file - ); + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", + file + ); } } @@ -1200,4 +1201,95 @@ mod test { let batch_small = read_file(&file, 3, false); assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); } + + #[test] + fn test_simple() { + // Each entry: (filename, batch_size1, expected_batch, batch_size2) + let tests = [ + ("avro/simple_enum.avro", 4, build_expected_enum(), 2), + ("avro/simple_fixed.avro", 2, build_expected_fixed(), 1), + ]; + + fn build_expected_enum() -> RecordBatch { + let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); + let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); + let f1_dict = + DictionaryArray::::try_new(keys_f1, Arc::new(vals_f1)).unwrap(); + let keys_f2 = Int32Array::from(vec![2, 3, 0, 1]); + let vals_f2 = StringArray::from(vec!["e", "f", "g", "h"]); + let f2_dict = + DictionaryArray::::try_new(keys_f2, Arc::new(vals_f2)).unwrap(); + let keys_f3 = Int32Array::from(vec![Some(1), Some(2), None, Some(0)]); + let vals_f3 = StringArray::from(vec!["i", "j", "k"]); + let f3_dict = + DictionaryArray::::try_new(keys_f3, Arc::new(vals_f3)).unwrap(); + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("f1", dict_type.clone(), false), + Field::new("f2", dict_type.clone(), false), + Field::new("f3", dict_type.clone(), true), + ])); + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1_dict) as Arc, + Arc::new(f2_dict) as Arc, + Arc::new(f3_dict) as Arc, + ], + ) + .unwrap() + } + + fn build_expected_fixed() -> RecordBatch { + let f1 = + FixedSizeBinaryArray::try_from_iter(vec![b"abcde", b"12345"].into_iter()).unwrap(); + let f2 = + FixedSizeBinaryArray::try_from_iter(vec![b"fghijklmno", b"1234567890"].into_iter()) + .unwrap(); + let f3 = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![Some(b"ABCDEF" as &[u8]), None].into_iter(), + 6, + ) + .unwrap(); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("f1", DataType::FixedSizeBinary(5), false), + Field::new("f2", DataType::FixedSizeBinary(10), false), + Field::new("f3", DataType::FixedSizeBinary(6), true), + ])); + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1) as Arc, + Arc::new(f2) as Arc, + Arc::new(f3) as Arc, + ], + ) + .unwrap() + } + for (file_name, batch_size, expected, alt_batch_size) in tests { + let file = arrow_test_data(file_name); + let actual = read_file(&file, batch_size, false); + assert_eq!(actual, expected); + let actual2 = read_file(&file, alt_batch_size, false); + assert_eq!(actual2, expected); + } + } + + #[test] + fn test_single_nan() { + let file = crate::test_util::arrow_test_data("avro/single_nan.avro"); + let actual = read_file(&file, 1, false); + use arrow_array::Float64Array; + let schema = Arc::new(Schema::new(vec![Field::new( + "mycol", + DataType::Float64, + true, + )])); + let col = Float64Array::from(vec![None as Option]); + let expected = RecordBatch::try_new(schema, vec![Arc::new(col)]).unwrap(); + assert_eq!(actual, expected); + let actual2 = read_file(&file, 2, false); + assert_eq!(actual2, expected); + } } From dbdf79af630886fde92a68173b2f7831fd1b8770 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 9 Feb 2025 07:35:16 -0600 Subject: [PATCH 32/38] Added Avro `Decoder`, `ReaderBuilder`, and `Reader` Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 103 +-- arrow-avro/src/reader/block.rs | 281 +++++- arrow-avro/src/reader/cursor.rs | 210 +++++ arrow-avro/src/reader/header.rs | 2 +- arrow-avro/src/reader/mod.rs | 840 ++++++++++++------ arrow-avro/src/reader/record.rs | 13 +- arrow-avro/test/data/nested_lists.snappy.avro | Bin 0 -> 407 bytes 7 files changed, 1100 insertions(+), 349 deletions(-) create mode 100644 arrow-avro/test/data/nested_lists.snappy.avro diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index a2475087b8a3..3ed9c315a0cd 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -620,14 +620,14 @@ mod tests { assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); let arrow_field = avro_field.field(); assert_eq!(arrow_field.name(), "long_col"); - assert_eq!(arrow_field.data_type(), &DataType::Int64); + assert_eq!(arrow_field.data_type(), &Int64); assert!(!arrow_field.is_nullable()); } #[test] fn test_avro_field_with_default() { let field_codec = AvroDataType::from_codec(Codec::Int32); - let default_value = serde_json::json!(123); + let default_value = json!(123); let avro_field = AvroField { name: "int_col".to_string(), data_type: field_codec.clone(), @@ -653,67 +653,67 @@ mod tests { #[test] fn test_arrow_field_to_avro_field() { - let arrow_field = Field::new("Null", DataType::Null, true); + let arrow_field = Field::new("Null", Null, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Null)); - let arrow_field = Field::new("Boolean", DataType::Boolean, true); + let arrow_field = Field::new("Boolean", Boolean, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Boolean)); - let arrow_field = Field::new("Int32", DataType::Int32, true); + let arrow_field = Field::new("Int32", Int32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Int32)); - let arrow_field = Field::new("Int64", DataType::Int64, true); + let arrow_field = Field::new("Int64", Int64, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Int64)); - let arrow_field = Field::new("Float32", DataType::Float32, true); + let arrow_field = Field::new("Float32", Float32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Float32)); - let arrow_field = Field::new("Float64", DataType::Float64, true); + let arrow_field = Field::new("Float64", Float64, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Float64)); - let arrow_field = Field::new("Binary", DataType::Binary, true); + let arrow_field = Field::new("Binary", Binary, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Binary)); - let arrow_field = Field::new("Utf8", DataType::Utf8, true); + let arrow_field = Field::new("Utf8", Utf8, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::String)); - let arrow_field = Field::new("Decimal128", DataType::Decimal128(1, 2), true); + let arrow_field = Field::new("Decimal128", Decimal128(1, 2), true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::Decimal(1, Some(2), Some(16)) )); - let arrow_field = Field::new("Decimal256", DataType::Decimal256(1, 2), true); + let arrow_field = Field::new("Decimal256", Decimal256(1, 2), true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::Decimal(1, Some(2), Some(32)) )); - let arrow_field = Field::new("Date32", DataType::Date32, true); + let arrow_field = Field::new("Date32", Date32, true); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Date32)); - let arrow_field = Field::new("Time32", DataType::Time32(TimeUnit::Millisecond), false); + let arrow_field = Field::new("Time32", Time32(TimeUnit::Millisecond), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::TimeMillis)); - let arrow_field = Field::new("Time32", DataType::Time64(TimeUnit::Microsecond), false); + let arrow_field = Field::new("Time32", Time64(TimeUnit::Microsecond), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::TimeMicros)); let arrow_field = Field::new( "utc_ts_ms", - DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -724,7 +724,7 @@ mod tests { let arrow_field = Field::new( "utc_ts_us", - DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -733,42 +733,30 @@ mod tests { Codec::TimestampMicros(true) )); - let arrow_field = Field::new( - "local_ts_ms", - DataType::Timestamp(TimeUnit::Millisecond, None), - false, - ); + let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::TimestampMillis(false) )); - let arrow_field = Field::new( - "local_ts_us", - DataType::Timestamp(TimeUnit::Microsecond, None), - false, - ); + let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!( avro_field.data_type().codec, Codec::TimestampMicros(false) )); - let arrow_field = Field::new( - "Interval", - DataType::Interval(IntervalUnit::MonthDayNano), - false, - ); + let arrow_field = Field::new("Interval", Interval(IntervalUnit::MonthDayNano), false); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::Duration)); let arrow_field = Field::new( "Struct", - DataType::Struct( + Struct( vec![ - Field::new("a", DataType::Boolean, false), - Field::new("b", DataType::Float64, false), + Field::new("a", Boolean, false), + Field::new("b", Float64, false), ] .into(), ), @@ -788,7 +776,7 @@ mod tests { let arrow_field = Field::new( "DictionaryEnum", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -796,7 +784,7 @@ mod tests { let arrow_field = Field::new( "DictionaryString", - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Boolean)), + Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Boolean)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -804,11 +792,7 @@ mod tests { // Array with nullable items let field = Field::new("Utf8", DataType::Utf8, true); - let arrow_field = Field::new( - "Array with nullable items", - DataType::List(Arc::new(field)), - true, - ); + let arrow_field = Field::new("Array with nullable items", List(Arc::new(field)), true); let avro_field = arrow_field_to_avro_field(&arrow_field); if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { assert_eq!(avro_data_type.nullability, Some(Nullability::NullFirst)); @@ -835,10 +819,10 @@ mod tests { let entries_field = Field::new( "entries", - DataType::Struct( + Struct( vec![ - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Utf8, true), + Field::new("key", Utf8, false), + Field::new("value", Utf8, true), ] .into(), ), @@ -846,7 +830,7 @@ mod tests { ); let arrow_field = Field::new( "Map with nullable items", - DataType::Map(Arc::new(entries_field), true), + Map(Arc::new(entries_field), true), true, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -860,10 +844,10 @@ mod tests { let arrow_field = Field::new( "Utf8", - DataType::Struct( + Struct( vec![ - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Utf8, false), + Field::new("key", Utf8, false), + Field::new("value", Utf8, false), ] .into(), ), @@ -871,7 +855,7 @@ mod tests { ); let arrow_field = Field::new( "Map with non-nullable items", - DataType::Map(Arc::new(arrow_field), false), + Map(Arc::new(arrow_field), false), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -882,7 +866,7 @@ mod tests { } else { panic!("Expected Codec::Map"); } - let arrow_field = Field::new("FixedSizeBinary", DataType::FixedSizeBinary(8), false); + let arrow_field = Field::new("FixedSizeBinary", FixedSizeBinary(8), false); let avro_field = arrow_field_to_avro_field(&arrow_field); let codec = &avro_field.data_type().codec; assert!(matches!(codec, Codec::Fixed(8))); @@ -890,9 +874,10 @@ mod tests { #[test] fn test_arrow_field_to_avro_field_meta_namespace() { - let arrow_field = Field::new("test_meta", DataType::Utf8, true).with_metadata( - HashMap::from([("namespace".to_string(), "arrow_meta_ns".to_string())]), - ); + let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( + "namespace".to_string(), + "arrow_meta_ns".to_string(), + )])); let avro_field = arrow_field_to_avro_field(&arrow_field); assert_eq!(avro_field.name(), "test_meta"); let actual_str = format!("{:?}", avro_field.data_type().codec); @@ -918,7 +903,7 @@ mod tests { ] } "#; - let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_field = AvroField::try_from(&schema).unwrap(); match &avro_field.data_type().codec { Codec::Record(fields) => { @@ -955,7 +940,7 @@ mod tests { ] } "#; - let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_field = AvroField::try_from(&schema).unwrap(); match &avro_field.data_type().codec { Codec::Record(fields) => { @@ -1015,7 +1000,7 @@ mod tests { ] } "#; - let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_field = AvroField::try_from(&schema).unwrap(); match &avro_field.data_type().codec { Codec::Record(fields) => { @@ -1070,7 +1055,7 @@ mod tests { ] } "#; - let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_field = AvroField::try_from(&schema).unwrap(); match &avro_field.data_type().codec { @@ -1131,7 +1116,7 @@ mod tests { ] } "#; - let schema: crate::schema::Schema = serde_json::from_str(json_schema).unwrap(); + let schema: Schema = serde_json::from_str(json_schema).unwrap(); let avro_field = AvroField::try_from(&schema).unwrap(); match &avro_field.data_type().codec { Codec::Record(fields) => { diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs index 479f0ef90909..21e0b231450b 100644 --- a/arrow-avro/src/reader/block.rs +++ b/arrow-avro/src/reader/block.rs @@ -77,6 +77,7 @@ impl BlockDecoder { /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf pub fn decode(&mut self, mut buf: &[u8]) -> Result { let max_read = buf.len(); + while !buf.is_empty() { match self.state { BlockDecoderState::Count => { @@ -86,7 +87,6 @@ impl BlockDecoder { "Block count cannot be negative, got {c}" )) })?; - self.state = BlockDecoderState::Size; } } @@ -108,23 +108,30 @@ impl BlockDecoder { buf = &buf[to_read..]; self.bytes_remaining -= to_read; if self.bytes_remaining == 0 { - self.bytes_remaining = 16; + self.bytes_remaining = 16; // Prepare to read the sync marker self.state = BlockDecoderState::Sync; } } BlockDecoderState::Sync => { let to_decode = buf.len().min(self.bytes_remaining); - let write = &mut self.in_progress.sync[16 - to_decode..]; - write[..to_decode].copy_from_slice(&buf[..to_decode]); + + // Fill sync bytes from left to right + let start = 16 - self.bytes_remaining; + let end = start + to_decode; + self.in_progress.sync[start..end].copy_from_slice(&buf[..to_decode]); + self.bytes_remaining -= to_decode; buf = &buf[to_decode..]; if self.bytes_remaining == 0 { self.state = BlockDecoderState::Finished; } } - BlockDecoderState::Finished => return Ok(max_read - buf.len()), + BlockDecoderState::Finished => { + return Ok(max_read - buf.len()); + } } } + Ok(max_read) } @@ -132,6 +139,7 @@ impl BlockDecoder { pub fn flush(&mut self) -> Option { match self.state { BlockDecoderState::Finished => { + // Reset to decode the next block self.state = BlockDecoderState::Count; Some(std::mem::take(&mut self.in_progress)) } @@ -139,3 +147,266 @@ impl BlockDecoder { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::ArrowError; + use std::convert::TryFrom; + + fn encode_vlq(value: i64) -> Vec { + let mut buf = vec![]; + let mut ux = ((value << 1) ^ (value >> 63)) as u64; // ZigZag + + loop { + let mut byte = (ux & 0x7F) as u8; + ux >>= 7; + if ux != 0 { + byte |= 0x80; + } + buf.push(byte); + if ux == 0 { + break; + } + } + buf + } + + #[test] + fn test_empty_input() { + let mut decoder = BlockDecoder::default(); + let buf = []; + let read = decoder.decode(&buf).unwrap(); + assert_eq!(read, 0); + assert!(decoder.flush().is_none()); + } + + #[test] + fn test_single_block_full_buffer() { + let mut decoder = BlockDecoder::default(); + + let count_encoded = encode_vlq(10); + let size_encoded = encode_vlq(4); + let data = vec![1u8, 2, 3, 4]; + let sync_marker = vec![0xAB; 16]; + + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + + let read = decoder.decode(&input).unwrap(); + assert_eq!(read, input.len()); + + let block = decoder.flush().expect("Should produce a finished block"); + assert_eq!(block.count, 10); + assert_eq!(block.data, data); + + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync); + } + + #[test] + fn test_single_block_partial_buffer() { + let mut decoder = BlockDecoder::default(); + + let count_encoded = encode_vlq(2); + let size_encoded = encode_vlq(3); + let data = vec![10u8, 20, 30]; + let sync_marker = vec![0xCD; 16]; + + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + + // Split into 3 parts + let part1 = &input[0..1]; + let part2 = &input[1..2]; + let part3 = &input[2..]; + + let read = decoder.decode(part1).unwrap(); + assert_eq!(read, 1); + assert!(decoder.flush().is_none()); + + let read = decoder.decode(part2).unwrap(); + assert_eq!(read, 1); + assert!(decoder.flush().is_none()); + + let read = decoder.decode(part3).unwrap(); + assert_eq!(read, part3.len()); + + let block = decoder.flush().expect("Should produce a finished block"); + assert_eq!(block.count, 2); + assert_eq!(block.data, data); + + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync); + } + + #[test] + fn test_multiple_blocks_in_one_buffer() { + let mut decoder = BlockDecoder::default(); + + // Block1 + let block1_count = encode_vlq(1); + let block1_size = encode_vlq(2); + let block1_data = vec![0x01, 0x02]; + let block1_sync = vec![0xAA; 16]; + + // Block2 + let block2_count = encode_vlq(3); + let block2_size = encode_vlq(1); + let block2_data = vec![0x99]; + let block2_sync = vec![0xBB; 16]; + + let mut input = Vec::new(); + input.extend_from_slice(&block1_count); + input.extend_from_slice(&block1_size); + input.extend_from_slice(&block1_data); + input.extend_from_slice(&block1_sync); + + input.extend_from_slice(&block2_count); + input.extend_from_slice(&block2_size); + input.extend_from_slice(&block2_data); + input.extend_from_slice(&block2_sync); + + // Decode once + let read1 = decoder.decode(&input).unwrap(); + + let block1 = decoder.flush().expect("First block should be complete"); + assert_eq!(block1.count, 1); + assert_eq!(block1.data, block1_data); + + let expected_sync1: [u8; 16] = <[u8; 16]>::try_from(&block1_sync[..16]).unwrap(); + assert_eq!(block1.sync, expected_sync1); + + // Decode remainder for block2 + let remainder = &input[read1..]; + decoder.decode(remainder).unwrap(); + let block2 = decoder.flush().expect("Second block should be complete"); + assert_eq!(block2.count, 3); + assert_eq!(block2.data, block2_data); + + let expected_sync2: [u8; 16] = <[u8; 16]>::try_from(&block2_sync[..16]).unwrap(); + assert_eq!(block2.sync, expected_sync2); + } + + #[test] + fn test_negative_count_should_error() { + let mut decoder = BlockDecoder::default(); + + let bad_count = encode_vlq(-1); + let size = encode_vlq(5); + + let mut input = Vec::new(); + input.extend_from_slice(&bad_count); + input.extend_from_slice(&size); + + let err = decoder.decode(&input).unwrap_err(); + match err { + ArrowError::ParseError(msg) => { + assert!( + msg.contains("Block count cannot be negative"), + "Expected negative count parse error, got: {msg}" + ); + } + _ => panic!("Unexpected error type: {err:?}"), + } + } + + #[test] + fn test_negative_size_should_error() { + let mut decoder = BlockDecoder::default(); + + let count = encode_vlq(5); + let bad_size = encode_vlq(-10); + + let mut input = Vec::new(); + input.extend_from_slice(&count); + input.extend_from_slice(&bad_size); + + let err = decoder.decode(&input).unwrap_err(); + match err { + ArrowError::ParseError(msg) => { + assert!( + msg.contains("Block size cannot be negative"), + "Expected negative size parse error, got: {msg}" + ); + } + _ => panic!("Unexpected error type: {err:?}"), + } + } + + #[test] + fn test_partial_sync_across_multiple_calls() { + let mut decoder = BlockDecoder::default(); + + // count=1, size=2, data=[0x01,0x02], sync=[0xCC;16] + let count_encoded = encode_vlq(1); + let size_encoded = encode_vlq(2); + let data = vec![0x01, 0x02]; + let sync_marker = vec![0xCC; 16]; + + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + + // We'll feed all but the last 4 sync bytes first + let split_point = input.len() - 4; + let part1 = &input[..split_point]; + let part2 = &input[split_point..]; + + let read1 = decoder.decode(part1).unwrap(); + assert_eq!(read1, part1.len()); + // Not finished yet + assert!(decoder.flush().is_none()); + + let read2 = decoder.decode(part2).unwrap(); + assert_eq!(read2, part2.len()); + + let block = decoder.flush().expect("Block should be complete now"); + assert_eq!(block.count, 1); + assert_eq!(block.data, data); + + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync, "Should match [0xCC; 16]"); + } + + #[test] + fn test_already_finished_state() { + let mut decoder = BlockDecoder::default(); + + // count=2, size=1, data=[0xAB], sync=[0xFF;16] + let count_encoded = encode_vlq(2); + let size_encoded = encode_vlq(1); + let data = vec![0xAB]; + let sync_marker = vec![0xFF; 16]; + + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + + let read = decoder.decode(&input).unwrap(); + assert_eq!(read, input.len()); + + // Now we should have a block + let block = decoder.flush().expect("Should have a block"); + assert_eq!(block.count, 2); + assert_eq!(block.data, data); + + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync); + + // Attempt to decode again with empty + let read2 = decoder.decode(&[]).unwrap(); + assert_eq!(read2, 0); + assert!(decoder.flush().is_none()); + } +} diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 65c93dab42fe..04aa8049047c 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -136,3 +136,213 @@ impl<'a> AvroCursor<'a> { Ok(ret) } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::ArrowError; + + fn hex_to_bytes(hex: &str) -> Vec { + let mut bytes = vec![]; + let mut chars = hex.chars().collect::>(); + if chars.len() % 2 != 0 { + chars.insert(0, '0'); + } + for chunk in chars.chunks(2) { + let s = format!("{}{}", chunk[0], chunk[1]); + bytes.push(u8::from_str_radix(&s, 16).unwrap()); + } + bytes + } + + #[test] + fn test_new_and_position() { + let data = [1, 2, 3, 4]; + let cursor = AvroCursor::new(&data); + assert_eq!(cursor.position(), 0); + } + + #[test] + fn test_get_u8_ok() { + let data = [0x12, 0x34, 0x56]; + let mut cursor = AvroCursor::new(&data); + assert_eq!(cursor.get_u8().unwrap(), 0x12); + assert_eq!(cursor.position(), 1); + assert_eq!(cursor.get_u8().unwrap(), 0x34); + assert_eq!(cursor.position(), 2); + assert_eq!(cursor.get_u8().unwrap(), 0x56); + assert_eq!(cursor.position(), 3); + } + + #[test] + fn test_get_u8_eof() { + let data = []; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_u8(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF")) + ); + } + + #[test] + fn test_get_bool_ok() { + let data = [0x00, 0x01, 0xFF]; + let mut cursor = AvroCursor::new(&data); + assert!(!cursor.get_bool().unwrap()); // 0x00 -> false + assert!(cursor.get_bool().unwrap()); // 0x01 -> true + assert!(cursor.get_bool().unwrap()); // 0xFF -> true (non-zero) + } + + #[test] + fn test_get_bool_eof() { + let data = []; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_bool(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF")) + ); + } + + #[test] + fn test_read_vlq_ok() { + let data = [0x80, 0x01, 0x05]; + let mut cursor = AvroCursor::new(&data); + let val1 = cursor.read_vlq().unwrap(); + assert_eq!(val1, 128); + let val2 = cursor.read_vlq().unwrap(); + assert_eq!(val2, 5); + } + + #[test] + fn test_read_vlq_bad_varint() { + let data = [0x80]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.read_vlq(); + assert!(matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("bad varint"))); + } + + #[test] + fn test_get_int_ok() { + let data = [0x04, 0x03]; // encodes +2, -2 + let mut cursor = AvroCursor::new(&data); + assert_eq!(cursor.get_int().unwrap(), 2); + assert_eq!(cursor.get_int().unwrap(), -2); + } + + #[test] + fn test_get_int_overflow() { + let data = [0x80, 0x80, 0x80, 0x80, 0x10]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_int(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("varint overflow")) + ); + } + + #[test] + fn test_get_long_ok() { + let data = [0x04, 0x03, 0xAC, 0x02]; + let mut cursor = AvroCursor::new(&data); + assert_eq!(cursor.get_long().unwrap(), 2); + assert_eq!(cursor.get_long().unwrap(), -2); + assert_eq!(cursor.get_long().unwrap(), 150); + } + + #[test] + fn test_get_long_eof() { + let data = [0x80]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_long(); + assert!(matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("bad varint"))); + } + + #[test] + fn test_get_bytes_ok() { + let data = [0x06, 0xAA, 0xBB, 0xCC, 0x05, 0x01]; + let mut cursor = AvroCursor::new(&data); + let bytes = cursor.get_bytes().unwrap(); + assert_eq!(bytes, [0xAA, 0xBB, 0xCC]); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn test_get_bytes_overflow() { + let data = [0xAC, 0x02, 0x01, 0x02, 0x03]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_bytes(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading bytes")) + ); + } + + #[test] + fn test_get_bytes_negative_length() { + let data = [0x01]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_bytes(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("offset overflow")) + ); + } + + #[test] + fn test_get_float_ok() { + let data = [0x00, 0x00, 0x80, 0x3F, 0x01]; + let mut cursor = AvroCursor::new(&data); + let val = cursor.get_float().unwrap(); + assert!((val - 1.0).abs() < f32::EPSILON); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn test_get_float_eof() { + let data = [0x00, 0x00, 0x80]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_float(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading float")) + ); + } + + #[test] + fn test_get_double_ok() { + let data = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, 0x99]; + let mut cursor = AvroCursor::new(&data); + let val = cursor.get_double().unwrap(); + assert!((val - 1.0).abs() < f64::EPSILON); + assert_eq!(cursor.position(), 8); + } + + #[test] + fn test_get_double_eof() { + let data = [0x00, 0x00, 0x00, 0x00]; // only 4 bytes + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_double(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading double")) + ); + } + + #[test] + fn test_get_fixed_ok() { + let data = [0x11, 0x22, 0x33, 0x44]; + let mut cursor = AvroCursor::new(&data); + let val = cursor.get_fixed(2).unwrap(); + assert_eq!(val, [0x11, 0x22]); + assert_eq!(cursor.position(), 2); + + let val = cursor.get_fixed(2).unwrap(); + assert_eq!(val, [0x33, 0x44]); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn test_get_fixed_eof() { + let data = [0x11, 0x22]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_fixed(3); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading fixed")) + ); + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index f62b01922814..9b7d3456589e 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -376,7 +376,7 @@ mod test { sync: [0; 16], }; let schema = header.schema().unwrap().unwrap(); - if let crate::schema::Schema::Complex(crate::schema::ComplexType::Record(record)) = schema { + if let Schema::Complex(crate::schema::ComplexType::Record(record)) = schema { assert_eq!(record.fields.len(), 1); assert_eq!(record.fields[0].default, Some(serde_json::json!(10))); } else { diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index f6c6ba0ab5d4..ad5452d93228 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -14,12 +14,23 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -//! Read Avro data to Arrow + +//! Avro reader +//! +//! This module provides facilities to read Apache Avro-encoded files or streams +//! into Arrow's [`RecordBatch`] format. In particular, it introduces: +//! +//! * [`ReaderBuilder`]: Configures Avro reading, e.g., batch size +//! * [`Reader`]: Yields [`RecordBatch`] values, implementing [`Iterator`] +//! * [`Decoder`]: A low-level push-based decoder for Avro records use crate::reader::block::{Block, BlockDecoder}; use crate::reader::header::{Header, HeaderDecoder}; -use arrow_schema::ArrowError; +use crate::reader::record::RecordDecoder; +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_schema::{ArrowError, Schema, SchemaRef}; use std::io::BufRead; +use std::sync::Arc; mod header; @@ -70,15 +81,238 @@ fn read_blocks(mut reader: R) -> impl Iterator Self { + Self { + record_decoder, + batch_size, + } + } + + /// Decode up to `to_read` Avro records from `data`, returning how many bytes were consumed. + /// + /// You can call this repeatedly with slices of Avro block data. Once you have called `decode` + /// enough times to process a chunk of rows (for example, `batch_size` rows), you may call + /// [`Self::flush`] to convert the accumulated rows to a [`RecordBatch`]. + /// + /// * `data` is Avro-encoded rows (potentially a partial block). + /// * `to_read` is how many rows to decode out of the buffer (not bytes, but Avro record count). + pub fn decode(&mut self, data: &[u8], to_read: usize) -> Result { + self.record_decoder.decode(data, to_read) + } + + /// Produce a [`RecordBatch`] from all fully decoded rows so far. + /// + /// Returns an error if partial Avro rows remain, or if any type conversions + /// fail. Returns `Ok(RecordBatch)` if at least one row was decoded, or an + /// error if no rows have yet been decoded. + pub fn flush(&mut self) -> Result { + self.record_decoder.flush() + } + + /// Return the configured batch size for this [`Decoder`]. + pub fn batch_size(&self) -> usize { + self.batch_size + } +} + +/// A builder to create an [`Avro Reader`](Reader) that reads Avro data +/// into Arrow [`RecordBatch`]es. +/// +/// ``` +/// # use std::fs::File; +/// # use std::io::BufReader; +/// # use arrow_avro::reader::{ReaderBuilder}; +/// let file = File::open("test/data/nested_lists.snappy.avro").unwrap(); +/// let buf_reader = BufReader::new(file); +/// +/// let builder = ReaderBuilder::new().with_batch_size(1024); +/// let reader = builder.build(buf_reader).unwrap(); +/// for maybe_batch in reader { +/// let batch = maybe_batch.unwrap(); +/// // process batch +/// } +/// ``` +#[derive(Debug, Default)] +pub struct ReaderBuilder { + batch_size: usize, + strict_mode: bool, +} + +impl ReaderBuilder { + /// Creates a new [`ReaderBuilder`] with default settings: + /// * `batch_size` = 1024 + /// * `strict_mode` = false + pub fn new() -> Self { + Self::default() + } + + /// Sets the batch size in rows to read + pub fn with_batch_size(self, batch_size: usize) -> Self { + Self { batch_size, ..self } + } + + /// Controls whether certain out of specification schema errors, + /// i.e. Impala's Union type with a null second + /// should produce an error (`strict_mode = true`) or be ignored + /// where possible. + pub fn with_strict_mode(self, strict_mode: bool) -> Self { + Self { + strict_mode, + ..self + } + } + + /// Create a [`Reader`] with the provided [`BufRead`] + pub fn build(self, mut reader: R) -> Result, ArrowError> { + let header = read_header(&mut reader)?; + let compression = header.compression()?; + let avro_schema = header + .schema() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))? + .ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".to_string()) + })?; + use crate::codec::AvroField; + let root_field = AvroField::try_from(&avro_schema)?; + let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; + let decoder = Decoder::new(record_decoder, self.batch_size); + Ok(Reader { + reader, + header, + compression, + decoder, + finished: false, + }) + } +} + +/// An iterator over [`RecordBatch`] that reads from an Avro-encoded +/// data stream (e.g. a file) using the schema stored in the Avro file header. +/// +/// This parallels the design of [`arrow_json::Reader`]. +/// +/// # Example +/// +/// ``` +/// # use std::fs::File; +/// # use std::io::BufReader; +/// # use arrow_avro::reader::{ReaderBuilder}; +/// # use arrow_schema::ArrowError; +/// # fn read_avro(path: &str) -> Result<(), ArrowError> { +/// let file = File::open(path)?; +/// let buf_reader = BufReader::new(file); +/// let mut reader = ReaderBuilder::new() +/// .with_batch_size(500) +/// .build(buf_reader)?; +/// if let Some(batch) = reader.next() { +/// let batch = batch?; +/// println!("Decoded batch: {} rows, {} columns", batch.num_rows(), batch.num_columns()); +/// // process batch +/// } +/// Ok(()) +/// # } +/// ``` + +#[derive(Debug)] +pub struct Reader { + /// The underlying buffered reader or stream from which to read Avro data. + reader: R, + + /// The Avro file header, including sync marker and schema. + header: Header, + + /// An optional compression codec (Snappy, BZip2, etc.) found in the Avro file. + compression: Option, + + /// A high-level decoder that wraps the low-level [`RecordDecoder`]. + decoder: Decoder, + + /// True if we have already returned the final batch or encountered EOF. + finished: bool, +} + +impl Reader { + /// Return the Arrow schema discovered from the Avro file's header. + pub fn schema(&self) -> SchemaRef { + self.decoder.record_decoder.schema().clone() + } +} + +impl Reader { + /// Reads the next [`RecordBatch`] from the file, returning `Ok(None)` if EOF + /// or if we have already yielded all data. + fn read_next_batch(&mut self) -> Result, ArrowError> { + if self.finished { + return Ok(None); + } + for block_result in read_blocks(&mut self.reader) { + let block = block_result?; + let block_data = if let Some(ref c) = self.compression { + c.decompress(&block.data)? + } else { + block.data + }; + let mut offset = 0; + let mut remaining = block.count; + while remaining > 0 { + let to_read = std::cmp::min(remaining, self.decoder.batch_size()); + let consumed = self.decoder.decode(&block_data[offset..], to_read)?; + offset += consumed; + remaining -= to_read; + } + } + let batch = self.decoder.flush()?; + self.finished = true; + Ok(Some(batch)) + } +} + +impl Iterator for Reader { + type Item = Result; + + fn next(&mut self) -> Option { + match self.read_next_batch() { + Ok(Some(b)) => Some(Ok(b)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + } +} + +impl RecordBatchReader for Reader { + fn schema(&self) -> SchemaRef { + self.schema() + } +} + #[cfg(test)] mod test { - use crate::codec::AvroField; - use crate::reader::record::RecordDecoder; - use crate::reader::{read_blocks, read_header}; + use super::*; use crate::test_util::arrow_test_data; use arrow_array::builder::{ ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, - ListBuilder, MapBuilder, StringBuilder, StructBuilder, + ListBuilder, MapBuilder, MapFieldNames, StringBuilder, StructBuilder, }; use arrow_array::types::Int32Type; use arrow_array::{ @@ -94,35 +328,16 @@ mod test { use std::io::BufReader; use std::sync::Arc; - /// Helper to read an Avro file into a `RecordBatch`. + /// Test helper that opens an Avro file, builds an Avro `Reader` with + /// a fixed batch size, then returns that `Reader`. /// - /// - `strict_mode`: if `true`, we reject unions of the form `[T,"null"]`. - fn read_file(file: &str, batch_size: usize, strict_mode: bool) -> RecordBatch { - let file = File::open(file).unwrap(); - let mut reader = BufReader::new(file); - let header = read_header(&mut reader).unwrap(); - let compression = header.compression().unwrap(); - let schema = header.schema().unwrap().unwrap(); - let root = AvroField::try_from(&schema).unwrap(); - let mut decoder = RecordDecoder::try_new(root.data_type(), strict_mode).unwrap(); - for result in read_blocks(reader) { - let block = result.unwrap(); - assert_eq!(block.sync, header.sync()); - let block_data = if let Some(c) = compression { - c.decompress(&block.data).unwrap() - } else { - block.data - }; - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = remaining.min(batch_size); - offset += decoder.decode(&block_data[offset..], to_read).unwrap(); - remaining -= to_read; - } - assert_eq!(offset, block_data.len()); - } - decoder.flush().unwrap() + /// We ignore `schema` because Avro is self-describing; the file has + /// its own schema. We also do not do a separate “infer” step. + fn read_file(path: &str, _schema: Option) -> super::Reader> { + let file = File::open(path).unwrap(); + let reader = BufReader::new(file); + let builder = ReaderBuilder::new().with_batch_size(64); + builder.build(reader).unwrap() } #[test] @@ -183,14 +398,14 @@ mod test { ( "date_string_col", Arc::new(BinaryArray::from_iter_values([ - [48, 51, 47, 48, 49, 47, 48, 57], - [48, 51, 47, 48, 49, 47, 48, 57], - [48, 52, 47, 48, 49, 47, 48, 57], - [48, 52, 47, 48, 49, 47, 48, 57], - [48, 50, 47, 48, 49, 47, 48, 57], - [48, 50, 47, 48, 49, 47, 48, 57], - [48, 49, 47, 48, 49, 47, 48, 57], - [48, 49, 47, 48, 49, 47, 48, 57], + b"03/01/09", + b"03/01/09", + b"04/01/09", + b"04/01/09", + b"02/01/09", + b"02/01/09", + b"01/01/09", + b"01/01/09", ])) as _, true, ), @@ -220,8 +435,12 @@ mod test { .unwrap(); for file in files { let file = arrow_test_data(file); - assert_eq!(read_file(&file, 8, false), expected); - assert_eq!(read_file(&file, 3, false), expected); + let mut reader = read_file(&file, None); + let batch_large = reader.next().unwrap().unwrap(); + assert_eq!(batch_large, expected); + let mut reader_small = read_file(&file, None); + let batch_small = reader_small.next().unwrap().unwrap(); + assert_eq!(batch_small, expected); } } @@ -285,16 +504,18 @@ mod test { ]) .unwrap(); let file_path = arrow_test_data(file); - let batch_large = read_file(&file_path, 8, false); + let mut reader = read_file(&file_path, None); + let batch_large = reader.next().unwrap().unwrap(); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match for file {}", file ); - let batch_small = read_file(&file_path, 3, false); + let mut reader_small = read_file(&file_path, None); + let batch_small = reader_small.next().unwrap().unwrap(); assert_eq!( batch_small, expected, - "Decoded RecordBatch (batch size 3) does not match for file {}", + "Decoded RecordBatch (batch size 64) does not match for file {}", file ); } @@ -337,16 +558,18 @@ mod test { ]) .unwrap(); let file_path = arrow_test_data(file); - let batch_large = read_file(&file_path, 8, false); + let mut reader = read_file(&file_path, None); + let batch_large = reader.next().unwrap().unwrap(); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match for file {}", file ); - let batch_small = read_file(&file_path, 3, false); + let mut reader_small = read_file(&file_path, None); + let batch_small = reader_small.next().unwrap().unwrap(); assert_eq!( batch_small, expected, - "Decoded RecordBatch (batch size 3) does not match for file {}", + "Decoded RecordBatch does not match for file {}", file ); } @@ -354,7 +577,8 @@ mod test { #[test] fn test_binary() { let file = arrow_test_data("avro/binary.avro"); - let batch = read_file(&file, 8, false); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); let expected = RecordBatch::try_from_iter_with_nullable([( "foo", Arc::new(BinaryArray::from_iter_values(vec![ @@ -386,39 +610,40 @@ mod test { ("avro/int64_decimal.avro", 10, 2), ]; let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); + for (file, precision, scale) in files { let file_path = arrow_test_data(file); - let actual_batch = read_file(&file_path, 8, false); + let mut reader = read_file(&file_path, None); + let actual_batch = reader.next().unwrap().unwrap(); + let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) .with_precision_and_scale(precision, scale) .unwrap(); + let mut meta = HashMap::new(); meta.insert("precision".to_string(), precision.to_string()); meta.insert("scale".to_string(), scale.to_string()); let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) .with_metadata(meta); + let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); let expected_batch = RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) .expect("Failed to build expected RecordBatch"); + assert_eq!( actual_batch, expected_batch, "Decoded RecordBatch does not match the expected Decimal128 data for file {}", file ); - let actual_batch_small = read_file(&file_path, 3, false); - assert_eq!( - actual_batch_small, expected_batch, - "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", - file - ); } } #[test] fn test_datapage_v2() { let file = arrow_test_data("avro/datapage_v2.snappy.avro"); - let batch = read_file(&file, 8, false); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); let a = StringArray::from(vec![ Some("abc"), Some("abc"), @@ -463,8 +688,10 @@ mod test { #[test] fn test_dict_pages_offset_zero() { let file = arrow_test_data("avro/dict-page-offset-zero.avro"); - let batch = read_file(&file, 32, false); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); let num_rows = batch.num_rows(); + let expected_field = Int32Array::from(vec![Some(1552); num_rows]); let expected = RecordBatch::try_from_iter_with_nullable([( "l_partkey", @@ -478,6 +705,7 @@ mod test { #[test] fn test_list_columns() { let file = arrow_test_data("avro/list_columns.avro"); + let mut reader = read_file(&file, None); let mut int64_list_builder = ListBuilder::new(Int64Builder::new()); { { @@ -533,13 +761,15 @@ mod test { ("utf8_list", Arc::new(utf8_list) as Arc, true), ]) .unwrap(); - let batch = read_file(&file, 8, false); + let batch = reader.next().unwrap().unwrap(); assert_eq!(batch, expected); } #[test] fn test_nested_lists() { let file = arrow_test_data("avro/nested_lists.snappy.avro"); + let mut reader = read_file(&file, None); + let left = reader.next().unwrap().unwrap(); let inner_values = StringArray::from(vec![ Some("a"), Some("b"), @@ -584,7 +814,7 @@ mod test { .unwrap(); let middle_list_array = ListArray::from(middle_list_data); let outer_offsets = Buffer::from_slice_ref([0, 2, 4, 6]); - let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all 3 rows valid + let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all valid let outer_field = Field::new("item", middle_list_array.data_type().clone(), true); let outer_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(outer_field))) .len(3) @@ -600,14 +830,14 @@ mod test { ("b", Arc::new(b_expected) as Arc, true), ]) .unwrap(); - let left = read_file(&file, 8, false); - assert_eq!(left, expected, "Mismatch for batch size=8"); - let left_small = read_file(&file, 3, false); - assert_eq!(left_small, expected, "Mismatch for batch size=3"); + assert_eq!(left, expected, "Mismatch for batch size=64"); } #[test] fn test_nested_records() { + let file = arrow_test_data("avro/nested_records.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); let f1_f1_1 = StringArray::from(vec!["aaa", "bbb"]); let f1_f1_2 = Int32Array::from(vec![10, 20]); let rounded_pi = (std::f64::consts::PI * 100.0).round() / 100.0; @@ -616,6 +846,7 @@ mod test { Arc::new(Field::new("f1_3_1", DataType::Float64, false)), Arc::new(f1_f1_3_1) as Arc, )]); + let f1_expected = StructArray::from(vec![ ( Arc::new(Field::new("f1_1", DataType::Utf8, false)), @@ -648,8 +879,8 @@ mod test { .map(|f| Arc::new(f.clone())) .collect::>>(), vec![ - Box::new(BooleanBuilder::new()) as Box, - Box::new(Float32Builder::new()) as Box, + Box::new(BooleanBuilder::new()) as Box, + Box::new(Float32Builder::new()) as Box, ], ); let mut f2_list_builder = ListBuilder::new(f2_struct_builder); @@ -749,29 +980,20 @@ mod test { ("f4", Arc::new(f4_expected) as Arc, false), ]) .unwrap(); - let file = arrow_test_data("avro/nested_records.avro"); - let batch_large = read_file(&file, 8, false); - assert_eq!( - batch_large, expected, - "Decoded RecordBatch does not match expected data for nested records (batch size 8)" - ); - let batch_small = read_file(&file, 3, false); - assert_eq!( - batch_small, expected, - "Decoded RecordBatch does not match expected data for nested records (batch size 3)" - ); + assert_eq!(batch, expected, "Mismatch in nested_records.avro contents"); } #[test] fn test_nonnullable_impala() { let file = arrow_test_data("avro/nonnullable.impala.avro"); + let mut reader = read_file(&file, None); let id = Int64Array::from(vec![Some(8)]); let mut int_array_builder = ListBuilder::new(Int32Builder::new()); { let vb = int_array_builder.values(); vb.append_value(-1); } - int_array_builder.append(true); // finalize one sub-list + int_array_builder.append(true); let int_array = int_array_builder.finish(); let mut iaa_builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); { @@ -786,7 +1008,6 @@ mod test { } iaa_builder.append(true); let int_array_array = iaa_builder.finish(); - use arrow_array::builder::MapFieldNames; let field_names = MapFieldNames { entry: "entries".to_string(), key: "key".to_string(), @@ -799,7 +1020,7 @@ mod test { keys.append_value("k1"); vals.append_value(-1); } - int_map_builder.append(true).unwrap(); // finalize map for row 0 + int_map_builder.append(true).unwrap(); let int_map = int_map_builder.finish(); let field_names2 = MapFieldNames { entry: "entries".to_string(), @@ -825,108 +1046,95 @@ mod test { } ima_builder.append(true); let int_map_array_ = ima_builder.finish(); - let mut nested_sb = StructBuilder::new( - vec![ - Arc::new(Field::new("a", DataType::Int32, true)), - Arc::new(Field::new( - "B", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - Arc::new(Field::new( - "c", - DataType::Struct( - vec![Field::new( - "D", - DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new( - "item", - DataType::Struct( - vec![ - Field::new("e", DataType::Int32, true), - Field::new("f", DataType::Utf8, true), - ] - .into(), - ), - true, - ))), - true, - ))), + let nested_schema_fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new( + "B", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "c", + DataType::Struct(Fields::from(vec![Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ])), true, - )] - .into(), - ), + ))), + true, + ))), true, - )), - Arc::new(Field::new( - "G", - DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct( - vec![ - Field::new("key", DataType::Utf8, false), - Field::new( - "value", - DataType::Struct( - vec![Field::new( - "h", - DataType::Struct( - vec![Field::new( - "i", - DataType::List(Arc::new(Field::new( - "item", - DataType::Float64, - true, - ))), - true, - )] - .into(), - ), - true, - )] - .into(), - ), + )])), + true, + ), + Field::new( + "G", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct(Fields::from(vec![Field::new( + "h", + DataType::Struct(Fields::from(vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), true, - ), - ] - .into(), + )])), + true, + )])), + true, ), - false, - )), + ])), false, - ), - true, - )), - ], + )), + false, + ), + true, + ), + ]; + let nested_schema = Arc::new(Schema::new(nested_schema_fields.clone())); + let mut nested_sb = StructBuilder::new( + nested_schema_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), vec![ Box::new(Int32Builder::new()), Box::new(ListBuilder::new(Int32Builder::new())), { - let d_field = Field::new( + let d_list_field = Field::new( "D", DataType::List(Arc::new(Field::new( "item", DataType::List(Arc::new(Field::new( "item", - DataType::Struct( - vec![ - Field::new("e", DataType::Int32, true), - Field::new("f", DataType::Utf8, true), - ] - .into(), - ), + DataType::Struct(Fields::from(vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ])), true, ))), true, ))), true, ); - Box::new(StructBuilder::new( - vec![Arc::new(d_field)], - vec![Box::new({ - let ef_struct_builder = StructBuilder::new( + let struct_c_builder = StructBuilder::new( + vec![Arc::new(d_list_field)], + vec![Box::new(ListBuilder::new(ListBuilder::new( + StructBuilder::new( vec![ Arc::new(Field::new("e", DataType::Int32, true)), Arc::new(Field::new("f", DataType::Utf8, true)), @@ -935,32 +1143,41 @@ mod test { Box::new(Int32Builder::new()), Box::new(StringBuilder::new()), ], - ); - let list_of_ef = ListBuilder::new(ef_struct_builder); - ListBuilder::new(list_of_ef) - })], - )) + ), + )))], + ); + Box::new(struct_c_builder) }, { - let map_field_names = MapFieldNames { - entry: "entries".to_string(), - key: "key".to_string(), - value: "value".to_string(), - }; - let i_list_builder = ListBuilder::new(Float64Builder::new()); - let h_struct = StructBuilder::new( - vec![Arc::new(Field::new( - "i", - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), - true, - ))], - vec![Box::new(i_list_builder)], + let i_list = Field::new( + "i", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + true, + ); + let h_struct = + Field::new("h", DataType::Struct(Fields::from(vec![i_list])), true); + let value_struct = Field::new( + "value", + DataType::Struct(Fields::from(vec![h_struct])), + true, ); - let g_value_builder = StructBuilder::new( - vec![Arc::new(Field::new( - "h", - DataType::Struct( - vec![Field::new( + let key_field = Field::new("key", DataType::Utf8, false); + let entries_field = Field::new( + "entries", + DataType::Struct(Fields::from(vec![key_field, value_struct])), + false, + ); + Box::new(MapBuilder::new( + Some(MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + StringBuilder::new(), + StructBuilder::new( + vec![Arc::new(Field::new( + "h", + DataType::Struct(Fields::from(vec![Field::new( "i", DataType::List(Arc::new(Field::new( "item", @@ -968,17 +1185,22 @@ mod test { true, ))), true, - )] - .into(), - ), - true, - ))], - vec![Box::new(h_struct)], - ); - Box::new(MapBuilder::new( - Some(map_field_names), - StringBuilder::new(), - g_value_builder, + )])), + true, + ))], + vec![Box::new(StructBuilder::new( + vec![Arc::new(Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + ))], + vec![Box::new(ListBuilder::new(Float64Builder::new()))], + ))], + ), )) }, ], @@ -987,8 +1209,6 @@ mod test { { let a_builder = nested_sb.field_builder::(0).unwrap(); a_builder.append_value(-1); - } - { let b_builder = nested_sb .field_builder::>(1) .unwrap(); @@ -997,57 +1217,131 @@ mod test { vb.append_value(-1); } b_builder.append(true); - } - { - let c_struct_builder = nested_sb.field_builder::(2).unwrap(); - c_struct_builder.append(true); - let d_list_builder = c_struct_builder - .field_builder::>>(0) - .unwrap(); + let c_sb = nested_sb.field_builder::(2).unwrap(); + c_sb.append(true); { - let sub_list_builder = d_list_builder.values(); + let d_list_builder = c_sb + .field_builder::>>(0) + .unwrap(); { - let ef_struct = sub_list_builder.values(); - ef_struct.append(true); + let sub_list_builder = d_list_builder.values(); { - let e_b = ef_struct.field_builder::(0).unwrap(); - e_b.append_value(-1); - let f_b = ef_struct.field_builder::(1).unwrap(); - f_b.append_value("nonnullable"); + let ef_struct_builder = sub_list_builder.values(); + ef_struct_builder.append(true); + { + let e_b = ef_struct_builder.field_builder::(0).unwrap(); + e_b.append_value(-1); + let f_b = ef_struct_builder.field_builder::(1).unwrap(); + f_b.append_value("nonnullable"); + } + sub_list_builder.append(true); } - sub_list_builder.append(true); + d_list_builder.append(true); } - d_list_builder.append(true); } - } - { let g_map_builder = nested_sb .field_builder::>(3) .unwrap(); g_map_builder.append(true).unwrap(); + { + let (keys, values) = g_map_builder.entries(); + keys.append_value("k1"); + values.append(true); + let h_struct_builder = values.field_builder::(0).unwrap(); + h_struct_builder.append(true); + { + let i_list_builder = h_struct_builder + .field_builder::>(0) + .unwrap(); + i_list_builder.append(true); + } + } } let nested_struct = nested_sb.finish(); - let expected = RecordBatch::try_from_iter_with_nullable([ - ("ID", Arc::new(id) as Arc, true), - ("Int_Array", Arc::new(int_array), true), - ("int_array_array", Arc::new(int_array_array), true), - ("Int_Map", Arc::new(int_map), true), - ("int_map_array", Arc::new(int_map_array_), true), - ("nested_Struct", Arc::new(nested_struct), true), - ]) + let schema = Arc::new(Schema::new(vec![ + Field::new("ID", DataType::Int64, true), + Field::new( + "Int_Array", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "int_array_array", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + true, + ), + Field::new( + "Int_Map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ), + Field::new( + "int_map_array", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ))), + true, + ), + Field::new( + "nested_Struct", + DataType::Struct(nested_schema.as_ref().fields.clone()), + true, + ), + ])); + let expected = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(id) as Arc, + Arc::new(int_array), + Arc::new(int_array_array), + Arc::new(int_map), + Arc::new(int_map_array_), + Arc::new(nested_struct), + ], + ) .unwrap(); - let batch_large = read_file(&file, 8, false); - assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); - let batch_small = read_file(&file, 3, false); - assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch, expected, "nonnullable impala avro data mismatch"); } #[test] fn test_nullable_impala() { + use arrow_array::{Int64Array, ListArray, StructArray}; let file = arrow_test_data("avro/nullable.impala.avro"); - let batch1 = read_file(&file, 3, false); - let batch2 = read_file(&file, 8, false); - assert_eq!(batch1, batch2); + let mut r1 = read_file(&file, None); + let batch1 = r1.next().unwrap().unwrap(); + let mut r2 = read_file(&file, None); + let batch2 = r2.next().unwrap().unwrap(); + assert_eq!( + batch1, batch2, + "Reading file multiple times should produce the same data" + ); let batch = batch1; assert_eq!(batch.num_rows(), 7); let id_array = batch @@ -1057,18 +1351,14 @@ mod test { .expect("id column should be an Int64Array"); let expected_ids = [1, 2, 3, 4, 5, 6, 7]; for (i, &expected_id) in expected_ids.iter().enumerate() { - assert_eq!( - id_array.value(i), - expected_id, - "Mismatch in id at row {}", - i - ); + assert_eq!(id_array.value(i), expected_id, "Mismatch in id at row {i}"); } let int_array = batch .column(1) .as_any() .downcast_ref::() .expect("int_array column should be a ListArray"); + { let offsets = int_array.value_offsets(); let start = offsets[0] as usize; @@ -1078,7 +1368,7 @@ mod test { .as_any() .downcast_ref::() .expect("Values of int_array should be an Int32Array"); - let row0: Vec> = (start..end).map(|i| Some(values.value(i))).collect(); + let row0: Vec> = (start..end).map(|idx| Some(values.value(idx))).collect(); assert_eq!( row0, vec![Some(1), Some(2), Some(3)], @@ -1111,16 +1401,13 @@ mod test { #[test] fn test_nulls_snappy() { let file = arrow_test_data("avro/nulls.snappy.avro"); - let batch_large = read_file(&file, 8, false); - use arrow_array::{Int32Array, StructArray}; - use arrow_buffer::Buffer; - use arrow_data::ArrayDataBuilder; - use arrow_schema::{DataType, Field, Fields}; + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); let b_c_int = Int32Array::from(vec![None; 8]); let b_c_int_data = b_c_int.into_data(); let b_struct_field = Field::new("b_c_int", DataType::Int32, true); - let b_struct_type = DataType::Struct(Fields::from(vec![b_struct_field])); - let struct_validity = Buffer::from_iter((0..8).map(|_| true)); + let b_struct_type = DataType::Struct(vec![b_struct_field].into()); + let struct_validity = arrow_buffer::Buffer::from_iter((0..8).map(|_| true)); let b_struct_data = ArrayDataBuilder::new(b_struct_type) .len(8) .null_bit_buffer(Some(struct_validity)) @@ -1128,21 +1415,21 @@ mod test { .build() .unwrap(); let b_struct_array = StructArray::from(b_struct_data); - let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([( + + let expected = RecordBatch::try_from_iter_with_nullable([( "b_struct", Arc::new(b_struct_array) as _, true, )]) .unwrap(); - assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); - let batch_small = read_file(&file, 3, false); - assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + assert_eq!(batch, expected); } #[test] fn test_repeated_no_annotation() { let file = arrow_test_data("avro/repeated_no_annotation.avro"); - let batch_large = read_file(&file, 8, false); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); use arrow_array::{Int32Array, Int64Array, ListArray, StringArray, StructArray}; use arrow_buffer::Buffer; use arrow_data::ArrayDataBuilder; @@ -1162,7 +1449,7 @@ mod test { Field::new("kind", DataType::Utf8, true), ]); let phone_struct_data = ArrayDataBuilder::new(DataType::Struct(phone_fields)) - .len(5) // 5 phone entries total + .len(5) .child_data(vec![number_array.into_data(), kind_array.into_data()]) .build() .unwrap(); @@ -1188,7 +1475,7 @@ mod test { .build() .unwrap(); let phone_numbers_struct_array = StructArray::from(phone_numbers_struct_data); - let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([ + let expected = RecordBatch::try_from_iter_with_nullable([ ("id", Arc::new(id_array) as _, true), ( "phoneNumbers", @@ -1197,19 +1484,11 @@ mod test { ), ]) .unwrap(); - assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); - let batch_small = read_file(&file, 3, false); - assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + assert_eq!(batch, expected); } #[test] fn test_simple() { - // Each entry: (filename, batch_size1, expected_batch, batch_size2) - let tests = [ - ("avro/simple_enum.avro", 4, build_expected_enum(), 2), - ("avro/simple_fixed.avro", 2, build_expected_fixed(), 1), - ]; - fn build_expected_enum() -> RecordBatch { let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); @@ -1267,29 +1546,38 @@ mod test { ) .unwrap() } - for (file_name, batch_size, expected, alt_batch_size) in tests { + + // We list the two test files + let tests = [ + ("avro/simple_enum.avro", build_expected_enum()), + ("avro/simple_fixed.avro", build_expected_fixed()), + ]; + for (file_name, expected) in tests { let file = arrow_test_data(file_name); - let actual = read_file(&file, batch_size, false); - assert_eq!(actual, expected); - let actual2 = read_file(&file, alt_batch_size, false); - assert_eq!(actual2, expected); + let mut reader = read_file(&file, None); + let actual = reader + .next() + .expect("Should have a batch") + .expect("Error reading batch"); + assert_eq!(actual, expected, "Mismatch for file {file_name}"); } } #[test] fn test_single_nan() { - let file = crate::test_util::arrow_test_data("avro/single_nan.avro"); - let actual = read_file(&file, 1, false); - use arrow_array::Float64Array; + let file = arrow_test_data("avro/single_nan.avro"); + let mut reader = read_file(&file, None); + let batch = reader + .next() + .expect("Should have a batch") + .expect("Error reading single_nan batch"); let schema = Arc::new(Schema::new(vec![Field::new( "mycol", DataType::Float64, true, )])); - let col = Float64Array::from(vec![None as Option]); - let expected = RecordBatch::try_new(schema, vec![Arc::new(col)]).unwrap(); - assert_eq!(actual, expected); - let actual2 = read_file(&file, 2, false); - assert_eq!(actual2, expected); + let col = arrow_array::Float64Array::from(vec![None]); + let expected = RecordBatch::try_new(schema.clone(), vec![Arc::new(col)]).unwrap(); + assert_eq!(batch, expected, "Mismatch in single_nan.avro data"); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index f497571f2d3e..316b55450b67 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -33,6 +33,7 @@ use std::sync::Arc; const DEFAULT_CAPACITY: usize = 1024; /// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. +#[derive(Debug)] pub struct RecordDecoder { schema: SchemaRef, fields: Vec, @@ -138,7 +139,6 @@ enum Decoder { impl Decoder { fn try_new(data_type: &AvroDataType, strict_mode: bool) -> Result { - // 1) Create the "base" decoder for the underlying Avro codec let base = match &data_type.codec { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -541,11 +541,10 @@ impl Decoder { type FlushResult = (Vec>, Option); fn flush_record_children( - mut kids: Vec>, + mut children: Vec>, parent_nulls: Option, ) -> Result { - let max_len = kids.iter().map(|c| c.len()).max().unwrap_or(0); - + let max_len = children.iter().map(|c| c.len()).max().unwrap_or(0); let fixed_parent_nulls = match parent_nulls { None => None, Some(nb) => { @@ -572,9 +571,8 @@ fn flush_record_children( } } }; - - let mut out = Vec::with_capacity(kids.len()); - for arr in kids { + let mut out = Vec::with_capacity(children.len()); + for arr in children { let cur_len = arr.len(); match cur_len.cmp(&max_len) { Ordering::Equal => out.push(arr), @@ -690,7 +688,6 @@ fn flush_values(vec: &mut Vec) -> Vec { fn append_nulls(arr: &Arc, count: usize) -> Result, ArrowError> { use arrow_data::transform::MutableArrayData; - let d = arr.to_data(); let mut mad = MutableArrayData::new(vec![&d], false, 0); mad.extend(0, 0, arr.len()); diff --git a/arrow-avro/test/data/nested_lists.snappy.avro b/arrow-avro/test/data/nested_lists.snappy.avro new file mode 100644 index 0000000000000000000000000000000000000000..6cbff89610a7fce5f817edd668a06f5b5ac76a5b GIT binary patch literal 407 zcmeZI%3@>_ODrqO*DFrWNX<<=#$2sbQdy9yWTjM;nw(#hqNJmgmzWFUm*f}tq?V=T z1i{49GE;L>ij}OQt6@qKfvO?8fnrc&5{rrwD}myfC8@a(#iU9o6_*rc=B0yNQks*a z6kCgr0e4Fh!YxXfc_j$lv9$*IMd^Bp1&Kf(>lGIy7G>*|r4|)u=I3!4>lx}9iGaf+ zIX@*enWs1}v7n%mA>d!>+<&~!>)XFJ@3vpKVl~?g#(WkA7FH%3rbGs&BnAd12Bu^N z1_l-;MlOyN1_nkUBSj#OQIS=OQ-vW_P=$ewQJT|LRE33!fmJ~VNCIIRPy+*#>x<@m GbmIXq^L?5C literal 0 HcmV?d00001 From fffcbed56da707841abd1a072678de206f3a5b73 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 11 Feb 2025 03:54:59 -0600 Subject: [PATCH 33/38] Implemented Avro `Decoder`, `ReaderBuilder`, and `Reader` + minor code cleanup Signed-off-by: Connor Sanders --- arrow-avro/Cargo.toml | 4 +- arrow-avro/src/codec.rs | 182 ++++---- arrow-avro/src/lib.rs | 3 +- arrow-avro/src/reader/block.rs | 57 +-- arrow-avro/src/reader/cursor.rs | 14 - arrow-avro/src/reader/header.rs | 1 - arrow-avro/src/reader/mod.rs | 576 +++++++++++++++++--------- arrow-avro/src/reader/record.rs | 14 +- arrow-avro/src/reader/vlq.rs | 2 +- arrow-avro/src/schema.rs | 19 - arrow-avro/test/data/simple_enum.avro | Bin 0 -> 411 bytes 11 files changed, 490 insertions(+), 382 deletions(-) create mode 100644 arrow-avro/test/data/simple_enum.avro diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index 433b16c3aa89..06f7d5d1cc06 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -55,4 +55,6 @@ crc = { version = "3.0", optional = true } [dev-dependencies] rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } - +bytes = "1.4" +futures = "0.3" +tokio = { version = "1.27", default-features = false, features = ["io-util", "macros", "rt-multi-thread"] } diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 3ed9c315a0cd..f75ee8167e8e 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,7 +16,6 @@ // under the License. use crate::schema::{ComplexType, PrimitiveType, Schema, TypeName}; -use arrow_array::Array; use arrow_schema::DataType::*; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, @@ -142,7 +141,6 @@ pub enum Codec { String, /// Complex Record(Arc<[AvroField]>), - /// Changed from `Dictionary(Utf8, Int32)` to `Dictionary(Int32, Utf8)` Enum(Arc<[String]>, Arc<[i32]>), Array(Arc), Map(Arc), @@ -476,97 +474,97 @@ fn make_data_type<'a>( } } -pub fn arrow_field_to_avro_field(field: &Field) -> AvroField { - let codec = arrow_type_to_codec(field.data_type()); - let top_null = field.is_nullable().then_some(Nullability::NullFirst); - let data_type = AvroDataType { - nullability: top_null, - metadata: field.metadata().clone(), - codec, - }; - AvroField { - name: field.name().to_string(), - data_type, - default: None, - } -} +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::{Field, IntervalUnit, TimeUnit}; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Arc; -fn arrow_type_to_codec(dt: &DataType) -> Codec { - match dt { - Null => Codec::Null, - Boolean => Codec::Boolean, - Int8 | Int16 | Int32 => Codec::Int32, - Int64 => Codec::Int64, - Float32 => Codec::Float32, - Float64 => Codec::Float64, - Binary | LargeBinary => Codec::Binary, - Utf8 => Codec::String, - Struct(fields) => { - let avro_fields: Vec = fields - .iter() - .map(|fref| arrow_field_to_avro_field(fref.as_ref())) - .collect(); - Codec::Record(Arc::from(avro_fields)) + pub fn arrow_field_to_avro_field(field: &Field) -> AvroField { + let codec = arrow_type_to_codec(field.data_type()); + let top_null = field.is_nullable().then_some(Nullability::NullFirst); + let data_type = AvroDataType { + nullability: top_null, + metadata: field.metadata().clone(), + codec, + }; + AvroField { + name: field.name().to_string(), + data_type, + default: None, } - Dictionary(dict_ty, val_ty) => { - if let Int32 = &**dict_ty { - if let Utf8 = &**val_ty { - return Codec::Enum(Arc::from(Vec::new()), Arc::from(Vec::new())); + } + + fn arrow_type_to_codec(dt: &DataType) -> Codec { + match dt { + Null => Codec::Null, + Boolean => Codec::Boolean, + Int8 | Int16 | Int32 => Codec::Int32, + Int64 => Codec::Int64, + Float32 => Codec::Float32, + Float64 => Codec::Float64, + Binary | LargeBinary => Codec::Binary, + Utf8 => Codec::String, + Struct(fields) => { + let avro_fields: Vec = fields + .iter() + .map(|fref| arrow_field_to_avro_field(fref.as_ref())) + .collect(); + Codec::Record(Arc::from(avro_fields)) + } + Dictionary(dict_ty, val_ty) => { + if let Int32 = &**dict_ty { + if let Utf8 = &**val_ty { + return Codec::Enum(Arc::from(Vec::new()), Arc::from(Vec::new())); + } } + Codec::String } - Codec::String - } - List(item_field) => { - let item_codec = arrow_type_to_codec(item_field.data_type()); - let child_nullability = item_field.is_nullable().then_some(Nullability::NullFirst); - let child_dt = AvroDataType { - codec: item_codec, - nullability: child_nullability, - metadata: item_field.metadata().clone(), - }; - Codec::Array(Arc::new(child_dt)) - } - Map(entries_field, _keys_sorted) => { - if let Struct(struct_fields) = entries_field.data_type() { - let val_field = &struct_fields[1]; - let val_codec = arrow_type_to_codec(val_field.data_type()); - let val_nullability = val_field.is_nullable().then_some(Nullability::NullFirst); - let val_dt = AvroDataType { - codec: val_codec, - nullability: val_nullability, - metadata: val_field.metadata().clone(), + List(item_field) => { + let item_codec = arrow_type_to_codec(item_field.data_type()); + let child_nullability = item_field.is_nullable().then_some(Nullability::NullFirst); + let child_dt = AvroDataType { + codec: item_codec, + nullability: child_nullability, + metadata: item_field.metadata().clone(), }; - Codec::Map(Arc::new(val_dt)) - } else { - Codec::Map(Arc::new(AvroDataType::from_codec(Codec::String))) + Codec::Array(Arc::new(child_dt)) } + Map(entries_field, _keys_sorted) => { + if let Struct(struct_fields) = entries_field.data_type() { + let val_field = &struct_fields[1]; + let val_codec = arrow_type_to_codec(val_field.data_type()); + let val_nullability = val_field.is_nullable().then_some(Nullability::NullFirst); + let val_dt = AvroDataType { + codec: val_codec, + nullability: val_nullability, + metadata: val_field.metadata().clone(), + }; + Codec::Map(Arc::new(val_dt)) + } else { + Codec::Map(Arc::new(AvroDataType::from_codec(Codec::String))) + } + } + FixedSizeBinary(n) => Codec::Fixed(*n), + Decimal128(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(16)), + Decimal256(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(32)), + Date32 => Codec::Date32, + Time32(TimeUnit::Millisecond) => Codec::TimeMillis, + Time64(TimeUnit::Microsecond) => Codec::TimeMicros, + Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMillis(true) + } + Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMicros(true) + } + Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), + Interval(IntervalUnit::MonthDayNano) => Codec::Duration, + _ => Codec::String, } - FixedSizeBinary(n) => Codec::Fixed(*n), - Decimal128(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(16)), - Decimal256(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(32)), - Date32 => Codec::Date32, - Time32(TimeUnit::Millisecond) => Codec::TimeMillis, - Time64(TimeUnit::Microsecond) => Codec::TimeMicros, - Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { - Codec::TimestampMillis(true) - } - Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), - Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { - Codec::TimestampMicros(true) - } - Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), - Interval(IntervalUnit::MonthDayNano) => Codec::Duration, - _ => Codec::String, } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; - use serde_json::json; - use std::collections::HashMap; - use std::sync::Arc; #[test] fn test_skip_avro_default_null_in_metadata() { @@ -646,7 +644,7 @@ mod tests { let codec = Codec::Fixed(12); let dt = codec.data_type(); match dt { - DataType::FixedSizeBinary(n) => assert_eq!(n, 12), + FixedSizeBinary(n) => assert_eq!(n, 12), _ => panic!("Expected FixedSizeBinary(12)"), } } @@ -776,7 +774,7 @@ mod tests { let arrow_field = Field::new( "DictionaryEnum", - Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + Dictionary(Box::new(Int32), Box::new(Utf8)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -784,14 +782,13 @@ mod tests { let arrow_field = Field::new( "DictionaryString", - Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Boolean)), + Dictionary(Box::new(Utf8), Box::new(Boolean)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); assert!(matches!(avro_field.data_type().codec, Codec::String)); - // Array with nullable items - let field = Field::new("Utf8", DataType::Utf8, true); + let field = Field::new("Utf8", Utf8, true); let arrow_field = Field::new("Array with nullable items", List(Arc::new(field)), true); let avro_field = arrow_field_to_avro_field(&arrow_field); if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { @@ -802,10 +799,10 @@ mod tests { panic!("Expected Codec::Array"); } - let field = Field::new("Utf8", DataType::Utf8, false); + let field = Field::new("Utf8", Utf8, false); let arrow_field = Field::new( "Array with non-nullable items", - DataType::List(Arc::new(field)), + List(Arc::new(field)), false, ); let avro_field = arrow_field_to_avro_field(&arrow_field); @@ -910,7 +907,6 @@ mod tests { assert_eq!(fields.len(), 1); assert_eq!(fields[0].name(), "f0"); let child_dt = fields[0].data_type(); - // "long" + "null" => NullSecond assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); assert!(matches!(child_dt.codec, Codec::Int64)); } diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index d01d681b7af0..a8487a1e5358 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -21,7 +21,6 @@ //! [Apache Avro]: https://avro.apache.org/ #![warn(missing_docs)] -#![allow(unused)] // Temporary pub mod reader; mod schema; @@ -30,6 +29,8 @@ mod compression; mod codec; +pub use reader::{Decoder, Reader, ReaderBuilder}; + #[cfg(test)] mod test_util { pub fn arrow_test_data(path: &str) -> String { diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs index 21e0b231450b..43722da23938 100644 --- a/arrow-avro/src/reader/block.rs +++ b/arrow-avro/src/reader/block.rs @@ -77,7 +77,6 @@ impl BlockDecoder { /// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf pub fn decode(&mut self, mut buf: &[u8]) -> Result { let max_read = buf.len(); - while !buf.is_empty() { match self.state { BlockDecoderState::Count => { @@ -108,18 +107,15 @@ impl BlockDecoder { buf = &buf[to_read..]; self.bytes_remaining -= to_read; if self.bytes_remaining == 0 { - self.bytes_remaining = 16; // Prepare to read the sync marker + self.bytes_remaining = 16; self.state = BlockDecoderState::Sync; } } BlockDecoderState::Sync => { let to_decode = buf.len().min(self.bytes_remaining); - - // Fill sync bytes from left to right let start = 16 - self.bytes_remaining; let end = start + to_decode; self.in_progress.sync[start..end].copy_from_slice(&buf[..to_decode]); - self.bytes_remaining -= to_decode; buf = &buf[to_decode..]; if self.bytes_remaining == 0 { @@ -131,7 +127,6 @@ impl BlockDecoder { } } } - Ok(max_read) } @@ -139,7 +134,6 @@ impl BlockDecoder { pub fn flush(&mut self) -> Option { match self.state { BlockDecoderState::Finished => { - // Reset to decode the next block self.state = BlockDecoderState::Count; Some(std::mem::take(&mut self.in_progress)) } @@ -184,25 +178,20 @@ mod tests { #[test] fn test_single_block_full_buffer() { let mut decoder = BlockDecoder::default(); - let count_encoded = encode_vlq(10); let size_encoded = encode_vlq(4); let data = vec![1u8, 2, 3, 4]; let sync_marker = vec![0xAB; 16]; - let mut input = Vec::new(); input.extend_from_slice(&count_encoded); input.extend_from_slice(&size_encoded); input.extend_from_slice(&data); input.extend_from_slice(&sync_marker); - let read = decoder.decode(&input).unwrap(); assert_eq!(read, input.len()); - let block = decoder.flush().expect("Should produce a finished block"); assert_eq!(block.count, 10); assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); assert_eq!(block.sync, expected_sync); } @@ -210,38 +199,30 @@ mod tests { #[test] fn test_single_block_partial_buffer() { let mut decoder = BlockDecoder::default(); - let count_encoded = encode_vlq(2); let size_encoded = encode_vlq(3); let data = vec![10u8, 20, 30]; let sync_marker = vec![0xCD; 16]; - let mut input = Vec::new(); input.extend_from_slice(&count_encoded); input.extend_from_slice(&size_encoded); input.extend_from_slice(&data); input.extend_from_slice(&sync_marker); - // Split into 3 parts let part1 = &input[0..1]; let part2 = &input[1..2]; let part3 = &input[2..]; - let read = decoder.decode(part1).unwrap(); assert_eq!(read, 1); assert!(decoder.flush().is_none()); - let read = decoder.decode(part2).unwrap(); assert_eq!(read, 1); assert!(decoder.flush().is_none()); - let read = decoder.decode(part3).unwrap(); assert_eq!(read, part3.len()); - let block = decoder.flush().expect("Should produce a finished block"); assert_eq!(block.count, 2); assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); assert_eq!(block.sync, expected_sync); } @@ -249,47 +230,36 @@ mod tests { #[test] fn test_multiple_blocks_in_one_buffer() { let mut decoder = BlockDecoder::default(); - // Block1 let block1_count = encode_vlq(1); let block1_size = encode_vlq(2); let block1_data = vec![0x01, 0x02]; let block1_sync = vec![0xAA; 16]; - // Block2 let block2_count = encode_vlq(3); let block2_size = encode_vlq(1); let block2_data = vec![0x99]; let block2_sync = vec![0xBB; 16]; - let mut input = Vec::new(); input.extend_from_slice(&block1_count); input.extend_from_slice(&block1_size); input.extend_from_slice(&block1_data); input.extend_from_slice(&block1_sync); - input.extend_from_slice(&block2_count); input.extend_from_slice(&block2_size); input.extend_from_slice(&block2_data); input.extend_from_slice(&block2_sync); - - // Decode once let read1 = decoder.decode(&input).unwrap(); - let block1 = decoder.flush().expect("First block should be complete"); assert_eq!(block1.count, 1); assert_eq!(block1.data, block1_data); - let expected_sync1: [u8; 16] = <[u8; 16]>::try_from(&block1_sync[..16]).unwrap(); assert_eq!(block1.sync, expected_sync1); - - // Decode remainder for block2 let remainder = &input[read1..]; decoder.decode(remainder).unwrap(); let block2 = decoder.flush().expect("Second block should be complete"); assert_eq!(block2.count, 3); assert_eq!(block2.data, block2_data); - let expected_sync2: [u8; 16] = <[u8; 16]>::try_from(&block2_sync[..16]).unwrap(); assert_eq!(block2.sync, expected_sync2); } @@ -297,14 +267,11 @@ mod tests { #[test] fn test_negative_count_should_error() { let mut decoder = BlockDecoder::default(); - let bad_count = encode_vlq(-1); let size = encode_vlq(5); - let mut input = Vec::new(); input.extend_from_slice(&bad_count); input.extend_from_slice(&size); - let err = decoder.decode(&input).unwrap_err(); match err { ArrowError::ParseError(msg) => { @@ -320,14 +287,11 @@ mod tests { #[test] fn test_negative_size_should_error() { let mut decoder = BlockDecoder::default(); - let count = encode_vlq(5); let bad_size = encode_vlq(-10); - let mut input = Vec::new(); input.extend_from_slice(&count); input.extend_from_slice(&bad_size); - let err = decoder.decode(&input).unwrap_err(); match err { ArrowError::ParseError(msg) => { @@ -343,36 +307,26 @@ mod tests { #[test] fn test_partial_sync_across_multiple_calls() { let mut decoder = BlockDecoder::default(); - - // count=1, size=2, data=[0x01,0x02], sync=[0xCC;16] let count_encoded = encode_vlq(1); let size_encoded = encode_vlq(2); let data = vec![0x01, 0x02]; let sync_marker = vec![0xCC; 16]; - let mut input = Vec::new(); input.extend_from_slice(&count_encoded); input.extend_from_slice(&size_encoded); input.extend_from_slice(&data); input.extend_from_slice(&sync_marker); - - // We'll feed all but the last 4 sync bytes first let split_point = input.len() - 4; let part1 = &input[..split_point]; let part2 = &input[split_point..]; - let read1 = decoder.decode(part1).unwrap(); assert_eq!(read1, part1.len()); - // Not finished yet assert!(decoder.flush().is_none()); - let read2 = decoder.decode(part2).unwrap(); assert_eq!(read2, part2.len()); - let block = decoder.flush().expect("Block should be complete now"); assert_eq!(block.count, 1); assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); assert_eq!(block.sync, expected_sync, "Should match [0xCC; 16]"); } @@ -380,31 +334,22 @@ mod tests { #[test] fn test_already_finished_state() { let mut decoder = BlockDecoder::default(); - - // count=2, size=1, data=[0xAB], sync=[0xFF;16] let count_encoded = encode_vlq(2); let size_encoded = encode_vlq(1); let data = vec![0xAB]; let sync_marker = vec![0xFF; 16]; - let mut input = Vec::new(); input.extend_from_slice(&count_encoded); input.extend_from_slice(&size_encoded); input.extend_from_slice(&data); input.extend_from_slice(&sync_marker); - let read = decoder.decode(&input).unwrap(); assert_eq!(read, input.len()); - - // Now we should have a block let block = decoder.flush().expect("Should have a block"); assert_eq!(block.count, 2); assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); assert_eq!(block.sync, expected_sync); - - // Attempt to decode again with empty let read2 = decoder.decode(&[]).unwrap(); assert_eq!(read2, 0); assert!(decoder.flush().is_none()); diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 04aa8049047c..ca98830be070 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; @@ -142,19 +141,6 @@ mod tests { use super::*; use arrow_schema::ArrowError; - fn hex_to_bytes(hex: &str) -> Vec { - let mut bytes = vec![]; - let mut chars = hex.chars().collect::>(); - if chars.len() % 2 != 0 { - chars.insert(0, '0'); - } - for chunk in chars.chunks(2) { - let s = format!("{}{}", chunk[0], chunk[1]); - bytes.push(u8::from_str_radix(&s, 16).unwrap()); - } - bytes - } - #[test] fn test_new_and_position() { let data = [1, 2, 3, 4]; diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 9b7d3456589e..ecb53f1f101b 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -20,7 +20,6 @@ use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; use crate::reader::vlq::VLQDecoder; use crate::schema::{Schema, SCHEMA_METADATA_KEY}; use arrow_schema::ArrowError; -use std::io::BufRead; #[derive(Debug)] enum HeaderDecoderState { diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index ad5452d93228..eec5ff8d0d95 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -23,24 +23,124 @@ //! * [`ReaderBuilder`]: Configures Avro reading, e.g., batch size //! * [`Reader`]: Yields [`RecordBatch`] values, implementing [`Iterator`] //! * [`Decoder`]: A low-level push-based decoder for Avro records +//! +//! # Basic Usage +//! +//! [`Reader`] can be used directly with synchronous data sources, such as [`std::fs::File`]. +//! +//! ## Reading a Single Batch +//! +//! ``` +//! # use std::fs::File; +//! # use std::io::BufReader; +//! +//! let file = File::open("test/data/simple_enum.avro").unwrap(); +//! let mut avro = arrow_avro::ReaderBuilder::new().build(BufReader::new(file)).unwrap(); +//! let batch = avro.next().unwrap().unwrap(); +//! ``` +//! +//! # Async Usage +//! +//! The lower-level [`Decoder`] can be integrated with various forms of async data streams, +//! and is designed to be agnostic to different async IO primitives within +//! the Rust ecosystem. It works by incrementally decoding Avro data from byte slices. +//! +//! For example, see below for how it could be used with an arbitrary `Stream` of `Bytes`: +//! +//! ``` +//! # use std::task::{Poll, ready}; +//! # use bytes::{Buf, Bytes}; +//! # use arrow_schema::ArrowError; +//! # use futures::stream::{Stream, StreamExt}; +//! # use arrow_array::RecordBatch; +//! # use arrow_avro::reader::Decoder; +//! # +//! fn decode_stream + Unpin>( +//! mut decoder: Decoder, +//! mut input: S, +//! ) -> impl Stream> { +//! let mut buffered = Bytes::new(); +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! if buffered.is_empty() { +//! buffered = match ready!(input.poll_next_unpin(cx)) { +//! Some(b) => b, +//! None => break, +//! }; +//! } +//! let decoded = match decoder.decode(buffered.as_ref()) { +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! let read = buffered.len(); +//! buffered.advance(decoded); +//! if decoded != read { +//! break +//! } +//! } +//! // Convert any fully-decoded rows to a RecordBatch, if available +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! +//! In a similar vein, it can also be used with tokio-based IO primitives +//! +//! ``` +//! # use std::sync::Arc; +//! # use arrow_schema::{DataType, Field, Schema}; +//! # use std::pin::Pin; +//! # use std::task::{Poll, ready}; +//! # use futures::{Stream, TryStreamExt}; +//! # use tokio::io::AsyncBufRead; +//! # use arrow_array::RecordBatch; +//! # use arrow_avro::reader::Decoder; +//! # use arrow_schema::ArrowError; +//! fn decode_stream( +//! mut decoder: Decoder, +//! mut reader: R, +//! ) -> impl Stream> { +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) { +//! Ok(b) if b.is_empty() => break, +//! Ok(b) => b, +//! Err(e) => return Poll::Ready(Some(Err(e.into()))), +//! }; +//! let read = b.len(); +//! let decoded = match decoder.decode(b) { +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! Pin::new(&mut reader).consume(decoded); +//! if decoded != read { +//! break; +//! } +//! } +//! +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! -use crate::reader::block::{Block, BlockDecoder}; -use crate::reader::header::{Header, HeaderDecoder}; -use crate::reader::record::RecordDecoder; use arrow_array::{RecordBatch, RecordBatchReader}; -use arrow_schema::{ArrowError, Schema, SchemaRef}; +use arrow_schema::{ArrowError, SchemaRef}; use std::io::BufRead; -use std::sync::Arc; - -mod header; mod block; - mod cursor; +mod header; mod record; mod vlq; -/// Read a [`Header`] from the provided [`BufRead`] +use crate::codec::AvroField; +use crate::schema::Schema as AvroSchema; +use block::BlockDecoder; +use header::{Header, HeaderDecoder}; +use record::RecordDecoder; + +/// Read the Avro file header (magic, metadata, sync marker) from `reader`. fn read_header(mut reader: R) -> Result { let mut decoder = HeaderDecoder::default(); loop { @@ -55,145 +155,117 @@ fn read_header(mut reader: R) -> Result { break; } } - decoder - .flush() - .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) -} - -/// Return an iterator of [`Block`] from the provided [`BufRead`] -fn read_blocks(mut reader: R) -> impl Iterator> { - let mut decoder = BlockDecoder::default(); - let mut try_next = move || { - loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { - break; - } - let read = buf.len(); - let decoded = decoder.decode(buf)?; - reader.consume(decoded); - if decoded != read { - break; - } - } - Ok(decoder.flush()) - }; - std::iter::from_fn(move || try_next().transpose()) + decoder.flush().ok_or_else(|| { + ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string()) + }) } -/// A low-level interface for decoding Avro-encoded bytes into Arrow [`RecordBatch`] -/// -/// This wraps [`RecordDecoder`] to allow incremental decoding of Avro blocks. -/// It parallels the JSON-based [`Decoder`](arrow_json::reader::Decoder), but -/// uses Avro’s block and sync marker approach. +/// A low-level interface for decoding Avro-encoded bytes into Arrow [`RecordBatch`]. #[derive(Debug)] pub struct Decoder { - /// Internal decoder that processes raw Avro-encoded records. record_decoder: RecordDecoder, - - /// The maximum number of records to read at once when decoding. - /// (This is used by higher-level readers that want to chunk data.) batch_size: usize, + decoded_rows: usize, } impl Decoder { - /// Create a new [`Decoder`], wrapping an existing [`RecordDecoder`] and using - /// the specified `batch_size`. - /// - /// The `record_decoder` typically comes from mapping the Avro file schema - /// into [`AvroField`], then calling [`RecordDecoder::try_new`]. + /// Create a new [`Decoder`], wrapping an existing [`RecordDecoder`]. pub fn new(record_decoder: RecordDecoder, batch_size: usize) -> Self { Self { record_decoder, batch_size, + decoded_rows: 0, } } - /// Decode up to `to_read` Avro records from `data`, returning how many bytes were consumed. - /// - /// You can call this repeatedly with slices of Avro block data. Once you have called `decode` - /// enough times to process a chunk of rows (for example, `batch_size` rows), you may call - /// [`Self::flush`] to convert the accumulated rows to a [`RecordBatch`]. - /// - /// * `data` is Avro-encoded rows (potentially a partial block). - /// * `to_read` is how many rows to decode out of the buffer (not bytes, but Avro record count). - pub fn decode(&mut self, data: &[u8], to_read: usize) -> Result { - self.record_decoder.decode(data, to_read) + /// Return the Arrow schema for the rows decoded by this decoder + pub fn schema(&self) -> SchemaRef { + self.record_decoder.schema().clone() + } + + /// Return the configured maximum number of rows per batch + pub fn batch_size(&self) -> usize { + self.batch_size } - /// Produce a [`RecordBatch`] from all fully decoded rows so far. + /// Feed `data` into the decoder row by row until we either: + /// - consume all bytes in `data`, or + /// - reach `batch_size` decoded rows. /// - /// Returns an error if partial Avro rows remain, or if any type conversions - /// fail. Returns `Ok(RecordBatch)` if at least one row was decoded, or an - /// error if no rows have yet been decoded. - pub fn flush(&mut self) -> Result { - self.record_decoder.flush() + /// Returns the number of bytes consumed. + pub fn decode(&mut self, data: &[u8]) -> Result { + let mut total_consumed = 0usize; + while total_consumed < data.len() && self.decoded_rows < self.batch_size { + let consumed = self.record_decoder.decode(&data[total_consumed..], 1)?; + if consumed == 0 { + break; + } + total_consumed += consumed; + self.decoded_rows += 1; + } + Ok(total_consumed) } - /// Return the configured batch size for this [`Decoder`]. - pub fn batch_size(&self) -> usize { - self.batch_size + /// Produce a [`RecordBatch`] if at least one row is fully decoded, returning + /// `Ok(None)` if no new rows are available. + pub fn flush(&mut self) -> Result, ArrowError> { + if self.decoded_rows == 0 { + Ok(None) + } else { + let batch = self.record_decoder.flush()?; + self.decoded_rows = 0; + Ok(Some(batch)) + } } } /// A builder to create an [`Avro Reader`](Reader) that reads Avro data -/// into Arrow [`RecordBatch`]es. -/// -/// ``` -/// # use std::fs::File; -/// # use std::io::BufReader; -/// # use arrow_avro::reader::{ReaderBuilder}; -/// let file = File::open("test/data/nested_lists.snappy.avro").unwrap(); -/// let buf_reader = BufReader::new(file); -/// -/// let builder = ReaderBuilder::new().with_batch_size(1024); -/// let reader = builder.build(buf_reader).unwrap(); -/// for maybe_batch in reader { -/// let batch = maybe_batch.unwrap(); -/// // process batch -/// } -/// ``` -#[derive(Debug, Default)] +/// into Arrow [`RecordBatch`]. +#[derive(Debug)] pub struct ReaderBuilder { batch_size: usize, strict_mode: bool, } +impl Default for ReaderBuilder { + fn default() -> Self { + Self { + batch_size: 1024, + strict_mode: false, + } + } +} + impl ReaderBuilder { /// Creates a new [`ReaderBuilder`] with default settings: - /// * `batch_size` = 1024 - /// * `strict_mode` = false + /// - `batch_size` = 1024 + /// - `strict_mode` = false pub fn new() -> Self { Self::default() } - /// Sets the batch size in rows to read - pub fn with_batch_size(self, batch_size: usize) -> Self { - Self { batch_size, ..self } + /// Sets the row-based batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self } - /// Controls whether certain out of specification schema errors, - /// i.e. Impala's Union type with a null second - /// should produce an error (`strict_mode = true`) or be ignored - /// where possible. - pub fn with_strict_mode(self, strict_mode: bool) -> Self { - Self { - strict_mode, - ..self - } + /// Controls whether certain Avro unions of the form `[T, "null"]` should produce an error. + pub fn with_strict_mode(mut self, strict_mode: bool) -> Self { + self.strict_mode = strict_mode; + self } - /// Create a [`Reader`] with the provided [`BufRead`] + /// Create a [`Reader`] from this builder and a `BufRead` pub fn build(self, mut reader: R) -> Result, ArrowError> { let header = read_header(&mut reader)?; let compression = header.compression()?; - let avro_schema = header + let avro_schema: Option> = header .schema() - .map_err(|e| ArrowError::ExternalError(Box::new(e)))? - .ok_or_else(|| { - ArrowError::ParseError("No Avro schema present in file header".to_string()) - })?; - use crate::codec::AvroField; + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let avro_schema = avro_schema.ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".to_string()) + })?; let root_field = AvroField::try_from(&avro_schema)?; let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; let decoder = Decoder::new(record_decoder, self.batch_size); @@ -202,89 +274,114 @@ impl ReaderBuilder { header, compression, decoder, + block_decoder: BlockDecoder::default(), + block_data: Vec::new(), finished: false, }) } -} -/// An iterator over [`RecordBatch`] that reads from an Avro-encoded -/// data stream (e.g. a file) using the schema stored in the Avro file header. -/// -/// This parallels the design of [`arrow_json::Reader`]. -/// -/// # Example -/// -/// ``` -/// # use std::fs::File; -/// # use std::io::BufReader; -/// # use arrow_avro::reader::{ReaderBuilder}; -/// # use arrow_schema::ArrowError; -/// # fn read_avro(path: &str) -> Result<(), ArrowError> { -/// let file = File::open(path)?; -/// let buf_reader = BufReader::new(file); -/// let mut reader = ReaderBuilder::new() -/// .with_batch_size(500) -/// .build(buf_reader)?; -/// if let Some(batch) = reader.next() { -/// let batch = batch?; -/// println!("Decoded batch: {} rows, {} columns", batch.num_rows(), batch.num_columns()); -/// // process batch -/// } -/// Ok(()) -/// # } -/// ``` + /// Create a [`Decoder`] from this builder and a `BufRead` by + /// reading and parsing the Avro file's header. This will + /// not create a full [`Reader`]. + pub fn build_decoder(self, mut reader: R) -> Result { + let header = read_header(&mut reader)?; + let avro_schema: Option> = header + .schema() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let avro_schema = avro_schema.ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".to_string()) + })?; + let root_field = AvroField::try_from(&avro_schema)?; + let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; + Ok(Decoder::new(record_decoder, self.batch_size)) + } +} + +/// A high-level Avro `Reader` that reads container-file blocks +/// and feeds them into a row-level [`Decoder`]. #[derive(Debug)] pub struct Reader { - /// The underlying buffered reader or stream from which to read Avro data. reader: R, - - /// The Avro file header, including sync marker and schema. header: Header, - - /// An optional compression codec (Snappy, BZip2, etc.) found in the Avro file. compression: Option, - - /// A high-level decoder that wraps the low-level [`RecordDecoder`]. decoder: Decoder, - - /// True if we have already returned the final batch or encountered EOF. + block_decoder: BlockDecoder, + block_data: Vec, finished: bool, } impl Reader { - /// Return the Arrow schema discovered from the Avro file's header. + /// Return the Arrow schema discovered from the Avro file header pub fn schema(&self) -> SchemaRef { - self.decoder.record_decoder.schema().clone() + self.decoder.schema() + } + + /// Return the Avro container-file header + pub fn avro_header(&self) -> &Header { + &self.header } } impl Reader { - /// Reads the next [`RecordBatch`] from the file, returning `Ok(None)` if EOF - /// or if we have already yielded all data. - fn read_next_batch(&mut self) -> Result, ArrowError> { + /// Reads the next [`RecordBatch`] from the Avro file or `Ok(None)` on EOF + fn read(&mut self) -> Result, ArrowError> { if self.finished { return Ok(None); } - for block_result in read_blocks(&mut self.reader) { - let block = block_result?; - let block_data = if let Some(ref c) = self.compression { - c.decompress(&block.data)? - } else { - block.data + loop { + if !self.block_data.is_empty() { + let consumed = self.decoder.decode(&self.block_data)?; + if consumed > 0 { + self.block_data.drain(..consumed); + } + match self.decoder.flush()? { + None => { + if !self.block_data.is_empty() { + break; + } + } + Some(batch) => { + return Ok(Some(batch)); + } + } + } + let maybe_block = { + let buf = self.reader.fill_buf()?; + if buf.is_empty() { + None + } else { + let read_len = buf.len(); + let consumed_len = self.block_decoder.decode(buf)?; + self.reader.consume(consumed_len); + if consumed_len == 0 && read_len != 0 { + return Err(ArrowError::ParseError( + "Could not decode next Avro block from partial data".to_string(), + )); + } + self.block_decoder.flush() + } }; - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = std::cmp::min(remaining, self.decoder.batch_size()); - let consumed = self.decoder.decode(&block_data[offset..], to_read)?; - offset += consumed; - remaining -= to_read; + match maybe_block { + Some(block) => { + let block_data = if let Some(ref codec) = self.compression { + codec.decompress(&block.data)? + } else { + block.data + }; + self.block_data = block_data; + } + None => { + self.finished = true; + if !self.block_data.is_empty() { + let consumed = self.decoder.decode(&self.block_data)?; + self.block_data.drain(..consumed); + } + return self.decoder.flush(); + } } } - let batch = self.decoder.flush()?; - self.finished = true; - Ok(Some(batch)) + self.decoder.flush() } } @@ -292,8 +389,8 @@ impl Iterator for Reader { type Item = Result; fn next(&mut self) -> Option { - match self.read_next_batch() { - Ok(Some(b)) => Some(Ok(b)), + match self.read() { + Ok(Some(batch)) => Some(Ok(batch)), Ok(None) => None, Err(e) => Some(Err(e)), } @@ -309,6 +406,7 @@ impl RecordBatchReader for Reader { #[cfg(test)] mod test { use super::*; + use crate::reader::vlq::VLQDecoder; use crate::test_util::arrow_test_data; use arrow_array::builder::{ ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, @@ -323,16 +421,15 @@ mod test { use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field, Fields, Schema}; + use bytes::{Buf, Bytes}; + use futures::{stream, Stream, StreamExt, TryStreamExt}; use std::collections::HashMap; + use std::fs; use std::fs::File; - use std::io::BufReader; + use std::io::{BufReader, Cursor}; use std::sync::Arc; + use std::task::{ready, Poll}; - /// Test helper that opens an Avro file, builds an Avro `Reader` with - /// a fixed batch size, then returns that `Reader`. - /// - /// We ignore `schema` because Avro is self-describing; the file has - /// its own schema. We also do not do a separate “infer” step. fn read_file(path: &str, _schema: Option) -> super::Reader> { let file = File::open(path).unwrap(); let reader = BufReader::new(file); @@ -340,6 +437,129 @@ mod test { builder.build(reader).unwrap() } + fn decode_stream + Unpin>( + mut decoder: Decoder, + mut input: S, + ) -> impl Stream> { + let mut buffered = Bytes::new(); + futures::stream::poll_fn(move |cx| { + loop { + if buffered.is_empty() { + buffered = match ready!(input.poll_next_unpin(cx)) { + Some(b) => b, + None => break, + }; + } + let decoded = match decoder.decode(buffered.as_ref()) { + Ok(decoded) => decoded, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + let read = buffered.len(); + buffered.advance(decoded); + if decoded != read { + break; + } + } + Poll::Ready(decoder.flush().transpose()) + }) + } + + #[test] + fn test_basic_usage_single_batch() { + let file = File::open(arrow_test_data("avro/simple_enum.avro")) + .expect("Failed to open test/data/simple_enum.avro"); + let mut avro = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Failed to build Avro Reader"); + + let batch = avro + .next() + .expect("No batch found?") + .expect("Error reading batch"); + + assert!(batch.num_rows() > 0, "Expected at least 1 row"); + assert!(batch.num_columns() > 0, "Expected at least 1 column"); + } + + #[test] + fn test_reader_read() -> Result<(), ArrowError> { + let file_path = "test/data/simple_enum.avro"; + let file = File::open(file_path).expect("Failed to open Avro file"); + let mut reader_direct = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Failed to build Reader"); + let mut direct_batches = Vec::new(); + while let Some(batch) = reader_direct.read()? { + direct_batches.push(batch); + } + let file = File::open(file_path).expect("Failed to open Avro file"); + let reader_iter = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Failed to build Reader"); + let iter_batches: Result, _> = reader_iter.collect(); + let iter_batches = iter_batches?; + assert_eq!(direct_batches, iter_batches); + Ok(()) + } + + #[tokio::test] + async fn test_async_decoder_with_bytes_stream() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/simple_enum.avro"); + let data = fs::read(&path).expect("Failed to read .avro file"); + let mut cursor = Cursor::new(&data); + let decoder: Decoder = ReaderBuilder::new().build_decoder(&mut cursor)?; + let header_consumed = cursor.position() as usize; + let mut remainder = &data[header_consumed..]; + let mut vlq_dec = VLQDecoder::default(); + let _block_count_i64 = vlq_dec + .long(&mut remainder) + .ok_or_else(|| ArrowError::ParseError("EOF reading block count".to_string()))?; + let block_size_i64 = vlq_dec + .long(&mut remainder) + .ok_or_else(|| ArrowError::ParseError("EOF reading block size".to_string()))?; + let block_size = block_size_i64 as usize; + if remainder.len() < block_size { + return Err(ArrowError::ParseError(format!( + "File truncated: Needed {} bytes for block data, got {}", + block_size, + remainder.len() + ))); + } + let block_data = &remainder[..block_size]; + remainder = &remainder[block_size..]; + if remainder.len() < 16 { + return Err(ArrowError::ParseError( + "Missing sync marker in Avro block".to_string(), + )); + } + let _sync_marker = &remainder[..16]; + let _remainder = &remainder[16..]; + let chunks = block_data + .chunks(16) + .map(Bytes::copy_from_slice) + .collect::>(); + let input_stream = stream::iter(chunks); + let record_batch_stream = decode_stream(decoder, input_stream); + let batches: Vec<_> = record_batch_stream.try_collect().await?; + assert!( + !batches.is_empty(), + "Should decode at least one batch from the block" + ); + let file = File::open(&path).unwrap(); + let mut sync_reader = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Could not build sync_reader"); + let expected_batch = sync_reader + .next() + .expect("No batch in file") + .expect("Sync decode failed"); + assert_eq!( + batches[0], expected_batch, + "Async decode differs from sync decode" + ); + Ok(()) + } + #[test] fn test_alltypes() { let files = [ @@ -1149,24 +1369,6 @@ mod test { Box::new(struct_c_builder) }, { - let i_list = Field::new( - "i", - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), - true, - ); - let h_struct = - Field::new("h", DataType::Struct(Fields::from(vec![i_list])), true); - let value_struct = Field::new( - "value", - DataType::Struct(Fields::from(vec![h_struct])), - true, - ); - let key_field = Field::new("key", DataType::Utf8, false); - let entries_field = Field::new( - "entries", - DataType::Struct(Fields::from(vec![key_field, value_struct])), - false, - ); Box::new(MapBuilder::new( Some(MapFieldNames { entry: "entries".to_string(), diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 316b55450b67..d17ed0a39968 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -27,7 +27,6 @@ use arrow_schema::{ Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use std::cmp::Ordering; -use std::io::Read; use std::sync::Arc; const DEFAULT_CAPACITY: usize = 1024; @@ -340,7 +339,6 @@ impl Decoder { } } UnionOrder::NullSecond => { - // In out-of-spec files: branch=0 => decode T, branch=1 => null if branch == 0 { nb.append(true); child.decode(buf)?; @@ -541,7 +539,7 @@ impl Decoder { type FlushResult = (Vec>, Option); fn flush_record_children( - mut children: Vec>, + children: Vec>, parent_nulls: Option, ) -> Result { let max_len = children.iter().map(|c| c.len()).max().unwrap_or(0); @@ -1082,7 +1080,7 @@ mod tests { let row2_values = vec![None, Some(1), Some(2), None, Some(3), None]; data.extend_from_slice(&encode_array(&row2_values, encode_int_or_null)); data.extend_from_slice(&encode_union_branch(0)); - data.extend_from_slice(&encode_avro_long(0)); // block_count=0 => end immediately + data.extend_from_slice(&encode_avro_long(0)); data.extend_from_slice(&encode_union_branch(1)); record_decoder.decode(&data, 4).unwrap(); let batch = record_decoder.flush().unwrap(); @@ -1639,10 +1637,9 @@ mod tests { #[test] fn test_enum_decoding_with_nulls() { - // Union => [Enum(...), null] let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); - let mut inner_decoder = Decoder::try_new(&enum_dt, true).unwrap(); + let inner_decoder = Decoder::try_new(&enum_dt, true).unwrap(); let mut nullable_decoder = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1667,7 +1664,6 @@ mod tests { assert!(dict_arr.is_valid(0)); assert!(!dict_arr.is_valid(1)); assert!(dict_arr.is_valid(2)); - let keys = dict_arr.keys(); let dict_values = dict_arr.values().as_string::(); assert_eq!(dict_values.value(0), "RED"); assert_eq!(dict_values.value(1), "GREEN"); @@ -1739,7 +1735,7 @@ mod tests { #[test] fn test_decimal_decoding_bytes_with_nulls() { let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); - let mut inner = Decoder::try_new(&dt, true).unwrap(); + let inner = Decoder::try_new(&dt, true).unwrap(); let mut decoder = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), @@ -1768,7 +1764,7 @@ mod tests { #[test] fn test_decimal_decoding_bytes_with_nulls_fixed_size() { let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); - let mut inner = Decoder::try_new(&dt, true).unwrap(); + let inner = Decoder::try_new(&dt, true).unwrap(); let mut decoder = Decoder::Nullable( UnionOrder::NullSecond, NullBufferBuilder::new(DEFAULT_CAPACITY), diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index b198a0d66f24..818c1f53cc0a 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -84,7 +84,7 @@ fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { #[cold] fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { let mut value = 0; - for (count, byte) in buf.iter().take(10).enumerate() { + for (count, _) in buf.iter().take(10).enumerate() { let byte = buf[count]; value |= u64::from(byte & 0x7F) << (count * 7); if byte <= 0x7F { diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index b05743887792..cd377c2d5078 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::Nullability; use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; @@ -228,24 +227,6 @@ pub struct Fixed<'a> { pub attributes: Attributes<'a>, } -/// An Avro data type (not an Avro schema) -#[derive(Debug, Clone)] -pub struct AvroDataType { - pub nullability: Option, - pub metadata: HashMap, - pub codec: crate::codec::Codec, -} - -impl AvroDataType { - /// Returns an Arrow [`Field`] with the given name, - /// respecting this type’s `nullability` (instead of forcing `true`). - pub fn field_with_name(&self, name: &str) -> arrow_schema::Field { - let d = self.codec.data_type(); - let is_nullable = self.nullability.is_some(); - arrow_schema::Field::new(name, d, is_nullable).with_metadata(self.metadata.clone()) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/arrow-avro/test/data/simple_enum.avro b/arrow-avro/test/data/simple_enum.avro new file mode 100644 index 0000000000000000000000000000000000000000..dbf0a42baae462801fa883bf5586dea8814b3df2 GIT binary patch literal 411 zcmZ`#F$%&!5RAtev=>W@Eky*i3;w|eh{fe{(ZowGEz5-_6PVksBe65lhUZ*LvPmnn1}0WjS4TzMAHVkM-?% UIc)BYj>8#aF5}#BT*iJ34>m)7+yDRo literal 0 HcmV?d00001 From 3dfe0df21428835abeaf2f27643ed2a819128ef5 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 11 Feb 2025 19:47:57 -0600 Subject: [PATCH 34/38] Optimized the code in Avro codec.rs and record.rs files. Signed-off-by: Connor Sanders --- arrow-avro/src/codec.rs | 177 +++++++++--------- arrow-avro/src/reader/record.rs | 308 +++++++------------------------- 2 files changed, 158 insertions(+), 327 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index f75ee8167e8e..f8607b4a9645 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{ComplexType, PrimitiveType, Schema, TypeName}; +use crate::schema::{Attributes, ComplexType, PrimitiveType, Schema, TypeName}; use arrow_schema::DataType::*; use arrow_schema::{ ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, @@ -43,7 +43,7 @@ pub enum Nullability { #[derive(Debug, Clone)] pub struct AvroDataType { pub nullability: Option, - pub metadata: HashMap, + pub metadata: Arc>, pub codec: Codec, } @@ -57,7 +57,7 @@ impl AvroDataType { AvroDataType { codec, nullability, - metadata, + metadata: Arc::new(metadata), } } @@ -69,7 +69,8 @@ impl AvroDataType { /// Returns an arrow [`Field`] with the given name, applying `nullability` if present. pub fn field_with_name(&self, name: &str) -> Field { let is_nullable = self.nullability.is_some(); - Field::new(name, self.codec.data_type(), is_nullable).with_metadata(self.metadata.clone()) + let metadata = Arc::try_unwrap(self.metadata.clone()).unwrap_or_else(|arc| (*arc).clone()); + Field::new(name, self.codec.data_type(), is_nullable).with_metadata(metadata) } } @@ -176,14 +177,16 @@ impl Codec { Self::Enum(_, _) => Dictionary(Box::new(Int32), Box::new(Utf8)), Self::Array(child_type) => { let child_dt = child_type.codec.data_type(); - let child_md = child_type.metadata.clone(); + let child_md = Arc::try_unwrap(child_type.metadata.clone()) + .unwrap_or_else(|arc| (*arc).clone()); let child_field = Field::new(Field::LIST_FIELD_DEFAULT_NAME, child_dt, true) .with_metadata(child_md); List(Arc::new(child_field)) } Self::Map(value_type) => { let val_dt = value_type.codec.data_type(); - let val_md = value_type.metadata.clone(); + let val_md = Arc::try_unwrap(value_type.metadata.clone()) + .unwrap_or_else(|arc| (*arc).clone()); let val_field = Field::new("value", val_dt, true).with_metadata(val_md); Map( Arc::new(Field::new( @@ -272,6 +275,32 @@ impl<'a> Resolver<'a> { } } +fn parse_decimal_attributes( + attributes: &Attributes, + fallback_size: Option, + precision_required: bool, +) -> Result<(usize, usize, Option), ArrowError> { + let precision = attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .or(if precision_required { None } else { Some(10) }) + .ok_or_else(|| ArrowError::ParseError("Decimal requires precision".to_string()))? + as usize; + let scale = attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let size = attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .map(|s| s as usize) + .or(fallback_size); + Ok((precision, scale, size)) +} + /// Parses a [`AvroDataType`] from the provided [`Schema`], plus optional `namespace`. fn make_data_type<'a>( schema: &Schema<'a>, @@ -281,30 +310,35 @@ fn make_data_type<'a>( match schema { Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { nullability: None, - metadata: Default::default(), + metadata: Arc::new(Default::default()), codec: (*p).into(), }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), Schema::Union(u) => { - let null_idx = u + let null_count = u .iter() - .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (u.len() == 2, null_idx) { - (true, Some(0)) => { - let mut dt = make_data_type(&u[1], namespace, resolver)?; - dt.nullability = Some(Nullability::NullFirst); - Ok(dt) - } - (true, Some(1)) => { - let mut dt = make_data_type(&u[0], namespace, resolver)?; - dt.nullability = Some(Nullability::NullSecond); - Ok(dt) - } - _ => Err(ArrowError::NotYetImplemented(format!( + .filter(|x| *x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))) + .count(); + if null_count == 1 && u.len() == 2 { + let null_idx = u + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))) + .unwrap(); + let other_idx = if null_idx == 0 { 1 } else { 0 }; + let mut dt = make_data_type(&u[other_idx], namespace, resolver)?; + dt.nullability = if null_idx == 0 { + Some(Nullability::NullFirst) + } else { + Some(Nullability::NullSecond) + }; + Ok(dt) + } else { + Err(ArrowError::NotYetImplemented(format!( "Union of {u:?} not currently supported" - ))), + ))) } } + Schema::Complex(c) => match c { ComplexType::Record(r) => { let ns = r.namespace.or(namespace); @@ -322,7 +356,7 @@ fn make_data_type<'a>( .collect::, ArrowError>>()?; let rec = AvroDataType { nullability: None, - metadata: r.attributes.field_metadata(), + metadata: Arc::new(r.attributes.field_metadata()), codec: Codec::Record(Arc::from(fields)), }; resolver.register(r.name, ns, rec.clone()); @@ -331,7 +365,7 @@ fn make_data_type<'a>( ComplexType::Enum(e) => { let en = AvroDataType { nullability: None, - metadata: e.attributes.field_metadata(), + metadata: Arc::new(e.attributes.field_metadata()), codec: Codec::Enum( Arc::from(e.symbols.iter().map(|s| s.to_string()).collect::>()), Arc::from(vec![]), @@ -344,7 +378,7 @@ fn make_data_type<'a>( let child = make_data_type(&a.items, namespace, resolver)?; Ok(AvroDataType { nullability: None, - metadata: a.attributes.field_metadata(), + metadata: Arc::new(a.attributes.field_metadata()), codec: Codec::Array(Arc::new(child)), }) } @@ -352,42 +386,26 @@ fn make_data_type<'a>( let val = make_data_type(&m.values, namespace, resolver)?; Ok(AvroDataType { nullability: None, - metadata: m.attributes.field_metadata(), + metadata: Arc::new(m.attributes.field_metadata()), codec: Codec::Map(Arc::new(val)), }) } ComplexType::Fixed(fx) => { let size = fx.size as i32; if let Some("decimal") = fx.attributes.logical_type { - let precision = fx - .attributes - .additional - .get("precision") - .and_then(|v| v.as_u64()) - .ok_or_else(|| { - ArrowError::ParseError("Decimal requires precision".to_string()) - })?; - let scale = fx - .attributes - .additional - .get("scale") - .and_then(|v| v.as_u64()) - .unwrap_or(0); + let (precision, scale, _) = + parse_decimal_attributes(&fx.attributes, Some(size as usize), true)?; let dec = AvroDataType { nullability: None, - metadata: fx.attributes.field_metadata(), - codec: Codec::Decimal( - precision as usize, - Some(scale as usize), - Some(size as usize), - ), + metadata: Arc::new(fx.attributes.field_metadata()), + codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), }; resolver.register(fx.name, namespace, dec.clone()); Ok(dec) } else { let fixed_dt = AvroDataType { nullability: None, - metadata: fx.attributes.field_metadata(), + metadata: Arc::new(fx.attributes.field_metadata()), codec: Codec::Fixed(size), }; resolver.register(fx.name, namespace, fixed_dt.clone()); @@ -395,43 +413,20 @@ fn make_data_type<'a>( } } }, + Schema::Type(t) => { let mut dt = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; match (t.attributes.logical_type, &mut dt.codec) { (Some("decimal"), Codec::Fixed(sz)) => { - let prec = t - .attributes - .additional - .get("precision") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - let sc = t - .attributes - .additional - .get("scale") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - *sz = t - .attributes - .additional - .get("size") - .and_then(|v| v.as_u64()) - .unwrap_or(*sz as u64) as i32; + let (prec, sc, size_opt) = + parse_decimal_attributes(&t.attributes, Some(*sz as usize), false)?; + if let Some(sz_actual) = size_opt { + *sz = sz_actual as i32; + } dt.codec = Codec::Decimal(prec, Some(sc), Some(*sz as usize)); } (Some("decimal"), Codec::Binary) => { - let prec = t - .attributes - .additional - .get("precision") - .and_then(|v| v.as_u64()) - .unwrap_or(10) as usize; - let sc = t - .attributes - .additional - .get("scale") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; + let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; dt.codec = Codec::Decimal(prec, Some(sc), None); } (Some("uuid"), Codec::String) => { @@ -462,12 +457,18 @@ fn make_data_type<'a>( dt.codec = Codec::Duration; } (Some(other), _) => { - dt.metadata.insert("logicalType".into(), other.into()); + if !dt.metadata.contains_key("logicalType") { + let mut arc_map = (*dt.metadata).clone(); + arc_map.insert("logicalType".into(), other.into()); + dt.metadata = Arc::new(arc_map); + } } (None, _) => {} } for (k, v) in &t.attributes.additional { - dt.metadata.insert(k.to_string(), v.to_string()); + let mut arc_map = (*dt.metadata).clone(); + arc_map.insert(k.to_string(), v.to_string()); + dt.metadata = Arc::new(arc_map); } Ok(dt) } @@ -487,7 +488,7 @@ mod tests { let top_null = field.is_nullable().then_some(Nullability::NullFirst); let data_type = AvroDataType { nullability: top_null, - metadata: field.metadata().clone(), + metadata: Arc::new(field.metadata().clone()), codec, }; AvroField { @@ -528,7 +529,7 @@ mod tests { let child_dt = AvroDataType { codec: item_codec, nullability: child_nullability, - metadata: item_field.metadata().clone(), + metadata: Arc::new(item_field.metadata().clone()), }; Codec::Array(Arc::new(child_dt)) } @@ -540,7 +541,7 @@ mod tests { let val_dt = AvroDataType { codec: val_codec, nullability: val_nullability, - metadata: val_field.metadata().clone(), + metadata: Arc::new(val_field.metadata().clone()), }; Codec::Map(Arc::new(val_dt)) } else { @@ -913,7 +914,7 @@ mod tests { _ => panic!("Expected a record with a single [long,null] field"), } let mut resolver = Resolver::default(); - let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); if let Codec::Record(fields) = &top_dt.codec { assert_eq!(fields.len(), 1); assert_eq!(fields[0].name(), "f0"); @@ -954,7 +955,7 @@ mod tests { _ => panic!("Expected a record with a single union array field"), } let mut resolver = Resolver::default(); - let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); if let Codec::Record(fields) = &top_dt.codec { assert_eq!(fields.len(), 1); let arr_dt = fields[0].data_type(); @@ -1019,7 +1020,7 @@ mod tests { _ => panic!("Expected a record with a single nested union array field"), } let mut resolver = Resolver::default(); - let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); if let Codec::Record(fields) = &top_dt.codec { assert_eq!(fields.len(), 1); let outer_dt = fields[0].data_type(); @@ -1070,7 +1071,7 @@ mod tests { _ => panic!("Expected a record with a single union map field"), } let mut resolver = Resolver::default(); - let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); if let Codec::Record(fields) = &top_dt.codec { assert_eq!(fields.len(), 1); let map_dt = fields[0].data_type(); @@ -1135,7 +1136,7 @@ mod tests { _ => panic!("Expected a record with a single union array-of-map field"), } let mut resolver = Resolver::default(); - let top_dt = make_data_type(&schema, None, &mut resolver).unwrap(); + let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); if let Codec::Record(fields) = &top_dt.codec { assert_eq!(fields.len(), 1); let outer_dt = fields[0].data_type(); @@ -1196,7 +1197,7 @@ mod tests { _ => panic!("Expected top-level record with a single union-based nested_struct"), } let mut resolver = Resolver::default(); - let dt = make_data_type(&schema, None, &mut resolver).unwrap(); + let dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); if let Codec::Record(fields) = &dt.codec { assert_eq!(fields.len(), 1); assert_eq!(fields[0].name(), "nested_struct"); diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index d17ed0a39968..f901d85b6611 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -21,7 +21,6 @@ use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilde use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; -use arrow_data::ArrayData; use arrow_schema::{ ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, @@ -74,21 +73,16 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush into a [`RecordBatch`]. + /// Flush into a [`RecordBatch`], /// - /// - Flush each `Decoder` => `Arc` - /// - Sanitize offsets in each final array => `sanitize_array_offsets(...)` + /// We collect arrays from each `Decoder` and build a new [`RecordBatch`]. pub fn flush(&mut self) -> Result { let arrays = self .fields .iter_mut() .map(|d| d.flush(None)) .collect::, _>>()?; - let sanitized_cols = arrays - .into_iter() - .map(sanitize_array_offsets) - .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), sanitized_cols) + RecordBatch::try_new(self.schema.clone(), arrays) } } @@ -218,7 +212,7 @@ impl Decoder { Some(Nullability::NullSecond) => { if strict_mode { return Err(ArrowError::ParseError( - "Found Avro union of the form ['T','null'], which is disallowed in strict_mode mode" + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" .to_string(), )); } @@ -428,8 +422,27 @@ impl Decoder { for c in children { child_arrays.push(c.flush(None)?); } - let (fixed, final_nulls) = flush_record_children(child_arrays, nulls)?; - let sarr = StructArray::new(fields.clone(), fixed, final_nulls); + let first_len = match child_arrays.first() { + Some(a) => a.len(), + None => 0, + }; + for (i, arr) in child_arrays.iter().enumerate() { + if arr.len() != first_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Inconsistent struct child length for field #{i}. Expected {first_len}, got {}", + arr.len() + ))); + } + } + if let Some(n) = &nulls { + if n.len() != first_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Struct null buffer length {} != struct fields length {first_len}", + n.len() + ))); + } + } + let sarr = StructArray::new(fields.clone(), child_arrays, nulls); Ok(Arc::new(sarr) as Arc) } Self::Enum(symbols, idxs) => { @@ -451,6 +464,15 @@ impl Decoder { Self::List(item_field, off, child) => { let c = child.flush(None)?; let offsets = flush_offsets(off); + let final_len = offsets.len() - 1; + if let Some(n) = &nulls { + if n.len() != final_len { + return Err(ArrowError::InvalidArgumentError(format!( + "List array null buffer length {} != final list length {final_len}", + n.len() + ))); + } + } let larr = ListArray::new(item_field.clone(), offsets, c, nulls); Ok(Arc::new(larr) as Arc) } @@ -460,17 +482,28 @@ impl Decoder { let kd = flush_values(kdata).into(); let val_arr = valdec.flush(None)?; let key_arr = StringArray::new(koff, kd, None); - let (fixed_keys, fixed_vals) = flush_map_children(&key_arr, &val_arr)?; + if key_arr.len() != val_arr.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Map keys length ({}) != map values length ({})", + key_arr.len(), + val_arr.len() + ))); + } + let final_len = moff.len() - 1; + if let Some(n) = &nulls { + if n.len() != final_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Map array null buffer length {} != final map length {final_len}", + n.len() + ))); + } + } let entries_struct = StructArray::new( Fields::from(vec![ Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new( - "value", - fixed_vals.data_type().clone(), - true, - )), + Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), ]), - vec![Arc::new(fixed_keys), fixed_vals], + vec![Arc::new(key_arr), val_arr], None, ); let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); @@ -536,113 +569,24 @@ impl Decoder { } } -type FlushResult = (Vec>, Option); - -fn flush_record_children( - children: Vec>, - parent_nulls: Option, -) -> Result { - let max_len = children.iter().map(|c| c.len()).max().unwrap_or(0); - let fixed_parent_nulls = match parent_nulls { - None => None, - Some(nb) => { - let old_len = nb.len(); - match old_len.cmp(&max_len) { - Ordering::Equal => Some(nb), - Ordering::Less => { - let mut b = NullBufferBuilder::new(max_len); - for i in 0..old_len { - b.append(nb.is_valid(i)); - } - for _ in old_len..max_len { - b.append(false); - } - b.finish() - } - Ordering::Greater => { - let mut b = NullBufferBuilder::new(max_len); - for i in 0..max_len { - b.append(nb.is_valid(i)); - } - b.finish() - } - } - } - }; - let mut out = Vec::with_capacity(children.len()); - for arr in children { - let cur_len = arr.len(); - match cur_len.cmp(&max_len) { - Ordering::Equal => out.push(arr), - Ordering::Less => { - let to_add = max_len - cur_len; - let appended = append_nulls(&arr, to_add)?; - out.push(appended); - } - Ordering::Greater => { - let sliced = arr.slice(0, max_len); - out.push(sliced); - } - } - } - - Ok((out, fixed_parent_nulls)) -} - -fn flush_map_children( - key_arr: &StringArray, - val_arr: &Arc, -) -> Result<(StringArray, Arc), ArrowError> { - let kl = key_arr.len(); - let vl = val_arr.len(); - match kl.cmp(&vl) { - Ordering::Equal => Ok((key_arr.clone(), val_arr.clone())), - Ordering::Less => { - let truncated = val_arr.slice(0, kl); - Ok((key_arr.clone(), truncated)) - } - Ordering::Greater => { - let to_add = kl - vl; - let appended = append_nulls(val_arr, to_add)?; - Ok((key_arr.clone(), appended)) - } - } -} - -/// Decode an Avro array in blocks until a 0 block_count signals end. fn read_array_blocks( buf: &mut AvroCursor, - mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { - let mut total = 0usize; - loop { - let blk = buf.get_long()?; - match blk.cmp(&0) { - Ordering::Equal => break, - Ordering::Less => { - let cnt = (-blk) as usize; - let _sz = buf.get_long()?; - for _i in 0..cnt { - decode_item(buf)?; - } - total += cnt; - } - Ordering::Greater => { - let cnt = blk as usize; - for _i in 0..cnt { - decode_item(buf)?; - } - total += cnt; - } - } - } - Ok(total) + read_blockwise_items(buf, true, decode_item) } -/// Decode an Avro map in blocks until 0 block_count signals end. fn read_map_blocks( buf: &mut AvroCursor, - mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + read_blockwise_items(buf, true, decode_entry) +} + +fn read_blockwise_items( + buf: &mut AvroCursor, + read_size_after_negative: bool, + mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, ) -> Result { let mut total = 0usize; loop { @@ -651,16 +595,18 @@ fn read_map_blocks( Ordering::Equal => break, Ordering::Less => { let cnt = (-blk) as usize; - let _sz = buf.get_long()?; - for _i in 0..cnt { - decode_entry(buf)?; + if read_size_after_negative { + let _size_in_bytes = buf.get_long()?; + } + for _ in 0..cnt { + decode_fn(buf)?; } total += cnt; } Ordering::Greater => { let cnt = blk as usize; for _i in 0..cnt { - decode_entry(buf)?; + decode_fn(buf)?; } total += cnt; } @@ -684,122 +630,6 @@ fn flush_values(vec: &mut Vec) -> Vec { std::mem::replace(vec, Vec::with_capacity(DEFAULT_CAPACITY)) } -fn append_nulls(arr: &Arc, count: usize) -> Result, ArrowError> { - use arrow_data::transform::MutableArrayData; - let d = arr.to_data(); - let mut mad = MutableArrayData::new(vec![&d], false, 0); - mad.extend(0, 0, arr.len()); - mad.extend_nulls(count); - let out = mad.freeze(); - let arr2 = make_array(out); - sanitize_array_offsets(arr2) -} - -fn sanitize_offsets_vec(offsets: &[i32], child_len: i32) -> Vec { - let mut new_offsets = Vec::with_capacity(offsets.len()); - let mut prev = 0; - for &offset in offsets { - // clamp each offset between the previous value and the child length - let clamped = offset.clamp(prev, child_len); - new_offsets.push(clamped); - if clamped > prev { - prev = clamped; - } - } - new_offsets -} - -fn sanitize_offsets_array( - original_data: &ArrayData, - child: Arc, - offsets: &[i32], -) -> Result { - let child_san = sanitize_array_offsets(child)?; - let child_len = child_san.len() as i32; - let new_offsets = sanitize_offsets_vec(offsets, child_len); - let final_len = new_offsets.len() - 1; - let mut new_data = original_data.clone(); - let mut bufs = new_data.buffers().to_vec(); - bufs[0] = Buffer::from_slice_ref(&new_offsets); - new_data = new_data - .into_builder() - .len(final_len) - .buffers(bufs) - .child_data(vec![child_san.to_data()]) - .build()?; - Ok(new_data) -} - -fn sanitize_struct_child( - array: Arc, - target_len: usize, -) -> Result { - let sanitized = sanitize_array_offsets(array)?; - let sanitized_len = sanitized.len(); - match sanitized_len.cmp(&target_len) { - Ordering::Equal => Ok(sanitized.to_data()), - Ordering::Less => { - let to_add = target_len - sanitized_len; - let appended = append_nulls(&sanitized, to_add)?; - Ok(appended.to_data()) - } - Ordering::Greater => { - let sliced = sanitized.slice(0, target_len); - Ok(sliced.to_data()) - } - } -} - -/// Recursively sanitizes the offsets for arrays of List, Map, and Struct types. -fn sanitize_array_offsets(array: Arc) -> Result, ArrowError> { - match array.data_type() { - DataType::List(_item) => { - let list_arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| ArrowError::ParseError("Downcast to ListArray".into()))?; - let child = Arc::new(list_arr.values().clone()) as Arc; - let new_data = - sanitize_offsets_array(&list_arr.to_data(), child, list_arr.value_offsets())?; - Ok(make_array(new_data)) - } - DataType::Map(_field, _keys_sorted) => { - let map_arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| ArrowError::ParseError("Downcast to MapArray".into()))?; - let child = Arc::new(map_arr.entries().clone()) as Arc; - let new_data = - sanitize_offsets_array(&map_arr.to_data(), child, map_arr.value_offsets())?; - Ok(make_array(new_data)) - } - DataType::Struct(_fs) => { - let struct_arr = array - .as_any() - .downcast_ref::() - .ok_or_else(|| ArrowError::ParseError("Downcast to StructArray".into()))?; - let length = struct_arr.len(); - - let new_child_data = struct_arr - .columns() - .iter() - .map(|col| { - let col_arc = Arc::new(col.clone()) as Arc; - sanitize_struct_child(col_arc, length) - }) - .collect::, _>>()?; - let new_data = struct_arr - .to_data() - .clone() - .into_builder() - .child_data(new_child_data) - .build()?; - Ok(make_array(new_data)) - } - _ => Ok(array), - } -} - /// A builder for Avro decimal, either 128-bit or 256-bit. #[derive(Debug)] enum DecimalBuilder { From 0de2eda5f64b7f1f2f2eabb2ce93e076c8e1dcd5 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Tue, 11 Feb 2025 21:21:14 -0600 Subject: [PATCH 35/38] Small arrow-avro lib.rs update Signed-off-by: Connor Sanders --- arrow-avro/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index a8487a1e5358..1d04129634b1 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -29,7 +29,7 @@ mod compression; mod codec; -pub use reader::{Decoder, Reader, ReaderBuilder}; +pub use self::reader::{Decoder, Reader, ReaderBuilder}; #[cfg(test)] mod test_util { From ad34fb44dc2d6964010759ee63d6a1c06cb83263 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sat, 15 Feb 2025 17:19:47 -0600 Subject: [PATCH 36/38] Reverted all changes not related to enhanced Avro codec and record decoder support + .avro file tests Signed-off-by: Connor Sanders --- arrow-avro/Cargo.toml | 5 +- arrow-avro/src/codec.rs | 742 ----------- arrow-avro/src/lib.rs | 3 +- arrow-avro/src/reader/block.rs | 214 ---- arrow-avro/src/reader/cursor.rs | 197 --- arrow-avro/src/reader/header.rs | 31 - arrow-avro/src/reader/mod.rs | 1092 +++++------------ arrow-avro/src/reader/record.rs | 908 -------------- arrow-avro/src/schema.rs | 177 --- arrow-avro/test/data/nested_lists.snappy.avro | Bin 407 -> 0 bytes arrow-avro/test/data/simple_enum.avro | Bin 411 -> 0 bytes 11 files changed, 303 insertions(+), 3066 deletions(-) delete mode 100644 arrow-avro/test/data/nested_lists.snappy.avro delete mode 100644 arrow-avro/test/data/simple_enum.avro diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index a9c237008140..331efda5680d 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -54,7 +54,4 @@ crc = { version = "3.0", optional = true } [dev-dependencies] -bytes = "1.4" -futures = "0.3" -tokio = { version = "1.27", default-features = false, features = ["io-util", "macros", "rt-multi-thread"] } -rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } \ No newline at end of file +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } \ No newline at end of file diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index f8607b4a9645..1c8df7d70421 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -474,745 +474,3 @@ fn make_data_type<'a>( } } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow_schema::{Field, IntervalUnit, TimeUnit}; - use serde_json::json; - use std::collections::HashMap; - use std::sync::Arc; - - pub fn arrow_field_to_avro_field(field: &Field) -> AvroField { - let codec = arrow_type_to_codec(field.data_type()); - let top_null = field.is_nullable().then_some(Nullability::NullFirst); - let data_type = AvroDataType { - nullability: top_null, - metadata: Arc::new(field.metadata().clone()), - codec, - }; - AvroField { - name: field.name().to_string(), - data_type, - default: None, - } - } - - fn arrow_type_to_codec(dt: &DataType) -> Codec { - match dt { - Null => Codec::Null, - Boolean => Codec::Boolean, - Int8 | Int16 | Int32 => Codec::Int32, - Int64 => Codec::Int64, - Float32 => Codec::Float32, - Float64 => Codec::Float64, - Binary | LargeBinary => Codec::Binary, - Utf8 => Codec::String, - Struct(fields) => { - let avro_fields: Vec = fields - .iter() - .map(|fref| arrow_field_to_avro_field(fref.as_ref())) - .collect(); - Codec::Record(Arc::from(avro_fields)) - } - Dictionary(dict_ty, val_ty) => { - if let Int32 = &**dict_ty { - if let Utf8 = &**val_ty { - return Codec::Enum(Arc::from(Vec::new()), Arc::from(Vec::new())); - } - } - Codec::String - } - List(item_field) => { - let item_codec = arrow_type_to_codec(item_field.data_type()); - let child_nullability = item_field.is_nullable().then_some(Nullability::NullFirst); - let child_dt = AvroDataType { - codec: item_codec, - nullability: child_nullability, - metadata: Arc::new(item_field.metadata().clone()), - }; - Codec::Array(Arc::new(child_dt)) - } - Map(entries_field, _keys_sorted) => { - if let Struct(struct_fields) = entries_field.data_type() { - let val_field = &struct_fields[1]; - let val_codec = arrow_type_to_codec(val_field.data_type()); - let val_nullability = val_field.is_nullable().then_some(Nullability::NullFirst); - let val_dt = AvroDataType { - codec: val_codec, - nullability: val_nullability, - metadata: Arc::new(val_field.metadata().clone()), - }; - Codec::Map(Arc::new(val_dt)) - } else { - Codec::Map(Arc::new(AvroDataType::from_codec(Codec::String))) - } - } - FixedSizeBinary(n) => Codec::Fixed(*n), - Decimal128(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(16)), - Decimal256(p, s) => Codec::Decimal(*p as usize, Some(*s as usize), Some(32)), - Date32 => Codec::Date32, - Time32(TimeUnit::Millisecond) => Codec::TimeMillis, - Time64(TimeUnit::Microsecond) => Codec::TimeMicros, - Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { - Codec::TimestampMillis(true) - } - Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), - Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { - Codec::TimestampMicros(true) - } - Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), - Interval(IntervalUnit::MonthDayNano) => Codec::Duration, - _ => Codec::String, - } - } - - #[test] - fn test_skip_avro_default_null_in_metadata() { - let dt = AvroDataType::from_codec(Codec::Int32); - let field = AvroField { - name: "test_col".into(), - data_type: dt, - default: Some(json!(null)), - }; - let arrow_field = field.field(); - assert!(arrow_field.metadata().get("avro.default").is_none()); - } - - #[test] - fn test_store_avro_default_nonnull_in_metadata() { - let dt = AvroDataType::from_codec(Codec::Int32); - let field = AvroField { - name: "test_col".into(), - data_type: dt, - default: Some(json!(42)), - }; - let arrow_field = field.field(); - let md = arrow_field.metadata(); - let got = md.get("avro.default").cloned(); - assert_eq!(got, Some("42".to_string())); - } - - #[test] - fn test_no_default_metadata_if_none() { - let dt = AvroDataType::from_codec(Codec::String); - let field = AvroField { - name: "col".to_string(), - data_type: dt, - default: None, - }; - let arrow_field = field.field(); - assert!(arrow_field.metadata().get("avro.default").is_none()); - } - - #[test] - fn test_avro_field() { - let field_codec = AvroDataType::from_codec(Codec::Int64); - let avro_field = AvroField { - name: "long_col".to_string(), - data_type: field_codec.clone(), - default: None, - }; - assert_eq!(avro_field.name(), "long_col"); - let actual_str = format!("{:?}", avro_field.data_type().codec); - let expected_str = format!("{:?}", &Codec::Int64); - assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); - let arrow_field = avro_field.field(); - assert_eq!(arrow_field.name(), "long_col"); - assert_eq!(arrow_field.data_type(), &Int64); - assert!(!arrow_field.is_nullable()); - } - - #[test] - fn test_avro_field_with_default() { - let field_codec = AvroDataType::from_codec(Codec::Int32); - let default_value = json!(123); - let avro_field = AvroField { - name: "int_col".to_string(), - data_type: field_codec.clone(), - default: Some(default_value.clone()), - }; - let arrow_field = avro_field.field(); - let metadata = arrow_field.metadata(); - assert_eq!( - metadata.get("avro.default").unwrap(), - &default_value.to_string() - ); - } - - #[test] - fn test_codec_fixedsizebinary() { - let codec = Codec::Fixed(12); - let dt = codec.data_type(); - match dt { - FixedSizeBinary(n) => assert_eq!(n, 12), - _ => panic!("Expected FixedSizeBinary(12)"), - } - } - - #[test] - fn test_arrow_field_to_avro_field() { - let arrow_field = Field::new("Null", Null, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Null)); - - let arrow_field = Field::new("Boolean", Boolean, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Boolean)); - - let arrow_field = Field::new("Int32", Int32, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Int32)); - - let arrow_field = Field::new("Int64", Int64, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Int64)); - - let arrow_field = Field::new("Float32", Float32, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Float32)); - - let arrow_field = Field::new("Float64", Float64, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Float64)); - - let arrow_field = Field::new("Binary", Binary, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Binary)); - - let arrow_field = Field::new("Utf8", Utf8, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::String)); - - let arrow_field = Field::new("Decimal128", Decimal128(1, 2), true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!( - avro_field.data_type().codec, - Codec::Decimal(1, Some(2), Some(16)) - )); - - let arrow_field = Field::new("Decimal256", Decimal256(1, 2), true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!( - avro_field.data_type().codec, - Codec::Decimal(1, Some(2), Some(32)) - )); - - let arrow_field = Field::new("Date32", Date32, true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Date32)); - - let arrow_field = Field::new("Time32", Time32(TimeUnit::Millisecond), false); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::TimeMillis)); - - let arrow_field = Field::new("Time32", Time64(TimeUnit::Microsecond), false); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::TimeMicros)); - - let arrow_field = Field::new( - "utc_ts_ms", - Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!( - avro_field.data_type().codec, - Codec::TimestampMillis(true) - )); - - let arrow_field = Field::new( - "utc_ts_us", - Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!( - avro_field.data_type().codec, - Codec::TimestampMicros(true) - )); - - let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!( - avro_field.data_type().codec, - Codec::TimestampMillis(false) - )); - - let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!( - avro_field.data_type().codec, - Codec::TimestampMicros(false) - )); - - let arrow_field = Field::new("Interval", Interval(IntervalUnit::MonthDayNano), false); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Duration)); - - let arrow_field = Field::new( - "Struct", - Struct( - vec![ - Field::new("a", Boolean, false), - Field::new("b", Float64, false), - ] - .into(), - ), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 2); - assert_eq!(fields[0].name(), "a"); - assert!(matches!(fields[0].data_type().codec, Codec::Boolean)); - assert_eq!(fields[1].name(), "b"); - assert!(matches!(fields[1].data_type().codec, Codec::Float64)); - } - _ => panic!("Expected Record data type"), - } - - let arrow_field = Field::new( - "DictionaryEnum", - Dictionary(Box::new(Int32), Box::new(Utf8)), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::Enum(_, _))); - - let arrow_field = Field::new( - "DictionaryString", - Dictionary(Box::new(Utf8), Box::new(Boolean)), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert!(matches!(avro_field.data_type().codec, Codec::String)); - - let field = Field::new("Utf8", Utf8, true); - let arrow_field = Field::new("Array with nullable items", List(Arc::new(field)), true); - let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { - assert_eq!(avro_data_type.nullability, Some(Nullability::NullFirst)); - assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec, Codec::String)); - } else { - panic!("Expected Codec::Array"); - } - - let field = Field::new("Utf8", Utf8, false); - let arrow_field = Field::new( - "Array with non-nullable items", - List(Arc::new(field)), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Array(avro_data_type) = &avro_field.data_type().codec { - assert!(avro_data_type.nullability.is_none()); - assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec, Codec::String)); - } else { - panic!("Expected Codec::Array"); - } - - let entries_field = Field::new( - "entries", - Struct( - vec![ - Field::new("key", Utf8, false), - Field::new("value", Utf8, true), - ] - .into(), - ), - false, - ); - let arrow_field = Field::new( - "Map with nullable items", - Map(Arc::new(entries_field), true), - true, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Map(avro_data_type) = &avro_field.data_type().codec { - assert_eq!(avro_data_type.nullability, Some(Nullability::NullFirst)); - assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec, Codec::String)); - } else { - panic!("Expected Codec::Map"); - } - - let arrow_field = Field::new( - "Utf8", - Struct( - vec![ - Field::new("key", Utf8, false), - Field::new("value", Utf8, false), - ] - .into(), - ), - false, - ); - let arrow_field = Field::new( - "Map with non-nullable items", - Map(Arc::new(arrow_field), false), - false, - ); - let avro_field = arrow_field_to_avro_field(&arrow_field); - if let Codec::Map(avro_data_type) = &avro_field.data_type().codec { - assert!(avro_data_type.nullability.is_none()); - assert_eq!(avro_data_type.metadata.len(), 0); - assert!(matches!(avro_data_type.codec, Codec::String)); - } else { - panic!("Expected Codec::Map"); - } - let arrow_field = Field::new("FixedSizeBinary", FixedSizeBinary(8), false); - let avro_field = arrow_field_to_avro_field(&arrow_field); - let codec = &avro_field.data_type().codec; - assert!(matches!(codec, Codec::Fixed(8))); - } - - #[test] - fn test_arrow_field_to_avro_field_meta_namespace() { - let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( - "namespace".to_string(), - "arrow_meta_ns".to_string(), - )])); - let avro_field = arrow_field_to_avro_field(&arrow_field); - assert_eq!(avro_field.name(), "test_meta"); - let actual_str = format!("{:?}", avro_field.data_type().codec); - let expected_str = format!("{:?}", &Codec::String); - assert_eq!(actual_str, expected_str); - let actual_str = format!("{:?}", avro_field.data_type().nullability); - let expected_str = format!("{:?}", Some(Nullability::NullFirst)); - assert_eq!(actual_str, expected_str); - assert_eq!( - avro_field.data_type().metadata.get("namespace"), - Some(&"arrow_meta_ns".to_string()) - ); - } - - #[test] - fn test_union_long_null() { - let json_schema = r#" - { - "type": "record", - "name": "test_long_null", - "fields": [ - {"name": "f0", "type": ["long", "null"]} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "f0"); - let child_dt = fields[0].data_type(); - assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(child_dt.codec, Codec::Int64)); - } - _ => panic!("Expected a record with a single [long,null] field"), - } - let mut resolver = Resolver::default(); - let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); - if let Codec::Record(fields) = &top_dt.codec { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "f0"); - let child_dt = fields[0].data_type(); - assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(child_dt.codec, Codec::Int64)); - } else { - panic!("Expected a record with a single [long,null] field (make_data_type)"); - } - } - - #[test] - fn test_union_array_of_int_null() { - let json_schema = r#" - { - "type":"record", - "name":"test_array_int_null", - "fields":[ - {"name":"arr","type":[{"type":"array","items":["int","null"]},"null"]} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "arr"); - let child_dt = fields[0].data_type(); - assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(item_type) = &child_dt.codec { - assert_eq!(item_type.nullability, Some(Nullability::NullSecond)); - assert!(matches!(item_type.codec, Codec::Int32)); - } else { - panic!("Expected Codec::Array for 'arr' field"); - } - } - _ => panic!("Expected a record with a single union array field"), - } - let mut resolver = Resolver::default(); - let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); - if let Codec::Record(fields) = &top_dt.codec { - assert_eq!(fields.len(), 1); - let arr_dt = fields[0].data_type(); - assert_eq!(arr_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(item_type) = &arr_dt.codec { - assert_eq!(item_type.nullability, Some(Nullability::NullSecond)); - assert!(matches!(item_type.codec, Codec::Int32)); - } else { - panic!("Expected Codec::Array (make_data_type)"); - } - } else { - panic!("Expected record (make_data_type)"); - } - } - - #[test] - fn test_union_nested_array_of_int_null() { - let json_schema = r#" - { - "type":"record", - "name":"test_nested_array_int_null", - "fields":[ - { - "name":"nested_arr", - "type":[ - { - "type":"array", - "items":[ - { - "type":"array", - "items":["int","null"] - }, - "null" - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "nested_arr"); - let outer_dt = fields[0].data_type(); - assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(mid_dt) = &outer_dt.codec { - assert_eq!(mid_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(inner_dt) = &mid_dt.codec { - assert_eq!(inner_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(inner_dt.codec, Codec::Int32)); - } else { - panic!("Expected inner Codec::Array for nested_arr"); - } - } else { - panic!("Expected outer Codec::Array for nested_arr"); - } - } - _ => panic!("Expected a record with a single nested union array field"), - } - let mut resolver = Resolver::default(); - let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); - if let Codec::Record(fields) = &top_dt.codec { - assert_eq!(fields.len(), 1); - let outer_dt = fields[0].data_type(); - assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(mid_dt) = &outer_dt.codec { - assert_eq!(mid_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(inner_dt) = &mid_dt.codec { - assert_eq!(inner_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(inner_dt.codec, Codec::Int32)); - } else { - panic!("Expected inner array (make_data_type)"); - } - } else { - panic!("Expected outer array (make_data_type)"); - } - } else { - panic!("Expected record (make_data_type)"); - } - } - - #[test] - fn test_union_map_of_int_null() { - let json_schema = r#" - { - "type":"record", - "name":"test_map_int_null", - "fields":[ - {"name":"map_field","type":[{"type":"map","values":["int","null"]},"null"]} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - - let avro_field = AvroField::try_from(&schema).unwrap(); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "map_field"); - let map_dt = fields[0].data_type(); - assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Map(value_type) = &map_dt.codec { - assert_eq!(value_type.nullability, Some(Nullability::NullSecond)); - assert!(matches!(value_type.codec, Codec::Int32)); - } else { - panic!("Expected Codec::Map for map_field"); - } - } - _ => panic!("Expected a record with a single union map field"), - } - let mut resolver = Resolver::default(); - let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); - if let Codec::Record(fields) = &top_dt.codec { - assert_eq!(fields.len(), 1); - let map_dt = fields[0].data_type(); - assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Map(val_dt) = &map_dt.codec { - assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(val_dt.codec, Codec::Int32)); - } else { - panic!("Expected map in make_data_type"); - } - } else { - panic!("Expected record in make_data_type"); - } - } - - #[test] - fn test_union_map_array_of_int_null() { - let json_schema = r#" - { - "type":"record", - "name":"test_map_array_int_null", - "fields":[ - { - "name":"map_arr", - "type":[ - { - "type":"array", - "items":[ - { - "type":"map", - "values":["int","null"] - }, - "null" - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "map_arr"); - let outer_dt = fields[0].data_type(); - assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(map_dt) = &outer_dt.codec { - assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Map(val_dt) = &map_dt.codec { - assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(val_dt.codec, Codec::Int32)); - } else { - panic!("Expected Codec::Map for map_arr items"); - } - } else { - panic!("Expected Codec::Array for map_arr"); - } - } - _ => panic!("Expected a record with a single union array-of-map field"), - } - let mut resolver = Resolver::default(); - let top_dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); - if let Codec::Record(fields) = &top_dt.codec { - assert_eq!(fields.len(), 1); - let outer_dt = fields[0].data_type(); - assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Array(map_dt) = &outer_dt.codec { - assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Map(val_dt) = &map_dt.codec { - assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(val_dt.codec, Codec::Int32)); - } else { - panic!("Expected Codec::Map in make_data_type"); - } - } else { - panic!("Expected Codec::Array in make_data_type"); - } - } else { - panic!("Expected record in make_data_type"); - } - } - - #[test] - fn test_union_nested_struct_out_of_spec() { - let json_schema = r#" - { - "type":"record","name":"topLevelRecord","fields":[ - {"name":"nested_struct","type":[ - { - "type":"record", - "name":"nested_struct", - "namespace":"topLevelRecord", - "fields":[ - {"name":"A","type":["int","null"]}, - {"name":"b","type":[{"type":"array","items":["int","null"]},"null"]} - ] - }, - "null" - ]} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); - match &avro_field.data_type().codec { - Codec::Record(fields) => { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "nested_struct"); - let ns_dt = fields[0].data_type(); - assert_eq!(ns_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Record(nested_fields) = &ns_dt.codec { - assert_eq!(nested_fields.len(), 2); - let field_a_dt = nested_fields[0].data_type(); - assert_eq!(field_a_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(field_a_dt.codec, Codec::Int32)); - } else { - panic!("Expected nested_struct to be a Record"); - } - } - _ => panic!("Expected top-level record with a single union-based nested_struct"), - } - let mut resolver = Resolver::default(); - let dt = super::make_data_type(&schema, None, &mut resolver).unwrap(); - if let Codec::Record(fields) = &dt.codec { - assert_eq!(fields.len(), 1); - assert_eq!(fields[0].name(), "nested_struct"); - let ns_dt = fields[0].data_type(); - assert_eq!(ns_dt.nullability, Some(Nullability::NullSecond)); - if let Codec::Record(nested_fields) = &ns_dt.codec { - assert_eq!(nested_fields.len(), 2); - let field_a_dt = nested_fields[0].data_type(); - assert_eq!(field_a_dt.nullability, Some(Nullability::NullSecond)); - assert!(matches!(field_a_dt.codec, Codec::Int32)); - } else { - panic!("Expected nested_struct to be a Record (make_data_type)"); - } - } else { - panic!("Expected top-level record (make_data_type)"); - } - } -} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index 1d04129634b1..d01d681b7af0 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -21,6 +21,7 @@ //! [Apache Avro]: https://avro.apache.org/ #![warn(missing_docs)] +#![allow(unused)] // Temporary pub mod reader; mod schema; @@ -29,8 +30,6 @@ mod compression; mod codec; -pub use self::reader::{Decoder, Reader, ReaderBuilder}; - #[cfg(test)] mod test_util { pub fn arrow_test_data(path: &str) -> String { diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs index 43722da23938..e022031164e7 100644 --- a/arrow-avro/src/reader/block.rs +++ b/arrow-avro/src/reader/block.rs @@ -141,217 +141,3 @@ impl BlockDecoder { } } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow_schema::ArrowError; - use std::convert::TryFrom; - - fn encode_vlq(value: i64) -> Vec { - let mut buf = vec![]; - let mut ux = ((value << 1) ^ (value >> 63)) as u64; // ZigZag - - loop { - let mut byte = (ux & 0x7F) as u8; - ux >>= 7; - if ux != 0 { - byte |= 0x80; - } - buf.push(byte); - if ux == 0 { - break; - } - } - buf - } - - #[test] - fn test_empty_input() { - let mut decoder = BlockDecoder::default(); - let buf = []; - let read = decoder.decode(&buf).unwrap(); - assert_eq!(read, 0); - assert!(decoder.flush().is_none()); - } - - #[test] - fn test_single_block_full_buffer() { - let mut decoder = BlockDecoder::default(); - let count_encoded = encode_vlq(10); - let size_encoded = encode_vlq(4); - let data = vec![1u8, 2, 3, 4]; - let sync_marker = vec![0xAB; 16]; - let mut input = Vec::new(); - input.extend_from_slice(&count_encoded); - input.extend_from_slice(&size_encoded); - input.extend_from_slice(&data); - input.extend_from_slice(&sync_marker); - let read = decoder.decode(&input).unwrap(); - assert_eq!(read, input.len()); - let block = decoder.flush().expect("Should produce a finished block"); - assert_eq!(block.count, 10); - assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); - assert_eq!(block.sync, expected_sync); - } - - #[test] - fn test_single_block_partial_buffer() { - let mut decoder = BlockDecoder::default(); - let count_encoded = encode_vlq(2); - let size_encoded = encode_vlq(3); - let data = vec![10u8, 20, 30]; - let sync_marker = vec![0xCD; 16]; - let mut input = Vec::new(); - input.extend_from_slice(&count_encoded); - input.extend_from_slice(&size_encoded); - input.extend_from_slice(&data); - input.extend_from_slice(&sync_marker); - // Split into 3 parts - let part1 = &input[0..1]; - let part2 = &input[1..2]; - let part3 = &input[2..]; - let read = decoder.decode(part1).unwrap(); - assert_eq!(read, 1); - assert!(decoder.flush().is_none()); - let read = decoder.decode(part2).unwrap(); - assert_eq!(read, 1); - assert!(decoder.flush().is_none()); - let read = decoder.decode(part3).unwrap(); - assert_eq!(read, part3.len()); - let block = decoder.flush().expect("Should produce a finished block"); - assert_eq!(block.count, 2); - assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); - assert_eq!(block.sync, expected_sync); - } - - #[test] - fn test_multiple_blocks_in_one_buffer() { - let mut decoder = BlockDecoder::default(); - // Block1 - let block1_count = encode_vlq(1); - let block1_size = encode_vlq(2); - let block1_data = vec![0x01, 0x02]; - let block1_sync = vec![0xAA; 16]; - // Block2 - let block2_count = encode_vlq(3); - let block2_size = encode_vlq(1); - let block2_data = vec![0x99]; - let block2_sync = vec![0xBB; 16]; - let mut input = Vec::new(); - input.extend_from_slice(&block1_count); - input.extend_from_slice(&block1_size); - input.extend_from_slice(&block1_data); - input.extend_from_slice(&block1_sync); - input.extend_from_slice(&block2_count); - input.extend_from_slice(&block2_size); - input.extend_from_slice(&block2_data); - input.extend_from_slice(&block2_sync); - let read1 = decoder.decode(&input).unwrap(); - let block1 = decoder.flush().expect("First block should be complete"); - assert_eq!(block1.count, 1); - assert_eq!(block1.data, block1_data); - let expected_sync1: [u8; 16] = <[u8; 16]>::try_from(&block1_sync[..16]).unwrap(); - assert_eq!(block1.sync, expected_sync1); - let remainder = &input[read1..]; - decoder.decode(remainder).unwrap(); - let block2 = decoder.flush().expect("Second block should be complete"); - assert_eq!(block2.count, 3); - assert_eq!(block2.data, block2_data); - let expected_sync2: [u8; 16] = <[u8; 16]>::try_from(&block2_sync[..16]).unwrap(); - assert_eq!(block2.sync, expected_sync2); - } - - #[test] - fn test_negative_count_should_error() { - let mut decoder = BlockDecoder::default(); - let bad_count = encode_vlq(-1); - let size = encode_vlq(5); - let mut input = Vec::new(); - input.extend_from_slice(&bad_count); - input.extend_from_slice(&size); - let err = decoder.decode(&input).unwrap_err(); - match err { - ArrowError::ParseError(msg) => { - assert!( - msg.contains("Block count cannot be negative"), - "Expected negative count parse error, got: {msg}" - ); - } - _ => panic!("Unexpected error type: {err:?}"), - } - } - - #[test] - fn test_negative_size_should_error() { - let mut decoder = BlockDecoder::default(); - let count = encode_vlq(5); - let bad_size = encode_vlq(-10); - let mut input = Vec::new(); - input.extend_from_slice(&count); - input.extend_from_slice(&bad_size); - let err = decoder.decode(&input).unwrap_err(); - match err { - ArrowError::ParseError(msg) => { - assert!( - msg.contains("Block size cannot be negative"), - "Expected negative size parse error, got: {msg}" - ); - } - _ => panic!("Unexpected error type: {err:?}"), - } - } - - #[test] - fn test_partial_sync_across_multiple_calls() { - let mut decoder = BlockDecoder::default(); - let count_encoded = encode_vlq(1); - let size_encoded = encode_vlq(2); - let data = vec![0x01, 0x02]; - let sync_marker = vec![0xCC; 16]; - let mut input = Vec::new(); - input.extend_from_slice(&count_encoded); - input.extend_from_slice(&size_encoded); - input.extend_from_slice(&data); - input.extend_from_slice(&sync_marker); - let split_point = input.len() - 4; - let part1 = &input[..split_point]; - let part2 = &input[split_point..]; - let read1 = decoder.decode(part1).unwrap(); - assert_eq!(read1, part1.len()); - assert!(decoder.flush().is_none()); - let read2 = decoder.decode(part2).unwrap(); - assert_eq!(read2, part2.len()); - let block = decoder.flush().expect("Block should be complete now"); - assert_eq!(block.count, 1); - assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); - assert_eq!(block.sync, expected_sync, "Should match [0xCC; 16]"); - } - - #[test] - fn test_already_finished_state() { - let mut decoder = BlockDecoder::default(); - let count_encoded = encode_vlq(2); - let size_encoded = encode_vlq(1); - let data = vec![0xAB]; - let sync_marker = vec![0xFF; 16]; - let mut input = Vec::new(); - input.extend_from_slice(&count_encoded); - input.extend_from_slice(&size_encoded); - input.extend_from_slice(&data); - input.extend_from_slice(&sync_marker); - let read = decoder.decode(&input).unwrap(); - assert_eq!(read, input.len()); - let block = decoder.flush().expect("Should have a block"); - assert_eq!(block.count, 2); - assert_eq!(block.data, data); - let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); - assert_eq!(block.sync, expected_sync); - let read2 = decoder.decode(&[]).unwrap(); - assert_eq!(read2, 0); - assert!(decoder.flush().is_none()); - } -} diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index ca98830be070..9e38a78c63ec 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -135,200 +135,3 @@ impl<'a> AvroCursor<'a> { Ok(ret) } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow_schema::ArrowError; - - #[test] - fn test_new_and_position() { - let data = [1, 2, 3, 4]; - let cursor = AvroCursor::new(&data); - assert_eq!(cursor.position(), 0); - } - - #[test] - fn test_get_u8_ok() { - let data = [0x12, 0x34, 0x56]; - let mut cursor = AvroCursor::new(&data); - assert_eq!(cursor.get_u8().unwrap(), 0x12); - assert_eq!(cursor.position(), 1); - assert_eq!(cursor.get_u8().unwrap(), 0x34); - assert_eq!(cursor.position(), 2); - assert_eq!(cursor.get_u8().unwrap(), 0x56); - assert_eq!(cursor.position(), 3); - } - - #[test] - fn test_get_u8_eof() { - let data = []; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_u8(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF")) - ); - } - - #[test] - fn test_get_bool_ok() { - let data = [0x00, 0x01, 0xFF]; - let mut cursor = AvroCursor::new(&data); - assert!(!cursor.get_bool().unwrap()); // 0x00 -> false - assert!(cursor.get_bool().unwrap()); // 0x01 -> true - assert!(cursor.get_bool().unwrap()); // 0xFF -> true (non-zero) - } - - #[test] - fn test_get_bool_eof() { - let data = []; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_bool(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF")) - ); - } - - #[test] - fn test_read_vlq_ok() { - let data = [0x80, 0x01, 0x05]; - let mut cursor = AvroCursor::new(&data); - let val1 = cursor.read_vlq().unwrap(); - assert_eq!(val1, 128); - let val2 = cursor.read_vlq().unwrap(); - assert_eq!(val2, 5); - } - - #[test] - fn test_read_vlq_bad_varint() { - let data = [0x80]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.read_vlq(); - assert!(matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("bad varint"))); - } - - #[test] - fn test_get_int_ok() { - let data = [0x04, 0x03]; // encodes +2, -2 - let mut cursor = AvroCursor::new(&data); - assert_eq!(cursor.get_int().unwrap(), 2); - assert_eq!(cursor.get_int().unwrap(), -2); - } - - #[test] - fn test_get_int_overflow() { - let data = [0x80, 0x80, 0x80, 0x80, 0x10]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_int(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("varint overflow")) - ); - } - - #[test] - fn test_get_long_ok() { - let data = [0x04, 0x03, 0xAC, 0x02]; - let mut cursor = AvroCursor::new(&data); - assert_eq!(cursor.get_long().unwrap(), 2); - assert_eq!(cursor.get_long().unwrap(), -2); - assert_eq!(cursor.get_long().unwrap(), 150); - } - - #[test] - fn test_get_long_eof() { - let data = [0x80]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_long(); - assert!(matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("bad varint"))); - } - - #[test] - fn test_get_bytes_ok() { - let data = [0x06, 0xAA, 0xBB, 0xCC, 0x05, 0x01]; - let mut cursor = AvroCursor::new(&data); - let bytes = cursor.get_bytes().unwrap(); - assert_eq!(bytes, [0xAA, 0xBB, 0xCC]); - assert_eq!(cursor.position(), 4); - } - - #[test] - fn test_get_bytes_overflow() { - let data = [0xAC, 0x02, 0x01, 0x02, 0x03]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_bytes(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading bytes")) - ); - } - - #[test] - fn test_get_bytes_negative_length() { - let data = [0x01]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_bytes(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("offset overflow")) - ); - } - - #[test] - fn test_get_float_ok() { - let data = [0x00, 0x00, 0x80, 0x3F, 0x01]; - let mut cursor = AvroCursor::new(&data); - let val = cursor.get_float().unwrap(); - assert!((val - 1.0).abs() < f32::EPSILON); - assert_eq!(cursor.position(), 4); - } - - #[test] - fn test_get_float_eof() { - let data = [0x00, 0x00, 0x80]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_float(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading float")) - ); - } - - #[test] - fn test_get_double_ok() { - let data = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, 0x99]; - let mut cursor = AvroCursor::new(&data); - let val = cursor.get_double().unwrap(); - assert!((val - 1.0).abs() < f64::EPSILON); - assert_eq!(cursor.position(), 8); - } - - #[test] - fn test_get_double_eof() { - let data = [0x00, 0x00, 0x00, 0x00]; // only 4 bytes - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_double(); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading double")) - ); - } - - #[test] - fn test_get_fixed_ok() { - let data = [0x11, 0x22, 0x33, 0x44]; - let mut cursor = AvroCursor::new(&data); - let val = cursor.get_fixed(2).unwrap(); - assert_eq!(val, [0x11, 0x22]); - assert_eq!(cursor.position(), 2); - - let val = cursor.get_fixed(2).unwrap(); - assert_eq!(val, [0x33, 0x44]); - assert_eq!(cursor.position(), 4); - } - - #[test] - fn test_get_fixed_eof() { - let data = [0x11, 0x22]; - let mut cursor = AvroCursor::new(&data); - let result = cursor.get_fixed(3); - assert!( - matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading fixed")) - ); - } -} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index ecb53f1f101b..99f2163fa5bb 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -351,35 +351,4 @@ mod test { 325166208089902833952788552656412487328 ); } - #[test] - fn test_header_schema_default() { - let json_schema = r#" - { - "type": "record", - "name": "TestRecord", - "fields": [ - {"name": "a", "type": "int", "default": 10} - ] - } - "#; - let key = "avro.schema"; - let key_bytes = key.as_bytes(); - let value_bytes = json_schema.as_bytes(); - let mut meta_buf = Vec::new(); - meta_buf.extend_from_slice(key_bytes); - meta_buf.extend_from_slice(value_bytes); - let meta_offsets = vec![key_bytes.len(), key_bytes.len() + value_bytes.len()]; - let header = Header { - meta_offsets, - meta_buf, - sync: [0; 16], - }; - let schema = header.schema().unwrap().unwrap(); - if let Schema::Complex(crate::schema::ComplexType::Record(record)) = schema { - assert_eq!(record.fields.len(), 1); - assert_eq!(record.fields[0].default, Some(serde_json::json!(10))); - } else { - panic!("Expected record schema"); - } - } } diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index eec5ff8d0d95..4d0cbb035088 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -14,133 +14,22 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +//! Read Avro data to Arrow -//! Avro reader -//! -//! This module provides facilities to read Apache Avro-encoded files or streams -//! into Arrow's [`RecordBatch`] format. In particular, it introduces: -//! -//! * [`ReaderBuilder`]: Configures Avro reading, e.g., batch size -//! * [`Reader`]: Yields [`RecordBatch`] values, implementing [`Iterator`] -//! * [`Decoder`]: A low-level push-based decoder for Avro records -//! -//! # Basic Usage -//! -//! [`Reader`] can be used directly with synchronous data sources, such as [`std::fs::File`]. -//! -//! ## Reading a Single Batch -//! -//! ``` -//! # use std::fs::File; -//! # use std::io::BufReader; -//! -//! let file = File::open("test/data/simple_enum.avro").unwrap(); -//! let mut avro = arrow_avro::ReaderBuilder::new().build(BufReader::new(file)).unwrap(); -//! let batch = avro.next().unwrap().unwrap(); -//! ``` -//! -//! # Async Usage -//! -//! The lower-level [`Decoder`] can be integrated with various forms of async data streams, -//! and is designed to be agnostic to different async IO primitives within -//! the Rust ecosystem. It works by incrementally decoding Avro data from byte slices. -//! -//! For example, see below for how it could be used with an arbitrary `Stream` of `Bytes`: -//! -//! ``` -//! # use std::task::{Poll, ready}; -//! # use bytes::{Buf, Bytes}; -//! # use arrow_schema::ArrowError; -//! # use futures::stream::{Stream, StreamExt}; -//! # use arrow_array::RecordBatch; -//! # use arrow_avro::reader::Decoder; -//! # -//! fn decode_stream + Unpin>( -//! mut decoder: Decoder, -//! mut input: S, -//! ) -> impl Stream> { -//! let mut buffered = Bytes::new(); -//! futures::stream::poll_fn(move |cx| { -//! loop { -//! if buffered.is_empty() { -//! buffered = match ready!(input.poll_next_unpin(cx)) { -//! Some(b) => b, -//! None => break, -//! }; -//! } -//! let decoded = match decoder.decode(buffered.as_ref()) { -//! Ok(decoded) => decoded, -//! Err(e) => return Poll::Ready(Some(Err(e))), -//! }; -//! let read = buffered.len(); -//! buffered.advance(decoded); -//! if decoded != read { -//! break -//! } -//! } -//! // Convert any fully-decoded rows to a RecordBatch, if available -//! Poll::Ready(decoder.flush().transpose()) -//! }) -//! } -//! ``` -//! -//! In a similar vein, it can also be used with tokio-based IO primitives -//! -//! ``` -//! # use std::sync::Arc; -//! # use arrow_schema::{DataType, Field, Schema}; -//! # use std::pin::Pin; -//! # use std::task::{Poll, ready}; -//! # use futures::{Stream, TryStreamExt}; -//! # use tokio::io::AsyncBufRead; -//! # use arrow_array::RecordBatch; -//! # use arrow_avro::reader::Decoder; -//! # use arrow_schema::ArrowError; -//! fn decode_stream( -//! mut decoder: Decoder, -//! mut reader: R, -//! ) -> impl Stream> { -//! futures::stream::poll_fn(move |cx| { -//! loop { -//! let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) { -//! Ok(b) if b.is_empty() => break, -//! Ok(b) => b, -//! Err(e) => return Poll::Ready(Some(Err(e.into()))), -//! }; -//! let read = b.len(); -//! let decoded = match decoder.decode(b) { -//! Ok(decoded) => decoded, -//! Err(e) => return Poll::Ready(Some(Err(e))), -//! }; -//! Pin::new(&mut reader).consume(decoded); -//! if decoded != read { -//! break; -//! } -//! } -//! -//! Poll::Ready(decoder.flush().transpose()) -//! }) -//! } -//! ``` -//! - -use arrow_array::{RecordBatch, RecordBatchReader}; -use arrow_schema::{ArrowError, SchemaRef}; +use crate::reader::block::{Block, BlockDecoder}; +use crate::reader::header::{Header, HeaderDecoder}; +use arrow_schema::ArrowError; use std::io::BufRead; +mod header; + mod block; + mod cursor; -mod header; mod record; mod vlq; -use crate::codec::AvroField; -use crate::schema::Schema as AvroSchema; -use block::BlockDecoder; -use header::{Header, HeaderDecoder}; -use record::RecordDecoder; - -/// Read the Avro file header (magic, metadata, sync marker) from `reader`. +/// Read a [`Header`] from the provided [`BufRead`] fn read_header(mut reader: R) -> Result { let mut decoder = HeaderDecoder::default(); loop { @@ -155,262 +44,41 @@ fn read_header(mut reader: R) -> Result { break; } } - decoder.flush().ok_or_else(|| { - ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string()) - }) -} - -/// A low-level interface for decoding Avro-encoded bytes into Arrow [`RecordBatch`]. -#[derive(Debug)] -pub struct Decoder { - record_decoder: RecordDecoder, - batch_size: usize, - decoded_rows: usize, -} - -impl Decoder { - /// Create a new [`Decoder`], wrapping an existing [`RecordDecoder`]. - pub fn new(record_decoder: RecordDecoder, batch_size: usize) -> Self { - Self { - record_decoder, - batch_size, - decoded_rows: 0, - } - } - - /// Return the Arrow schema for the rows decoded by this decoder - pub fn schema(&self) -> SchemaRef { - self.record_decoder.schema().clone() - } - - /// Return the configured maximum number of rows per batch - pub fn batch_size(&self) -> usize { - self.batch_size - } - - /// Feed `data` into the decoder row by row until we either: - /// - consume all bytes in `data`, or - /// - reach `batch_size` decoded rows. - /// - /// Returns the number of bytes consumed. - pub fn decode(&mut self, data: &[u8]) -> Result { - let mut total_consumed = 0usize; - while total_consumed < data.len() && self.decoded_rows < self.batch_size { - let consumed = self.record_decoder.decode(&data[total_consumed..], 1)?; - if consumed == 0 { - break; - } - total_consumed += consumed; - self.decoded_rows += 1; - } - Ok(total_consumed) - } - - /// Produce a [`RecordBatch`] if at least one row is fully decoded, returning - /// `Ok(None)` if no new rows are available. - pub fn flush(&mut self) -> Result, ArrowError> { - if self.decoded_rows == 0 { - Ok(None) - } else { - let batch = self.record_decoder.flush()?; - self.decoded_rows = 0; - Ok(Some(batch)) - } - } -} - -/// A builder to create an [`Avro Reader`](Reader) that reads Avro data -/// into Arrow [`RecordBatch`]. -#[derive(Debug)] -pub struct ReaderBuilder { - batch_size: usize, - strict_mode: bool, -} - -impl Default for ReaderBuilder { - fn default() -> Self { - Self { - batch_size: 1024, - strict_mode: false, - } - } -} - -impl ReaderBuilder { - /// Creates a new [`ReaderBuilder`] with default settings: - /// - `batch_size` = 1024 - /// - `strict_mode` = false - pub fn new() -> Self { - Self::default() - } - - /// Sets the row-based batch size - pub fn with_batch_size(mut self, batch_size: usize) -> Self { - self.batch_size = batch_size; - self - } - - /// Controls whether certain Avro unions of the form `[T, "null"]` should produce an error. - pub fn with_strict_mode(mut self, strict_mode: bool) -> Self { - self.strict_mode = strict_mode; - self - } - - /// Create a [`Reader`] from this builder and a `BufRead` - pub fn build(self, mut reader: R) -> Result, ArrowError> { - let header = read_header(&mut reader)?; - let compression = header.compression()?; - let avro_schema: Option> = header - .schema() - .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; - let avro_schema = avro_schema.ok_or_else(|| { - ArrowError::ParseError("No Avro schema present in file header".to_string()) - })?; - let root_field = AvroField::try_from(&avro_schema)?; - let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; - let decoder = Decoder::new(record_decoder, self.batch_size); - Ok(Reader { - reader, - header, - compression, - decoder, - block_decoder: BlockDecoder::default(), - block_data: Vec::new(), - finished: false, - }) - } - - /// Create a [`Decoder`] from this builder and a `BufRead` by - /// reading and parsing the Avro file's header. This will - /// not create a full [`Reader`]. - pub fn build_decoder(self, mut reader: R) -> Result { - let header = read_header(&mut reader)?; - let avro_schema: Option> = header - .schema() - .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; - - let avro_schema = avro_schema.ok_or_else(|| { - ArrowError::ParseError("No Avro schema present in file header".to_string()) - })?; - let root_field = AvroField::try_from(&avro_schema)?; - let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; - Ok(Decoder::new(record_decoder, self.batch_size)) - } -} - -/// A high-level Avro `Reader` that reads container-file blocks -/// and feeds them into a row-level [`Decoder`]. -#[derive(Debug)] -pub struct Reader { - reader: R, - header: Header, - compression: Option, - decoder: Decoder, - block_decoder: BlockDecoder, - block_data: Vec, - finished: bool, -} - -impl Reader { - /// Return the Arrow schema discovered from the Avro file header - pub fn schema(&self) -> SchemaRef { - self.decoder.schema() - } - - /// Return the Avro container-file header - pub fn avro_header(&self) -> &Header { - &self.header - } + decoder + .flush() + .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) } -impl Reader { - /// Reads the next [`RecordBatch`] from the Avro file or `Ok(None)` on EOF - fn read(&mut self) -> Result, ArrowError> { - if self.finished { - return Ok(None); - } +/// Return an iterator of [`Block`] from the provided [`BufRead`] +fn read_blocks(mut reader: R) -> impl Iterator> { + let mut decoder = BlockDecoder::default(); + let mut try_next = move || { loop { - if !self.block_data.is_empty() { - let consumed = self.decoder.decode(&self.block_data)?; - if consumed > 0 { - self.block_data.drain(..consumed); - } - match self.decoder.flush()? { - None => { - if !self.block_data.is_empty() { - break; - } - } - Some(batch) => { - return Ok(Some(batch)); - } - } + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; } - let maybe_block = { - let buf = self.reader.fill_buf()?; - if buf.is_empty() { - None - } else { - let read_len = buf.len(); - let consumed_len = self.block_decoder.decode(buf)?; - self.reader.consume(consumed_len); - if consumed_len == 0 && read_len != 0 { - return Err(ArrowError::ParseError( - "Could not decode next Avro block from partial data".to_string(), - )); - } - self.block_decoder.flush() - } - }; - match maybe_block { - Some(block) => { - let block_data = if let Some(ref codec) = self.compression { - codec.decompress(&block.data)? - } else { - block.data - }; - self.block_data = block_data; - } - None => { - self.finished = true; - if !self.block_data.is_empty() { - let consumed = self.decoder.decode(&self.block_data)?; - self.block_data.drain(..consumed); - } - return self.decoder.flush(); - } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; } } - self.decoder.flush() - } -} - -impl Iterator for Reader { - type Item = Result; - - fn next(&mut self) -> Option { - match self.read() { - Ok(Some(batch)) => Some(Ok(batch)), - Ok(None) => None, - Err(e) => Some(Err(e)), - } - } -} - -impl RecordBatchReader for Reader { - fn schema(&self) -> SchemaRef { - self.schema() - } + Ok(decoder.flush()) + }; + std::iter::from_fn(move || try_next().transpose()) } #[cfg(test)] mod test { - use super::*; - use crate::reader::vlq::VLQDecoder; + use crate::codec::AvroField; + use crate::reader::record::RecordDecoder; + use crate::reader::{read_blocks, read_header}; use crate::test_util::arrow_test_data; use arrow_array::builder::{ ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, - ListBuilder, MapBuilder, MapFieldNames, StringBuilder, StructBuilder, + ListBuilder, MapBuilder, StringBuilder, StructBuilder, }; use arrow_array::types::Int32Type; use arrow_array::{ @@ -421,143 +89,40 @@ mod test { use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{DataType, Field, Fields, Schema}; - use bytes::{Buf, Bytes}; - use futures::{stream, Stream, StreamExt, TryStreamExt}; use std::collections::HashMap; - use std::fs; use std::fs::File; - use std::io::{BufReader, Cursor}; + use std::io::BufReader; use std::sync::Arc; - use std::task::{ready, Poll}; - - fn read_file(path: &str, _schema: Option) -> super::Reader> { - let file = File::open(path).unwrap(); - let reader = BufReader::new(file); - let builder = ReaderBuilder::new().with_batch_size(64); - builder.build(reader).unwrap() - } - fn decode_stream + Unpin>( - mut decoder: Decoder, - mut input: S, - ) -> impl Stream> { - let mut buffered = Bytes::new(); - futures::stream::poll_fn(move |cx| { - loop { - if buffered.is_empty() { - buffered = match ready!(input.poll_next_unpin(cx)) { - Some(b) => b, - None => break, - }; - } - let decoded = match decoder.decode(buffered.as_ref()) { - Ok(decoded) => decoded, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - let read = buffered.len(); - buffered.advance(decoded); - if decoded != read { - break; - } + /// Helper to read an Avro file into a `RecordBatch`. + /// + /// - `strict_mode`: if `true`, we reject unions of the form `[T,"null"]`. + fn read_file(file: &str, batch_size: usize, strict_mode: bool) -> RecordBatch { + let file = File::open(file).unwrap(); + let mut reader = BufReader::new(file); + let header = read_header(&mut reader).unwrap(); + let compression = header.compression().unwrap(); + let schema = header.schema().unwrap().unwrap(); + let root = AvroField::try_from(&schema).unwrap(); + let mut decoder = RecordDecoder::try_new(root.data_type(), strict_mode).unwrap(); + for result in read_blocks(reader) { + let block = result.unwrap(); + assert_eq!(block.sync, header.sync()); + let block_data = if let Some(c) = compression { + c.decompress(&block.data).unwrap() + } else { + block.data + }; + let mut offset = 0; + let mut remaining = block.count; + while remaining > 0 { + let to_read = remaining.min(batch_size); + offset += decoder.decode(&block_data[offset..], to_read).unwrap(); + remaining -= to_read; } - Poll::Ready(decoder.flush().transpose()) - }) - } - - #[test] - fn test_basic_usage_single_batch() { - let file = File::open(arrow_test_data("avro/simple_enum.avro")) - .expect("Failed to open test/data/simple_enum.avro"); - let mut avro = ReaderBuilder::new() - .build(BufReader::new(file)) - .expect("Failed to build Avro Reader"); - - let batch = avro - .next() - .expect("No batch found?") - .expect("Error reading batch"); - - assert!(batch.num_rows() > 0, "Expected at least 1 row"); - assert!(batch.num_columns() > 0, "Expected at least 1 column"); - } - - #[test] - fn test_reader_read() -> Result<(), ArrowError> { - let file_path = "test/data/simple_enum.avro"; - let file = File::open(file_path).expect("Failed to open Avro file"); - let mut reader_direct = ReaderBuilder::new() - .build(BufReader::new(file)) - .expect("Failed to build Reader"); - let mut direct_batches = Vec::new(); - while let Some(batch) = reader_direct.read()? { - direct_batches.push(batch); + assert_eq!(offset, block_data.len()); } - let file = File::open(file_path).expect("Failed to open Avro file"); - let reader_iter = ReaderBuilder::new() - .build(BufReader::new(file)) - .expect("Failed to build Reader"); - let iter_batches: Result, _> = reader_iter.collect(); - let iter_batches = iter_batches?; - assert_eq!(direct_batches, iter_batches); - Ok(()) - } - - #[tokio::test] - async fn test_async_decoder_with_bytes_stream() -> Result<(), ArrowError> { - let path = arrow_test_data("avro/simple_enum.avro"); - let data = fs::read(&path).expect("Failed to read .avro file"); - let mut cursor = Cursor::new(&data); - let decoder: Decoder = ReaderBuilder::new().build_decoder(&mut cursor)?; - let header_consumed = cursor.position() as usize; - let mut remainder = &data[header_consumed..]; - let mut vlq_dec = VLQDecoder::default(); - let _block_count_i64 = vlq_dec - .long(&mut remainder) - .ok_or_else(|| ArrowError::ParseError("EOF reading block count".to_string()))?; - let block_size_i64 = vlq_dec - .long(&mut remainder) - .ok_or_else(|| ArrowError::ParseError("EOF reading block size".to_string()))?; - let block_size = block_size_i64 as usize; - if remainder.len() < block_size { - return Err(ArrowError::ParseError(format!( - "File truncated: Needed {} bytes for block data, got {}", - block_size, - remainder.len() - ))); - } - let block_data = &remainder[..block_size]; - remainder = &remainder[block_size..]; - if remainder.len() < 16 { - return Err(ArrowError::ParseError( - "Missing sync marker in Avro block".to_string(), - )); - } - let _sync_marker = &remainder[..16]; - let _remainder = &remainder[16..]; - let chunks = block_data - .chunks(16) - .map(Bytes::copy_from_slice) - .collect::>(); - let input_stream = stream::iter(chunks); - let record_batch_stream = decode_stream(decoder, input_stream); - let batches: Vec<_> = record_batch_stream.try_collect().await?; - assert!( - !batches.is_empty(), - "Should decode at least one batch from the block" - ); - let file = File::open(&path).unwrap(); - let mut sync_reader = ReaderBuilder::new() - .build(BufReader::new(file)) - .expect("Could not build sync_reader"); - let expected_batch = sync_reader - .next() - .expect("No batch in file") - .expect("Sync decode failed"); - assert_eq!( - batches[0], expected_batch, - "Async decode differs from sync decode" - ); - Ok(()) + decoder.flush().unwrap() } #[test] @@ -618,14 +183,14 @@ mod test { ( "date_string_col", Arc::new(BinaryArray::from_iter_values([ - b"03/01/09", - b"03/01/09", - b"04/01/09", - b"04/01/09", - b"02/01/09", - b"02/01/09", - b"01/01/09", - b"01/01/09", + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], ])) as _, true, ), @@ -655,12 +220,8 @@ mod test { .unwrap(); for file in files { let file = arrow_test_data(file); - let mut reader = read_file(&file, None); - let batch_large = reader.next().unwrap().unwrap(); - assert_eq!(batch_large, expected); - let mut reader_small = read_file(&file, None); - let batch_small = reader_small.next().unwrap().unwrap(); - assert_eq!(batch_small, expected); + assert_eq!(read_file(&file, 8, false), expected); + assert_eq!(read_file(&file, 3, false), expected); } } @@ -724,18 +285,16 @@ mod test { ]) .unwrap(); let file_path = arrow_test_data(file); - let mut reader = read_file(&file_path, None); - let batch_large = reader.next().unwrap().unwrap(); + let batch_large = read_file(&file_path, 8, false); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match for file {}", file ); - let mut reader_small = read_file(&file_path, None); - let batch_small = reader_small.next().unwrap().unwrap(); + let batch_small = read_file(&file_path, 3, false); assert_eq!( batch_small, expected, - "Decoded RecordBatch (batch size 64) does not match for file {}", + "Decoded RecordBatch (batch size 3) does not match for file {}", file ); } @@ -778,18 +337,16 @@ mod test { ]) .unwrap(); let file_path = arrow_test_data(file); - let mut reader = read_file(&file_path, None); - let batch_large = reader.next().unwrap().unwrap(); + let batch_large = read_file(&file_path, 8, false); assert_eq!( batch_large, expected, "Decoded RecordBatch does not match for file {}", file ); - let mut reader_small = read_file(&file_path, None); - let batch_small = reader_small.next().unwrap().unwrap(); + let batch_small = read_file(&file_path, 3, false); assert_eq!( batch_small, expected, - "Decoded RecordBatch does not match for file {}", + "Decoded RecordBatch (batch size 3) does not match for file {}", file ); } @@ -797,8 +354,7 @@ mod test { #[test] fn test_binary() { let file = arrow_test_data("avro/binary.avro"); - let mut reader = read_file(&file, None); - let batch = reader.next().unwrap().unwrap(); + let batch = read_file(&file, 8, false); let expected = RecordBatch::try_from_iter_with_nullable([( "foo", Arc::new(BinaryArray::from_iter_values(vec![ @@ -830,40 +386,39 @@ mod test { ("avro/int64_decimal.avro", 10, 2), ]; let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); - for (file, precision, scale) in files { let file_path = arrow_test_data(file); - let mut reader = read_file(&file_path, None); - let actual_batch = reader.next().unwrap().unwrap(); - + let actual_batch = read_file(&file_path, 8, false); let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) .with_precision_and_scale(precision, scale) .unwrap(); - let mut meta = HashMap::new(); meta.insert("precision".to_string(), precision.to_string()); meta.insert("scale".to_string(), scale.to_string()); let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) .with_metadata(meta); - let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); let expected_batch = RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) .expect("Failed to build expected RecordBatch"); - assert_eq!( actual_batch, expected_batch, "Decoded RecordBatch does not match the expected Decimal128 data for file {}", file ); + let actual_batch_small = read_file(&file_path, 3, false); + assert_eq!( + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", + file + ); } } #[test] fn test_datapage_v2() { let file = arrow_test_data("avro/datapage_v2.snappy.avro"); - let mut reader = read_file(&file, None); - let batch = reader.next().unwrap().unwrap(); + let batch = read_file(&file, 8, false); let a = StringArray::from(vec![ Some("abc"), Some("abc"), @@ -908,10 +463,8 @@ mod test { #[test] fn test_dict_pages_offset_zero() { let file = arrow_test_data("avro/dict-page-offset-zero.avro"); - let mut reader = read_file(&file, None); - let batch = reader.next().unwrap().unwrap(); + let batch = read_file(&file, 32, false); let num_rows = batch.num_rows(); - let expected_field = Int32Array::from(vec![Some(1552); num_rows]); let expected = RecordBatch::try_from_iter_with_nullable([( "l_partkey", @@ -925,7 +478,6 @@ mod test { #[test] fn test_list_columns() { let file = arrow_test_data("avro/list_columns.avro"); - let mut reader = read_file(&file, None); let mut int64_list_builder = ListBuilder::new(Int64Builder::new()); { { @@ -981,15 +533,13 @@ mod test { ("utf8_list", Arc::new(utf8_list) as Arc, true), ]) .unwrap(); - let batch = reader.next().unwrap().unwrap(); + let batch = read_file(&file, 8, false); assert_eq!(batch, expected); } #[test] fn test_nested_lists() { let file = arrow_test_data("avro/nested_lists.snappy.avro"); - let mut reader = read_file(&file, None); - let left = reader.next().unwrap().unwrap(); let inner_values = StringArray::from(vec![ Some("a"), Some("b"), @@ -1034,7 +584,7 @@ mod test { .unwrap(); let middle_list_array = ListArray::from(middle_list_data); let outer_offsets = Buffer::from_slice_ref([0, 2, 4, 6]); - let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all valid + let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all 3 rows valid let outer_field = Field::new("item", middle_list_array.data_type().clone(), true); let outer_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(outer_field))) .len(3) @@ -1050,14 +600,14 @@ mod test { ("b", Arc::new(b_expected) as Arc, true), ]) .unwrap(); - assert_eq!(left, expected, "Mismatch for batch size=64"); + let left = read_file(&file, 8, false); + assert_eq!(left, expected, "Mismatch for batch size=8"); + let left_small = read_file(&file, 3, false); + assert_eq!(left_small, expected, "Mismatch for batch size=3"); } #[test] fn test_nested_records() { - let file = arrow_test_data("avro/nested_records.avro"); - let mut reader = read_file(&file, None); - let batch = reader.next().unwrap().unwrap(); let f1_f1_1 = StringArray::from(vec!["aaa", "bbb"]); let f1_f1_2 = Int32Array::from(vec![10, 20]); let rounded_pi = (std::f64::consts::PI * 100.0).round() / 100.0; @@ -1066,7 +616,6 @@ mod test { Arc::new(Field::new("f1_3_1", DataType::Float64, false)), Arc::new(f1_f1_3_1) as Arc, )]); - let f1_expected = StructArray::from(vec![ ( Arc::new(Field::new("f1_1", DataType::Utf8, false)), @@ -1099,8 +648,8 @@ mod test { .map(|f| Arc::new(f.clone())) .collect::>>(), vec![ - Box::new(BooleanBuilder::new()) as Box, - Box::new(Float32Builder::new()) as Box, + Box::new(BooleanBuilder::new()) as Box, + Box::new(Float32Builder::new()) as Box, ], ); let mut f2_list_builder = ListBuilder::new(f2_struct_builder); @@ -1142,7 +691,7 @@ mod test { let f2_expected = f2_list_builder.finish(); let mut f3_struct_builder = StructBuilder::new( vec![Arc::new(Field::new("f3_1", DataType::Utf8, false))], - vec![Box::new(StringBuilder::new()) as Box], + vec![Box::new(StringBuilder::new()) as Box], ); f3_struct_builder.append(true); { @@ -1200,20 +749,29 @@ mod test { ("f4", Arc::new(f4_expected) as Arc, false), ]) .unwrap(); - assert_eq!(batch, expected, "Mismatch in nested_records.avro contents"); + let file = arrow_test_data("avro/nested_records.avro"); + let batch_large = read_file(&file, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 8)" + ); + let batch_small = read_file(&file, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 3)" + ); } #[test] fn test_nonnullable_impala() { let file = arrow_test_data("avro/nonnullable.impala.avro"); - let mut reader = read_file(&file, None); let id = Int64Array::from(vec![Some(8)]); let mut int_array_builder = ListBuilder::new(Int32Builder::new()); { let vb = int_array_builder.values(); vb.append_value(-1); } - int_array_builder.append(true); + int_array_builder.append(true); // finalize one sub-list let int_array = int_array_builder.finish(); let mut iaa_builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); { @@ -1228,6 +786,7 @@ mod test { } iaa_builder.append(true); let int_array_array = iaa_builder.finish(); + use arrow_array::builder::MapFieldNames; let field_names = MapFieldNames { entry: "entries".to_string(), key: "key".to_string(), @@ -1240,7 +799,7 @@ mod test { keys.append_value("k1"); vals.append_value(-1); } - int_map_builder.append(true).unwrap(); + int_map_builder.append(true).unwrap(); // finalize map for row 0 let int_map = int_map_builder.finish(); let field_names2 = MapFieldNames { entry: "entries".to_string(), @@ -1266,95 +825,108 @@ mod test { } ima_builder.append(true); let int_map_array_ = ima_builder.finish(); - let nested_schema_fields = vec![ - Field::new("a", DataType::Int32, true), - Field::new( - "B", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - ), - Field::new( - "c", - DataType::Struct(Fields::from(vec![Field::new( - "D", - DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new( - "item", - DataType::Struct(Fields::from(vec![ - Field::new("e", DataType::Int32, true), - Field::new("f", DataType::Utf8, true), - ])), - true, - ))), - true, - ))), + let mut nested_sb = StructBuilder::new( + vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new( + "B", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, - )])), - true, - ), - Field::new( - "G", - DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct(Fields::from(vec![ - Field::new("key", DataType::Utf8, false), - Field::new( - "value", - DataType::Struct(Fields::from(vec![Field::new( - "h", - DataType::Struct(Fields::from(vec![Field::new( - "i", - DataType::List(Arc::new(Field::new( - "item", - DataType::Float64, - true, - ))), - true, - )])), + )), + Arc::new(Field::new( + "c", + DataType::Struct( + vec![Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), true, - )])), + ))), true, + ))), + true, + )] + .into(), + ), + true, + )), + Arc::new(Field::new( + "G", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct( + vec![Field::new( + "h", + DataType::Struct( + vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )] + .into(), + ), + true, + )] + .into(), + ), + true, + ), + ] + .into(), ), - ])), + false, + )), false, - )), - false, - ), - true, - ), - ]; - let nested_schema = Arc::new(Schema::new(nested_schema_fields.clone())); - let mut nested_sb = StructBuilder::new( - nested_schema_fields - .iter() - .map(|f| Arc::new(f.clone())) - .collect::>(), + ), + true, + )), + ], vec![ Box::new(Int32Builder::new()), Box::new(ListBuilder::new(Int32Builder::new())), { - let d_list_field = Field::new( + let d_field = Field::new( "D", DataType::List(Arc::new(Field::new( "item", DataType::List(Arc::new(Field::new( "item", - DataType::Struct(Fields::from(vec![ - Field::new("e", DataType::Int32, true), - Field::new("f", DataType::Utf8, true), - ])), + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), true, ))), true, ))), true, ); - let struct_c_builder = StructBuilder::new( - vec![Arc::new(d_list_field)], - vec![Box::new(ListBuilder::new(ListBuilder::new( - StructBuilder::new( + Box::new(StructBuilder::new( + vec![Arc::new(d_field)], + vec![Box::new({ + let ef_struct_builder = StructBuilder::new( vec![ Arc::new(Field::new("e", DataType::Int32, true)), Arc::new(Field::new("f", DataType::Utf8, true)), @@ -1363,35 +935,32 @@ mod test { Box::new(Int32Builder::new()), Box::new(StringBuilder::new()), ], - ), - )))], - ); - Box::new(struct_c_builder) + ); + let list_of_ef = ListBuilder::new(ef_struct_builder); + ListBuilder::new(list_of_ef) + })], + )) }, { - Box::new(MapBuilder::new( - Some(MapFieldNames { - entry: "entries".to_string(), - key: "key".to_string(), - value: "value".to_string(), - }), - StringBuilder::new(), - StructBuilder::new( - vec![Arc::new(Field::new( - "h", - DataType::Struct(Fields::from(vec![Field::new( - "i", - DataType::List(Arc::new(Field::new( - "item", - DataType::Float64, - true, - ))), - true, - )])), - true, - ))], - vec![Box::new(StructBuilder::new( - vec![Arc::new(Field::new( + let map_field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let i_list_builder = ListBuilder::new(Float64Builder::new()); + let h_struct = StructBuilder::new( + vec![Arc::new(Field::new( + "i", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + true, + ))], + vec![Box::new(i_list_builder)], + ); + let g_value_builder = StructBuilder::new( + vec![Arc::new(Field::new( + "h", + DataType::Struct( + vec![Field::new( "i", DataType::List(Arc::new(Field::new( "item", @@ -1399,10 +968,17 @@ mod test { true, ))), true, - ))], - vec![Box::new(ListBuilder::new(Float64Builder::new()))], - ))], - ), + )] + .into(), + ), + true, + ))], + vec![Box::new(h_struct)], + ); + Box::new(MapBuilder::new( + Some(map_field_names), + StringBuilder::new(), + g_value_builder, )) }, ], @@ -1411,6 +987,8 @@ mod test { { let a_builder = nested_sb.field_builder::(0).unwrap(); a_builder.append_value(-1); + } + { let b_builder = nested_sb .field_builder::>(1) .unwrap(); @@ -1419,131 +997,57 @@ mod test { vb.append_value(-1); } b_builder.append(true); - let c_sb = nested_sb.field_builder::(2).unwrap(); - c_sb.append(true); + } + { + let c_struct_builder = nested_sb.field_builder::(2).unwrap(); + c_struct_builder.append(true); + let d_list_builder = c_struct_builder + .field_builder::>>(0) + .unwrap(); { - let d_list_builder = c_sb - .field_builder::>>(0) - .unwrap(); + let sub_list_builder = d_list_builder.values(); { - let sub_list_builder = d_list_builder.values(); + let ef_struct = sub_list_builder.values(); + ef_struct.append(true); { - let ef_struct_builder = sub_list_builder.values(); - ef_struct_builder.append(true); - { - let e_b = ef_struct_builder.field_builder::(0).unwrap(); - e_b.append_value(-1); - let f_b = ef_struct_builder.field_builder::(1).unwrap(); - f_b.append_value("nonnullable"); - } - sub_list_builder.append(true); + let e_b = ef_struct.field_builder::(0).unwrap(); + e_b.append_value(-1); + let f_b = ef_struct.field_builder::(1).unwrap(); + f_b.append_value("nonnullable"); } - d_list_builder.append(true); + sub_list_builder.append(true); } + d_list_builder.append(true); } + } + { let g_map_builder = nested_sb .field_builder::>(3) .unwrap(); g_map_builder.append(true).unwrap(); - { - let (keys, values) = g_map_builder.entries(); - keys.append_value("k1"); - values.append(true); - let h_struct_builder = values.field_builder::(0).unwrap(); - h_struct_builder.append(true); - { - let i_list_builder = h_struct_builder - .field_builder::>(0) - .unwrap(); - i_list_builder.append(true); - } - } } let nested_struct = nested_sb.finish(); - let schema = Arc::new(Schema::new(vec![ - Field::new("ID", DataType::Int64, true), - Field::new( - "Int_Array", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - ), - Field::new( - "int_array_array", - DataType::List(Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - ))), - true, - ), - Field::new( - "Int_Map", - DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct(Fields::from(vec![ - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Int32, true), - ])), - false, - )), - false, - ), - true, - ), - Field::new( - "int_map_array", - DataType::List(Arc::new(Field::new( - "item", - DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct(Fields::from(vec![ - Field::new("key", DataType::Utf8, false), - Field::new("value", DataType::Int32, true), - ])), - false, - )), - false, - ), - true, - ))), - true, - ), - Field::new( - "nested_Struct", - DataType::Struct(nested_schema.as_ref().fields.clone()), - true, - ), - ])); - let expected = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(id) as Arc, - Arc::new(int_array), - Arc::new(int_array_array), - Arc::new(int_map), - Arc::new(int_map_array_), - Arc::new(nested_struct), - ], - ) + let expected = RecordBatch::try_from_iter_with_nullable([ + ("ID", Arc::new(id) as Arc, true), + ("Int_Array", Arc::new(int_array), true), + ("int_array_array", Arc::new(int_array_array), true), + ("Int_Map", Arc::new(int_map), true), + ("int_map_array", Arc::new(int_map_array_), true), + ("nested_Struct", Arc::new(nested_struct), true), + ]) .unwrap(); - let batch = reader.next().unwrap().unwrap(); - assert_eq!(batch, expected, "nonnullable impala avro data mismatch"); + let batch_large = read_file(&file, 8, false); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); } #[test] fn test_nullable_impala() { - use arrow_array::{Int64Array, ListArray, StructArray}; let file = arrow_test_data("avro/nullable.impala.avro"); - let mut r1 = read_file(&file, None); - let batch1 = r1.next().unwrap().unwrap(); - let mut r2 = read_file(&file, None); - let batch2 = r2.next().unwrap().unwrap(); - assert_eq!( - batch1, batch2, - "Reading file multiple times should produce the same data" - ); + let batch1 = read_file(&file, 3, false); + let batch2 = read_file(&file, 8, false); + assert_eq!(batch1, batch2); let batch = batch1; assert_eq!(batch.num_rows(), 7); let id_array = batch @@ -1553,14 +1057,18 @@ mod test { .expect("id column should be an Int64Array"); let expected_ids = [1, 2, 3, 4, 5, 6, 7]; for (i, &expected_id) in expected_ids.iter().enumerate() { - assert_eq!(id_array.value(i), expected_id, "Mismatch in id at row {i}"); + assert_eq!( + id_array.value(i), + expected_id, + "Mismatch in id at row {}", + i + ); } let int_array = batch .column(1) .as_any() .downcast_ref::() .expect("int_array column should be a ListArray"); - { let offsets = int_array.value_offsets(); let start = offsets[0] as usize; @@ -1570,7 +1078,7 @@ mod test { .as_any() .downcast_ref::() .expect("Values of int_array should be an Int32Array"); - let row0: Vec> = (start..end).map(|idx| Some(values.value(idx))).collect(); + let row0: Vec> = (start..end).map(|i| Some(values.value(i))).collect(); assert_eq!( row0, vec![Some(1), Some(2), Some(3)], @@ -1603,13 +1111,16 @@ mod test { #[test] fn test_nulls_snappy() { let file = arrow_test_data("avro/nulls.snappy.avro"); - let mut reader = read_file(&file, None); - let batch = reader.next().unwrap().unwrap(); + let batch_large = read_file(&file, 8, false); + use arrow_array::{Int32Array, StructArray}; + use arrow_buffer::Buffer; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields}; let b_c_int = Int32Array::from(vec![None; 8]); let b_c_int_data = b_c_int.into_data(); let b_struct_field = Field::new("b_c_int", DataType::Int32, true); - let b_struct_type = DataType::Struct(vec![b_struct_field].into()); - let struct_validity = arrow_buffer::Buffer::from_iter((0..8).map(|_| true)); + let b_struct_type = DataType::Struct(Fields::from(vec![b_struct_field])); + let struct_validity = Buffer::from_iter((0..8).map(|_| true)); let b_struct_data = ArrayDataBuilder::new(b_struct_type) .len(8) .null_bit_buffer(Some(struct_validity)) @@ -1617,21 +1128,21 @@ mod test { .build() .unwrap(); let b_struct_array = StructArray::from(b_struct_data); - - let expected = RecordBatch::try_from_iter_with_nullable([( + let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([( "b_struct", Arc::new(b_struct_array) as _, true, )]) .unwrap(); - assert_eq!(batch, expected); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); } #[test] fn test_repeated_no_annotation() { let file = arrow_test_data("avro/repeated_no_annotation.avro"); - let mut reader = read_file(&file, None); - let batch = reader.next().unwrap().unwrap(); + let batch_large = read_file(&file, 8, false); use arrow_array::{Int32Array, Int64Array, ListArray, StringArray, StructArray}; use arrow_buffer::Buffer; use arrow_data::ArrayDataBuilder; @@ -1677,7 +1188,7 @@ mod test { .build() .unwrap(); let phone_numbers_struct_array = StructArray::from(phone_numbers_struct_data); - let expected = RecordBatch::try_from_iter_with_nullable([ + let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([ ("id", Arc::new(id_array) as _, true), ( "phoneNumbers", @@ -1686,11 +1197,19 @@ mod test { ), ]) .unwrap(); - assert_eq!(batch, expected); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); } #[test] fn test_simple() { + // Each entry: (filename, batch_size1, expected_batch, batch_size2) + let tests = [ + ("avro/simple_enum.avro", 4, build_expected_enum(), 2), + ("avro/simple_fixed.avro", 2, build_expected_fixed(), 1), + ]; + fn build_expected_enum() -> RecordBatch { let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); @@ -1748,38 +1267,29 @@ mod test { ) .unwrap() } - - // We list the two test files - let tests = [ - ("avro/simple_enum.avro", build_expected_enum()), - ("avro/simple_fixed.avro", build_expected_fixed()), - ]; - for (file_name, expected) in tests { + for (file_name, batch_size, expected, alt_batch_size) in tests { let file = arrow_test_data(file_name); - let mut reader = read_file(&file, None); - let actual = reader - .next() - .expect("Should have a batch") - .expect("Error reading batch"); - assert_eq!(actual, expected, "Mismatch for file {file_name}"); + let actual = read_file(&file, batch_size, false); + assert_eq!(actual, expected); + let actual2 = read_file(&file, alt_batch_size, false); + assert_eq!(actual2, expected); } } #[test] fn test_single_nan() { - let file = arrow_test_data("avro/single_nan.avro"); - let mut reader = read_file(&file, None); - let batch = reader - .next() - .expect("Should have a batch") - .expect("Error reading single_nan batch"); + let file = crate::test_util::arrow_test_data("avro/single_nan.avro"); + let actual = read_file(&file, 1, false); + use arrow_array::Float64Array; let schema = Arc::new(Schema::new(vec![Field::new( "mycol", DataType::Float64, true, )])); - let col = arrow_array::Float64Array::from(vec![None]); - let expected = RecordBatch::try_new(schema.clone(), vec![Arc::new(col)]).unwrap(); - assert_eq!(batch, expected, "Mismatch in single_nan.avro data"); + let col = Float64Array::from(vec![None]); + let expected = RecordBatch::try_new(schema, vec![Arc::new(col)]).unwrap(); + assert_eq!(actual, expected); + let actual2 = read_file(&file, 2, false); + assert_eq!(actual2, expected); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index f901d85b6611..3f56997f5733 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -771,911 +771,3 @@ fn sign_extend(raw: &[u8], target_len: usize) -> Vec { out.extend_from_slice(raw); out } - -#[cfg(test)] -mod tests { - use super::*; - use crate::codec::AvroField; - use crate::schema::Schema; - use arrow_array::{cast::AsArray, Array, ListArray, MapArray, StructArray}; - use std::sync::Arc; - - fn encode_avro_int(value: i32) -> Vec { - let mut buf = Vec::new(); - let mut v = (value << 1) ^ (value >> 31); - while v & !0x7F != 0 { - buf.push(((v & 0x7F) | 0x80) as u8); - v >>= 7; - } - buf.push(v as u8); - buf - } - - fn encode_avro_long(value: i64) -> Vec { - let mut buf = Vec::new(); - let mut v = (value << 1) ^ (value >> 63); - while v & !0x7F != 0 { - buf.push(((v & 0x7F) | 0x80) as u8); - v >>= 7; - } - buf.push(v as u8); - buf - } - - fn encode_avro_bytes(bytes: &[u8]) -> Vec { - let mut out = encode_avro_long(bytes.len() as i64); - out.extend_from_slice(bytes); - out - } - - fn encode_union_branch(branch_idx: i32) -> Vec { - encode_avro_int(branch_idx) - } - - fn encode_array(items: &[T], mut encode_item: impl FnMut(&T) -> Vec) -> Vec { - let mut out = Vec::new(); - if !items.is_empty() { - out.extend_from_slice(&encode_avro_long(items.len() as i64)); - for it in items { - out.extend_from_slice(&encode_item(it)); - } - } - out.extend_from_slice(&encode_avro_long(0)); - out - } - - fn encode_map(entries: &[(&str, Vec)]) -> Vec { - let mut out = Vec::new(); - if !entries.is_empty() { - out.extend_from_slice(&encode_avro_long(entries.len() as i64)); - for (k, val) in entries { - out.extend_from_slice(&encode_avro_bytes(k.as_bytes())); - out.extend_from_slice(val); - } - } - out.extend_from_slice(&encode_avro_long(0)); - out - } - - #[test] - fn test_union_primitive_long_null_record_decoder() { - let json_schema = r#" - { - "type": "record", - "name": "topLevelRecord", - "fields": [ - { - "name": "id", - "type": ["long","null"] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); - let mut data = Vec::new(); - data.extend_from_slice(&encode_union_branch(0)); - data.extend_from_slice(&encode_avro_long(1)); - data.extend_from_slice(&encode_union_branch(1)); - let used = record_decoder.decode(&data, 2).unwrap(); - assert_eq!(used, data.len()); - let batch = record_decoder.flush().unwrap(); - assert_eq!(batch.num_rows(), 2); - let arr = batch.column(0).as_primitive::(); - assert_eq!(arr.value(0), 1); - assert!(arr.is_null(1)); - } - - #[test] - fn test_union_array_of_int_null_record_decoder() { - let json_schema = r#" - { - "type":"record", - "name":"topLevelRecord", - "fields":[ - { - "name":"int_array", - "type":[ - { - "type":"array", - "items":[ "int", "null" ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); - let mut data = Vec::new(); - - fn encode_int_or_null(opt_val: &Option) -> Vec { - match opt_val { - Some(v) => { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_avro_int(*v)); - out - } - None => encode_union_branch(1), - } - } - - data.extend_from_slice(&encode_union_branch(0)); - let row1_values = vec![Some(1), Some(2), Some(3)]; - data.extend_from_slice(&encode_array(&row1_values, encode_int_or_null)); - data.extend_from_slice(&encode_union_branch(0)); - let row2_values = vec![None, Some(1), Some(2), None, Some(3), None]; - data.extend_from_slice(&encode_array(&row2_values, encode_int_or_null)); - data.extend_from_slice(&encode_union_branch(0)); - data.extend_from_slice(&encode_avro_long(0)); - data.extend_from_slice(&encode_union_branch(1)); - record_decoder.decode(&data, 4).unwrap(); - let batch = record_decoder.flush().unwrap(); - assert_eq!(batch.num_rows(), 4); - let list_arr = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - assert!(list_arr.is_null(3)); - { - let start = list_arr.value_offsets()[0] as usize; - let end = list_arr.value_offsets()[1] as usize; - let child = list_arr.values().as_primitive::(); - assert_eq!(end - start, 3); - assert_eq!(child.value(start), 1); - assert_eq!(child.value(start + 1), 2); - assert_eq!(child.value(start + 2), 3); - } - { - let start = list_arr.value_offsets()[1] as usize; - let end = list_arr.value_offsets()[2] as usize; - let child = list_arr.values().as_primitive::(); - assert_eq!(end - start, 6); - // index-by-index - assert!(child.is_null(start)); // None - assert_eq!(child.value(start + 1), 1); // Some(1) - assert_eq!(child.value(start + 2), 2); - assert!(child.is_null(start + 3)); - assert_eq!(child.value(start + 4), 3); - assert!(child.is_null(start + 5)); - } - { - let start = list_arr.value_offsets()[2] as usize; - let end = list_arr.value_offsets()[3] as usize; - assert_eq!(end - start, 0); - } - } - - #[test] - fn test_union_nested_array_of_int_null_record_decoder() { - let json_schema = r#" - { - "type":"record", - "name":"topLevelRecord", - "fields":[ - { - "name":"int_array_Array", - "type":[ - { - "type":"array", - "items":[ - { - "type":"array", - "items":[ - "int", - "null" - ] - }, - "null" - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); - let mut data = Vec::new(); - - fn encode_inner(vals: &[Option]) -> Vec { - encode_array(vals, |o| match o { - Some(v) => { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_avro_int(*v)); - out - } - None => encode_union_branch(1), - }) - } - - data.extend_from_slice(&encode_union_branch(0)); - { - let outer_vals: Vec>>> = - vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), None])]; - data.extend_from_slice(&encode_array(&outer_vals, |maybe_arr| match maybe_arr { - Some(vlist) => { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_inner(vlist)); - out - } - None => encode_union_branch(1), - })); - } - data.extend_from_slice(&encode_union_branch(0)); - { - let outer_vals: Vec>>> = vec![None]; - data.extend_from_slice(&encode_array(&outer_vals, |maybe_arr| match maybe_arr { - Some(vlist) => { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_inner(vlist)); - out - } - None => encode_union_branch(1), - })); - } - data.extend_from_slice(&encode_union_branch(1)); - record_decoder.decode(&data, 3).unwrap(); - let batch = record_decoder.flush().unwrap(); - assert_eq!(batch.num_rows(), 3); - let outer_list = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - assert!(outer_list.is_null(2)); - assert!(!outer_list.is_null(0)); - let start = outer_list.value_offsets()[0] as usize; - let end = outer_list.value_offsets()[1] as usize; - assert_eq!(end - start, 2); - let start2 = outer_list.value_offsets()[1] as usize; - let end2 = outer_list.value_offsets()[2] as usize; - assert_eq!(end2 - start2, 1); - let subitem_arr = outer_list.value(1); - let sub_list = subitem_arr.as_any().downcast_ref::().unwrap(); - assert_eq!(sub_list.len(), 1); - assert!(sub_list.is_null(0)); - } - - #[test] - fn test_union_map_of_int_null_record_decoder() { - let json_schema = r#" - { - "type":"record", - "name":"topLevelRecord", - "fields":[ - { - "name":"int_map", - "type":[ - { - "type":"map", - "values":[ - "int", - "null" - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); - let mut data = Vec::new(); - data.extend_from_slice(&encode_union_branch(0)); - let row1_map = vec![ - ("k1", { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_avro_int(1)); - out - }), - ("k2", encode_union_branch(1)), - ]; - data.extend_from_slice(&encode_map(&row1_map)); - data.extend_from_slice(&encode_union_branch(0)); - let empty: [(&str, Vec); 0] = []; - data.extend_from_slice(&encode_map(&empty)); - data.extend_from_slice(&encode_union_branch(1)); - record_decoder.decode(&data, 3).unwrap(); - let batch = record_decoder.flush().unwrap(); - assert_eq!(batch.num_rows(), 3); - let map_arr = batch.column(0).as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 3); - assert!(map_arr.is_null(2)); - assert_eq!(map_arr.value_length(0), 2); - let binding = map_arr.value(0); - let struct_arr = binding.as_any().downcast_ref::().unwrap(); - let keys = struct_arr.column(0).as_string::(); - let vals = struct_arr.column(1).as_primitive::(); - assert_eq!(keys.value(0), "k1"); - assert_eq!(vals.value(0), 1); - assert_eq!(keys.value(1), "k2"); - assert!(vals.is_null(1)); - assert_eq!(map_arr.value_length(1), 0); - } - - #[test] - fn test_union_map_array_of_int_null_record_decoder() { - let json_schema = r#" - { - "type": "record", - "name": "topLevelRecord", - "fields": [ - { - "name": "int_Map_Array", - "type": [ - { - "type": "array", - "items": [ - { - "type": "map", - "values": [ - "int", - "null" - ] - }, - "null" - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); - let mut data = Vec::new(); - fn encode_map_int_null(entries: &[(&str, Option)]) -> Vec { - let items: Vec<(&str, Vec)> = entries - .iter() - .map(|(k, v)| { - let val = match v { - Some(x) => { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_avro_int(*x)); - out - } - None => encode_union_branch(1), - }; - (*k, val) - }) - .collect(); - encode_map(&items) - } - data.extend_from_slice(&encode_union_branch(0)); - { - let mut arr_buf = encode_avro_long(1); - { - let mut item_buf = encode_union_branch(0); - item_buf.extend_from_slice(&encode_map_int_null(&[("k1", Some(1))])); - arr_buf.extend_from_slice(&item_buf); - } - arr_buf.extend_from_slice(&encode_avro_long(0)); - data.extend_from_slice(&arr_buf); - } - data.extend_from_slice(&encode_union_branch(0)); - { - let mut arr_buf = encode_avro_long(2); // 2 items - arr_buf.extend_from_slice(&encode_union_branch(1)); - { - let mut item1 = encode_union_branch(0); - item1.extend_from_slice(&encode_map_int_null(&[("k2", None)])); - arr_buf.extend_from_slice(&item1); - } - arr_buf.extend_from_slice(&encode_avro_long(0)); // end - data.extend_from_slice(&arr_buf); - } - data.extend_from_slice(&encode_union_branch(1)); - record_decoder.decode(&data, 3).unwrap(); - let batch = record_decoder.flush().unwrap(); - assert_eq!(batch.num_rows(), 3); - let outer_list = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - assert!(outer_list.is_null(2)); - { - let start = outer_list.value_offsets()[0] as usize; - let end = outer_list.value_offsets()[1] as usize; - assert_eq!(end - start, 1); - let subarr = outer_list.value(0); - let sublist = subarr.as_any().downcast_ref::().unwrap(); - assert_eq!(sublist.len(), 1); - assert!(!sublist.is_null(0)); - let sub_value_0 = sublist.value(0); - let struct_arr = sub_value_0.as_any().downcast_ref::().unwrap(); - let keys = struct_arr.column(0).as_string::(); - let vals = struct_arr.column(1).as_primitive::(); - assert_eq!(keys.value(0), "k1"); - assert_eq!(vals.value(0), 1); - } - } - - #[test] - fn test_union_nested_struct_out_of_spec_record_decoder() { - let json_schema = r#" - { - "type":"record", - "name":"topLevelRecord", - "fields":[ - { - "name":"nested_struct", - "type":[ - { - "type":"record", - "name":"nested_struct", - "namespace":"topLevelRecord", - "fields":[ - { - "name":"A", - "type":[ - "int", - "null" - ] - }, - { - "name":"b", - "type":[ - { - "type":"array", - "items":[ - "int", - "null" - ] - }, - "null" - ] - } - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); - let mut data = Vec::new(); - data.extend_from_slice(&encode_union_branch(0)); - data.extend_from_slice(&encode_union_branch(0)); - data.extend_from_slice(&encode_avro_int(7)); - data.extend_from_slice(&encode_union_branch(0)); - let row1_b = [Some(1), Some(2)]; - data.extend_from_slice(&encode_array(&row1_b, |val| match val { - Some(x) => { - let mut out = encode_union_branch(0); - out.extend_from_slice(&encode_avro_int(*x)); - out - } - None => encode_union_branch(1), - })); - data.extend_from_slice(&encode_union_branch(0)); - data.extend_from_slice(&encode_union_branch(1)); - data.extend_from_slice(&encode_union_branch(1)); - data.extend_from_slice(&encode_union_branch(1)); - record_decoder.decode(&data, 3).unwrap(); - let batch = record_decoder.flush().unwrap(); - assert_eq!(batch.num_rows(), 3); - let col = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - assert!(col.is_null(2)); - let field_a = col.column(0).as_primitive::(); - let field_b = col.column(1).as_any().downcast_ref::().unwrap(); - assert_eq!(field_a.value(0), 7); - { - let start = field_b.value_offsets()[0] as usize; - let end = field_b.value_offsets()[1] as usize; - let values = field_b.values().as_primitive::(); - assert_eq!(end - start, 2); - assert_eq!(values.value(start), 1); - assert_eq!(values.value(start + 1), 2); - } - assert!(field_a.is_null(1)); - assert!(field_b.is_null(1)); - } - - #[test] - fn test_record_decoder_default_metadata() { - use crate::codec::AvroField; - use crate::schema::Schema; - let json_schema = r#" - { - "type": "record", - "name": "TestRecord", - "fields": [ - {"name": "default_int", "type": "int", "default": 42} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - let avro_record = AvroField::try_from(&schema).unwrap(); - let record_decoder = RecordDecoder::try_new(avro_record.data_type(), true).unwrap(); - let arrow_schema = record_decoder.schema(); - assert_eq!(arrow_schema.fields().len(), 1); - let field = arrow_schema.field(0); - let metadata = field.metadata(); - assert_eq!(metadata.get("avro.default").unwrap(), "42"); - } - - #[test] - fn test_fixed_decoding() { - let dt = AvroDataType::from_codec(Codec::Fixed(4)); - let mut dec = Decoder::try_new(&dt, true).unwrap(); - let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; - let row2 = [0x01, 0x23, 0x45, 0x67]; - let mut data = Vec::new(); - data.extend_from_slice(&row1); - data.extend_from_slice(&row2); - let mut cursor = AvroCursor::new(&data); - dec.decode(&mut cursor).unwrap(); - dec.decode(&mut cursor).unwrap(); - let arr = dec.flush(None).unwrap(); - let fsb = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(fsb.len(), 2); - assert_eq!(fsb.value_length(), 4); - assert_eq!(fsb.value(0), row1); - assert_eq!(fsb.value(1), row2); - } - - #[test] - fn test_fixed_with_nulls() { - let dt = AvroDataType::from_codec(Codec::Fixed(2)); - let child = Decoder::try_new(&dt, true).unwrap(); - let mut dec = Decoder::Nullable( - UnionOrder::NullSecond, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(child), - ); - let row1 = [0x11, 0x22]; - let row3 = [0x55, 0x66]; - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&row1); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&row3); - let mut cursor = AvroCursor::new(&data); - dec.decode(&mut cursor).unwrap(); // Row1 - dec.decode(&mut cursor).unwrap(); // Row2 (null) - dec.decode(&mut cursor).unwrap(); // Row3 - let arr = dec.flush(None).unwrap(); - let fsb = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(fsb.len(), 3); - assert!(fsb.is_valid(0)); - assert!(!fsb.is_valid(1)); - assert!(fsb.is_valid(2)); - assert_eq!(fsb.value_length(), 2); - assert_eq!(fsb.value(0), row1); - assert_eq!(fsb.value(2), row3); - } - - #[test] - fn test_interval_decoding() { - let dt = AvroDataType::from_codec(Codec::Duration); - let mut dec = Decoder::try_new(&dt, true).unwrap(); - let row1 = [ - 0x01, 0x00, 0x00, 0x00, // months=1 - 0x02, 0x00, 0x00, 0x00, // days=2 - 0x64, 0x00, 0x00, 0x00, // ms=100 - ]; - let row2 = [ - 0xFF, 0xFF, 0xFF, 0xFF, // months=-1 - 0x0A, 0x00, 0x00, 0x00, // days=10 - 0x0F, 0x27, 0x00, 0x00, // ms=9999 - ]; - let mut data = Vec::new(); - data.extend_from_slice(&row1); - data.extend_from_slice(&row2); - let mut cursor = AvroCursor::new(&data); - dec.decode(&mut cursor).unwrap(); - dec.decode(&mut cursor).unwrap(); - let arr = dec.flush(None).unwrap(); - let intervals = arr - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(intervals.len(), 2); - let val0 = intervals.value(0); - assert_eq!(val0.months, 1); - assert_eq!(val0.days, 2); - assert_eq!(val0.nanoseconds, 100_000_000); - let val1 = intervals.value(1); - assert_eq!(val1.months, -1); - assert_eq!(val1.days, 10); - assert_eq!(val1.nanoseconds, 9_999_000_000); - } - - #[test] - fn test_interval_decoding_with_nulls() { - let dt = AvroDataType::from_codec(Codec::Duration); - let child = Decoder::try_new(&dt, true).unwrap(); - let mut dec = Decoder::Nullable( - UnionOrder::NullSecond, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(child), - ); - let row1 = [ - 0x02, 0x00, 0x00, 0x00, // months=2 - 0x03, 0x00, 0x00, 0x00, // days=3 - 0xF4, 0x01, 0x00, 0x00, // ms=500 - ]; - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&row1); - data.extend_from_slice(&encode_avro_int(1)); - let mut cursor = AvroCursor::new(&data); - dec.decode(&mut cursor).unwrap(); // Row1 - dec.decode(&mut cursor).unwrap(); // Row2 (null) - let arr = dec.flush(None).unwrap(); - let intervals = arr - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(intervals.len(), 2); - assert!(intervals.is_valid(0)); - assert!(!intervals.is_valid(1)); - let val0 = intervals.value(0); - assert_eq!(val0.months, 2); - assert_eq!(val0.days, 3); - assert_eq!(val0.nanoseconds, 500_000_000); - } - - #[test] - fn test_enum_decoding() { - let symbols = Arc::new(["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]); - let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols, Arc::new([]))); - let mut decoder = Decoder::try_new(&enum_dt, true).unwrap(); - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_int(2)); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let array = decoder.flush(None).unwrap(); - let dict_arr = array - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(dict_arr.len(), 3); - let keys = dict_arr.keys(); - assert_eq!(keys.value(0), 1); - assert_eq!(keys.value(1), 0); - assert_eq!(keys.value(2), 2); - let dict_values = dict_arr.values().as_string::(); - assert_eq!(dict_values.value(0), "RED"); - assert_eq!(dict_values.value(1), "GREEN"); - assert_eq!(dict_values.value(2), "BLUE"); - } - - #[test] - fn test_enum_decoding_with_nulls() { - let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; - let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); - let inner_decoder = Decoder::try_new(&enum_dt, true).unwrap(); - let mut nullable_decoder = Decoder::Nullable( - UnionOrder::NullSecond, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(inner_decoder), - ); - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_int(0)); - let mut cursor = AvroCursor::new(&data); - nullable_decoder.decode(&mut cursor).unwrap(); - nullable_decoder.decode(&mut cursor).unwrap(); - nullable_decoder.decode(&mut cursor).unwrap(); - let array = nullable_decoder.flush(None).unwrap(); - let dict_arr = array - .as_any() - .downcast_ref::>() - .unwrap(); - assert_eq!(dict_arr.len(), 3); - assert!(dict_arr.is_valid(0)); - assert!(!dict_arr.is_valid(1)); - assert!(dict_arr.is_valid(2)); - let dict_values = dict_arr.values().as_string::(); - assert_eq!(dict_values.value(0), "RED"); - assert_eq!(dict_values.value(1), "GREEN"); - assert_eq!(dict_values.value(2), "BLUE"); - } - - #[test] - fn test_map_decoding_one_entry() { - let value_type = AvroDataType::from_codec(Codec::String); - let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); - let mut decoder = Decoder::try_new(&map_type, true).unwrap(); - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); - data.extend_from_slice(&encode_avro_bytes(b"hello")); - data.extend_from_slice(&encode_avro_bytes(b"world")); - data.extend_from_slice(&encode_avro_long(0)); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - let array = decoder.flush(None).unwrap(); - let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); - assert_eq!(map_arr.value_length(0), 1); - let struct_arr = map_arr.value(0); - assert_eq!(struct_arr.len(), 1); - let keys = struct_arr.column(0).as_string::(); - let vals = struct_arr.column(1).as_string::(); - assert_eq!(keys.value(0), "hello"); - assert_eq!(vals.value(0), "world"); - } - - #[test] - fn test_map_decoding_empty() { - let value_type = AvroDataType::from_codec(Codec::String); - let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); - let mut decoder = Decoder::try_new(&map_type, true).unwrap(); - let data = encode_avro_long(0); - decoder.decode(&mut AvroCursor::new(&data)).unwrap(); - let array = decoder.flush(None).unwrap(); - let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); - assert_eq!(map_arr.value_length(0), 0); - } - - #[test] - fn test_decimal_decoding_fixed128() { - let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); - let mut decoder = Decoder::try_new(&dt, true).unwrap(); - let row1 = [ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x30, 0x39, - ]; - let row2 = [ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0x85, - ]; - let mut data = Vec::new(); - data.extend_from_slice(&row1); - data.extend_from_slice(&row2); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let arr = decoder.flush(None).unwrap(); - let dec = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec.len(), 2); - assert_eq!(dec.value_as_string(0), "123.45"); - assert_eq!(dec.value_as_string(1), "-1.23"); - } - - #[test] - fn test_decimal_decoding_bytes_with_nulls() { - let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); - let inner = Decoder::try_new(&dt, true).unwrap(); - let mut decoder = Decoder::Nullable( - UnionOrder::NullSecond, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(inner), - ); - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); // row1 - decoder.decode(&mut cursor).unwrap(); // row2 - decoder.decode(&mut cursor).unwrap(); // row3 - let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "123.4"); - assert_eq!(dec_arr.value_as_string(2), "-123.4"); - } - - #[test] - fn test_decimal_decoding_bytes_with_nulls_fixed_size() { - let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); - let inner = Decoder::try_new(&dt, true).unwrap(); - let mut decoder = Decoder::Nullable( - UnionOrder::NullSecond, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(inner), - ); - let row1 = [ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0xE2, 0x40, - ]; - let row3 = [ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, - 0x1D, 0xC0, - ]; - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&row1); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(0)); - data.extend_from_slice(&row3); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - decoder.decode(&mut cursor).unwrap(); - let arr = decoder.flush(None).unwrap(); - let dec_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(dec_arr.len(), 3); - assert!(dec_arr.is_valid(0)); - assert!(!dec_arr.is_valid(1)); - assert!(dec_arr.is_valid(2)); - assert_eq!(dec_arr.value_as_string(0), "1234.56"); - assert_eq!(dec_arr.value_as_string(2), "-1234.56"); - } - - #[test] - fn test_list_decoding() { - let item_dt = AvroDataType::from_codec(Codec::Int32); - let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); - let mut decoder = Decoder::try_new(&list_dt, true).unwrap(); - let mut row1 = Vec::new(); - row1.extend_from_slice(&encode_avro_long(2)); - row1.extend_from_slice(&encode_avro_int(10)); - row1.extend_from_slice(&encode_avro_int(20)); - row1.extend_from_slice(&encode_avro_long(0)); - let row2 = encode_avro_long(0); - let mut cursor = AvroCursor::new(&row1); - decoder.decode(&mut cursor).unwrap(); - let mut cursor2 = AvroCursor::new(&row2); - decoder.decode(&mut cursor2).unwrap(); - let array = decoder.flush(None).unwrap(); - let list_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(list_arr.len(), 2); - let offsets = list_arr.value_offsets(); - assert_eq!(offsets, &[0, 2, 2]); - let values = list_arr.values(); - let int_arr = values.as_primitive::(); - assert_eq!(int_arr.len(), 2); - assert_eq!(int_arr.value(0), 10); - assert_eq!(int_arr.value(1), 20); - } - - #[test] - fn test_list_decoding_with_negative_block_count() { - let item_dt = AvroDataType::from_codec(Codec::Int32); - let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); - let mut decoder = Decoder::try_new(&list_dt, true).unwrap(); - let mut data = encode_avro_long(-3); - data.extend_from_slice(&encode_avro_long(12)); - data.extend_from_slice(&encode_avro_int(1)); - data.extend_from_slice(&encode_avro_int(2)); - data.extend_from_slice(&encode_avro_int(3)); - data.extend_from_slice(&encode_avro_long(0)); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - let array = decoder.flush(None).unwrap(); - let list_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(list_arr.len(), 1); - assert_eq!(list_arr.value_length(0), 3); - let values = list_arr.values().as_primitive::(); - assert_eq!(values.len(), 3); - assert_eq!(values.value(0), 1); - assert_eq!(values.value(1), 2); - assert_eq!(values.value(2), 3); - } -} diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index cd377c2d5078..6380eef5b839 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -587,181 +587,4 @@ mod tests { assert_eq!(schema, with_aliases); } - - #[test] - fn test_default_parsing() { - // Test that a default value is correctly parsed for a record field. - let json_schema = r#" - { - "type": "record", - "name": "TestRecord", - "fields": [ - {"name": "a", "type": "int", "default": 10}, - {"name": "b", "type": "string", "default": "default_str"}, - {"name": "c", "type": "boolean"} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - if let Schema::Complex(ComplexType::Record(rec)) = schema { - assert_eq!(rec.fields.len(), 3); - assert_eq!(rec.fields[0].default, Some(json!(10))); - assert_eq!(rec.fields[1].default, Some(json!("default_str"))); - assert_eq!(rec.fields[2].default, None); - } else { - panic!("Expected record schema"); - } - } - - #[test] - fn test_union_int_null_with_default_null() { - let json_schema = r#" - { - "type": "record", - "name": "ImpalaNullableRecord", - "fields": [ - {"name": "i", "type": ["int","null"], "default": null} - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - if let Schema::Complex(ComplexType::Record(rec)) = schema { - assert_eq!(rec.fields.len(), 1); - assert_eq!(rec.fields[0].name, "i"); - assert_eq!(rec.fields[0].default, Some(json!(null))); - let field_codec = - AvroField::try_from(&Schema::Complex(ComplexType::Record(rec))).unwrap(); - use arrow_schema::{DataType, Field, Fields}; - assert_eq!( - field_codec.field(), - Field::new( - "ImpalaNullableRecord", - DataType::Struct(Fields::from(vec![Field::new("i", DataType::Int32, true),])), - false - ) - ); - } else { - panic!("Expected record schema with union int|null, default null"); - } - } - - #[test] - fn test_union_impala_null_with_default_null() { - let json_schema = r#" - { - "type":"record","name":"topLevelRecord","fields":[ - {"name":"id","type":["long","null"]}, - {"name":"int_array","type":[{"type":"array","items":["int","null"]},"null"]}, - {"name":"int_array_Array","type":[{"type":"array","items":[{"type":"array","items":["int","null"]},"null"]},"null"]}, - {"name":"int_map","type":[{"type":"map","values":["int","null"]},"null"]}, - {"name":"int_Map_Array","type":[{"type":"array","items":[{"type":"map","values":["int","null"]},"null"]},"null"]}, - { - "name":"nested_struct", - "type":[ - { - "type":"record", - "name":"nested_struct", - "namespace":"topLevelRecord", - "fields":[ - {"name":"A","type":["int","null"]}, - {"name":"b","type":[{"type":"array","items":["int","null"]},"null"]}, - { - "name":"C", - "type":[ - { - "type":"record", - "name":"C", - "namespace":"topLevelRecord.nested_struct", - "fields":[ - { - "name":"d", - "type":[ - { - "type":"array", - "items":[ - { - "type":"array", - "items":[ - { - "type":"record", - "name":"d", - "namespace":"topLevelRecord.nested_struct.C", - "fields":[ - {"name":"E","type":["int","null"]}, - {"name":"F","type":["string","null"]} - ] - }, - "null" - ] - }, - "null" - ] - }, - "null" - ] - } - ] - }, - "null" - ] - }, - { - "name":"g", - "type":[ - { - "type":"map", - "values":[ - { - "type":"record", - "name":"g", - "namespace":"topLevelRecord.nested_struct", - "fields":[ - { - "name":"H", - "type":[ - { - "type":"record", - "name":"H", - "namespace":"topLevelRecord.nested_struct.g", - "fields":[ - { - "name":"i", - "type":[ - { - "type":"array", - "items":["double","null"] - }, - "null" - ] - } - ] - }, - "null" - ] - } - ] - }, - "null" - ] - }, - "null" - ] - } - ] - }, - "null" - ] - } - ] - } - "#; - let schema: Schema = serde_json::from_str(json_schema).unwrap(); - if let Schema::Complex(ComplexType::Record(rec)) = &schema { - assert_eq!(rec.name, "topLevelRecord"); - assert_eq!(rec.fields.len(), 6); - let _field_codec = AvroField::try_from(&schema).unwrap(); - } else { - panic!("Expected top-level record schema"); - } - } } diff --git a/arrow-avro/test/data/nested_lists.snappy.avro b/arrow-avro/test/data/nested_lists.snappy.avro deleted file mode 100644 index 6cbff89610a7fce5f817edd668a06f5b5ac76a5b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 407 zcmeZI%3@>_ODrqO*DFrWNX<<=#$2sbQdy9yWTjM;nw(#hqNJmgmzWFUm*f}tq?V=T z1i{49GE;L>ij}OQt6@qKfvO?8fnrc&5{rrwD}myfC8@a(#iU9o6_*rc=B0yNQks*a z6kCgr0e4Fh!YxXfc_j$lv9$*IMd^Bp1&Kf(>lGIy7G>*|r4|)u=I3!4>lx}9iGaf+ zIX@*enWs1}v7n%mA>d!>+<&~!>)XFJ@3vpKVl~?g#(WkA7FH%3rbGs&BnAd12Bu^N z1_l-;MlOyN1_nkUBSj#OQIS=OQ-vW_P=$ewQJT|LRE33!fmJ~VNCIIRPy+*#>x<@m GbmIXq^L?5C diff --git a/arrow-avro/test/data/simple_enum.avro b/arrow-avro/test/data/simple_enum.avro deleted file mode 100644 index dbf0a42baae462801fa883bf5586dea8814b3df2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 411 zcmZ`#F$%&!5RAtev=>W@Eky*i3;w|eh{fe{(ZowGEz5-_6PVksBe65lhUZ*LvPmnn1}0WjS4TzMAHVkM-?% UIc)BYj>8#aF5}#BT*iJ34>m)7+yDRo From 3424300f33f5e064e1c082b29f51bef0fb0fc7f7 Mon Sep 17 00:00:00 2001 From: Connor Sanders Date: Sun, 16 Feb 2025 12:36:13 -0600 Subject: [PATCH 37/38] Fixed dev-dependency lint issue and removed some currently unused functions from codec.rs --- arrow-avro/Cargo.toml | 3 +-- arrow-avro/src/codec.rs | 18 ------------------ 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index 331efda5680d..b3fb6bfa9384 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -42,7 +42,6 @@ snappy = ["snap", "crc"] arrow-schema = { workspace = true } arrow-buffer = { workspace = true } arrow-array = { workspace = true } -arrow-data = { workspace = true } serde_json = { version = "1.0", default-features = false, features = ["std"] } serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } @@ -52,6 +51,6 @@ bzip2 = { version = "0.4.4", default-features = false, optional = true } xz = { version = "0.1.0", default-features = false, optional = true } crc = { version = "3.0", optional = true } - [dev-dependencies] +arrow-data = { workspace = true } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } \ No newline at end of file diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 1c8df7d70421..57b2383c3d09 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -48,24 +48,6 @@ pub struct AvroDataType { } impl AvroDataType { - /// Create a new AvroDataType with the given parts. - pub fn new( - codec: Codec, - nullability: Option, - metadata: HashMap, - ) -> Self { - AvroDataType { - codec, - nullability, - metadata: Arc::new(metadata), - } - } - - /// Create a new AvroDataType from a `Codec`, with default (no) nullability and empty metadata. - pub fn from_codec(codec: Codec) -> Self { - Self::new(codec, None, Default::default()) - } - /// Returns an arrow [`Field`] with the given name, applying `nullability` if present. pub fn field_with_name(&self, name: &str) -> Field { let is_nullable = self.nullability.is_some(); From 7ffda2bb92d67e350a6f60fe99b7c9849a41cda9 Mon Sep 17 00:00:00 2001 From: Nathaniel Davis Date: Thu, 17 Apr 2025 09:31:55 -0500 Subject: [PATCH 38/38] adds check for empty fields in record flush --- arrow-avro/src/codec.rs | 8 +++++++- arrow-avro/src/reader/mod.rs | 27 +++++++++++++++++++++------ arrow-avro/src/reader/record.rs | 12 +++++++++--- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 57b2383c3d09..b39c74e483c2 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -345,9 +345,15 @@ fn make_data_type<'a>( Ok(rec) } ComplexType::Enum(e) => { + // Insert "avro.enum.symbols" into metadata so we can preserve it. + let mut md = e.attributes.field_metadata(); + if let Ok(symbols_json) = serde_json::to_string(&e.symbols) { + md.insert("avro.enum.symbols".to_string(), symbols_json); + } + let en = AvroDataType { nullability: None, - metadata: Arc::new(e.attributes.field_metadata()), + metadata: Arc::new(md), codec: Codec::Enum( Arc::from(e.symbols.iter().map(|s| s.to_string()).collect::>()), Arc::from(vec![]), diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 4d0cbb035088..b9cea79d1e3a 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -1211,6 +1211,7 @@ mod test { ]; fn build_expected_enum() -> RecordBatch { + // Build the DictionaryArrays for f1, f2, f3 let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); let f1_dict = @@ -1225,11 +1226,25 @@ mod test { DictionaryArray::::try_new(keys_f3, Arc::new(vals_f3)).unwrap(); let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); - let expected_schema = Arc::new(Schema::new(vec![ - Field::new("f1", dict_type.clone(), false), - Field::new("f2", dict_type.clone(), false), - Field::new("f3", dict_type.clone(), true), - ])); + let mut md_f1 = HashMap::new(); + md_f1.insert( + "avro.enum.symbols".to_string(), + r#"["a","b","c","d"]"#.to_string(), + ); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + let mut md_f2 = HashMap::new(); + md_f2.insert( + "avro.enum.symbols".to_string(), + r#"["e","f","g","h"]"#.to_string(), + ); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + let mut md_f3 = HashMap::new(); + md_f3.insert( + "avro.enum.symbols".to_string(), + r#"["i","j","k"]"#.to_string(), + ); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![f1_field, f2_field, f3_field])); RecordBatch::try_new( expected_schema, vec![ @@ -1278,7 +1293,7 @@ mod test { #[test] fn test_single_nan() { - let file = crate::test_util::arrow_test_data("avro/single_nan.avro"); + let file = arrow_test_data("avro/single_nan.avro"); let actual = read_file(&file, 1, false); use arrow_array::Float64Array; let schema = Arc::new(Schema::new(vec![Field::new( diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 3f56997f5733..517b3ac80750 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -17,7 +17,8 @@ use crate::codec::{AvroDataType, Codec, Nullability}; use crate::reader::cursor::AvroCursor; -use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder, StructBuilder}; +use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; @@ -442,8 +443,13 @@ impl Decoder { ))); } } - let sarr = StructArray::new(fields.clone(), child_arrays, nulls); - Ok(Arc::new(sarr) as Arc) + match fields.is_empty() { + true => Ok(Arc::new(StructArray::new_empty_fields(1, None)) as Arc), + false => Ok( + Arc::new(StructArray::new(fields.clone(), child_arrays, nulls)) + as Arc, + ), + } } Self::Enum(symbols, idxs) => { let dict_vals = StringArray::from_iter_values(symbols.iter());