diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs index c807d99d77d9..fb126aa9e6de 100644 --- a/arrow-select/src/coalesce.rs +++ b/arrow-select/src/coalesce.rs @@ -21,13 +21,13 @@ //! [`filter`]: crate::filter::filter //! [`take`]: crate::take::take use crate::concat::concat_batches; -use arrow_array::StringViewArray; +use crate::filter::filter_record_batch; use arrow_array::{cast::AsArray, Array, ArrayRef, RecordBatch}; +use arrow_array::{BooleanArray, StringViewArray}; use arrow_data::ByteView; use arrow_schema::{ArrowError, SchemaRef}; use std::collections::VecDeque; use std::sync::Arc; - // Originally From DataFusion's coalesce module: // https://github.com/apache/datafusion/blob/9d2f04996604e709ee440b65f41e7b882f50b788/datafusion/physical-plan/src/coalesce/mod.rs#L26-L25 @@ -155,9 +155,62 @@ impl BatchCoalescer { Arc::clone(&self.schema) } - /// Push next batch into the Coalescer + /// Push a batch into the Coalescer after applying a filter + /// + /// This is semantically equivalent of calling [`Self::push_batch`] + /// with the results from [`filter_record_batch`] + /// + /// # Example + /// # Example + /// ``` + /// # use arrow_array::{record_batch, BooleanArray}; + /// # use arrow_select::coalesce::BatchCoalescer; + /// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + /// let batch2 = record_batch!(("a", Int32, [4, 5, 6])).unwrap(); + /// // Apply a filter to each batch to pick the first and last row + /// let filter = BooleanArray::from(vec![true, false, true]); + /// // create a new Coalescer that targets creating 1000 row batches + /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000); + /// coalescer.push_batch_with_filter(batch1, &filter); + /// coalescer.push_batch_with_filter(batch2, &filter); + /// // finsh and retrieve the created batch + /// coalescer.finish_buffered_batch().unwrap(); + /// let completed_batch = coalescer.next_completed_batch().unwrap(); + /// // filtered out 2 and 5: + /// let expected_batch = record_batch!(("a", Int32, [1, 3, 4, 6])).unwrap(); + /// assert_eq!(completed_batch, expected_batch); + /// ``` + pub fn push_batch_with_filter( + &mut self, + batch: RecordBatch, + filter: &BooleanArray, + ) -> Result<(), ArrowError> { + // TODO: optimize this to avoid materializing (copying the results + // of filter to a new batch) + let filtered_batch = filter_record_batch(&batch, filter)?; + self.push_batch(filtered_batch) + } + + /// Push all the rows from `batch` into the Coalescer /// /// See [`Self::next_completed_batch()`] to retrieve any completed batches. + /// + /// # Example + /// ``` + /// # use arrow_array::record_batch; + /// # use arrow_select::coalesce::BatchCoalescer; + /// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); + /// let batch2 = record_batch!(("a", Int32, [4, 5, 6])).unwrap(); + /// // create a new Coalescer that targets creating 1000 row batches + /// let mut coalescer = BatchCoalescer::new(batch1.schema(), 1000); + /// coalescer.push_batch(batch1); + /// coalescer.push_batch(batch2); + /// // finsh and retrieve the created batch + /// coalescer.finish_buffered_batch().unwrap(); + /// let completed_batch = coalescer.next_completed_batch().unwrap(); + /// let expected_batch = record_batch!(("a", Int32, [1, 2, 3, 4, 5, 6])).unwrap(); + /// assert_eq!(completed_batch, expected_batch); + /// ``` pub fn push_batch(&mut self, batch: RecordBatch) -> Result<(), ArrowError> { if batch.num_rows() == 0 { // If the batch is empty, we don't need to do anything diff --git a/arrow/benches/coalesce_kernels.rs b/arrow/benches/coalesce_kernels.rs index 16db07e38875..1168d4b023cd 100644 --- a/arrow/benches/coalesce_kernels.rs +++ b/arrow/benches/coalesce_kernels.rs @@ -214,10 +214,9 @@ fn filter_streams( while num_output_batches > 0 { let filter = filter_stream.next_filter(); let batch = data_stream.next_batch(); - // Apply the filter to the input batch - let filtered_batch = arrow_select::filter::filter_record_batch(batch, filter).unwrap(); - // Add the filtered batch to the coalescer - coalescer.push_batch(filtered_batch).unwrap(); + coalescer + .push_batch_with_filter(batch.clone(), filter) + .unwrap(); // consume (but discard) the output batch if coalescer.next_completed_batch().is_some() { num_output_batches -= 1;