diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index c60413c5939d..47fbb654e8cb 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -52,10 +52,13 @@ zstd = { version = "0.13", default-features = false, optional = true } crc = { version = "3.0", optional = true } [dev-dependencies] -rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } +rand = { version = "0.9.1", default-features = false, features = ["std", "std_rng", "thread_rng"] } criterion = { version = "0.6.0", default-features = false } tempfile = "3.3" arrow = { workspace = true } +futures = "0.3.31" +bytes = "1.10.1" +async-stream = "0.3.6" [[bench]] name = "avro_reader" diff --git a/arrow-avro/benches/avro_reader.rs b/arrow-avro/benches/avro_reader.rs index bea69b149138..2f2a3a10dbf3 100644 --- a/arrow-avro/benches/avro_reader.rs +++ b/arrow-avro/benches/avro_reader.rs @@ -20,7 +20,7 @@ //! This benchmark suite compares the performance characteristics of StringArray vs //! StringViewArray across three key dimensions: //! 1. Array creation performance -//! 2. String value access operations +//! 2. String value access operations //! 3. Avro file reading with each array type use std::fs::File; @@ -31,7 +31,6 @@ use std::time::Duration; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Int32Array, StringArray, StringViewArray}; -use arrow_avro::ReadOptions; use arrow_schema::ArrowError; use criterion::*; use tempfile::NamedTempFile; @@ -79,7 +78,7 @@ fn create_avro_test_file(row_count: usize, str_length: usize) -> Result Result { let file = File::open(file_path)?; let mut reader = BufReader::new(file); @@ -110,7 +109,7 @@ fn read_avro_test_file( ints.push(i32::from_le_bytes(int_bytes)); } - let string_array: ArrayRef = if options.use_utf8view() { + let string_array: ArrayRef = if use_utf8view { Arc::new(StringViewArray::from_iter( strings.iter().map(|s| Some(s.as_str())), )) @@ -123,7 +122,7 @@ fn read_avro_test_file( let int_array: ArrayRef = Arc::new(Int32Array::from(ints)); let schema = Arc::new(Schema::new(vec![ - if options.use_utf8view() { + if use_utf8view { Field::new("string_field", DataType::Utf8View, false) } else { Field::new("string_field", DataType::Utf8, false) @@ -244,16 +243,14 @@ fn bench_avro_reader(c: &mut Criterion) { group.bench_function(format!("string_array_{str_length}_chars"), |b| { b.iter(|| { - let options = ReadOptions::default(); - let batch = read_avro_test_file(file_path, &options).unwrap(); + let batch = read_avro_test_file(file_path, false).unwrap(); std::hint::black_box(batch) }) }); group.bench_function(format!("string_view_{str_length}_chars"), |b| { b.iter(|| { - let options = ReadOptions::default().with_utf8view(true); - let batch = read_avro_test_file(file_path, &options).unwrap(); + let batch = read_avro_test_file(file_path, true).unwrap(); std::hint::black_box(batch) }) }); diff --git a/arrow-avro/examples/read_with_utf8view.rs b/arrow-avro/examples/read_with_utf8view.rs index d79f8dad565d..707be575168a 100644 --- a/arrow-avro/examples/read_with_utf8view.rs +++ b/arrow-avro/examples/read_with_utf8view.rs @@ -22,13 +22,11 @@ use std::env; use std::fs::File; -use std::io::{BufReader, Seek, SeekFrom}; -use std::sync::Arc; +use std::io::BufReader; use std::time::Instant; -use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; -use arrow_avro::reader::ReadOptions; -use arrow_schema::{ArrowError, DataType, Field, Schema}; +use arrow_array::{RecordBatch, StringArray, StringViewArray}; +use arrow_avro::reader::ReaderBuilder; fn main() -> Result<(), Box> { let args: Vec = env::args().collect(); @@ -41,20 +39,26 @@ fn main() -> Result<(), Box> { }; let file = File::open(file_path)?; - let mut reader = BufReader::new(file); + let file_for_view = file.try_clone()?; let start = Instant::now(); - let batch = read_avro_with_options(&mut reader, &ReadOptions::default())?; + let reader = BufReader::new(file); + let avro_reader = ReaderBuilder::new().build(reader)?; + let schema = avro_reader.schema(); + let batches: Vec = avro_reader.collect::>()?; let regular_duration = start.elapsed(); - reader.seek(SeekFrom::Start(0))?; - let start = Instant::now(); - let options = ReadOptions::default().with_utf8view(true); - let batch_view = read_avro_with_options(&mut reader, &options)?; + let reader_view = BufReader::new(file_for_view); + let avro_reader_view = ReaderBuilder::new() + .with_utf8_view(true) + .build(reader_view)?; + let batches_view: Vec = avro_reader_view.collect::>()?; let view_duration = start.elapsed(); - println!("Read {} rows from {}", batch.num_rows(), file_path); + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + + println!("Read {num_rows} rows from {file_path}"); println!("Reading with StringArray: {regular_duration:?}"); println!("Reading with StringViewArray: {view_duration:?}"); @@ -70,7 +74,16 @@ fn main() -> Result<(), Box> { ); } - for (i, field) in batch.schema().fields().iter().enumerate() { + if batches.is_empty() { + println!("No data read from file."); + return Ok(()); + } + + // Inspect the first batch from each run to show the array types + let batch = &batches[0]; + let batch_view = &batches_view[0]; + + for (i, field) in schema.fields().iter().enumerate() { let col = batch.column(i); let col_view = batch_view.column(i); @@ -93,29 +106,3 @@ fn main() -> Result<(), Box> { Ok(()) } - -fn read_avro_with_options( - reader: &mut BufReader, - options: &ReadOptions, -) -> Result { - reader.get_mut().seek(SeekFrom::Start(0))?; - - let mock_schema = Schema::new(vec![ - Field::new("string_field", DataType::Utf8, false), - Field::new("int_field", DataType::Int32, false), - ]); - - let string_data = vec!["avro1", "avro2", "avro3", "avro4", "avro5"]; - let int_data = vec![1, 2, 3, 4, 5]; - - let string_array: ArrayRef = if options.use_utf8view() { - Arc::new(StringViewArray::from(string_data)) - } else { - Arc::new(StringArray::from(string_data)) - }; - - let int_array: ArrayRef = Arc::new(Int32Array::from(int_data)); - - RecordBatch::try_new(Arc::new(mock_schema), vec![string_array, int_array]) - .map_err(|e| ArrowError::ComputeError(format!("Failed to create record batch: {e}"))) -} diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 8d7500b35c04..399037fdf9f7 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -16,7 +16,6 @@ // under the License. use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; -use arrow_schema::DataType::{Decimal128, Decimal256}; use arrow_schema::{ ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, @@ -251,9 +250,9 @@ impl Codec { } }; if too_large_for_128 { - Decimal256(p, s) + DataType::Decimal256(p, s) } else { - Decimal128(p, s) + DataType::Decimal128(p, s) } } Self::Uuid => DataType::FixedSizeBinary(16), diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index e413e0aa9173..ae13c3861842 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -50,8 +50,6 @@ pub mod compression; /// Avro data types and Arrow data types. pub mod codec; -pub use reader::ReadOptions; - /// Extension trait for AvroField to add Utf8View support /// /// This trait adds methods for working with Utf8View support to the AvroField struct. diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 91026dbd6aed..2f8b3a2b9552 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -15,11 +15,84 @@ // specific language governing permissions and limitations // under the License. -//! Read Avro data to Arrow - -use crate::reader::block::{Block, BlockDecoder}; -use crate::reader::header::{Header, HeaderDecoder}; -use arrow_schema::ArrowError; +//! Avro reader +//! +//! This module provides facilities to read Apache Avro-encoded files or streams +//! into Arrow's `RecordBatch` format. In particular, it introduces: +//! +//! * `ReaderBuilder`: Configures Avro reading, e.g., batch size +//! * `Reader`: Yields `RecordBatch` values, implementing `Iterator` +//! * `Decoder`: A low-level push-based decoder for Avro records +//! +//! # Basic Usage +//! +//! `Reader` can be used directly with synchronous data sources, such as [`std::fs::File`]. +//! +//! ## Reading a Single Batch +//! +//! ``` +//! # use std::fs::File; +//! # use std::io::BufReader; +//! # use arrow_avro::reader::ReaderBuilder; +//! +//! let file = File::open("../testing/data/avro/alltypes_plain.avro").unwrap(); +//! let mut avro = ReaderBuilder::new().build(BufReader::new(file)).unwrap(); +//! let batch = avro.next().unwrap(); +//! ``` +//! +//! # Async Usage +//! +//! The lower-level `Decoder` can be integrated with various forms of async data streams, +//! and is designed to be agnostic to different async IO primitives within +//! the Rust ecosystem. It works by incrementally decoding Avro data from byte slices. +//! +//! For example, see below for how it could be used with an arbitrary `Stream` of `Bytes`: +//! +//! ``` +//! # use std::task::{Poll, ready}; +//! # use bytes::{Buf, Bytes}; +//! # use arrow_schema::ArrowError; +//! # use futures::stream::{Stream, StreamExt}; +//! # use arrow_array::RecordBatch; +//! # use arrow_avro::reader::Decoder; +//! +//! fn decode_stream + Unpin>( +//! mut decoder: Decoder, +//! mut input: S, +//! ) -> impl Stream> { +//! let mut buffered = Bytes::new(); +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! if buffered.is_empty() { +//! buffered = match ready!(input.poll_next_unpin(cx)) { +//! Some(b) => b, +//! None => break, +//! }; +//! } +//! let decoded = match decoder.decode(buffered.as_ref()) { +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! let read = buffered.len(); +//! buffered.advance(decoded); +//! if decoded != read { +//! break +//! } +//! } +//! // Convert any fully-decoded rows to a RecordBatch, if available +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! + +use crate::codec::AvroField; +use crate::schema::Schema as AvroSchema; +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_schema::{ArrowError, SchemaRef}; +use block::BlockDecoder; +use header::{Header, HeaderDecoder}; +use record::RecordDecoder; use std::io::BufRead; mod block; @@ -28,90 +101,292 @@ mod header; mod record; mod vlq; -/// Configuration options for reading Avro data into Arrow arrays -/// -/// This struct contains configuration options that control how Avro data is -/// converted into Arrow arrays. It allows customizing various aspects of the -/// data conversion process. -/// -/// # Examples -/// -/// ``` -/// # use arrow_avro::reader::ReadOptions; -/// // Use default options (regular StringArray for strings) -/// let default_options = ReadOptions::default(); -/// -/// // Enable Utf8View support for better string performance -/// let options = ReadOptions::default() -/// .with_utf8view(true); -/// ``` -#[derive(Default, Debug, Clone)] -pub struct ReadOptions { - use_utf8view: bool, +/// Read the Avro file header (magic, metadata, sync marker) from `reader`. +fn read_header(mut reader: R) -> Result { + let mut decoder = HeaderDecoder::default(); + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + decoder.flush().ok_or_else(|| { + ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string()) + }) } -impl ReadOptions { - /// Create a new `ReadOptions` with default values +/// A low-level interface for decoding Avro-encoded bytes into Arrow `RecordBatch`. +#[derive(Debug)] +pub struct Decoder { + record_decoder: RecordDecoder, + batch_size: usize, + decoded_rows: usize, +} + +impl Decoder { + fn new(record_decoder: RecordDecoder, batch_size: usize) -> Self { + Self { + record_decoder, + batch_size, + decoded_rows: 0, + } + } + + /// Return the Arrow schema for the rows decoded by this decoder + pub fn schema(&self) -> SchemaRef { + self.record_decoder.schema().clone() + } + + /// Return the configured maximum number of rows per batch + pub fn batch_size(&self) -> usize { + self.batch_size + } + + /// Feed `data` into the decoder row by row until we either: + /// - consume all bytes in `data`, or + /// - reach `batch_size` decoded rows. + /// + /// Returns the number of bytes consumed. + pub fn decode(&mut self, data: &[u8]) -> Result { + let mut total_consumed = 0usize; + while total_consumed < data.len() && self.decoded_rows < self.batch_size { + let consumed = self.record_decoder.decode(&data[total_consumed..], 1)?; + if consumed == 0 { + break; + } + total_consumed += consumed; + self.decoded_rows += 1; + } + Ok(total_consumed) + } + + /// Produce a `RecordBatch` if at least one row is fully decoded, returning + /// `Ok(None)` if no new rows are available. + pub fn flush(&mut self) -> Result, ArrowError> { + if self.decoded_rows == 0 { + Ok(None) + } else { + let batch = self.record_decoder.flush()?; + self.decoded_rows = 0; + Ok(Some(batch)) + } + } + + /// Returns the number of rows that can be added to this decoder before it is full. + pub fn capacity(&self) -> usize { + self.batch_size.saturating_sub(self.decoded_rows) + } + + /// Returns true if the decoder has reached its capacity for the current batch. + pub fn batch_is_full(&self) -> bool { + self.capacity() == 0 + } +} + +/// A builder to create an [`Avro Reader`](Reader) that reads Avro data +/// into Arrow `RecordBatch`. +#[derive(Debug)] +pub struct ReaderBuilder { + batch_size: usize, + strict_mode: bool, + utf8_view: bool, + schema: Option>, +} + +impl Default for ReaderBuilder { + fn default() -> Self { + Self { + batch_size: 1024, + strict_mode: false, + utf8_view: false, + schema: None, + } + } +} + +impl ReaderBuilder { + /// Creates a new [`ReaderBuilder`] with default settings: + /// - `batch_size` = 1024 + /// - `strict_mode` = false + /// - `utf8_view` = false + /// - `schema` = None pub fn new() -> Self { Self::default() } + fn make_record_decoder(&self, schema: &AvroSchema<'_>) -> Result { + let root_field = AvroField::try_from(schema)?; + RecordDecoder::try_new_with_options( + root_field.data_type(), + self.utf8_view, + self.strict_mode, + ) + } + + fn build_impl(self, reader: &mut R) -> Result<(Header, Decoder), ArrowError> { + let header = read_header(reader)?; + let record_decoder = if let Some(schema) = &self.schema { + self.make_record_decoder(schema)? + } else { + let avro_schema: Option> = header + .schema() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let avro_schema = avro_schema.ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".to_string()) + })?; + self.make_record_decoder(&avro_schema)? + }; + let decoder = Decoder::new(record_decoder, self.batch_size); + Ok((header, decoder)) + } + + /// Sets the row-based batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + /// Set whether to use StringViewArray for string data /// /// When enabled, string data from Avro files will be loaded into /// Arrow's StringViewArray instead of the standard StringArray. - pub fn with_utf8view(mut self, use_utf8view: bool) -> Self { - self.use_utf8view = use_utf8view; + pub fn with_utf8_view(mut self, utf8_view: bool) -> Self { + self.utf8_view = utf8_view; self } /// Get whether StringViewArray is enabled for string data pub fn use_utf8view(&self) -> bool { - self.use_utf8view + self.utf8_view } -} -/// Read a [`Header`] from the provided [`BufRead`] -fn read_header(mut reader: R) -> Result { - let mut decoder = HeaderDecoder::default(); - loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { - break; - } - let read = buf.len(); - let decoded = decoder.decode(buf)?; - reader.consume(decoded); - if decoded != read { - break; + /// Controls whether certain Avro unions of the form `[T, "null"]` should produce an error. + pub fn with_strict_mode(mut self, strict_mode: bool) -> Self { + self.strict_mode = strict_mode; + self + } + + /// Sets the Avro schema. + /// + /// If a schema is not provided, the schema will be read from the Avro file header. + pub fn with_schema(mut self, schema: AvroSchema<'static>) -> Self { + self.schema = Some(schema); + self + } + + /// Create a [`Reader`] from this builder and a `BufRead` + pub fn build(self, mut reader: R) -> Result, ArrowError> { + let (header, decoder) = self.build_impl(&mut reader)?; + Ok(Reader { + reader, + header, + decoder, + block_decoder: BlockDecoder::default(), + block_data: Vec::new(), + block_cursor: 0, + finished: false, + }) + } + + /// Create a [`Decoder`] from this builder and a `BufRead` by + /// reading and parsing the Avro file's header. This will + /// not create a full [`Reader`]. + pub fn build_decoder(self, mut reader: R) -> Result { + match self.schema { + Some(ref schema) => { + let record_decoder = self.make_record_decoder(schema)?; + Ok(Decoder::new(record_decoder, self.batch_size)) + } + None => { + let (_, decoder) = self.build_impl(&mut reader)?; + Ok(decoder) + } } } +} - decoder - .flush() - .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) +/// A high-level Avro `Reader` that reads container-file blocks +/// and feeds them into a row-level [`Decoder`]. +#[derive(Debug)] +pub struct Reader { + reader: R, + header: Header, + decoder: Decoder, + block_decoder: BlockDecoder, + block_data: Vec, + block_cursor: usize, + finished: bool, } -/// Return an iterator of [`Block`] from the provided [`BufRead`] -fn read_blocks(mut reader: R) -> impl Iterator> { - let mut decoder = BlockDecoder::default(); +impl Reader { + /// Return the Arrow schema discovered from the Avro file header + pub fn schema(&self) -> SchemaRef { + self.decoder.schema() + } - let mut try_next = move || { - loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { - break; + /// Return the Avro container-file header + pub fn avro_header(&self) -> &Header { + &self.header + } + + /// Reads the next [`RecordBatch`] from the Avro file or `Ok(None)` on EOF + fn read(&mut self) -> Result, ArrowError> { + 'outer: while !self.finished && !self.decoder.batch_is_full() { + while self.block_cursor == self.block_data.len() { + let buf = self.reader.fill_buf()?; + if buf.is_empty() { + self.finished = true; + break 'outer; + } + // Try to decode another block from the buffered reader. + let consumed = self.block_decoder.decode(buf)?; + self.reader.consume(consumed); + if let Some(block) = self.block_decoder.flush() { + // Successfully decoded a block. + let block_data = if let Some(ref codec) = self.header.compression()? { + codec.decompress(&block.data)? + } else { + block.data + }; + self.block_data = block_data; + self.block_cursor = 0; + } else if consumed == 0 { + // The block decoder made no progress on a non-empty buffer. + return Err(ArrowError::ParseError( + "Could not decode next Avro block from partial data".to_string(), + )); + } } - let read = buf.len(); - let decoded = decoder.decode(buf)?; - reader.consume(decoded); - if decoded != read { - break; + // Try to decode more rows from the current block. + let consumed = self.decoder.decode(&self.block_data[self.block_cursor..])?; + if consumed == 0 && self.block_cursor < self.block_data.len() { + self.block_cursor = self.block_data.len(); + } else { + self.block_cursor += consumed; } } - Ok(decoder.flush()) - }; - std::iter::from_fn(move || try_next().transpose()) + self.decoder.flush() + } +} + +impl Iterator for Reader { + type Item = Result; + + fn next(&mut self) -> Option { + self.read().transpose() + } +} + +impl RecordBatchReader for Reader { + fn schema(&self) -> SchemaRef { + self.schema() + } } #[cfg(test)] @@ -119,61 +394,51 @@ mod test { use crate::codec::{AvroDataType, AvroField, Codec}; use crate::compression::CompressionCodec; use crate::reader::record::RecordDecoder; - use crate::reader::{read_blocks, read_header}; + use crate::reader::vlq::VLQDecoder; + use crate::reader::{read_header, Decoder, ReaderBuilder}; use crate::test_util::arrow_test_data; use arrow_array::types::Int32Type; use arrow_array::*; - use arrow_schema::{DataType, Field, Schema}; + use arrow_schema::{ArrowError, DataType, Field, Schema}; + use bytes::{Buf, BufMut, Bytes}; + use futures::executor::block_on; + use futures::{stream, Stream, StreamExt, TryStreamExt}; use std::collections::HashMap; + use std::fs; use std::fs::File; - use std::io::BufReader; + use std::io::{BufReader, Cursor, Read}; use std::sync::Arc; - - fn read_file(file: &str, batch_size: usize) -> RecordBatch { - read_file_with_options(file, batch_size, &crate::ReadOptions::default()) + use std::task::{ready, Poll}; + + fn read_file(path: &str, batch_size: usize, utf8_view: bool) -> RecordBatch { + let file = File::open(path).unwrap(); + let reader = ReaderBuilder::new() + .with_batch_size(batch_size) + .with_utf8_view(utf8_view) + .build(BufReader::new(file)) + .unwrap(); + let schema = reader.schema(); + let batches = reader.collect::, _>>().unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() } - fn read_file_with_options( - file: &str, - batch_size: usize, - options: &crate::ReadOptions, - ) -> RecordBatch { - let file = File::open(file).unwrap(); - let mut reader = BufReader::new(file); - let header = read_header(&mut reader).unwrap(); - let compression = header.compression().unwrap(); - let schema = header.schema().unwrap().unwrap(); - let root = AvroField::try_from(&schema).unwrap(); - - let mut decoder = - RecordDecoder::try_new_with_options(root.data_type(), options.clone()).unwrap(); - - for result in read_blocks(reader) { - let block = result.unwrap(); - assert_eq!(block.sync, header.sync()); - - let mut decode_data = |data: &[u8]| { - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = remaining.min(batch_size); - if to_read == 0 { - break; - } - offset += decoder.decode(&data[offset..], to_read).unwrap(); - remaining -= to_read; + fn decode_stream + Unpin>( + mut decoder: Decoder, + mut input: S, + ) -> impl Stream> { + async_stream::try_stream! { + if let Some(data) = input.next().await { + let consumed = decoder.decode(&data)?; + if consumed < data.len() { + Err(ArrowError::ParseError( + "did not consume all bytes".to_string(), + ))?; } - assert_eq!(offset, data.len()); - }; - - if let Some(c) = compression { - let decompressed = c.decompress(&block.data).unwrap(); - decode_data(&decompressed); - } else { - decode_data(&block.data); + } + if let Some(batch) = decoder.flush()? { + yield batch } } - decoder.flush().unwrap() } #[test] @@ -311,8 +576,97 @@ mod test { for file in files { let file = arrow_test_data(file); - assert_eq!(read_file(&file, 8), expected); - assert_eq!(read_file(&file, 3), expected); + assert_eq!(read_file(&file, 8, false), expected); + assert_eq!(read_file(&file, 3, false), expected); + } + } + + #[test] + fn test_decode_stream_with_schema() { + struct TestCase<'a> { + name: &'a str, + schema: &'a str, + expected_error: Option<&'a str>, + } + let tests = vec![ + TestCase { + name: "success", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#, + expected_error: None, + }, + TestCase { + name: "valid schema invalid data", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#, + expected_error: Some("did not consume all bytes"), + }, + ]; + for test in tests { + let schema_s2: crate::schema::Schema = serde_json::from_str(test.schema).unwrap(); + let record_val = "some_string"; + let mut body = vec![]; + body.push((record_val.len() as u8) << 1); + body.extend_from_slice(record_val.as_bytes()); + let mut reader_placeholder = Cursor::new(&[] as &[u8]); + let builder = ReaderBuilder::new() + .with_batch_size(1) + .with_schema(schema_s2); + let decoder_result = builder.build_decoder(&mut reader_placeholder); + let decoder = match decoder_result { + Ok(decoder) => decoder, + Err(e) => { + if let Some(expected) = test.expected_error { + assert!( + e.to_string().contains(expected), + "Test '{}' failed: unexpected error message at build.\nExpected to contain: '{expected}'\nActual: '{e}'", + test.name, + ); + continue; + } else { + panic!("Test '{}' failed at decoder build: {e}", test.name); + } + } + }; + let stream = Box::pin(stream::once(async { Bytes::from(body) })); + let decoded_stream = decode_stream(decoder, stream); + let batches_result: Result, ArrowError> = + block_on(decoded_stream.try_collect()); + match (batches_result, test.expected_error) { + (Ok(batches), None) => { + let batch = + arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap(); + let expected_field = Field::new("f2", DataType::Utf8, false); + let expected_schema = Arc::new(Schema::new(vec![expected_field])); + let expected_array = Arc::new(StringArray::from(vec![record_val])); + let expected_batch = + RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap(); + assert_eq!(batch, expected_batch, "Test '{}' failed", test.name); + assert_eq!( + batch.schema().field(0).name(), + "f2", + "Test '{}' failed", + test.name + ); + } + (Err(e), Some(expected)) => { + assert!( + e.to_string().contains(expected), + "Test '{}' failed: unexpected error message at decode.\nExpected to contain: '{expected}'\nActual: '{e}'", + test.name, + ); + } + (Ok(batches), Some(expected)) => { + panic!( + "Test '{}' was expected to fail with '{expected}', but it succeeded with: {:?}", + test.name, batches + ); + } + (Err(e), None) => { + panic!( + "Test '{}' was not expected to fail, but it did with '{e}'", + test.name + ); + } + } } } @@ -327,7 +681,7 @@ mod test { let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); for (file, precision, scale) in files { let file_path = arrow_test_data(file); - let actual_batch = read_file(&file_path, 8); + let actual_batch = read_file(&file_path, 8, false); let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) .with_precision_and_scale(precision, scale) .unwrap(); @@ -344,7 +698,7 @@ mod test { actual_batch, expected_batch, "Decoded RecordBatch does not match the expected Decimal128 data for file {file}" ); - let actual_batch_small = read_file(&file_path, 3); + let actual_batch_small = read_file(&file_path, 3, false); assert_eq!( actual_batch_small, expected_batch, @@ -434,9 +788,9 @@ mod test { } for (file_name, batch_size, expected, alt_batch_size) in tests { let file = arrow_test_data(file_name); - let actual = read_file(&file, batch_size); + let actual = read_file(&file, batch_size, false); assert_eq!(actual, expected); - let actual2 = read_file(&file, alt_batch_size); + let actual2 = read_file(&file, alt_batch_size, false); assert_eq!(actual2, expected); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 8cb9c433e928..972a416a6a51 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -19,7 +19,6 @@ use crate::codec::{AvroDataType, Codec, Nullability}; use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; use crate::reader::header::Header; -use crate::reader::ReadOptions; use crate::schema::*; use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; use arrow_array::types::*; @@ -36,35 +35,84 @@ use std::sync::Arc; const DEFAULT_CAPACITY: usize = 1024; +#[derive(Debug)] +pub(crate) struct RecordDecoderBuilder<'a> { + data_type: &'a AvroDataType, + use_utf8view: bool, + strict_mode: bool, +} + +impl<'a> RecordDecoderBuilder<'a> { + pub(crate) fn new(data_type: &'a AvroDataType) -> Self { + Self { + data_type, + use_utf8view: false, + strict_mode: false, + } + } + + pub(crate) fn with_utf8_view(mut self, use_utf8view: bool) -> Self { + self.use_utf8view = use_utf8view; + self + } + + pub(crate) fn with_strict_mode(mut self, strict_mode: bool) -> Self { + self.strict_mode = strict_mode; + self + } + + /// Builds the `RecordDecoder`. + pub(crate) fn build(self) -> Result { + RecordDecoder::try_new_with_options(self.data_type, self.use_utf8view, self.strict_mode) + } +} + /// Decodes avro encoded data into [`RecordBatch`] -pub struct RecordDecoder { +#[derive(Debug)] +pub(crate) struct RecordDecoder { schema: SchemaRef, fields: Vec, use_utf8view: bool, + strict_mode: bool, } impl RecordDecoder { + /// Creates a new `RecordDecoderBuilder` for configuring a `RecordDecoder`. + pub(crate) fn new(data_type: &'_ AvroDataType) -> Self { + RecordDecoderBuilder::new(data_type).build().unwrap() + } + /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with default options - pub fn try_new(data_type: &AvroDataType) -> Result { - Self::try_new_with_options(data_type, ReadOptions::default()) + pub(crate) fn try_new(data_type: &AvroDataType) -> Result { + RecordDecoderBuilder::new(data_type) + .with_utf8_view(true) + .with_strict_mode(true) + .build() } - /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with additional options + /// Creates a new [`RecordDecoder`] from the provided [`AvroDataType`] with additional options. /// /// This method allows you to customize how the Avro data is decoded into Arrow arrays. /// - /// # Parameters - /// * `data_type` - The Avro data type to decode - /// * `options` - Configuration options for decoding - pub fn try_new_with_options( + /// # Arguments + /// * `data_type` - The Avro data type to decode. + /// * `use_utf8view` - A flag indicating whether to use `Utf8View` for string types. + /// * `strict_mode` - A flag to enable strict decoding, returning an error if the data + /// does not conform to the schema. + /// + /// # Errors + /// This function will return an error if the provided `data_type` is not a `Record`. + pub(crate) fn try_new_with_options( data_type: &AvroDataType, - options: ReadOptions, + use_utf8view: bool, + strict_mode: bool, ) -> Result { match Decoder::try_new(data_type)? { Decoder::Record(fields, encodings) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), fields: encodings, - use_utf8view: options.use_utf8view(), + use_utf8view, + strict_mode, }), encoding => Err(ArrowError::ParseError(format!( "Expected record got {encoding:?}" @@ -72,12 +120,13 @@ impl RecordDecoder { } } - pub fn schema(&self) -> &SchemaRef { + /// Returns the decoder's `SchemaRef` + pub(crate) fn schema(&self) -> &SchemaRef { &self.schema } /// Decode `count` records from `buf` - pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { + pub(crate) fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); for _ in 0..count { for field in &mut self.fields { @@ -88,7 +137,7 @@ impl RecordDecoder { } /// Flush the decoded records into a [`RecordBatch`] - pub fn flush(&mut self) -> Result { + pub(crate) fn flush(&mut self) -> Result { let arrays = self .fields .iter_mut()