Skip to content

Commit cc42894

Browse files
haohuaijinalamb
andauthored
fix: struct field don't push down to TableScan (#8774)
* fix: struct don't push down to TableScan * add similar to test and apply comment * remove catch all in outer_columns_helper * minor * fix clippy --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent ff27d90 commit cc42894

File tree

1 file changed

+253
-62
lines changed

1 file changed

+253
-62
lines changed

datafusion/optimizer/src/optimize_projections.rs

Lines changed: 253 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,11 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result<Option<Expr>> {
583583
///
584584
/// # Returns
585585
///
586-
/// If the function can safely infer all outer-referenced columns, returns a
587-
/// `Some(HashSet<Column>)` containing these columns. Otherwise, returns `None`.
588-
fn outer_columns(expr: &Expr) -> Option<HashSet<Column>> {
586+
/// returns a `HashSet<Column>` containing all outer-referenced columns.
587+
fn outer_columns(expr: &Expr) -> HashSet<Column> {
589588
let mut columns = HashSet::new();
590-
outer_columns_helper(expr, &mut columns).then_some(columns)
589+
outer_columns_helper(expr, &mut columns);
590+
columns
591591
}
592592

593593
/// A recursive subroutine that accumulates outer-referenced columns by the
@@ -598,87 +598,104 @@ fn outer_columns(expr: &Expr) -> Option<HashSet<Column>> {
598598
/// * `expr` - The expression to analyze for outer-referenced columns.
599599
/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
600600
/// columns are collected.
601-
///
602-
/// Returns `true` if it can safely collect all outer-referenced columns.
603-
/// Otherwise, returns `false`.
604-
fn outer_columns_helper(expr: &Expr, columns: &mut HashSet<Column>) -> bool {
601+
fn outer_columns_helper(expr: &Expr, columns: &mut HashSet<Column>) {
605602
match expr {
606603
Expr::OuterReferenceColumn(_, col) => {
607604
columns.insert(col.clone());
608-
true
609605
}
610606
Expr::BinaryExpr(binary_expr) => {
611-
outer_columns_helper(&binary_expr.left, columns)
612-
&& outer_columns_helper(&binary_expr.right, columns)
607+
outer_columns_helper(&binary_expr.left, columns);
608+
outer_columns_helper(&binary_expr.right, columns);
613609
}
614610
Expr::ScalarSubquery(subquery) => {
615611
let exprs = subquery.outer_ref_columns.iter();
616-
outer_columns_helper_multi(exprs, columns)
612+
outer_columns_helper_multi(exprs, columns);
617613
}
618614
Expr::Exists(exists) => {
619615
let exprs = exists.subquery.outer_ref_columns.iter();
620-
outer_columns_helper_multi(exprs, columns)
616+
outer_columns_helper_multi(exprs, columns);
621617
}
622618
Expr::Alias(alias) => outer_columns_helper(&alias.expr, columns),
623619
Expr::InSubquery(insubquery) => {
624620
let exprs = insubquery.subquery.outer_ref_columns.iter();
625-
outer_columns_helper_multi(exprs, columns)
621+
outer_columns_helper_multi(exprs, columns);
626622
}
627-
Expr::IsNotNull(expr) | Expr::IsNull(expr) => outer_columns_helper(expr, columns),
628623
Expr::Cast(cast) => outer_columns_helper(&cast.expr, columns),
629624
Expr::Sort(sort) => outer_columns_helper(&sort.expr, columns),
630625
Expr::AggregateFunction(aggregate_fn) => {
631-
outer_columns_helper_multi(aggregate_fn.args.iter(), columns)
632-
&& aggregate_fn
633-
.order_by
634-
.as_ref()
635-
.map_or(true, |obs| outer_columns_helper_multi(obs.iter(), columns))
636-
&& aggregate_fn
637-
.filter
638-
.as_ref()
639-
.map_or(true, |filter| outer_columns_helper(filter, columns))
626+
outer_columns_helper_multi(aggregate_fn.args.iter(), columns);
627+
if let Some(filter) = aggregate_fn.filter.as_ref() {
628+
outer_columns_helper(filter, columns);
629+
}
630+
if let Some(obs) = aggregate_fn.order_by.as_ref() {
631+
outer_columns_helper_multi(obs.iter(), columns);
632+
}
640633
}
641634
Expr::WindowFunction(window_fn) => {
642-
outer_columns_helper_multi(window_fn.args.iter(), columns)
643-
&& outer_columns_helper_multi(window_fn.order_by.iter(), columns)
644-
&& outer_columns_helper_multi(window_fn.partition_by.iter(), columns)
635+
outer_columns_helper_multi(window_fn.args.iter(), columns);
636+
outer_columns_helper_multi(window_fn.order_by.iter(), columns);
637+
outer_columns_helper_multi(window_fn.partition_by.iter(), columns);
645638
}
646639
Expr::GroupingSet(groupingset) => match groupingset {
647-
GroupingSet::GroupingSets(multi_exprs) => multi_exprs
648-
.iter()
649-
.all(|e| outer_columns_helper_multi(e.iter(), columns)),
640+
GroupingSet::GroupingSets(multi_exprs) => {
641+
multi_exprs
642+
.iter()
643+
.for_each(|e| outer_columns_helper_multi(e.iter(), columns));
644+
}
650645
GroupingSet::Cube(exprs) | GroupingSet::Rollup(exprs) => {
651-
outer_columns_helper_multi(exprs.iter(), columns)
646+
outer_columns_helper_multi(exprs.iter(), columns);
652647
}
653648
},
654649
Expr::ScalarFunction(scalar_fn) => {
655-
outer_columns_helper_multi(scalar_fn.args.iter(), columns)
650+
outer_columns_helper_multi(scalar_fn.args.iter(), columns);
656651
}
657652
Expr::Like(like) => {
658-
outer_columns_helper(&like.expr, columns)
659-
&& outer_columns_helper(&like.pattern, columns)
653+
outer_columns_helper(&like.expr, columns);
654+
outer_columns_helper(&like.pattern, columns);
660655
}
661656
Expr::InList(in_list) => {
662-
outer_columns_helper(&in_list.expr, columns)
663-
&& outer_columns_helper_multi(in_list.list.iter(), columns)
657+
outer_columns_helper(&in_list.expr, columns);
658+
outer_columns_helper_multi(in_list.list.iter(), columns);
664659
}
665660
Expr::Case(case) => {
666661
let when_then_exprs = case
667662
.when_then_expr
668663
.iter()
669664
.flat_map(|(first, second)| [first.as_ref(), second.as_ref()]);
670-
outer_columns_helper_multi(when_then_exprs, columns)
671-
&& case
672-
.expr
673-
.as_ref()
674-
.map_or(true, |expr| outer_columns_helper(expr, columns))
675-
&& case
676-
.else_expr
677-
.as_ref()
678-
.map_or(true, |expr| outer_columns_helper(expr, columns))
665+
outer_columns_helper_multi(when_then_exprs, columns);
666+
if let Some(expr) = case.expr.as_ref() {
667+
outer_columns_helper(expr, columns);
668+
}
669+
if let Some(expr) = case.else_expr.as_ref() {
670+
outer_columns_helper(expr, columns);
671+
}
672+
}
673+
Expr::SimilarTo(similar_to) => {
674+
outer_columns_helper(&similar_to.expr, columns);
675+
outer_columns_helper(&similar_to.pattern, columns);
676+
}
677+
Expr::TryCast(try_cast) => outer_columns_helper(&try_cast.expr, columns),
678+
Expr::GetIndexedField(index) => outer_columns_helper(&index.expr, columns),
679+
Expr::Between(between) => {
680+
outer_columns_helper(&between.expr, columns);
681+
outer_columns_helper(&between.low, columns);
682+
outer_columns_helper(&between.high, columns);
679683
}
680-
Expr::Column(_) | Expr::Literal(_) | Expr::Wildcard { .. } => true,
681-
_ => false,
684+
Expr::Not(expr)
685+
| Expr::IsNotFalse(expr)
686+
| Expr::IsFalse(expr)
687+
| Expr::IsTrue(expr)
688+
| Expr::IsNotTrue(expr)
689+
| Expr::IsUnknown(expr)
690+
| Expr::IsNotUnknown(expr)
691+
| Expr::IsNotNull(expr)
692+
| Expr::IsNull(expr)
693+
| Expr::Negative(expr) => outer_columns_helper(expr, columns),
694+
Expr::Column(_)
695+
| Expr::Literal(_)
696+
| Expr::Wildcard { .. }
697+
| Expr::ScalarVariable { .. }
698+
| Expr::Placeholder(_) => (),
682699
}
683700
}
684701

@@ -690,14 +707,11 @@ fn outer_columns_helper(expr: &Expr, columns: &mut HashSet<Column>) -> bool {
690707
/// * `exprs` - The expressions to analyze for outer-referenced columns.
691708
/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
692709
/// columns are collected.
693-
///
694-
/// Returns `true` if it can safely collect all outer-referenced columns.
695-
/// Otherwise, returns `false`.
696710
fn outer_columns_helper_multi<'a>(
697-
mut exprs: impl Iterator<Item = &'a Expr>,
711+
exprs: impl Iterator<Item = &'a Expr>,
698712
columns: &mut HashSet<Column>,
699-
) -> bool {
700-
exprs.all(|e| outer_columns_helper(e, columns))
713+
) {
714+
exprs.for_each(|e| outer_columns_helper(e, columns));
701715
}
702716

703717
/// Generates the required expressions (columns) that reside at `indices` of
@@ -766,13 +780,7 @@ fn indices_referred_by_expr(
766780
) -> Result<Vec<usize>> {
767781
let mut cols = expr.to_columns()?;
768782
// Get outer-referenced columns:
769-
if let Some(outer_cols) = outer_columns(expr) {
770-
cols.extend(outer_cols);
771-
} else {
772-
// Expression is not known to contain outer columns or not. Hence, do
773-
// not assume anything and require all the schema indices at the input:
774-
return Ok((0..input_schema.fields().len()).collect());
775-
}
783+
cols.extend(outer_columns(expr));
776784
Ok(cols
777785
.iter()
778786
.flat_map(|col| input_schema.index_of_column(col))
@@ -978,8 +986,8 @@ mod tests {
978986
use arrow::datatypes::{DataType, Field, Schema};
979987
use datafusion_common::{Result, TableReference};
980988
use datafusion_expr::{
981-
binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder,
982-
table_scan, Expr, LogicalPlan, Operator,
989+
binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not,
990+
table_scan, try_cast, Expr, Like, LogicalPlan, Operator,
983991
};
984992

985993
fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
@@ -1060,4 +1068,187 @@ mod tests {
10601068
\n TableScan: ?table? projection=[]";
10611069
assert_optimized_plan_equal(&plan, expected)
10621070
}
1071+
1072+
#[test]
1073+
fn test_struct_field_push_down() -> Result<()> {
1074+
let schema = Arc::new(Schema::new(vec![
1075+
Field::new("a", DataType::Int64, false),
1076+
Field::new_struct(
1077+
"s",
1078+
vec![
1079+
Field::new("x", DataType::Int64, false),
1080+
Field::new("y", DataType::Int64, false),
1081+
],
1082+
false,
1083+
),
1084+
]));
1085+
1086+
let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?;
1087+
let plan = LogicalPlanBuilder::from(table_scan)
1088+
.project(vec![col("s").field("x")])?
1089+
.build()?;
1090+
let expected = "Projection: (?table?.s)[x]\
1091+
\n TableScan: ?table? projection=[s]";
1092+
assert_optimized_plan_equal(&plan, expected)
1093+
}
1094+
1095+
#[test]
1096+
fn test_neg_push_down() -> Result<()> {
1097+
let table_scan = test_table_scan()?;
1098+
let plan = LogicalPlanBuilder::from(table_scan)
1099+
.project(vec![-col("a")])?
1100+
.build()?;
1101+
1102+
let expected = "Projection: (- test.a)\
1103+
\n TableScan: test projection=[a]";
1104+
assert_optimized_plan_equal(&plan, expected)
1105+
}
1106+
1107+
#[test]
1108+
fn test_is_null() -> Result<()> {
1109+
let table_scan = test_table_scan()?;
1110+
let plan = LogicalPlanBuilder::from(table_scan)
1111+
.project(vec![col("a").is_null()])?
1112+
.build()?;
1113+
1114+
let expected = "Projection: test.a IS NULL\
1115+
\n TableScan: test projection=[a]";
1116+
assert_optimized_plan_equal(&plan, expected)
1117+
}
1118+
1119+
#[test]
1120+
fn test_is_not_null() -> Result<()> {
1121+
let table_scan = test_table_scan()?;
1122+
let plan = LogicalPlanBuilder::from(table_scan)
1123+
.project(vec![col("a").is_not_null()])?
1124+
.build()?;
1125+
1126+
let expected = "Projection: test.a IS NOT NULL\
1127+
\n TableScan: test projection=[a]";
1128+
assert_optimized_plan_equal(&plan, expected)
1129+
}
1130+
1131+
#[test]
1132+
fn test_is_true() -> Result<()> {
1133+
let table_scan = test_table_scan()?;
1134+
let plan = LogicalPlanBuilder::from(table_scan)
1135+
.project(vec![col("a").is_true()])?
1136+
.build()?;
1137+
1138+
let expected = "Projection: test.a IS TRUE\
1139+
\n TableScan: test projection=[a]";
1140+
assert_optimized_plan_equal(&plan, expected)
1141+
}
1142+
1143+
#[test]
1144+
fn test_is_not_true() -> Result<()> {
1145+
let table_scan = test_table_scan()?;
1146+
let plan = LogicalPlanBuilder::from(table_scan)
1147+
.project(vec![col("a").is_not_true()])?
1148+
.build()?;
1149+
1150+
let expected = "Projection: test.a IS NOT TRUE\
1151+
\n TableScan: test projection=[a]";
1152+
assert_optimized_plan_equal(&plan, expected)
1153+
}
1154+
1155+
#[test]
1156+
fn test_is_false() -> Result<()> {
1157+
let table_scan = test_table_scan()?;
1158+
let plan = LogicalPlanBuilder::from(table_scan)
1159+
.project(vec![col("a").is_false()])?
1160+
.build()?;
1161+
1162+
let expected = "Projection: test.a IS FALSE\
1163+
\n TableScan: test projection=[a]";
1164+
assert_optimized_plan_equal(&plan, expected)
1165+
}
1166+
1167+
#[test]
1168+
fn test_is_not_false() -> Result<()> {
1169+
let table_scan = test_table_scan()?;
1170+
let plan = LogicalPlanBuilder::from(table_scan)
1171+
.project(vec![col("a").is_not_false()])?
1172+
.build()?;
1173+
1174+
let expected = "Projection: test.a IS NOT FALSE\
1175+
\n TableScan: test projection=[a]";
1176+
assert_optimized_plan_equal(&plan, expected)
1177+
}
1178+
1179+
#[test]
1180+
fn test_is_unknown() -> Result<()> {
1181+
let table_scan = test_table_scan()?;
1182+
let plan = LogicalPlanBuilder::from(table_scan)
1183+
.project(vec![col("a").is_unknown()])?
1184+
.build()?;
1185+
1186+
let expected = "Projection: test.a IS UNKNOWN\
1187+
\n TableScan: test projection=[a]";
1188+
assert_optimized_plan_equal(&plan, expected)
1189+
}
1190+
1191+
#[test]
1192+
fn test_is_not_unknown() -> Result<()> {
1193+
let table_scan = test_table_scan()?;
1194+
let plan = LogicalPlanBuilder::from(table_scan)
1195+
.project(vec![col("a").is_not_unknown()])?
1196+
.build()?;
1197+
1198+
let expected = "Projection: test.a IS NOT UNKNOWN\
1199+
\n TableScan: test projection=[a]";
1200+
assert_optimized_plan_equal(&plan, expected)
1201+
}
1202+
1203+
#[test]
1204+
fn test_not() -> Result<()> {
1205+
let table_scan = test_table_scan()?;
1206+
let plan = LogicalPlanBuilder::from(table_scan)
1207+
.project(vec![not(col("a"))])?
1208+
.build()?;
1209+
1210+
let expected = "Projection: NOT test.a\
1211+
\n TableScan: test projection=[a]";
1212+
assert_optimized_plan_equal(&plan, expected)
1213+
}
1214+
1215+
#[test]
1216+
fn test_try_cast() -> Result<()> {
1217+
let table_scan = test_table_scan()?;
1218+
let plan = LogicalPlanBuilder::from(table_scan)
1219+
.project(vec![try_cast(col("a"), DataType::Float64)])?
1220+
.build()?;
1221+
1222+
let expected = "Projection: TRY_CAST(test.a AS Float64)\
1223+
\n TableScan: test projection=[a]";
1224+
assert_optimized_plan_equal(&plan, expected)
1225+
}
1226+
1227+
#[test]
1228+
fn test_similar_to() -> Result<()> {
1229+
let table_scan = test_table_scan()?;
1230+
let expr = Box::new(col("a"));
1231+
let pattern = Box::new(lit("[0-9]"));
1232+
let similar_to_expr =
1233+
Expr::SimilarTo(Like::new(false, expr, pattern, None, false));
1234+
let plan = LogicalPlanBuilder::from(table_scan)
1235+
.project(vec![similar_to_expr])?
1236+
.build()?;
1237+
1238+
let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\
1239+
\n TableScan: test projection=[a]";
1240+
assert_optimized_plan_equal(&plan, expected)
1241+
}
1242+
1243+
#[test]
1244+
fn test_between() -> Result<()> {
1245+
let table_scan = test_table_scan()?;
1246+
let plan = LogicalPlanBuilder::from(table_scan)
1247+
.project(vec![col("a").between(lit(1), lit(3))])?
1248+
.build()?;
1249+
1250+
let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\
1251+
\n TableScan: test projection=[a]";
1252+
assert_optimized_plan_equal(&plan, expected)
1253+
}
10631254
}

0 commit comments

Comments
 (0)