Skip to content

Commit 176f1af

Browse files
committed
[test] unit test for join optim
1 parent 77a711a commit 176f1af

File tree

2 files changed

+180
-5
lines changed

2 files changed

+180
-5
lines changed

datafusion/src/physical_optimizer/hash_build_probe_order.rs

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ use arrow::datatypes::Schema;
2323
use crate::execution::context::ExecutionConfig;
2424
use crate::logical_plan::JoinType;
2525
use crate::physical_plan::cross_join::CrossJoinExec;
26+
use crate::physical_plan::expressions::Column;
2627
use crate::physical_plan::hash_join::HashJoinExec;
2728
use crate::physical_plan::projection::ProjectionExec;
28-
use crate::physical_plan::{expressions, ExecutionPlan, PhysicalExpr};
29+
use crate::physical_plan::{ExecutionPlan, PhysicalExpr};
2930

3031
use super::optimizer::PhysicalOptimizerRule;
3132
use super::utils::optimize_children;
@@ -84,15 +85,14 @@ fn swap_reverting_projection(
8485
) -> Vec<(Arc<dyn PhysicalExpr>, String)> {
8586
let right_cols = right_schema.fields().iter().enumerate().map(|(i, f)| {
8687
(
87-
Arc::new(expressions::Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
88+
Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>,
8889
f.name().to_owned(),
8990
)
9091
});
9192
let right_len = right_cols.len();
9293
let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| {
9394
(
94-
Arc::new(expressions::Column::new(f.name(), right_len + i))
95-
as Arc<dyn PhysicalExpr>,
95+
Arc::new(Column::new(f.name(), right_len + i)) as Arc<dyn PhysicalExpr>,
9696
f.name().to_owned(),
9797
)
9898
});
@@ -153,11 +153,107 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
153153

154154
#[cfg(test)]
155155
mod tests {
156+
use crate::{
157+
physical_plan::{hash_join::PartitionMode, Statistics},
158+
test::exec::StatisticsExec,
159+
};
160+
156161
use super::*;
157162
use std::sync::Arc;
158163

159164
use arrow::datatypes::{DataType, Field, Schema};
160165

166+
fn create_big_and_small() -> (Arc<dyn ExecutionPlan>, Arc<dyn ExecutionPlan>) {
167+
let big = Arc::new(StatisticsExec::new(
168+
Statistics {
169+
num_rows: Some(100000),
170+
..Default::default()
171+
},
172+
Schema::new(vec![Field::new("big_col", DataType::Int32, false)]),
173+
));
174+
175+
let small = Arc::new(StatisticsExec::new(
176+
Statistics {
177+
num_rows: Some(10),
178+
..Default::default()
179+
},
180+
Schema::new(vec![Field::new("small_col", DataType::Int32, false)]),
181+
));
182+
(big, small)
183+
}
184+
185+
#[tokio::test]
186+
async fn test_join_with_swap() {
187+
let (big, small) = create_big_and_small();
188+
189+
let join = HashJoinExec::try_new(
190+
Arc::clone(&big),
191+
Arc::clone(&small),
192+
vec![(
193+
Column::new_with_schema("big_col", &big.schema()).unwrap(),
194+
Column::new_with_schema("small_col", &small.schema()).unwrap(),
195+
)],
196+
&JoinType::Left,
197+
PartitionMode::CollectLeft,
198+
)
199+
.unwrap();
200+
201+
let optimized_join = HashBuildProbeOrder::new()
202+
.optimize(Arc::new(join), &ExecutionConfig::new())
203+
.unwrap();
204+
205+
let swapping_projection = optimized_join
206+
.as_any()
207+
.downcast_ref::<ProjectionExec>()
208+
.expect("A proj is required to swap columns back to their original order");
209+
210+
assert_eq!(swapping_projection.expr().len(), 2);
211+
let (col, name) = &swapping_projection.expr()[0];
212+
assert_eq!(name, "big_col");
213+
assert_col_expr(col, "big_col", 1);
214+
let (col, name) = &swapping_projection.expr()[1];
215+
assert_eq!(name, "small_col");
216+
assert_col_expr(col, "small_col", 0);
217+
218+
let swapped_join = swapping_projection
219+
.input()
220+
.as_any()
221+
.downcast_ref::<HashJoinExec>()
222+
.expect("The type of the plan should not be changed");
223+
224+
assert_eq!(swapped_join.left().statistics().num_rows, Some(10));
225+
assert_eq!(swapped_join.right().statistics().num_rows, Some(100000));
226+
}
227+
228+
#[tokio::test]
229+
async fn test_join_no_swap() {
230+
let (big, small) = create_big_and_small();
231+
232+
let join = HashJoinExec::try_new(
233+
Arc::clone(&small),
234+
Arc::clone(&big),
235+
vec![(
236+
Column::new_with_schema("small_col", &small.schema()).unwrap(),
237+
Column::new_with_schema("big_col", &big.schema()).unwrap(),
238+
)],
239+
&JoinType::Left,
240+
PartitionMode::CollectLeft,
241+
)
242+
.unwrap();
243+
244+
let optimized_join = HashBuildProbeOrder::new()
245+
.optimize(Arc::new(join), &ExecutionConfig::new())
246+
.unwrap();
247+
248+
let swapped_join = optimized_join
249+
.as_any()
250+
.downcast_ref::<HashJoinExec>()
251+
.expect("The type of the plan should not be changed");
252+
253+
assert_eq!(swapped_join.left().statistics().num_rows, Some(10));
254+
assert_eq!(swapped_join.right().statistics().num_rows, Some(100000));
255+
}
256+
161257
#[tokio::test]
162258
async fn test_swap_reverting_projection() {
163259
let left_schema = Schema::new(vec![
@@ -187,7 +283,7 @@ mod tests {
187283
fn assert_col_expr(expr: &Arc<dyn PhysicalExpr>, name: &str, index: usize) {
188284
let col = expr
189285
.as_any()
190-
.downcast_ref::<expressions::Column>()
286+
.downcast_ref::<Column>()
191287
.expect("Projection items should be Column expression");
192288
assert_eq!(col.name(), name);
193289
assert_eq!(col.index(), index);

datafusion/src/test/exec.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,82 @@ impl ExecutionPlan for ErrorExec {
393393
Statistics::default()
394394
}
395395
}
396+
397+
/// A mock execution plan that simply returns the provided statistics
398+
#[derive(Debug, Clone)]
399+
pub struct StatisticsExec {
400+
stats: Statistics,
401+
schema: Arc<Schema>,
402+
}
403+
impl StatisticsExec {
404+
pub fn new(stats: Statistics, schema: Schema) -> Self {
405+
assert!(
406+
stats
407+
.column_statistics
408+
.as_ref()
409+
.map(|cols| cols.len() == schema.fields().len())
410+
.unwrap_or(true),
411+
"if defined, the column statistics vector length should be the number of fields"
412+
);
413+
Self {
414+
stats,
415+
schema: Arc::new(schema),
416+
}
417+
}
418+
}
419+
#[async_trait]
420+
impl ExecutionPlan for StatisticsExec {
421+
fn as_any(&self) -> &dyn Any {
422+
self
423+
}
424+
425+
fn schema(&self) -> SchemaRef {
426+
Arc::clone(&self.schema)
427+
}
428+
429+
fn output_partitioning(&self) -> Partitioning {
430+
Partitioning::UnknownPartitioning(2)
431+
}
432+
433+
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
434+
vec![]
435+
}
436+
437+
fn with_new_children(
438+
&self,
439+
children: Vec<Arc<dyn ExecutionPlan>>,
440+
) -> Result<Arc<dyn ExecutionPlan>> {
441+
if children.is_empty() {
442+
Ok(Arc::new(self.clone()))
443+
} else {
444+
Err(DataFusionError::Internal(
445+
"Children cannot be replaced in CustomExecutionPlan".to_owned(),
446+
))
447+
}
448+
}
449+
450+
async fn execute(&self, _partition: usize) -> Result<SendableRecordBatchStream> {
451+
unimplemented!("This plan only serves for testing statistics")
452+
}
453+
454+
fn statistics(&self) -> Statistics {
455+
self.stats.clone()
456+
}
457+
458+
fn fmt_as(
459+
&self,
460+
t: DisplayFormatType,
461+
f: &mut std::fmt::Formatter,
462+
) -> std::fmt::Result {
463+
match t {
464+
DisplayFormatType::Default => {
465+
write!(
466+
f,
467+
"StatisticsExec: col_count={}, row_count={:?}",
468+
self.schema.fields().len(),
469+
self.stats.num_rows,
470+
)
471+
}
472+
}
473+
}
474+
}

0 commit comments

Comments
 (0)