diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 52e35985698f..271ba6ddcff5 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -938,6 +938,11 @@ config_namespace! { /// HashJoin can work more efficiently than SortMergeJoin but consumes more memory pub prefer_hash_join: bool, default = true + /// When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently + /// experimental. Physical planner will opt for PiecewiseMergeJoin when there is only + /// one range filter. + pub enable_piecewise_merge_join: bool, default = false + /// The maximum estimated size in bytes for one input side of a HashJoin /// will be collected into a single partition pub hash_join_single_partition_threshold: usize, default = 1024 * 1024 diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 0fa17deea129..bea51d31baac 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -78,10 +78,11 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::split_conjunction; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, + WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; @@ -91,6 +92,7 @@ use datafusion_physical_expr::{ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::execution_plan::InvariantLevel; +use datafusion_physical_plan::joins::PiecewiseMergeJoinExec; use datafusion_physical_plan::metrics::MetricType; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::recursive_query::RecursiveQueryExec; @@ -1133,8 +1135,42 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; + // TODO: `num_range_filters` can be used later on for ASOF joins (`num_range_filters > 1`) + let mut num_range_filters = 0; + let mut range_filters: Vec = Vec::new(); + let mut total_filters = 0; + let join_filter = match filter { Some(expr) => { + let split_expr = split_conjunction(expr); + for expr in split_expr.iter() { + match *expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + right: _, + op, + }) => { + if matches!( + op, + Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + ) { + range_filters.push((**expr).clone()); + num_range_filters += 1; + } + total_filters += 1; + } + // TODO: Want to deal with `Expr::Between` for IEJoins, it counts as two range predicates + // which is why it is not dealt with in PWMJ + // Expr::Between(_) => {}, + _ => { + total_filters += 1; + } + } + } + // Extract columns from filter expression and saved in a HashSet let cols = expr.column_refs(); @@ -1190,6 +1226,7 @@ impl DefaultPhysicalPlanner { )?; let filter_schema = Schema::new_with_metadata(filter_fields, metadata); + let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1212,10 +1249,125 @@ impl DefaultPhysicalPlanner { let prefer_hash_join = session_state.config_options().optimizer.prefer_hash_join; + // TODO: Allow PWMJ to deal with residual equijoin conditions let join: Arc = if join_on.is_empty() { if join_filter.is_none() && matches!(join_type, JoinType::Inner) { // cross join if there is no join conditions and no join filter set Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else if num_range_filters == 1 + && total_filters == 1 + && !matches!( + join_type, + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) + && session_state + .config_options() + .optimizer + .enable_piecewise_merge_join + { + let Expr::BinaryExpr(be) = &range_filters[0] else { + return plan_err!( + "Unsupported expression for PWMJ: Expected `Expr::BinaryExpr`" + ); + }; + + let mut op = be.op; + if !matches!( + op, + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq + ) { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + fn reverse_ineq(op: Operator) -> Operator { + match op { + Operator::Lt => Operator::Gt, + Operator::LtEq => Operator::GtEq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + _ => op, + } + } + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + enum Side { + Left, + Right, + Both, + } + + let side_of = |e: &Expr| -> Result { + let cols = e.column_refs(); + let any_left = cols + .iter() + .any(|c| left_df_schema.index_of_column(c).is_ok()); + let any_right = cols + .iter() + .any(|c| right_df_schema.index_of_column(c).is_ok()); + + Ok(match (any_left, any_right) { + (true, false) => Side::Left, + (false, true) => Side::Right, + (true, true) => Side::Both, + _ => unreachable!(), + }) + }; + + let mut lhs_logical = &be.left; + let mut rhs_logical = &be.right; + + let left_side = side_of(lhs_logical)?; + let right_side = side_of(rhs_logical)?; + if matches!(left_side, Side::Both) + || matches!(right_side, Side::Both) + { + return Ok(Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + None, + )?)); + } + + if left_side == Side::Right && right_side == Side::Left { + std::mem::swap(&mut lhs_logical, &mut rhs_logical); + op = reverse_ineq(op); + } else if !(left_side == Side::Left && right_side == Side::Right) + { + return plan_err!( + "Unsupported operator for PWMJ: {:?}. Expected one of <, <=, >, >=", + op + ); + } + + let on_left = create_physical_expr( + lhs_logical, + left_df_schema, + session_state.execution_props(), + )?; + let on_right = create_physical_expr( + rhs_logical, + right_df_schema, + session_state.execution_props(), + )?; + + Arc::new(PiecewiseMergeJoinExec::try_new( + physical_left, + physical_right, + (on_left, on_right), + op, + *join_type, + session_state.config().target_partitions(), + )?) } else { // there is no equal join condition, use the nested loop join Arc::new(NestedLoopJoinExec::try_new( diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index adc00d9fe75e..88c50c2eb2ce 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -637,6 +637,7 @@ impl HashJoinStream { let (left_side, right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, + true, ); let empty_right_batch = RecordBatch::new_empty(self.right.schema()); // use the left and right indices to produce the batch result diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434..b0c28cf994f7 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -24,11 +24,13 @@ pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet +pub use piecewise_merge_join::PiecewiseMergeJoinExec; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; mod nested_loop_join; +mod piecewise_merge_join; mod sort_merge_join; mod stream_join_utils; mod symmetric_hash_join; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs new file mode 100644 index 000000000000..646905e0d787 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -0,0 +1,1550 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner) + +use arrow::array::{new_null_array, Array, PrimitiveBuilder}; +use arrow::compute::{take, BatchCoalescer}; +use arrow::datatypes::UInt32Type; +use arrow::{ + array::{ArrayRef, RecordBatch, UInt32Array}, + compute::{sort_to_indices, take_record_batch}, +}; +use arrow_schema::{Schema, SchemaRef, SortOptions}; +use datafusion_common::NullEquality; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{Stream, StreamExt}; +use std::{cmp::Ordering, task::ready}; +use std::{sync::Arc, task::Poll}; + +use crate::handle_state; +use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; +use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; +use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; +use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; + +pub(super) enum PiecewiseMergeJoinStreamState { + WaitBufferedSide, + FetchStreamBatch, + ProcessStreamBatch(SortedStreamBatch), + ProcessUnmatched, + Completed, +} + +impl PiecewiseMergeJoinStreamState { + // Grab mutable reference to the current stream batch + fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> { + match self { + PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), + _ => internal_err!("Expected streamed batch in StreamBatch"), + } + } +} + +/// The stream side incoming batch with required sort order. +/// +/// Note the compare key in the join predicate might include expressions on the original +/// columns, so we store the evaluated compare key separately. +/// e.g. For join predicate `buffer.v1 < (stream.v1 + 1)`, the `compare_key_values` field stores +/// the evaluated `stream.v1 + 1` array. +pub(super) struct SortedStreamBatch { + pub batch: RecordBatch, + compare_key_values: Vec, +} + +impl SortedStreamBatch { + #[allow(dead_code)] + fn new(batch: RecordBatch, compare_key_values: Vec) -> Self { + Self { + batch, + compare_key_values, + } + } + + fn compare_key_values(&self) -> &Vec { + &self.compare_key_values + } +} + +pub(super) struct ClassicPWMJStream { + // Output schema of the `PiecewiseMergeJoin` + pub schema: Arc, + + // Physical expression that is evaluated on the streamed side + // We do not need on_buffered as this is already evaluated when + // creating the buffered side which happens before initializing + // `PiecewiseMergeJoinStream` + pub on_streamed: PhysicalExprRef, + // Type of join + pub join_type: JoinType, + // Comparison operator + pub operator: Operator, + // Streamed batch + pub streamed: SendableRecordBatchStream, + // Streamed schema + streamed_schema: SchemaRef, + // Buffered side data + buffered_side: BufferedSide, + // Tracks the state of the `PiecewiseMergeJoin` + state: PiecewiseMergeJoinStreamState, + // Sort option for streamed side (specifies whether + // the sort is ascending or descending) + sort_option: SortOptions, + // Metrics for build + probe joins + join_metrics: BuildProbeJoinMetrics, + // Tracking incremental state for emitting record batches + batch_process_state: BatchProcessState, +} + +impl RecordBatchStream for ClassicPWMJStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, +// `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`. +// +// Classic Joins +// 1. `WaitBufferedSide` - Load in the buffered side data into memory. +// 2. `FetchStreamBatch` - Fetch + sort incoming stream batches. We switch the state to +// `Completed` if there are are still remaining partitions to process. It is only switched to +// `ExhaustedStreamBatch` if all partitions have been processed. +// 3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data. +// 4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as +// `Completed` however for Full and Right we will need to process the unmatched buffered rows. +impl ClassicPWMJStream { + // Creates a new `PiecewiseMergeJoinStream` instance + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: Arc, + on_streamed: PhysicalExprRef, + join_type: JoinType, + operator: Operator, + streamed: SendableRecordBatchStream, + buffered_side: BufferedSide, + state: PiecewiseMergeJoinStreamState, + sort_option: SortOptions, + join_metrics: BuildProbeJoinMetrics, + batch_size: usize, + ) -> Self { + Self { + schema: Arc::clone(&schema), + on_streamed, + join_type, + operator, + streamed_schema: streamed.schema(), + streamed, + buffered_side, + state, + sort_option, + join_metrics, + batch_process_state: BatchProcessState::new(schema, batch_size), + } + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + PiecewiseMergeJoinStreamState::WaitBufferedSide => { + handle_state!(ready!(self.collect_buffered_side(cx))) + } + PiecewiseMergeJoinStreamState::FetchStreamBatch => { + handle_state!(ready!(self.fetch_stream_batch(cx))) + } + PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { + handle_state!(self.process_stream_batch()) + } + PiecewiseMergeJoinStreamState::ProcessUnmatched => { + handle_state!(self.process_unmatched_buffered_batch()) + } + PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + // Collects buffered side data + fn collect_buffered_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); + let buffered_data = ready!(self + .buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx))?; + build_timer.done(); + + // We will start fetching stream batches for classic joins + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + + self.buffered_side = + BufferedSide::Ready(BufferedSideReadyState { buffered_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Fetches incoming stream batches + fn fetch_stream_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.streamed.poll_next_unpin(cx)) { + None => { + if self + .buffered_side + .try_as_ready_mut()? + .buffered_data + .remaining_partitions + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) + == 1 + { + self.batch_process_state.reset(); + self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched; + } else { + self.state = PiecewiseMergeJoinStreamState::Completed; + } + } + Some(Ok(batch)) => { + // Evaluate the streamed physical expression on the stream batch + let stream_values: ArrayRef = self + .on_streamed + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + // Sort stream values and change the streamed record batch accordingly + let indices = sort_to_indices( + stream_values.as_ref(), + Some(self.sort_option), + None, + )?; + let stream_batch = take_record_batch(&batch, &indices)?; + let stream_values = take(stream_values.as_ref(), &indices, None)?; + + // Reset BatchProcessState before processing a new stream batch + self.batch_process_state.reset(); + self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch( + SortedStreamBatch { + batch: stream_batch, + compare_key_values: vec![stream_values], + }, + ); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Only classic join will call. This function will process stream batches and evaluate against + // the buffered side data. + fn process_stream_batch( + &mut self, + ) -> Result>> { + let buffered_side = self.buffered_side.try_as_ready_mut()?; + let stream_batch = self.state.try_as_process_stream_batch_mut()?; + + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + // Produce more work + let batch = resolve_classic_join( + buffered_side, + stream_batch, + Arc::clone(&self.schema), + self.operator, + self.sort_option, + self.join_type, + &mut self.batch_process_state, + )?; + + if !self.batch_process_state.continue_process { + // We finished scanning this stream batch. + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(b) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + return Ok(StatefulStreamResult::Ready(Some(b))); + } + + // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. + if self.batch_process_state.output_batches.is_empty() { + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + } + + Ok(StatefulStreamResult::Ready(Some(batch))) + } + + // Process remaining unmatched rows + fn process_unmatched_buffered_batch( + &mut self, + ) -> Result>> { + // Return early for `JoinType::Right` and `JoinType::Inner` + if matches!(self.join_type, JoinType::Right | JoinType::Inner) { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + } + + if !self.batch_process_state.continue_process { + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + } + + let buffered_data = + Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + + let (buffered_indices, _streamed_indices) = get_final_indices_from_shared_bitmap( + &buffered_data.visited_indices_bitmap, + self.join_type, + true, + ); + + let new_buffered_batch = + take_record_batch(buffered_data.batch(), &buffered_indices)?; + let mut buffered_columns = new_buffered_batch.columns().to_vec(); + + let streamed_columns: Vec = self + .streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), new_buffered_batch.num_rows())) + .collect(); + + buffered_columns.extend(streamed_columns); + + let batch = RecordBatch::try_new(Arc::clone(&self.schema), buffered_columns)?; + + self.batch_process_state.output_batches.push_batch(batch)?; + + self.batch_process_state.continue_process = false; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.state = PiecewiseMergeJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.state = PiecewiseMergeJoinStreamState::Completed; + self.batch_process_state.reset(); + Ok(StatefulStreamResult::Ready(None)) + } +} + +struct BatchProcessState { + // Used to pick up from the last index on the stream side + output_batches: Box, + // Used to store the unmatched stream indices for `JoinType::Right` and `JoinType::Full` + unmatched_indices: PrimitiveBuilder, + // Used to store the start index on the buffered side; used to resume processing on the correct + // row + start_buffer_idx: usize, + // Used to store the start index on the stream side; used to resume processing on the correct + // row + start_stream_idx: usize, + // Signals if we found a match for the current stream row + found: bool, + // Signals to continue processing the current stream batch + continue_process: bool, + // Skip nulls + processed_null_count: bool, +} + +impl BatchProcessState { + pub(crate) fn new(schema: Arc, batch_size: usize) -> Self { + Self { + output_batches: Box::new(BatchCoalescer::new(schema, batch_size)), + unmatched_indices: PrimitiveBuilder::new(), + start_buffer_idx: 0, + start_stream_idx: 0, + found: false, + continue_process: true, + processed_null_count: false, + } + } + + pub(crate) fn reset(&mut self) { + self.unmatched_indices = PrimitiveBuilder::new(); + self.start_buffer_idx = 0; + self.start_stream_idx = 0; + self.found = false; + self.continue_process = true; + self.processed_null_count = false; + } +} + +impl Stream for ClassicPWMJStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +// For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +#[allow(clippy::too_many_arguments)] +fn resolve_classic_join( + buffered_side: &mut BufferedSideReadyState, + stream_batch: &SortedStreamBatch, + join_schema: Arc, + operator: Operator, + sort_options: SortOptions, + join_type: JoinType, + batch_process_state: &mut BatchProcessState, +) -> Result { + let buffered_len = buffered_side.buffered_data.values().len(); + let stream_values = stream_batch.compare_key_values(); + + let mut buffer_idx = batch_process_state.start_buffer_idx; + let mut stream_idx = batch_process_state.start_stream_idx; + + if !batch_process_state.processed_null_count { + let buffered_null_idx = buffered_side.buffered_data.values().null_count(); + let stream_null_idx = stream_values[0].null_count(); + buffer_idx = buffered_null_idx; + stream_idx = stream_null_idx; + batch_process_state.processed_null_count = true; + } + + // Our buffer_idx variable allows us to start probing on the buffered side where we last matched + // in the previous stream row. + for row_idx in stream_idx..stream_batch.batch.num_rows() { + while buffer_idx < buffered_len { + let compare = { + let buffered_values = buffered_side.buffered_data.values(); + compare_join_arrays( + &[Arc::clone(&stream_values[0])], + row_idx, + &[Arc::clone(buffered_values)], + buffer_idx, + &[sort_options], + NullEquality::NullEqualsNothing, + )? + }; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + batch_process_state.found = true; + let count = buffered_len - buffer_idx; + + let batch = build_matched_indices_and_set_buffered_bitmap( + (buffer_idx, count), + (row_idx, count), + buffered_side, + stream_batch, + join_type, + Arc::clone(&join_schema), + )?; + + batch_process_state.output_batches.push_batch(batch)?; + + // Flush batch and update pointers if we have a completed batch + if let Some(batch) = + batch_process_state.output_batches.next_completed_batch() + { + batch_process_state.found = false; + batch_process_state.start_buffer_idx = buffer_idx; + batch_process_state.start_stream_idx = row_idx + 1; + return Ok(batch); + } + + break; + } + } + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + batch_process_state.found = true; + let count = buffered_len - buffer_idx; + let batch = build_matched_indices_and_set_buffered_bitmap( + (buffer_idx, count), + (row_idx, count), + buffered_side, + stream_batch, + join_type, + Arc::clone(&join_schema), + )?; + + // Flush batch and update pointers if we have a completed batch + batch_process_state.output_batches.push_batch(batch)?; + if let Some(batch) = + batch_process_state.output_batches.next_completed_batch() + { + batch_process_state.found = false; + batch_process_state.start_buffer_idx = buffer_idx; + batch_process_state.start_stream_idx = row_idx + 1; + return Ok(batch); + } + + break; + } + } + _ => { + return internal_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + + // Increment buffer_idx after every row + buffer_idx += 1; + } + + // If a match was not found for the current stream row index the stream indice is appended + // to the unmatched indices to be flushed later. + if matches!(join_type, JoinType::Right | JoinType::Full) + && !batch_process_state.found + { + batch_process_state + .unmatched_indices + .append_value(row_idx as u32); + } + + batch_process_state.found = false; + } + + // Flushed all unmatched indices on the streamed side + if matches!(join_type, JoinType::Right | JoinType::Full) { + let batch = create_unmatched_batch( + &mut batch_process_state.unmatched_indices, + stream_batch, + Arc::clone(&join_schema), + )?; + + batch_process_state.output_batches.push_batch(batch)?; + } + + batch_process_state.continue_process = false; + Ok(RecordBatch::new_empty(Arc::clone(&join_schema))) +} + +// Builds a record batch from indices ranges on the buffered and streamed side. +// +// The two ranges are: buffered_range: (start index, count) and streamed_range: (start index, count) due +// to batch.slice(start, count). +fn build_matched_indices_and_set_buffered_bitmap( + buffered_range: (usize, usize), + streamed_range: (usize, usize), + buffered_side: &mut BufferedSideReadyState, + stream_batch: &SortedStreamBatch, + join_type: JoinType, + join_schema: Arc, +) -> Result { + // Mark the buffered indices as visited + if need_produce_result_in_final(join_type) { + let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); + for i in buffered_range.0..buffered_range.0 + buffered_range.1 { + bitmap.set_bit(i, true); + } + } + + let new_buffered_batch = buffered_side + .buffered_data + .batch() + .slice(buffered_range.0, buffered_range.1); + let mut buffered_columns = new_buffered_batch.columns().to_vec(); + + let indices = UInt32Array::from_value(streamed_range.0 as u32, streamed_range.1); + let new_stream_batch = take_record_batch(&stream_batch.batch, &indices)?; + let streamed_columns = new_stream_batch.columns().to_vec(); + + buffered_columns.extend(streamed_columns); + + Ok(RecordBatch::try_new( + Arc::clone(&join_schema), + buffered_columns, + )?) +} + +// Creates a record batch from the unmatched indices on the streamed side +fn create_unmatched_batch( + streamed_indices: &mut PrimitiveBuilder, + stream_batch: &SortedStreamBatch, + join_schema: Arc, +) -> Result { + let streamed_indices = streamed_indices.finish(); + let new_stream_batch = take_record_batch(&stream_batch.batch, &streamed_indices)?; + let streamed_columns = new_stream_batch.columns().to_vec(); + let buffered_cols_len = join_schema.fields().len() - streamed_columns.len(); + + let num_rows = new_stream_batch.num_rows(); + let mut buffered_columns: Vec = join_schema + .fields() + .iter() + .take(buffered_cols_len) + .map(|field| new_null_array(field.data_type(), num_rows)) + .collect(); + + buffered_columns.extend(streamed_columns); + + Ok(RecordBatch::try_new( + Arc::clone(&join_schema), + buffered_columns, + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + common, + joins::PiecewiseMergeJoinExec, + test::{build_table_i32, TestMemoryExec}, + ExecutionPlan, + }; + use arrow::array::{Date32Array, Date64Array}; + use arrow_schema::{DataType, Field}; + use datafusion_common::test_util::batches_to_string; + use datafusion_execution::TaskContext; + use datafusion_expr::JoinType; + use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use insta::assert_snapshot; + use std::sync::Arc; + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn join( + left: Arc, + right: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options(left, right, on, operator, join_type).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = join(left, right, on, operator, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 20 | 3 | 80 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_unsorted() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 10 | 3 | 70 | + | 3 | 1 | 9 | 10 | 3 | 70 | + | 3 | 1 | 9 | 20 | 2 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 2 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![2, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 2 | 7 | 30 | 1 | 90 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 1 | 2 | 7 | 20 | 2 | 80 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 2 | 3 | 8 | 10 | 3 | 70 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_left() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // (empty) + // +----+----+----+ + let left = build_table( + ("a1", &Vec::::new()), + ("b1", &Vec::::new()), + ("c1", &Vec::::new()), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 1 | 1 | 1 | + // | 2 | 2 | 2 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c2", &vec![1, 2]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_full_greater_than_equal_to() -> Result<()> { + // +----+----+-----+ + // | a1 | b1 | c1 | + // +----+----+-----+ + // | 1 | 1 | 100 | + // | 2 | 2 | 200 | + // +----+----+-----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![100, 200]), + ); + + // +----+----+-----+ + // | a2 | b1 | c2 | + // +----+----+-----+ + // | 10 | 3 | 300 | + // | 20 | 2 | 400 | + // +----+----+-----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 2]), + ("c2", &vec![300, 400]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+-----+ + | 2 | 2 | 200 | 20 | 2 | 400 | + | | | | 10 | 3 | 300 | + | 1 | 1 | 100 | | | | + +----+----+-----+----+----+-----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 10 | 3 | 70 | + | 1 | 1 | 7 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 5 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![5, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 2 | 90 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 20 | 3 | 80 | + | | | | 10 | 5 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 3 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 3, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 30 | 5 | 90 | + | 2 | 3 | 8 | 30 | 5 | 90 | + | 3 | 1 | 9 | 30 | 5 | 90 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_equal_with_dups() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 4 | 8 | + // | 3 | 2 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 4, 2]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 4 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + + // Expected grouping follows right.b1 descending (4, 3, 2) + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 4 | 8 | 10 | 4 | 70 | + | 3 | 2 | 9 | 10 | 4 | 70 | + | 3 | 2 | 9 | 20 | 3 | 80 | + | 3 | 2 | 9 | 30 | 2 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_unsorted_right() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 4]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 1 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; + + // Grouped by right in ascending evaluation for > (1,2,3) + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 2 | 8 | 20 | 1 | 80 | + | 3 | 4 | 9 | 20 | 1 | 80 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_less_than_equal_with_left_nulls_on_no_match() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // +----+----+----+ + let right = build_table(("a2", &vec![10]), ("b1", &vec![3]), ("c2", &vec![70])); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 3 | 1 | 9 | 10 | 3 | 70 | + | 1 | 5 | 7 | | | | + | 2 | 4 | 8 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than_equal_with_right_nulls_on_no_match() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 5 | 80 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 5]), + ("c2", &vec![70, 80]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | | | | 10 | 3 | 70 | + | | | | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_single_row_left_less_than() -> Result<()> { + let left = build_table(("a1", &vec![42]), ("b1", &vec![5]), ("c1", &vec![999])); + + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1, 5, 7]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+----+ + | 42 | 5 | 999 | 30 | 7 | 90 | + +----+----+-----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), + ("c1", &vec![7, 8, 9]), + ); + + let right = build_table( + ("a2", &Vec::::new()), + ("b1", &Vec::::new()), + ("c2", &Vec::::new()), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date32_inner_less_than() -> Result<()> { + // +----+-------+----+ + // | a1 | b1 | c1 | + // +----+-------+----+ + // | 1 | 19107 | 7 | + // | 2 | 19107 | 8 | + // | 3 | 19105 | 9 | + // +----+-------+----+ + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19107, 19105]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+-------+----+ + // | a2 | b1 | c2 | + // +----+-------+----+ + // | 10 | 19105 | 70 | + // | 20 | 19103 | 80 | + // | 30 | 19107 | 90 | + // +----+-------+----+ + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19105, 19103, 19107]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_inner_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650903441000 | 8 | + // | 3 | 1650703441000 | 9 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650903441000, 1650903441000, 1650703441000]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 70 | + // | 20 | 1650503441000 | 80 | + // | 30 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_right_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650703441000 | 8 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1650903441000, 1650703441000]), + ("c1", &vec![7, 8]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 80 | + // | 20 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20]), + ("b1", &vec![1650703441000, 1650903441000]), + ("c2", &vec![80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ +"#); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs new file mode 100644 index 000000000000..987f3e9df45a --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -0,0 +1,748 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::{ + array::{ArrayRef, BooleanBufferBuilder, RecordBatch}, + compute::concat_batches, + util::bit_util, +}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::not_impl_err; +use datafusion_common::{internal_err, JoinSide, Result}; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + SendableRecordBatchStream, +}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{ + Distribution, LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, + PhysicalSortExpr, +}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::TryStreamExt; +use parking_lot::Mutex; +use std::fmt::Formatter; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use crate::execution_plan::{boundedness_from_children, EmissionType}; + +use crate::joins::piecewise_merge_join::classic_join::{ + ClassicPWMJStream, PiecewiseMergeJoinStreamState, +}; +use crate::joins::piecewise_merge_join::utils::{ + build_visited_indices_map, is_existence_join, is_right_existence_join, +}; +use crate::joins::utils::asymmetric_join_output_partitioning; +use crate::{ + joins::{ + utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, + SharedBitmapBuilder, + }, + metrics::ExecutionPlanMetricsSet, + spill::get_record_batch_memory_size, + ExecutionPlan, PlanProperties, +}; +use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; + +/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter and show much +/// better performance for these workloads than `NestedLoopJoin` +/// +/// The physical planner will choose to evaluate this join when there is only one comparison filter. This +/// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and +/// [`Operator::GtEq`].: +/// Examples: +/// - `col0` < `colb`, `col0` <= `colb`, `col0` > `colb`, `col0` >= `colb` +/// +/// # Execution Plan Inputs +/// For `PiecewiseMergeJoin` we label all right inputs as the `streamed' side and the left outputs as the +/// 'buffered' side. +/// +/// `PiecewiseMergeJoin` takes a sorted input for the side to be buffered and is able to sort streamed record +/// batches during processing. Sorted input must specifically be ascending/descending based on the operator. +/// +/// # Algorithms +/// Classic joins are processed differently compared to existence joins. +/// +/// ## Classic Joins (Inner, Full, Left, Right) +/// For classic joins we buffer the build side and stream the probe side (the "probe" side). +/// Both sides are sorted so that we can iterate from index 0 to the end on each side. This ordering ensures +/// that when we find the first matching pair of rows, we can emit the current stream row joined with all remaining +/// probe rows from the match position onward, without rescanning earlier probe rows. +/// +/// For `<` and `<=` operators, both inputs are sorted in **descending** order, while for `>` and `>=` operators +/// they are sorted in **ascending** order. This choice ensures that the pointer on the buffered side can advance +/// monotonically as we stream new batches from the stream side. +/// +/// The streamed side may arrive unsorted, so this operator sorts each incoming batch in memory before +/// processing. The buffered side is required to be globally sorted; the plan declares this requirement +/// in `requires_input_order`, which allows the optimizer to automatically insert a `SortExec` on that side if needed. +/// By the time this operator runs, the buffered side is guaranteed to be in the proper order. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// for stream_row in stream_batch: +/// for buffer_row in buffer_batch: +/// if compare(stream_row, probe_row): +/// output stream_row X buffer_batch[buffer_row:] +/// else: +/// continue +/// ``` +/// +/// The algorithm uses the streamed side (larger) to drive the loop. This is due to every row on the stream side iterating +/// the buffered side to find every first match. By doing this, each match can output more result so that output +/// handling can be better vectorized for performance. +/// +/// Here is an example: +/// +/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). For each +/// row on the streamed side we move a pointer on the buffered until it matches the condition. Once we reach +/// the row which matches (in this case with row 1 on streamed will have its first match on row 2 on +/// buffered; 100 < 200 is true), we can emit all rows after that match. We can emit the rows like this because +/// if the batch is sorted in ascending order, every subsequent row will also satisfy the condition as they will +/// all be larger values. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// LEFT JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Processing Row 1: +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ +/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all +/// ├──────────────────┤ │ rows after the first match (row +/// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) +/// └──────────────────┘ +/// +/// Processing Row 2: +/// By sorting the streamed side we know +/// +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ +/// ├──────────────────┤ +/// 5 │ 400 │ +/// └──────────────────┘ +/// +/// ``` +/// +/// ## Existence Joins (Semi, Anti, Mark) +/// Existence joins are made magnitudes of times faster with a `PiecewiseMergeJoin` as we only need to find +/// the min/max value of the streamed side to be able to emit all matches on the buffered side. By putting +/// the side we need to mark onto the sorted buffer side, we can emit all these matches at once. +/// +/// For less than operations (`<`) both inputs are to be sorted in descending order and vice versa for greater +/// than (`>`) operations. `SortExec` is used to enforce sorting on the buffered side and streamed side does not +/// need to be sorted due to only needing to find the min/max. +/// +/// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// // Using the example of a less than `<` operation +/// let max = max_batch(streamed_batch) +/// +/// for buffer_row in buffer_batch: +/// if buffer_row < max: +/// output buffer_batch[buffer_row:] +/// ``` +/// +/// Only need to find the min/max value and iterate through the buffered side once. +/// +/// Here is an example: +/// We perform a `JoinType::LeftSemi` with these two batches and the operator being `Operator::Lt`(<). Because +/// the operator is `Operator::Lt` we can find the minimum value in the streamed side; in this case it is 200. +/// We can then advance a pointer from the start of the buffer side until we find the first value that satisfies +/// the predicate. All rows after that first matched value satisfy the condition 200 < x so we can mark all of +/// those rows as matched. +/// +/// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (500), (200), (300)) AS streamed(a) +/// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// +/// Sorted Buffered Side Unsorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 500 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 3 │ 200 │ 3 │ 300 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ ─┐ +/// ├──────────────────┤ | We emit matches for row 4 - 5 +/// 5 │ 400 │ ─┘ on the buffered side. +/// └──────────────────┘ +/// min value: 200 +/// ``` +/// +/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or +/// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). +/// +/// # Partitioning Logic +/// Piecewise Merge Join requires one buffered side partition + round robin partitioned stream side. A counter +/// is used in the buffered side to coordinate when all streamed partitions are finished execution. This allows +/// for processing the rest of the unmatched rows for Left and Full joins. The last partition that finishes +/// execution will be responsible for outputting the unmatched rows. +/// +/// # Performance Explanation (cost) +/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: +/// +/// R: Buffered Side +/// S: Streamed Side +/// +/// ## Piecewise Merge Join (PWMJ) +/// +/// # Classic Join: +/// Requires sorting the probe side and, for each probe row, scanning the buffered side until the first match +/// is found. +/// Complexity: `O(sort(S) + num_of_batches(|S|) * scan(R))`. +/// +/// # Mark Join: +/// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only +/// within that range. +/// Complexity: `O(|S| + scan(R[range]))`. +/// +/// ## Nested Loop Join +/// Compares every row from `S` with every row from `R`. +/// Complexity: `O(|S| * |R|)`. +/// +/// ## Nested Loop Join +/// Always going to be probe (O(S) * O(R)). +/// +/// # Further Reference Material +/// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) +#[derive(Debug)] +pub struct PiecewiseMergeJoinExec { + /// Left buffered execution plan + pub buffered: Arc, + /// Right streamed execution plan + pub streamed: Arc, + /// The two expressions being compared + pub on: (Arc, Arc), + /// Comparison operator in the range predicate + pub operator: Operator, + /// How the join is performed + pub join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Buffered data + buffered_fut: OnceAsync, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + + /// Sort expressions - See above for more details [`PiecewiseMergeJoinExec`] + /// + /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + left_child_plan_required_order: LexOrdering, + /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + /// Unsorted for mark joins + #[allow(unused)] + right_batch_required_orders: LexOrdering, + + /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. + sort_options: SortOptions, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, + /// Number of partitions to process + num_partitions: usize, +} + +impl PiecewiseMergeJoinExec { + pub fn try_new( + buffered: Arc, + streamed: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + num_partitions: usize, + ) -> Result { + // TODO: Implement existence joins for PiecewiseMergeJoin + if is_existence_join(join_type) { + return not_impl_err!( + "Existence Joins are currently not supported for PiecewiseMergeJoin" + ); + } + + // Take the operator and enforce a sort order on the streamed + buffered side based on + // the operator type. + let sort_options = match operator { + Operator::Lt | Operator::LtEq => { + // For left existence joins the inputs will be swapped so the sort + // options are switched + if is_right_existence_join(join_type) { + SortOptions::new(false, true) + } else { + SortOptions::new(true, true) + } + } + Operator::Gt | Operator::GtEq => { + if is_right_existence_join(join_type) { + SortOptions::new(true, true) + } else { + SortOptions::new(false, true) + } + } + _ => { + return internal_err!( + "Cannot contain non-range operator in PiecewiseMergeJoinExec" + ) + } + }; + + // Give the same `sort_option for comparison later` + let left_child_plan_required_order = + vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; + let right_batch_required_orders = + vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + + let Some(left_child_plan_required_order) = + LexOrdering::new(left_child_plan_required_order) + else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_batch_required_orders) = + LexOrdering::new(right_batch_required_orders) + else { + return internal_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let buffered_schema = buffered.schema(); + let streamed_schema = streamed.schema(); + + // Create output schema for the join + let schema = + Arc::new(build_join_schema(&buffered_schema, &streamed_schema, &join_type).0); + let cache = Self::compute_properties( + &buffered, + &streamed, + Arc::clone(&schema), + join_type, + &on, + )?; + + Ok(Self { + streamed, + buffered, + on, + operator, + join_type, + schema, + buffered_fut: Default::default(), + metrics: ExecutionPlanMetricsSet::new(), + left_child_plan_required_order, + right_batch_required_orders, + sort_options, + cache, + num_partitions, + }) + } + + /// Reference to buffered side execution plan + pub fn buffered(&self) -> &Arc { + &self.buffered + } + + /// Reference to streamed side execution plan + pub fn streamed(&self) -> &Arc { + &self.streamed + } + + /// Join type + pub fn join_type(&self) -> JoinType { + self.join_type + } + + /// Reference to sort options + pub fn sort_options(&self) -> &SortOptions { + &self.sort_options + } + + /// Get probe side (streamed side) for the PiecewiseMergeJoin + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + match join_type { + JoinType::Right + | JoinType::Inner + | JoinType::Full + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, + } + } + + pub fn compute_properties( + buffered: &Arc, + streamed: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: &(PhysicalExprRef, PhysicalExprRef), + ) -> Result { + let eq_properties = join_equivalence_properties( + buffered.equivalence_properties().clone(), + streamed.equivalence_properties().clone(), + &join_type, + schema, + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + std::slice::from_ref(join_on), + )?; + + let output_partitioning = + asymmetric_join_output_partitioning(buffered, streamed, &join_type)?; + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([buffered, streamed]), + )) + } + + // TODO: Add input order. Now they're all `false` indicating it will not maintain the input order. + // However, for certain join types the order is maintained. This can be updated in the future after + // more testing. + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + // The existence side is expected to come in sorted + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + vec![false, false] + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + vec![false, false] + } + // Left, Right, Full, Inner Join is not guaranteed to maintain + // input order as the streamed side will be sorted during + // execution for `PiecewiseMergeJoin` + _ => vec![false, false], + } + } + + // TODO + pub fn swap_inputs(&self) -> Result> { + todo!() + } +} + +impl ExecutionPlan for PiecewiseMergeJoinExec { + fn name(&self) -> &str { + "PiecewiseMergeJoinExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.buffered, &self.streamed] + } + + fn required_input_distribution(&self) -> Vec { + vec![ + Distribution::SinglePartition, + Distribution::UnspecifiedDistribution, + ] + } + + fn required_input_ordering(&self) -> Vec> { + // Existence joins don't need to be sorted on one side. + if is_right_existence_join(self.join_type) { + unimplemented!() + } else { + // Sort the right side in memory, so we do not need to enforce any sorting + vec![ + Some(OrderingRequirements::from( + self.left_child_plan_required_order.clone(), + )), + None, + ] + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.operator, + self.join_type, + self.num_partitions, + )?)), + _ => internal_err!( + "PiecewiseMergeJoin should have 2 children, found {}", + children.len() + ), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_buffered = Arc::clone(&self.on.0); + let on_streamed = Arc::clone(&self.on.1); + + let metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); + let buffered_fut = self.buffered_fut.try_once(|| { + let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") + .register(context.memory_pool()); + + let buffered_stream = self.buffered.execute(0, Arc::clone(&context))?; + Ok(build_buffered_data( + buffered_stream, + Arc::clone(&on_buffered), + metrics.clone(), + reservation, + build_visited_indices_map(self.join_type), + self.num_partitions, + )) + })?; + + let streamed = self.streamed.execute(partition, Arc::clone(&context))?; + + let batch_size = context.session_config().batch_size(); + + // TODO: Add existence joins + this is guarded at physical planner + if is_existence_join(self.join_type()) { + unreachable!() + } else { + Ok(Box::pin(ClassicPWMJStream::try_new( + Arc::clone(&self.schema), + on_streamed, + self.join_type, + self.operator, + streamed, + BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), + PiecewiseMergeJoinStreamState::WaitBufferedSide, + self.sort_options, + metrics, + batch_size, + ))) + } + } +} + +impl DisplayAs for PiecewiseMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let on_str = format!( + "({} {} {})", + fmt_sql(self.on.0.as_ref()), + self.operator, + fmt_sql(self.on.1.as_ref()) + ); + + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "PiecewiseMergeJoin: operator={:?}, join_type={:?}, on={}", + self.operator, self.join_type, on_str + ) + } + + DisplayFormatType::TreeRender => { + writeln!(f, "operator={:?}", self.operator)?; + if self.join_type != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + writeln!(f, "on={on_str}") + } + } + } +} + +async fn build_buffered_data( + buffered: SendableRecordBatchStream, + on_buffered: PhysicalExprRef, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + build_map: bool, + remaining_partitions: usize, +) -> Result { + let schema = buffered.schema(); + + // Combine batches and record number of rows + let initial = (Vec::new(), 0, metrics, reservation); + let (batches, num_rows, metrics, mut reservation) = buffered + .try_fold(initial, |mut acc, batch| async { + let batch_size = get_record_batch_memory_size(&batch); + acc.3.try_grow(batch_size)?; + acc.2.build_mem_used.add(batch_size); + acc.2.build_input_batches.add(1); + acc.2.build_input_rows.add(batch.num_rows()); + // Update row count + acc.1 += batch.num_rows(); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) + .await?; + + let single_batch = concat_batches(&schema, batches.iter())?; + + // Evaluate physical expression on the buffered side. + let buffered_values = on_buffered + .evaluate(&single_batch)? + .into_array(single_batch.num_rows())?; + + // We add the single batch size + the memory of the join keys + // size of the size estimation + let size_estimation = get_record_batch_memory_size(&single_batch) + + buffered_values.get_array_memory_size(); + reservation.try_grow(size_estimation)?; + metrics.build_mem_used.add(size_estimation); + + // Created visited indices bitmap only if the join type requires it + let visited_indices_bitmap = if build_map { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let buffered_data = BufferedSideData::new( + single_batch, + buffered_values, + Mutex::new(visited_indices_bitmap), + remaining_partitions, + reservation, + ); + + Ok(buffered_data) +} + +pub(super) struct BufferedSideData { + pub(super) batch: RecordBatch, + values: ArrayRef, + pub(super) visited_indices_bitmap: SharedBitmapBuilder, + pub(super) remaining_partitions: AtomicUsize, + _reservation: MemoryReservation, +} + +impl BufferedSideData { + pub(super) fn new( + batch: RecordBatch, + values: ArrayRef, + visited_indices_bitmap: SharedBitmapBuilder, + remaining_partitions: usize, + reservation: MemoryReservation, + ) -> Self { + Self { + batch, + values, + visited_indices_bitmap, + remaining_partitions: AtomicUsize::new(remaining_partitions), + _reservation: reservation, + } + } + + pub(super) fn batch(&self) -> &RecordBatch { + &self.batch + } + + pub(super) fn values(&self) -> &ArrayRef { + &self.values + } +} + +pub(super) enum BufferedSide { + /// Indicates that build-side not collected yet + Initial(BufferedSideInitialState), + /// Indicates that build-side data has been collected + Ready(BufferedSideReadyState), +} + +impl BufferedSide { + // Takes a mutable state of the buffered row batches + pub(super) fn try_as_initial_mut(&mut self) -> Result<&mut BufferedSideInitialState> { + match self { + BufferedSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + pub(super) fn try_as_ready(&self) -> Result<&BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => { + internal_err!("Expected build side in ready state") + } + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + pub(super) fn try_as_ready_mut(&mut self) -> Result<&mut BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +pub(super) struct BufferedSideInitialState { + pub(crate) buffered_fut: OnceFut, +} + +pub(super) struct BufferedSideReadyState { + /// Collected build-side data + pub(super) buffered_data: Arc, +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs new file mode 100644 index 000000000000..c85a7cc16f65 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! PiecewiseMergeJoin is currently experimental + +pub use exec::PiecewiseMergeJoinExec; + +mod classic_join; +mod exec; +mod utils; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs new file mode 100644 index 000000000000..5bbb496322b5 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::JoinType; + +// Returns boolean for whether the join is a right existence join +pub(super) fn is_right_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark + ) +} + +// Returns boolean for whether the join is an existence join +pub(super) fn is_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} + +// Returns boolean to check if the join type needs to record +// buffered side matches for classic joins +pub(super) fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!(join_type, JoinType::Full | JoinType::Left) +} + +// Returns boolean for whether or not we need to build the buffered side +// bitmap for marking matched rows on the buffered side. +pub(super) fn build_visited_indices_map(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Full + | JoinType::Left + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 879f47638d2c..5a2e3669ab5e 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -34,7 +34,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::utils::JoinFilter; +use crate::joins::utils::{compare_join_arrays, JoinFilter}; use crate::spill::spill_manager::SpillManager; use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; @@ -1865,101 +1865,6 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut res = Ordering::Equal; - for ((left_array, right_array), sort_options) in - left_arrays.iter().zip(right_arrays).zip(sort_options) - { - macro_rules! compare_value { - ($T:ty) => {{ - let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_value = &left_array.value(left); - let right_value = &right_array.value(right); - res = left_value.partial_cmp(right_value).unwrap(); - if sort_options.descending { - res = res.reverse(); - } - } - (true, false) => { - res = if sort_options.nulls_first { - Ordering::Less - } else { - Ordering::Greater - }; - } - (false, true) => { - res = if sort_options.nulls_first { - Ordering::Greater - } else { - Ordering::Less - }; - } - _ => { - res = match null_equality { - NullEquality::NullEqualsNothing => Ordering::Less, - NullEquality::NullEqualsNull => Ordering::Equal, - }; - } - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Binary => compare_value!(BinaryArray), - DataType::BinaryView => compare_value!(BinaryViewArray), - DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), - DataType::LargeBinary => compare_value!(LargeBinaryArray), - DataType::Decimal32(..) => compare_value!(Decimal32Array), - DataType::Decimal64(..) => compare_value!(Decimal64Array), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !res.is_eq() { - break; - } - } - Ok(res) -} - /// A faster version of compare_join_arrays() that only output whether /// the given two rows are equal fn is_join_arrays_equal( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c50bfce93a2d..78652d443d3c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,7 +17,7 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::min; +use std::cmp::{min, Ordering}; use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; @@ -43,7 +43,13 @@ use arrow::array::{ BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; -use arrow::array::{ArrayRef, BooleanArray}; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, +}; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; use arrow::compute::{self, and, take, FilterBuilder}; @@ -51,12 +57,13 @@ use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, + not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, + SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; @@ -284,7 +291,7 @@ pub fn build_join_schema( JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), JoinType::LeftMark => { let right_field = once(( - Field::new("mark", arrow::datatypes::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -295,7 +302,7 @@ pub fn build_join_schema( JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), JoinType::RightMark => { let left_field = once(( - Field::new("mark", arrow_schema::DataType::Boolean, false), + Field::new("mark", DataType::Boolean, false), ColumnIndex { index: 0, side: JoinSide::None, @@ -812,9 +819,10 @@ pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { pub(crate) fn get_final_indices_from_shared_bitmap( shared_bitmap: &SharedBitmapBuilder, join_type: JoinType, + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) + get_final_indices_from_bit_map(&bitmap, join_type, piecewise) } /// In the end of join execution, need to use bit map of the matched @@ -829,16 +837,22 @@ pub(crate) fn get_final_indices_from_shared_bitmap( pub(crate) fn get_final_indices_from_bit_map( left_bit_map: &BooleanBufferBuilder, join_type: JoinType, + // We add a flag for whether this is being passed from the `PiecewiseMergeJoin` + // because the bitmap can be for left + right `JoinType`s + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); - if join_type == JoinType::LeftMark { + if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise) + { let left_indices = (0..left_size as u64).collect::(); let right_indices = (0..left_size) .map(|idx| left_bit_map.get_bit(idx).then_some(0)) .collect::(); return (left_indices, right_indices); } - let left_indices = if join_type == JoinType::LeftSemi { + let left_indices = if join_type == JoinType::LeftSemi + || (join_type == JoinType::RightSemi && piecewise) + { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) .collect::() @@ -1749,6 +1763,99 @@ fn eq_dyn_null( } } +/// Get comparison result of two rows of join arrays +pub fn compare_join_arrays( + left_arrays: &[ArrayRef], + left: usize, + right_arrays: &[ArrayRef], + right: usize, + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let mut res = Ordering::Equal; + for ((left_array, right_array), sort_options) in + left_arrays.iter().zip(right_arrays).zip(sort_options) + { + macro_rules! compare_value { + ($T:ty) => {{ + let left_array = left_array.as_any().downcast_ref::<$T>().unwrap(); + let right_array = right_array.as_any().downcast_ref::<$T>().unwrap(); + match (left_array.is_null(left), right_array.is_null(right)) { + (false, false) => { + let left_value = &left_array.value(left); + let right_value = &right_array.value(right); + res = left_value.partial_cmp(right_value).unwrap(); + if sort_options.descending { + res = res.reverse(); + } + } + (true, false) => { + res = if sort_options.nulls_first { + Ordering::Less + } else { + Ordering::Greater + }; + } + (false, true) => { + res = if sort_options.nulls_first { + Ordering::Greater + } else { + Ordering::Less + }; + } + _ => { + res = match null_equality { + NullEquality::NullEqualsNothing => Ordering::Less, + NullEquality::NullEqualsNull => Ordering::Equal, + }; + } + } + }}; + } + + match left_array.data_type() { + DataType::Null => {} + DataType::Boolean => compare_value!(BooleanArray), + DataType::Int8 => compare_value!(Int8Array), + DataType::Int16 => compare_value!(Int16Array), + DataType::Int32 => compare_value!(Int32Array), + DataType::Int64 => compare_value!(Int64Array), + DataType::UInt8 => compare_value!(UInt8Array), + DataType::UInt16 => compare_value!(UInt16Array), + DataType::UInt32 => compare_value!(UInt32Array), + DataType::UInt64 => compare_value!(UInt64Array), + DataType::Float32 => compare_value!(Float32Array), + DataType::Float64 => compare_value!(Float64Array), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Utf8 => compare_value!(StringArray), + DataType::Utf8View => compare_value!(StringViewArray), + DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Timestamp(time_unit, None) => match time_unit { + TimeUnit::Second => compare_value!(TimestampSecondArray), + TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), + TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), + }, + DataType::Date32 => compare_value!(Date32Array), + DataType::Date64 => compare_value!(Date64Array), + dt => { + return not_impl_err!( + "Unsupported data type in sort merge join comparator: {}", + dt + ); + } + } + if !res.is_eq() { + break; + } + } + Ok(res) +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 412e36b8124f..b15ec026372d 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -291,6 +291,7 @@ datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_dynamic_filter_pushdown true datafusion.optimizer.enable_join_dynamic_filter_pushdown true +datafusion.optimizer.enable_piecewise_merge_join false datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true datafusion.optimizer.enable_topk_dynamic_filter_pushdown true @@ -410,6 +411,7 @@ datafusion.optimizer.default_filter_selectivity 20 The default filter selectivit datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_dynamic_filter_pushdown true When set to true attempts to push down dynamic filters generated by operators (topk & join) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown` & `enable_topk_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. datafusion.optimizer.enable_join_dynamic_filter_pushdown true When set to true, the optimizer will attempt to push down Join dynamic filters into the file scan phase. +datafusion.optimizer.enable_piecewise_merge_join false When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible datafusion.optimizer.enable_topk_dynamic_filter_pushdown true When set to true, the optimizer will attempt to push down TopK dynamic filters into the file scan phase. diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 9472395da641..0174321dd831 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4240,7 +4240,7 @@ physical_plan 03)----DataSourceExec: partitions=1, partition_sizes=[2] 04)----DataSourceExec: partitions=1, partition_sizes=[2] -## Test join.on.is_empty() && join.filter.is_some() +## Test join.on.is_empty() && join.filter.is_some() -> single filter now a PWMJ query TT EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; ---- @@ -5193,6 +5193,40 @@ SELECT c 8 9 +# PiecewiseMergeJoin Test +statement ok +set datafusion.optimizer.enable_piecewise_merge_join = true; + +query II +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 +---- +22 11 +33 11 +44 11 + +query TT +EXPLAIN +SELECT join_t1.t1_id, join_t2.t2_id +FROM join_t1 +INNER JOIN join_t2 ON join_t1.t1_id > join_t2.t2_id +WHERE join_t1.t1_id > 10 AND join_t2.t2_int > 1 +ORDER BY 1 +---- +physical_plan +01)SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +03)----SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------FilterExec: t1_id@0 > 10 +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=3 +08)------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +09)--------DataSourceExec: partitions=1, partition_sizes=[1] + statement ok DROP TABLE t1; @@ -5201,3 +5235,6 @@ DROP TABLE t2; statement ok set datafusion.explain.physical_plan_only = false; + +statement ok +set datafusion.optimizer.enable_piecewise_merge_join = false; diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt new file mode 100644 index 000000000000..0014b3c545f2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -0,0 +1,354 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +statement ok +set datafusion.optimizer.enable_piecewise_merge_join = true; + +statement ok +CREATE TABLE join_t1 (t1_id INT); + +statement ok +CREATE TABLE join_t2 (t2_id INT, t2_name TEXT, t2_int INT); + +statement ok +INSERT INTO join_t1 VALUES (11), (22), (33), (44); + +statement ok +INSERT INTO join_t2 VALUES + (11, 'z', 3), + (22, 'y', 1), + (44, 'x', 3), + (55, 'w', 3); + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +22 11 +33 11 +44 11 + +# Checking `SELECT *` +query IITI +SELECT * +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +22 11 z 3 +33 11 z 3 +44 11 z 3 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id +WHERE t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id > t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id > Int32(10) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int > Int32(1) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Gt, join_type=Inner, on=(t1_id > t2_id) +04)------SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 > 10 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id >= t2.t2_id +WHERE t1.t1_id >= 22 + AND t2.t2_int = 3 +ORDER BY 1,2; +---- +22 11 +33 11 +44 11 +44 44 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id >= t2.t2_id +WHERE t1.t1_id >= 22 + AND t2.t2_int = 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id >= t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id >= Int32(22) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int = Int32(3) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=GtEq, join_type=Inner, on=(t1_id >= t2_id) +04)------SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 >= 22 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_int@1 = 3, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < t2.t2_id +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +11 55 +11 44 +22 55 +22 44 +33 55 +33 44 +44 55 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < t2.t2_id +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id < t2.t2_id +03)----SubqueryAlias: t1 +04)------TableScan: join_t1 projection=[t1_id] +05)----SubqueryAlias: t2 +06)------Projection: join_t2.t2_id +07)--------Filter: join_t2.t2_int >= Int32(3) +08)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(t1_id < t2_id) +04)------SortExec: expr=[t1_id@0 DESC], preserve_partitioning=[false] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] +06)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)--------CoalesceBatchesExec: target_batch_size=8192 +08)----------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] +09)------------DataSourceExec: partitions=1, partition_sizes=[1] + + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < (t2.t2_id + 1) +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +11 11 +11 44 +11 55 +22 44 +22 55 +33 44 +33 55 +44 44 +44 55 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id < (t2.t2_id + 1) +WHERE t2.t2_int >= 3 +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: CAST(t1.t1_id AS Int64) < CAST(t2.t2_id AS Int64) + Int64(1) +03)----SubqueryAlias: t1 +04)------TableScan: join_t1 projection=[t1_id] +05)----SubqueryAlias: t2 +06)------Projection: join_t2.t2_id +07)--------Filter: join_t2.t2_int >= Int32(3) +08)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Lt, join_type=Inner, on=(CAST(t1_id AS Int64) < CAST(t2_id AS Int64) + 1) +04)------SortExec: expr=[CAST(t1_id@0 AS Int64) DESC], preserve_partitioning=[false] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] +06)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)--------CoalesceBatchesExec: target_batch_size=8192 +08)----------FilterExec: t2_int@1 >= 3, projection=[t2_id@0] +09)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id <= t2.t2_id +WHERE t1.t1_id IN (11, 44) + AND t2.t2_name <> 'y' +ORDER BY 1,2; +---- +11 55 +11 44 +11 11 +44 55 +44 44 + +query TT +EXPLAIN +SELECT t1.t1_id, t2.t2_id +FROM join_t1 t1 +JOIN join_t2 t2 + ON t1.t1_id <= t2.t2_id +WHERE t1.t1_id IN (11, 44) + AND t2.t2_name <> 'y' +ORDER BY 1,2; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST, t2.t2_id ASC NULLS LAST +02)--Inner Join: Filter: t1.t1_id <= t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id = Int32(11) OR join_t1.t1_id = Int32(44) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_name != Utf8View("y") +09)----------TableScan: join_t2 projection=[t2_id, t2_name] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST, t2_id@1 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=LtEq, join_type=Inner, on=(t1_id <= t2_id) +04)------SortExec: expr=[t1_id@0 DESC], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 = 11 OR t1_id@0 = 44 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_name@1 != y, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +CREATE TABLE null_join_t1 (id INT); + +statement ok +CREATE TABLE null_join_t2 (id INT); + +statement ok +INSERT INTO null_join_t1 VALUES (1), (2), (NULL); + +statement ok +INSERT INTO null_join_t2 VALUES (1), (NULL), (3); + +query II +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id > t2.id +ORDER BY 1,2; +---- +2 1 + +# Verify this will offload this query to Nested Loop Join +query II +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id < (t1.id + t2.id) +ORDER BY 1,2; +---- +1 1 +1 3 +2 1 +2 3 + +query TT +EXPLAIN +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id < (t1.id + t2.id) +ORDER BY 1,2; +---- +logical_plan +01)Sort: left_id ASC NULLS LAST, right_id ASC NULLS LAST +02)--Projection: t1.id AS left_id, t2.id AS right_id +03)----Inner Join: Filter: t1.id < t1.id + t2.id +04)------SubqueryAlias: t1 +05)--------TableScan: null_join_t1 projection=[id] +06)------SubqueryAlias: t2 +07)--------TableScan: null_join_t2 projection=[id] +physical_plan +01)SortExec: expr=[left_id@0 ASC NULLS LAST, right_id@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[id@0 as left_id, id@1 as right_id] +03)----NestedLoopJoinExec: join_type=Inner, filter=id@0 < id@0 + id@1 +04)------DataSourceExec: partitions=1, partition_sizes=[1] +05)------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.id AS left_id, t2.id AS right_id +FROM null_join_t1 t1 +JOIN null_join_t2 t2 + ON t1.id < t2.id +ORDER BY 1,2; +---- +1 3 +2 3 + +statement ok +set datafusion.optimizer.enable_piecewise_merge_join = false; diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index a302cecfa622..7ec1864b4667 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -148,6 +148,7 @@ The following configuration settings are available: | datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | | datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | | datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.enable_piecewise_merge_join | false | When set to true, piecewise merge join is enabled. PiecewiseMergeJoin is currently experimental. Physical planner will opt for PiecewiseMergeJoin when there is only one range filter. | | datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | | datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). |