From af4f5e70b0e46ab65eefe9f801550e1c28aa0b0a Mon Sep 17 00:00:00 2001 From: Jonathan Date: Thu, 30 Oct 2025 18:34:37 -0400 Subject: [PATCH 1/5] feat: Add Semi/Anti join to PiecewiseMergeJoin --- datafusion/core/src/physical_planner.rs | 11 +- .../physical-optimizer/src/join_selection.rs | 108 +- .../piecewise_merge_join/classic_join.rs | 49 +- .../src/joins/piecewise_merge_join/exec.rs | 130 +- .../piecewise_merge_join/existence_join.rs | 1061 +++++++++++++++++ .../src/joins/piecewise_merge_join/mod.rs | 1 + .../src/joins/piecewise_merge_join/utils.rs | 20 +- datafusion/sqllogictest/test_files/pwmj.slt | 216 +++- 8 files changed, 1434 insertions(+), 162 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c280b50a9f07..1995ebb9e878 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1256,15 +1256,7 @@ impl DefaultPhysicalPlanner { Arc::new(CrossJoinExec::new(physical_left, physical_right)) } else if num_range_filters == 1 && total_filters == 1 - && !matches!( - join_type, - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::RightMark - ) + && !matches!(join_type, JoinType::LeftMark | JoinType::RightMark) && session_state .config_options() .optimizer @@ -1366,7 +1358,6 @@ impl DefaultPhysicalPlanner { (on_left, on_right), op, *join_type, - session_state.config().target_partitions(), )?) } else { // there is no equal join condition, use the nested loop join diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 1db4d7b30565..b13c120c5acb 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -35,7 +35,7 @@ use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, - StreamJoinPartitionMode, SymmetricHashJoinExec, + PiecewiseMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use std::sync::Arc; @@ -256,59 +256,75 @@ fn statistical_join_selection_subrule( collect_threshold_byte_size: usize, collect_threshold_num_rows: usize, ) -> Result>> { - let transformed = - if let Some(hash_join) = plan.as_any().downcast_ref::() { - match hash_join.partition_mode() { - PartitionMode::Auto => try_collect_left( - hash_join, - false, - collect_threshold_byte_size, - collect_threshold_num_rows, - )? + let transformed = if let Some(hash_join) = + plan.as_any().downcast_ref::() + { + match hash_join.partition_mode() { + PartitionMode::Auto => try_collect_left( + hash_join, + false, + collect_threshold_byte_size, + collect_threshold_num_rows, + )? + .map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )?, + PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? .map_or_else( || partitioned_hash_join(hash_join).map(Some), |v| Ok(Some(v)), )?, - PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if hash_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - hash_join - .swap_inputs(PartitionMode::Partitioned) - .map(Some)? - } else { - None - } + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + hash_join + .swap_inputs(PartitionMode::Partitioned) + .map(Some)? + } else { + None } } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right)? { - cross_join.swap_inputs().map(Some)? - } else { - None - } - } else if let Some(nl_join) = plan.as_any().downcast_ref::() { - let left = nl_join.left(); - let right = nl_join.right(); - if nl_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - nl_join.swap_inputs().map(Some)? - } else { - None - } + } + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right)? { + cross_join.swap_inputs().map(Some)? } else { None - }; + } + } else if let Some(nl_join) = plan.as_any().downcast_ref::() { + let left = nl_join.left(); + let right = nl_join.right(); + if nl_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + nl_join.swap_inputs().map(Some)? + } else { + None + } + } else if let Some(pwmj) = plan.as_any().downcast_ref::() { + let left = pwmj.buffered(); + let right = pwmj.streamed(); + if pwmj.join_type().supports_swap() + // Put ! here because should_swap_join_order returns true if left > right but + // PiecewiseMergeJoin wants the left side to be the larger one, so only swap if + // left < right + && (!should_swap_join_order(&**left, &**right)? + || matches!(pwmj.join_type(), JoinType::RightSemi | JoinType::RightAnti)) + && !matches!(pwmj.join_type(), JoinType::LeftSemi | JoinType::LeftAnti) + { + pwmj.swap_inputs().map(Some)? + } else { + None + } + } else { + None + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 646905e0d787..7657951a869e 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -40,7 +40,7 @@ use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; -pub(super) enum PiecewiseMergeJoinStreamState { +pub(super) enum ClassicPWMJStreamState { WaitBufferedSide, FetchStreamBatch, ProcessStreamBatch(SortedStreamBatch), @@ -48,11 +48,11 @@ pub(super) enum PiecewiseMergeJoinStreamState { Completed, } -impl PiecewiseMergeJoinStreamState { +impl ClassicPWMJStreamState { // Grab mutable reference to the current stream batch fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> { match self { - PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), + ClassicPWMJStreamState::ProcessStreamBatch(state) => Ok(state), _ => internal_err!("Expected streamed batch in StreamBatch"), } } @@ -103,7 +103,7 @@ pub(super) struct ClassicPWMJStream { // Buffered side data buffered_side: BufferedSide, // Tracks the state of the `PiecewiseMergeJoin` - state: PiecewiseMergeJoinStreamState, + state: ClassicPWMJStreamState, // Sort option for streamed side (specifies whether // the sort is ascending or descending) sort_option: SortOptions, @@ -119,7 +119,7 @@ impl RecordBatchStream for ClassicPWMJStream { } } -// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, +// `ClassicPWMJStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, // `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`. // // Classic Joins @@ -140,7 +140,7 @@ impl ClassicPWMJStream { operator: Operator, streamed: SendableRecordBatchStream, buffered_side: BufferedSide, - state: PiecewiseMergeJoinStreamState, + state: ClassicPWMJStreamState, sort_option: SortOptions, join_metrics: BuildProbeJoinMetrics, batch_size: usize, @@ -166,19 +166,19 @@ impl ClassicPWMJStream { ) -> Poll>> { loop { return match self.state { - PiecewiseMergeJoinStreamState::WaitBufferedSide => { + ClassicPWMJStreamState::WaitBufferedSide => { handle_state!(ready!(self.collect_buffered_side(cx))) } - PiecewiseMergeJoinStreamState::FetchStreamBatch => { + ClassicPWMJStreamState::FetchStreamBatch => { handle_state!(ready!(self.fetch_stream_batch(cx))) } - PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { + ClassicPWMJStreamState::ProcessStreamBatch(_) => { handle_state!(self.process_stream_batch()) } - PiecewiseMergeJoinStreamState::ProcessUnmatched => { + ClassicPWMJStreamState::ProcessUnmatched => { handle_state!(self.process_unmatched_buffered_batch()) } - PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), + ClassicPWMJStreamState::Completed => Poll::Ready(None), }; } } @@ -197,7 +197,7 @@ impl ClassicPWMJStream { build_timer.done(); // We will start fetching stream batches for classic joins - self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + self.state = ClassicPWMJStreamState::FetchStreamBatch; self.buffered_side = BufferedSide::Ready(BufferedSideReadyState { buffered_data }); @@ -221,9 +221,9 @@ impl ClassicPWMJStream { == 1 { self.batch_process_state.reset(); - self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched; + self.state = ClassicPWMJStreamState::ProcessUnmatched; } else { - self.state = PiecewiseMergeJoinStreamState::Completed; + self.state = ClassicPWMJStreamState::Completed; } } Some(Ok(batch)) => { @@ -247,12 +247,11 @@ impl ClassicPWMJStream { // Reset BatchProcessState before processing a new stream batch self.batch_process_state.reset(); - self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch( - SortedStreamBatch { + self.state = + ClassicPWMJStreamState::ProcessStreamBatch(SortedStreamBatch { batch: stream_batch, compare_key_values: vec![stream_values], - }, - ); + }); } Some(Err(err)) => return Poll::Ready(Err(err)), }; @@ -297,13 +296,13 @@ impl ClassicPWMJStream { .output_batches .next_completed_batch() { - self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + self.state = ClassicPWMJStreamState::FetchStreamBatch; return Ok(StatefulStreamResult::Ready(Some(b))); } // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. if self.batch_process_state.output_batches.is_empty() { - self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + self.state = ClassicPWMJStreamState::FetchStreamBatch; return Ok(StatefulStreamResult::Ready(Some(batch))); } @@ -318,7 +317,7 @@ impl ClassicPWMJStream { ) -> Result>> { // Return early for `JoinType::Right` and `JoinType::Inner` if matches!(self.join_type, JoinType::Right | JoinType::Inner) { - self.state = PiecewiseMergeJoinStreamState::Completed; + self.state = ClassicPWMJStreamState::Completed; return Ok(StatefulStreamResult::Ready(None)); } @@ -339,7 +338,7 @@ impl ClassicPWMJStream { .output_batches .next_completed_batch() { - self.state = PiecewiseMergeJoinStreamState::Completed; + self.state = ClassicPWMJStreamState::Completed; return Ok(StatefulStreamResult::Ready(Some(batch))); } } @@ -387,11 +386,11 @@ impl ClassicPWMJStream { .output_batches .next_completed_batch() { - self.state = PiecewiseMergeJoinStreamState::Completed; + self.state = ClassicPWMJStreamState::Completed; return Ok(StatefulStreamResult::Ready(Some(batch))); } - self.state = PiecewiseMergeJoinStreamState::Completed; + self.state = ClassicPWMJStreamState::Completed; self.batch_process_state.reset(); Ok(StatefulStreamResult::Ready(None)) } @@ -743,7 +742,7 @@ mod tests { operator: Operator, join_type: JoinType, ) -> Result { - PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1) + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) } async fn join_collect( diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 987f3e9df45a..0aa2cbf20b7b 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -22,7 +22,7 @@ use arrow::{ util::bit_util, }; use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::not_impl_err; +use datafusion_common::ScalarValue; use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, @@ -44,12 +44,17 @@ use std::sync::Arc; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::joins::piecewise_merge_join::classic_join::{ - ClassicPWMJStream, PiecewiseMergeJoinStreamState, + ClassicPWMJStream, ClassicPWMJStreamState, +}; +use crate::joins::piecewise_merge_join::existence_join::{ + ExistencePWMJStream, ExistencePWMJStreamState, }; use crate::joins::piecewise_merge_join::utils::{ - build_visited_indices_map, is_existence_join, is_right_existence_join, + build_visited_indices_map, is_existence_join, +}; +use crate::joins::utils::{ + asymmetric_join_output_partitioning, reorder_output_after_swap, }; -use crate::joins::utils::asymmetric_join_output_partitioning; use crate::{ joins::{ utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, @@ -271,17 +276,20 @@ pub struct PiecewiseMergeJoinExec { /// /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations left_child_plan_required_order: LexOrdering, - /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations - /// Unsorted for mark joins - #[allow(unused)] - right_batch_required_orders: LexOrdering, /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. sort_options: SortOptions, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, - /// Number of partitions to process - num_partitions: usize, + + /// Both buffered and streamed partitions are tracked so if there is swapping, then the correct + /// number of partitions are tracked. + /// + /// Number of output partitions for buffered side + #[allow(unused)] + buffered_partitions: usize, + /// Number of output partitions for streamed side + streamed_partitions: usize, } impl PiecewiseMergeJoinExec { @@ -291,14 +299,9 @@ impl PiecewiseMergeJoinExec { on: (Arc, Arc), operator: Operator, join_type: JoinType, - num_partitions: usize, ) -> Result { - // TODO: Implement existence joins for PiecewiseMergeJoin - if is_existence_join(join_type) { - return not_impl_err!( - "Existence Joins are currently not supported for PiecewiseMergeJoin" - ); - } + let buffered_partitions = buffered.output_partitioning().partition_count(); + let streamed_partitions = streamed.output_partitioning().partition_count(); // Take the operator and enforce a sort order on the streamed + buffered side based on // the operator type. @@ -306,19 +309,9 @@ impl PiecewiseMergeJoinExec { Operator::Lt | Operator::LtEq => { // For left existence joins the inputs will be swapped so the sort // options are switched - if is_right_existence_join(join_type) { - SortOptions::new(false, true) - } else { - SortOptions::new(true, true) - } - } - Operator::Gt | Operator::GtEq => { - if is_right_existence_join(join_type) { - SortOptions::new(true, true) - } else { - SortOptions::new(false, true) - } + SortOptions::new(true, true) } + Operator::Gt | Operator::GtEq => SortOptions::new(false, true), _ => { return internal_err!( "Cannot contain non-range operator in PiecewiseMergeJoinExec" @@ -329,8 +322,6 @@ impl PiecewiseMergeJoinExec { // Give the same `sort_option for comparison later` let left_child_plan_required_order = vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; - let right_batch_required_orders = - vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; let Some(left_child_plan_required_order) = LexOrdering::new(left_child_plan_required_order) @@ -339,13 +330,6 @@ impl PiecewiseMergeJoinExec { "PiecewiseMergeJoinExec requires valid sort expressions for its left side" ); }; - let Some(right_batch_required_orders) = - LexOrdering::new(right_batch_required_orders) - else { - return internal_err!( - "PiecewiseMergeJoinExec requires valid sort expressions for its right side" - ); - }; let buffered_schema = buffered.schema(); let streamed_schema = streamed.schema(); @@ -371,10 +355,10 @@ impl PiecewiseMergeJoinExec { buffered_fut: Default::default(), metrics: ExecutionPlanMetricsSet::new(), left_child_plan_required_order, - right_batch_required_orders, sort_options, cache, - num_partitions, + buffered_partitions, + streamed_partitions, }) } @@ -462,9 +446,31 @@ impl PiecewiseMergeJoinExec { } } - // TODO pub fn swap_inputs(&self) -> Result> { - todo!() + let left = self.buffered(); + let right = self.streamed(); + + let new_join = PiecewiseMergeJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + (Arc::clone(&self.on.1), Arc::clone(&self.on.0)), + self.operator.swap().unwrap(), + self.join_type.swap(), + )?; + + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } } } @@ -493,18 +499,13 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } fn required_input_ordering(&self) -> Vec> { - // Existence joins don't need to be sorted on one side. - if is_right_existence_join(self.join_type) { - unimplemented!() - } else { - // Sort the right side in memory, so we do not need to enforce any sorting - vec![ - Some(OrderingRequirements::from( - self.left_child_plan_required_order.clone(), - )), - None, - ] - } + // Sort the right side in memory, so we do not need to enforce any sorting + vec![ + Some(OrderingRequirements::from( + self.left_child_plan_required_order.clone(), + )), + None, + ] } fn with_new_children( @@ -518,7 +519,6 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self.on.clone(), self.operator, self.join_type, - self.num_partitions, )?)), _ => internal_err!( "PiecewiseMergeJoin should have 2 children, found {}", @@ -547,17 +547,25 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { metrics.clone(), reservation, build_visited_indices_map(self.join_type), - self.num_partitions, + self.streamed_partitions, )) })?; let streamed = self.streamed.execute(partition, Arc::clone(&context))?; - let batch_size = context.session_config().batch_size(); - // TODO: Add existence joins + this is guarded at physical planner if is_existence_join(self.join_type()) { - unreachable!() + Ok(Box::pin(ExistencePWMJStream::try_new( + Arc::clone(&self.schema), + on_streamed, + self.join_type, + self.operator, + streamed, + BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), + ExistencePWMJStreamState::CollectBufferedSide, + metrics, + batch_size, + ))) } else { Ok(Box::pin(ClassicPWMJStream::try_new( Arc::clone(&self.schema), @@ -566,7 +574,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self.operator, streamed, BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), - PiecewiseMergeJoinStreamState::WaitBufferedSide, + ClassicPWMJStreamState::WaitBufferedSide, self.sort_options, metrics, batch_size, @@ -674,6 +682,7 @@ pub(super) struct BufferedSideData { values: ArrayRef, pub(super) visited_indices_bitmap: SharedBitmapBuilder, pub(super) remaining_partitions: AtomicUsize, + pub(super) min_max_value: Arc>>, _reservation: MemoryReservation, } @@ -690,6 +699,7 @@ impl BufferedSideData { values, visited_indices_bitmap, remaining_partitions: AtomicUsize::new(remaining_partitions), + min_max_value: Arc::new(Mutex::new(None)), _reservation: reservation, } } diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs new file mode 100644 index 000000000000..e891ee299ae5 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs @@ -0,0 +1,1061 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementation for PiecewiseMergeJoin's Existence Join (Left, Right, Full, Inner) + +use arrow::array::Array; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::compute::BatchCoalescer; +use arrow_schema::{Schema, SchemaRef, SortOptions}; +use datafusion_common::NullEquality; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion_expr::Accumulator; +use datafusion_expr::{JoinType, Operator}; +use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{Stream, StreamExt}; +use std::{cmp::Ordering, task::ready}; +use std::{sync::Arc, task::Poll}; + +use crate::handle_state; +use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; +use crate::joins::utils::compare_join_arrays; +use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; + +pub(super) enum ExistencePWMJStreamState { + CollectBufferedSide, + FetchAndProcessStreamBatch, + ProcessMatched, + Completed, +} + +pub(crate) struct ExistencePWMJStream { + // Output schema of the `PiecewiseMergeJoin` + pub schema: Arc, + + // Physical expression that is evaluated on the streamed side + // We do not need on_buffered as this is already evaluated when + // creating the buffered side which happens before initializing + // `PiecewiseMergeJoinStream` + pub on_streamed: PhysicalExprRef, + // Type of join + pub join_type: JoinType, + // Comparison operator + pub operator: Operator, + // Streamed batch + pub streamed: SendableRecordBatchStream, + // Buffered side data + buffered_side: BufferedSide, + // Tracks the state of the `PiecewiseMergeJoin` + state: ExistencePWMJStreamState, + // Metrics for build + probe joins + join_metrics: BuildProbeJoinMetrics, + // Tracking incremental state for emitting record batches + batch_process_state: BatchProcessState, +} + +impl RecordBatchStream for ExistencePWMJStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// `ExistencePWMJStreamState` is separated into `CollectBufferedSide`, `FetchAndProcessStreamBatch`, +// `ProcessMatched`, and `Completed`. +// +// Classic Joins +// 1. `CollectBufferedSide` - Load in the buffered side data into memory. +// 2. `FetchAndProcessStreamBatch` - Finds the min/max value in the stream batch and keep a running +// min/max value across all stream batches. +// 3. `ProcessMatched` - Compare the min/max value against the buffered side data to determine +// which rows match the join condition and produce output batches. +// 4. `Completed` - All data has been processed. +impl ExistencePWMJStream { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + schema: Arc, + on_streamed: PhysicalExprRef, + join_type: JoinType, + operator: Operator, + streamed: SendableRecordBatchStream, + buffered_side: BufferedSide, + state: ExistencePWMJStreamState, + join_metrics: BuildProbeJoinMetrics, + batch_size: usize, + ) -> Self { + Self { + schema: Arc::clone(&schema), + on_streamed, + join_type, + operator, + streamed, + buffered_side, + state, + join_metrics, + batch_process_state: BatchProcessState::new(schema, batch_size), + } + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + ExistencePWMJStreamState::CollectBufferedSide => { + handle_state!(ready!(self.collect_buffered_side(cx))) + } + ExistencePWMJStreamState::FetchAndProcessStreamBatch => { + handle_state!(ready!(self.fetch_and_process_stream_batch(cx))) + } + ExistencePWMJStreamState::ProcessMatched => { + handle_state!(self.process_matched_batch()) + } + ExistencePWMJStreamState::Completed => Poll::Ready(None), + }; + } + } + + // Collects buffered side data + fn collect_buffered_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); + let buffered_data = ready!(self + .buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx))?; + build_timer.done(); + + // Start fetching stream batches for classic joins + self.state = ExistencePWMJStreamState::FetchAndProcessStreamBatch; + + self.buffered_side = + BufferedSide::Ready(BufferedSideReadyState { buffered_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + fn fetch_and_process_stream_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + // Load RecordBatches from the streamed side + match ready!(self.streamed.poll_next_unpin(cx)) { + None => { + if self + .buffered_side + .try_as_ready_mut()? + .buffered_data + .remaining_partitions + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) + == 1 + { + self.state = ExistencePWMJStreamState::ProcessMatched; + } else { + self.state = ExistencePWMJStreamState::Completed; + } + } + Some(Ok(batch)) => { + // Evaluate the streamed physical expression on the stream batch + let stream_values: ArrayRef = self + .on_streamed + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + let buffered_data = + self.buffered_side.try_as_ready_mut()?.buffered_data.clone(); + + let min_max_val = + if matches!(self.operator, Operator::Lt | Operator::LtEq) { + let mut max_accumulator = + MaxAccumulator::try_new(stream_values.data_type())?; + max_accumulator.update_batch(&[stream_values])?; + max_accumulator.evaluate()? + } else { + let mut min_accumulator = + MinAccumulator::try_new(stream_values.data_type())?; + min_accumulator.update_batch(&[stream_values])?; + min_accumulator.evaluate()? + }; + + let mut min_max = buffered_data.min_max_value.lock(); + match &mut *min_max { + None => { + if !min_max_val.is_null() { + *min_max = Some(min_max_val) + } + } + Some(cur) => { + if !min_max_val.is_null() { + let matches = if matches!(self.operator, Operator::Lt) { + min_max_val > *cur + } else if matches!(self.operator, Operator::LtEq) { + min_max_val >= *cur + } else if matches!(self.operator, Operator::Gt) { + min_max_val < *cur + } else { + min_max_val <= *cur + }; + + if matches { + *cur = min_max_val; + } + } + } + } + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + fn process_matched_batch( + &mut self, + ) -> Result>> { + // If batches were already returned, keep processing the coalescer until empty + if self.batch_process_state.finished_processing_batches { + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } else { + self.state = ExistencePWMJStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + } + + let min_max = self + .buffered_side + .try_as_ready_mut()? + .buffered_data + .min_max_value + .as_ref() + .lock() + .clone(); + + // If no min/max value was found in the streamed side, then + // for anti joins, return the whole buffered side as output + if min_max.is_none() { + let buffered_values = + self.buffered_side.try_as_ready()?.buffered_data.values(); + + if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) { + let whole_buffer_batch = self + .buffered_side + .try_as_ready()? + .buffered_data + .batch() + .slice(0, buffered_values.len()); + self.state = ExistencePWMJStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(RecordBatch::try_new( + Arc::clone(&self.schema), + whole_buffer_batch.columns().to_vec(), + )?))); + } else { + self.state = ExistencePWMJStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(RecordBatch::new_empty( + Arc::clone(&self.schema), + )))); + } + } + + // Convert the min_max value to array for comparison + let min_max_array = min_max.unwrap().to_array()?; + let buffered_values = self.buffered_side.try_as_ready()?.buffered_data.values(); + + // Start the pointer after the null values + let mut buffer_idx = buffered_values.null_count(); + + while buffer_idx < buffered_values.len() { + let compare = { + compare_join_arrays( + &[Arc::clone(buffered_values)], + buffer_idx, + // Always index into the single value which is the min/max value + &[Arc::clone(&min_max_array)], + 0, + &[SortOptions::default()], + NullEquality::NullEqualsNothing, + )? + }; + + let matched = match self.operator { + Operator::Gt => { + matches!(compare, Ordering::Greater) + } + Operator::GtEq => { + matches!(compare, Ordering::Greater | Ordering::Equal) + } + Operator::Lt => { + matches!(compare, Ordering::Less) + } + Operator::LtEq => { + matches!(compare, Ordering::Less | Ordering::Equal) + } + _ => { + return internal_err!( + "PiecewiseMergeJoin should not contain operator, {}", + self.operator + ); + } + }; + + if matched { + break; + } + + buffer_idx += 1; + } + + // Determine the start index and length of the new buffered batch + // For anti joins, we include all rows before the matched index + let start_buffer_idx = + if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) { + 0 + } else { + buffer_idx + }; + + let buffer_length = + if matches!(self.join_type, JoinType::LeftAnti | JoinType::RightAnti) { + buffer_idx + } else { + buffered_values.len() - buffer_idx + }; + + let new_buffered_batch = self + .buffered_side + .try_as_ready()? + .buffered_data + .batch() + .slice(start_buffer_idx, buffer_length); + + let buffered_columns = new_buffered_batch.columns().to_vec(); + let batch = RecordBatch::try_new(Arc::clone(&self.schema), buffered_columns)?; + + if buffer_length == 0 { + self.state = ExistencePWMJStreamState::Completed; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.batch_process_state.output_batches.push_batch(batch)?; + + self.batch_process_state + .output_batches + .finish_buffered_batch()?; + + if let Some(batch) = self + .batch_process_state + .output_batches + .next_completed_batch() + { + self.batch_process_state.finished_processing_batches = true; + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + self.state = ExistencePWMJStreamState::Completed; + Ok(StatefulStreamResult::Continue) + } +} + +impl Stream for ExistencePWMJStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +struct BatchProcessState { + // Coalescer for output batches + output_batches: Box, + // Flag to indicate if we have finished processing batches + finished_processing_batches: bool, +} + +impl BatchProcessState { + pub fn new(join_schema: Arc, batch_size: usize) -> Self { + Self { + output_batches: Box::new(BatchCoalescer::new(join_schema, batch_size)), + finished_processing_batches: false, + } + } +} + +// Tests for Exitence Joins can only properly handle +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + common, + joins::PiecewiseMergeJoinExec, + test::{build_table_i32, TestMemoryExec}, + ExecutionPlan, + }; + use arrow::array::{Date32Array, Int32Array}; + use arrow_schema::{DataType, Field}; + use datafusion_common::test_util::batches_to_string; + use datafusion_execution::TaskContext; + use datafusion_expr::JoinType; + use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use insta::assert_snapshot; + use std::sync::Arc; + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn join( + left: Arc, + right: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options(left, right, on, operator, join_type).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = join(left, right, on, operator, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_left_mark_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 3 | 7 | + | 2 | 2 | 8 | + | 3 | 1 | 9 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // | 3 | 3 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 3 | 3 | 9 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // | 3 | 3 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 1 | 7 | + | 2 | 2 | 8 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_less_than_equal() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 2, 1]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 3 | 7 | + | 2 | 2 | 8 | + | 3 | 1 | 9 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_greater_than_equal() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 2 | 8 | + // | 3 | 3 | 9 | + // +----+----+----+ + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 1 | 7 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_date32_greater_than() -> Result<()> { + // +------------+------------+------------+ + // | a1 | b1 | c1 | + // +------------+------------+------------+ + // | 1970-01-11 | 1970-01-11 | 1970-01-01 | + // | 1970-01-21 | 1970-01-21 | 1970-01-01 | + // | 1970-01-31 | 1970-01-31 | 1970-01-01 | + // +------------+------------+------------+ + let left = build_date_table( + ("a1", &vec![10, 20, 30]), + ("b1", &vec![10, 20, 30]), + ("c1", &vec![0, 0, 0]), + ); + + // +------------+------------+------------+ + // | a2 | b1 | c2 | + // +------------+------------+------------+ + // | 1970-04-11 | 1970-01-16 | 1970-01-01 | + // | 1970-07-19 | 1970-01-26 | 1970-01-01 | + // | 1970-10-27 | 1970-02-05 | 1970-01-01 | + // +------------+------------+------------+ + let right = build_date_table( + ("a2", &vec![100, 200, 300]), + ("b1", &vec![15, 25, 35]), + ("c2", &vec![0, 0, 0]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+ + | a1 | b1 | c1 | + +------------+------------+------------+ + | 1970-01-21 | 1970-01-21 | 1970-01-01 | + | 1970-01-31 | 1970-01-31 | 1970-01-01 | + +------------+------------+------------+ + "#); + + Ok(()) + } + + // TESTING NULL CASES + fn build_table_i32_nullable( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())) as _, + Arc::new(Int32Array::from(b.1.clone())) as _, + Arc::new(Int32Array::from(c.1.clone())) as _, + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + #[tokio::test] + async fn join_left_semi_greater_than_with_nulls() -> Result<()> { + // +----+------+----+ + // | a1 | b1 | c1 | + // +----+------+----+ + // | 1 | NULL | 7 | + // | 2 | 1 | 8 | + // | 3 | 2 | 9 | + // | 4 | 3 | 10 | + // +----+------+----+ + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(3), Some(4)]), + ("b1", &vec![None, Some(1), Some(2), Some(3)]), + ("c1", &vec![Some(7), Some(8), Some(9), Some(10)]), + ); + + // +----+------+-----+ + // | a2 | b1 | c2 | + // +----+------+-----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | NULL | 90 | + // | 40 | 4 | 100 | + // +----+------+-----+ + let right = build_table_i32_nullable( + ("a2", &vec![Some(10), Some(20), Some(30), Some(40)]), + ("b1", &vec![Some(2), Some(3), None, Some(4)]), + ("c2", &vec![Some(70), Some(80), Some(90), Some(100)]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 4 | 3 | 10 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_greater_than_with_nulls() -> Result<()> { + // +----+------+----+ + // | a1 | b1 | c1 | + // +----+------+----+ + // | 1 | NULL | 7 | + // | 2 | 1 | 8 | + // | 3 | 2 | 9 | + // | 4 | 3 | 10 | + // +----+------+----+ + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(3), Some(4)]), + ("b1", &vec![None, Some(1), Some(2), Some(3)]), + ("c1", &vec![Some(7), Some(8), Some(9), Some(10)]), + ); + + // +----+------+-----+ + // | a2 | b1 | c2 | + // +----+------+-----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | NULL | 90 | + // | 40 | 4 | 100 | + // +----+------+-----+ + let right = build_table_i32_nullable( + ("a2", &vec![Some(10), Some(20), Some(30), Some(40)]), + ("b1", &vec![Some(2), Some(3), None, Some(4)]), + ("c2", &vec![Some(70), Some(80), Some(90), Some(100)]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | | 7 | + | 2 | 1 | 8 | + | 3 | 2 | 9 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_less_than_equal_with_nulls() -> Result<()> { + // +----+------+----+ + // | a1 | b1 | c1 | + // +----+------+----+ + // | 1 | NULL | 7 | + // | 2 | 3 | 8 | + // | 3 | 2 | 9 | + // | 4 | 1 | 10 | + // +----+------+----+ + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2), Some(3), Some(4)]), + ("b1", &vec![None, Some(3), Some(2), Some(1)]), + ("c1", &vec![Some(7), Some(8), Some(9), Some(10)]), + ); + + // +----+------+----+ + // | a2 | b1 | c2 | + // +----+------+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | NULL | 90 | + // +----+------+----+ + let right = build_table_i32_nullable( + ("a2", &vec![Some(10), Some(20), Some(30)]), + ("b1", &vec![Some(2), Some(3), None]), + ("c2", &vec![Some(70), Some(80), Some(90)]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 2 | 3 | 8 | + | 3 | 2 | 9 | + | 4 | 1 | 10 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_with_all_right_nulls() -> Result<()> { + // +----+------+----+ + // | a1 | b1 | c1 | + // +----+------+----+ + // | 1 | 10 | 7 | + // | 2 | NULL | 8 | + // +----+------+----+ + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2)]), + ("b1", &vec![Some(10), None]), + ("c1", &vec![Some(7), Some(8)]), + ); + + // +----+------+----+ + // | a2 | b1 | c2 | + // +----+------+----+ + // | 10 | NULL | 70 | + // | 20 | NULL | 80 | + // +----+------+----+ + let right = build_table_i32_nullable( + ("a2", &vec![Some(10), Some(20)]), + ("b1", &vec![None, None]), + ("c2", &vec![Some(70), Some(80)]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 10 | 7 | + | 2 | | 8 | + +----+----+----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_with_all_right_nulls() -> Result<()> { + // +----+------+----+ + // | a1 | b1 | c1 | + // +----+------+----+ + // | 1 | 10 | 7 | + // | 2 | NULL | 8 | + // +----+------+----+ + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(2)]), + ("b1", &vec![Some(10), None]), + ("c1", &vec![Some(7), Some(8)]), + ); + + // +----+------+----+ + // | a2 | b1 | c2 | + // +----+------+----+ + // | 10 | NULL | 70 | + // | 20 | NULL | 80 | + // +----+------+----+ + let right = build_table_i32_nullable( + ("a2", &vec![Some(10), Some(20)]), + ("b1", &vec![None, None]), + ("c2", &vec![Some(70), Some(80)]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 10 | 7 | + | 2 | | 8 | + +----+----+----+ + "#); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs index c85a7cc16f65..8c6815ad6c63 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/mod.rs @@ -21,4 +21,5 @@ pub use exec::PiecewiseMergeJoinExec; mod classic_join; mod exec; +mod existence_join; mod utils; diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs index 5bbb496322b5..3c70274441d6 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/utils.rs @@ -17,14 +17,6 @@ use datafusion_expr::JoinType; -// Returns boolean for whether the join is a right existence join -pub(super) fn is_right_existence_join(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark - ) -} - // Returns boolean for whether the join is an existence join pub(super) fn is_existence_join(join_type: JoinType) -> bool { matches!( @@ -47,15 +39,5 @@ pub(super) fn need_produce_result_in_final(join_type: JoinType) -> bool { // Returns boolean for whether or not we need to build the buffered side // bitmap for marking matched rows on the buffered side. pub(super) fn build_visited_indices_map(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::Full - | JoinType::Left - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftMark - | JoinType::RightMark - ) + matches!(join_type, JoinType::Full | JoinType::Left) } diff --git a/datafusion/sqllogictest/test_files/pwmj.slt b/datafusion/sqllogictest/test_files/pwmj.slt index eafa4d0ba394..f07ab1292e54 100644 --- a/datafusion/sqllogictest/test_files/pwmj.slt +++ b/datafusion/sqllogictest/test_files/pwmj.slt @@ -58,9 +58,9 @@ WHERE t1.t1_id > 10 AND t2.t2_int > 1 ORDER BY 1; ---- -22 11 z 3 -33 11 z 3 44 11 z 3 +33 11 z 3 +22 11 z 3 query TT EXPLAIN @@ -350,5 +350,217 @@ ORDER BY 1,2; 1 3 2 3 +query I +SELECT t1.t1_id +FROM join_t1 t1 +LEFT SEMI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t2.t2_int > 1 +WHERE t1.t1_id > 10 +ORDER BY 1; +---- +22 +33 +44 + +query TT +EXPLAIN +SELECT t1.t1_id +FROM join_t1 t1 +LEFT SEMI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t2.t2_int > 1 +WHERE t1.t1_id > 10 +ORDER BY 1; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST +02)--LeftSemi Join: Filter: t1.t1_id > t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id > Int32(10) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int > Int32(1) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Gt, join_type=LeftSemi, on=(t1_id > t2_id) +04)------SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 > 10 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t1.t1_id +FROM join_t1 t1 +LEFT ANTI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t2.t2_int > 1 +WHERE t1.t1_id > 10 +ORDER BY 1; +---- +11 + +query TT +EXPLAIN +SELECT t1.t1_id +FROM join_t1 t1 +LEFT ANTI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t2.t2_int > 1 +WHERE t1.t1_id > 10 +ORDER BY 1; +---- +logical_plan +01)Sort: t1.t1_id ASC NULLS LAST +02)--LeftAnti Join: Filter: t1.t1_id > t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id > Int32(10) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int > Int32(1) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Gt, join_type=LeftAnti, on=(t1_id > t2_id) +04)------SortExec: expr=[t1_id@0 ASC], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t1_id@0 > 10 +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t2.t2_id +FROM join_t1 t1 +RIGHT SEMI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +11 + +query TT +EXPLAIN +SELECT t2.t2_id +FROM join_t1 t1 +RIGHT SEMI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +logical_plan +01)Sort: t2.t2_id ASC NULLS LAST +02)--RightSemi Join: Filter: t1.t1_id > t2.t2_id +03)----SubqueryAlias: t1 +04)------Filter: join_t1.t1_id > Int32(10) +05)--------TableScan: join_t1 projection=[t1_id] +06)----SubqueryAlias: t2 +07)------Projection: join_t2.t2_id +08)--------Filter: join_t2.t2_int > Int32(1) +09)----------TableScan: join_t2 projection=[t2_id, t2_int] +physical_plan +01)SortPreservingMergeExec: [t2_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t2_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----PiecewiseMergeJoin: operator=Lt, join_type=LeftSemi, on=(t2_id < t1_id) +04)------SortExec: expr=[t2_id@0 DESC], preserve_partitioning=[false] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] +08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------CoalesceBatchesExec: target_batch_size=8192 +10)----------FilterExec: t1_id@0 > 10 +11)------------DataSourceExec: partitions=1, partition_sizes=[1] + +query I +SELECT t2.t2_id +FROM join_t1 t1 +RIGHT ANTI JOIN join_t2 t2 + ON t1.t1_id > t2.t2_id + AND t1.t1_id > 10 + AND t2.t2_int > 1 +ORDER BY 1; +---- +22 +44 +55 + +query I +SELECT t1.id +FROM null_join_t1 t1 +LEFT SEMI JOIN null_join_t2 t2 + ON t1.id < t2.id +ORDER BY 1; +---- +1 +2 + +query I +SELECT t1.id +FROM null_join_t1 t1 +LEFT ANTI JOIN null_join_t2 t2 + ON t1.id < t2.id +ORDER BY 1; +---- +NULL + +query I +SELECT t2.id +FROM null_join_t1 t1 +RIGHT SEMI JOIN null_join_t2 t2 + ON t1.id < t2.id +WHERE t2.id IS NOT NULL +ORDER BY 1; +---- +3 + +query I +SELECT t2.id +FROM null_join_t1 t1 +RIGHT ANTI JOIN null_join_t2 t2 + ON t1.id < t2.id +WHERE t2.id IS NOT NULL +ORDER BY 1; +---- +1 + +query TT +EXPLAIN +SELECT t2.id +FROM null_join_t1 t1 +RIGHT ANTI JOIN null_join_t2 t2 + ON t1.id < t2.id +WHERE t2.id IS NOT NULL +ORDER BY 1; +---- +logical_plan +01)Sort: t2.id ASC NULLS LAST +02)--RightAnti Join: Filter: t1.id < t2.id +03)----SubqueryAlias: t1 +04)------TableScan: null_join_t1 projection=[id] +05)----SubqueryAlias: t2 +06)------Filter: null_join_t2.id IS NOT NULL +07)--------TableScan: null_join_t2 projection=[id] +physical_plan +01)SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--PiecewiseMergeJoin: operator=Gt, join_type=LeftAnti, on=(id > id) +03)----SortExec: expr=[id@0 ASC], preserve_partitioning=[false] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------FilterExec: id@0 IS NOT NULL +06)----------DataSourceExec: partitions=1, partition_sizes=[1] +07)----DataSourceExec: partitions=1, partition_sizes=[1] + statement ok set datafusion.optimizer.enable_piecewise_merge_join = false; From 6861080106d266252416b5db65229b934379389a Mon Sep 17 00:00:00 2001 From: Jonathan Date: Thu, 30 Oct 2025 18:47:20 -0400 Subject: [PATCH 2/5] Clarify swapping --- .../physical-plan/src/joins/piecewise_merge_join/exec.rs | 4 ++++ .../src/joins/piecewise_merge_join/existence_join.rs | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index 0aa2cbf20b7b..d538c777dadb 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -446,6 +446,10 @@ impl PiecewiseMergeJoinExec { } } + // Inner, Outer, Left, Right joins are swapped so that left side is larger than right + // Left Semi/Anti joins are not swapped while Right Semi/Anti joins are always swapped. This + // is so that the algorithm of finding the streamed side min/max value, then probing the buffered + // side works properly. pub fn swap_inputs(&self) -> Result> { let left = self.buffered(); let right = self.streamed(); diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs index e891ee299ae5..7f6faa1fd2d0 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs @@ -409,7 +409,8 @@ impl BatchProcessState { } } -// Tests for Exitence Joins can only properly handle +// Tests for Existence Joins can only properly handle Left Semi/Anti joins because +// Right Semi/Anti are swapped #[cfg(test)] mod tests { use super::*; From 8bd2a1bb0f9f46e19a736945ed2900ba6b42a526 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Thu, 30 Oct 2025 18:53:30 -0400 Subject: [PATCH 3/5] clippy --- .../src/joins/piecewise_merge_join/existence_join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs index 7f6faa1fd2d0..55f2f1c0ebb4 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs @@ -184,7 +184,7 @@ impl ExistencePWMJStream { self.join_metrics.input_rows.add(batch.num_rows()); let buffered_data = - self.buffered_side.try_as_ready_mut()?.buffered_data.clone(); + Arc::clone(&self.buffered_side.try_as_ready_mut()?.buffered_data); let min_max_val = if matches!(self.operator, Operator::Lt | Operator::LtEq) { @@ -409,7 +409,7 @@ impl BatchProcessState { } } -// Tests for Existence Joins can only properly handle Left Semi/Anti joins because +// Tests for Existence Joins can only properly handle Left Semi/Anti joins because // Right Semi/Anti are swapped #[cfg(test)] mod tests { From 9055043443aa416a515199d224b4b1af2f2e44b2 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Thu, 30 Oct 2025 20:29:51 -0400 Subject: [PATCH 4/5] Update datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs --- .../physical-plan/src/joins/piecewise_merge_join/exec.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index d538c777dadb..8ecdb3597238 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -503,7 +503,8 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } fn required_input_ordering(&self) -> Vec> { - // Sort the right side in memory, so we do not need to enforce any sorting + // The left side is sorted for classic and existence joins. + // For classic joins, the right side is sorted in memory so there is no need to sort vec![ Some(OrderingRequirements::from( self.left_child_plan_required_order.clone(), From 89a8b2b5f5f426e32067d9d5eba5a21a37bcccd7 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Thu, 30 Oct 2025 20:50:11 -0400 Subject: [PATCH 5/5] fix test --- .../src/joins/piecewise_merge_join/existence_join.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs index 55f2f1c0ebb4..ffa3ad8868d4 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/existence_join.rs @@ -1052,8 +1052,6 @@ mod tests { +----+----+----+ | a1 | b1 | c1 | +----+----+----+ - | 1 | 10 | 7 | - | 2 | | 8 | +----+----+----+ "#);