@@ -583,11 +583,11 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result<Option<Expr>> {
583583///
584584/// # Returns
585585///
586- /// If the function can safely infer all outer-referenced columns, returns a
587- /// `Some(HashSet<Column>)` containing these columns. Otherwise, returns `None`.
588- fn outer_columns ( expr : & Expr ) -> Option < HashSet < Column > > {
586+ /// returns a `HashSet<Column>` containing all outer-referenced columns.
587+ fn outer_columns ( expr : & Expr ) -> HashSet < Column > {
589588 let mut columns = HashSet :: new ( ) ;
590- outer_columns_helper ( expr, & mut columns) . then_some ( columns)
589+ outer_columns_helper ( expr, & mut columns) ;
590+ columns
591591}
592592
593593/// A recursive subroutine that accumulates outer-referenced columns by the
@@ -598,87 +598,104 @@ fn outer_columns(expr: &Expr) -> Option<HashSet<Column>> {
598598/// * `expr` - The expression to analyze for outer-referenced columns.
599599/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
600600/// columns are collected.
601- ///
602- /// Returns `true` if it can safely collect all outer-referenced columns.
603- /// Otherwise, returns `false`.
604- fn outer_columns_helper ( expr : & Expr , columns : & mut HashSet < Column > ) -> bool {
601+ fn outer_columns_helper ( expr : & Expr , columns : & mut HashSet < Column > ) {
605602 match expr {
606603 Expr :: OuterReferenceColumn ( _, col) => {
607604 columns. insert ( col. clone ( ) ) ;
608- true
609605 }
610606 Expr :: BinaryExpr ( binary_expr) => {
611- outer_columns_helper ( & binary_expr. left , columns)
612- && outer_columns_helper ( & binary_expr. right , columns)
607+ outer_columns_helper ( & binary_expr. left , columns) ;
608+ outer_columns_helper ( & binary_expr. right , columns) ;
613609 }
614610 Expr :: ScalarSubquery ( subquery) => {
615611 let exprs = subquery. outer_ref_columns . iter ( ) ;
616- outer_columns_helper_multi ( exprs, columns)
612+ outer_columns_helper_multi ( exprs, columns) ;
617613 }
618614 Expr :: Exists ( exists) => {
619615 let exprs = exists. subquery . outer_ref_columns . iter ( ) ;
620- outer_columns_helper_multi ( exprs, columns)
616+ outer_columns_helper_multi ( exprs, columns) ;
621617 }
622618 Expr :: Alias ( alias) => outer_columns_helper ( & alias. expr , columns) ,
623619 Expr :: InSubquery ( insubquery) => {
624620 let exprs = insubquery. subquery . outer_ref_columns . iter ( ) ;
625- outer_columns_helper_multi ( exprs, columns)
621+ outer_columns_helper_multi ( exprs, columns) ;
626622 }
627- Expr :: IsNotNull ( expr) | Expr :: IsNull ( expr) => outer_columns_helper ( expr, columns) ,
628623 Expr :: Cast ( cast) => outer_columns_helper ( & cast. expr , columns) ,
629624 Expr :: Sort ( sort) => outer_columns_helper ( & sort. expr , columns) ,
630625 Expr :: AggregateFunction ( aggregate_fn) => {
631- outer_columns_helper_multi ( aggregate_fn. args . iter ( ) , columns)
632- && aggregate_fn
633- . order_by
634- . as_ref ( )
635- . map_or ( true , |obs| outer_columns_helper_multi ( obs. iter ( ) , columns) )
636- && aggregate_fn
637- . filter
638- . as_ref ( )
639- . map_or ( true , |filter| outer_columns_helper ( filter, columns) )
626+ outer_columns_helper_multi ( aggregate_fn. args . iter ( ) , columns) ;
627+ if let Some ( filter) = aggregate_fn. filter . as_ref ( ) {
628+ outer_columns_helper ( filter, columns) ;
629+ }
630+ if let Some ( obs) = aggregate_fn. order_by . as_ref ( ) {
631+ outer_columns_helper_multi ( obs. iter ( ) , columns) ;
632+ }
640633 }
641634 Expr :: WindowFunction ( window_fn) => {
642- outer_columns_helper_multi ( window_fn. args . iter ( ) , columns)
643- && outer_columns_helper_multi ( window_fn. order_by . iter ( ) , columns)
644- && outer_columns_helper_multi ( window_fn. partition_by . iter ( ) , columns)
635+ outer_columns_helper_multi ( window_fn. args . iter ( ) , columns) ;
636+ outer_columns_helper_multi ( window_fn. order_by . iter ( ) , columns) ;
637+ outer_columns_helper_multi ( window_fn. partition_by . iter ( ) , columns) ;
645638 }
646639 Expr :: GroupingSet ( groupingset) => match groupingset {
647- GroupingSet :: GroupingSets ( multi_exprs) => multi_exprs
648- . iter ( )
649- . all ( |e| outer_columns_helper_multi ( e. iter ( ) , columns) ) ,
640+ GroupingSet :: GroupingSets ( multi_exprs) => {
641+ multi_exprs
642+ . iter ( )
643+ . for_each ( |e| outer_columns_helper_multi ( e. iter ( ) , columns) ) ;
644+ }
650645 GroupingSet :: Cube ( exprs) | GroupingSet :: Rollup ( exprs) => {
651- outer_columns_helper_multi ( exprs. iter ( ) , columns)
646+ outer_columns_helper_multi ( exprs. iter ( ) , columns) ;
652647 }
653648 } ,
654649 Expr :: ScalarFunction ( scalar_fn) => {
655- outer_columns_helper_multi ( scalar_fn. args . iter ( ) , columns)
650+ outer_columns_helper_multi ( scalar_fn. args . iter ( ) , columns) ;
656651 }
657652 Expr :: Like ( like) => {
658- outer_columns_helper ( & like. expr , columns)
659- && outer_columns_helper ( & like. pattern , columns)
653+ outer_columns_helper ( & like. expr , columns) ;
654+ outer_columns_helper ( & like. pattern , columns) ;
660655 }
661656 Expr :: InList ( in_list) => {
662- outer_columns_helper ( & in_list. expr , columns)
663- && outer_columns_helper_multi ( in_list. list . iter ( ) , columns)
657+ outer_columns_helper ( & in_list. expr , columns) ;
658+ outer_columns_helper_multi ( in_list. list . iter ( ) , columns) ;
664659 }
665660 Expr :: Case ( case) => {
666661 let when_then_exprs = case
667662 . when_then_expr
668663 . iter ( )
669664 . flat_map ( |( first, second) | [ first. as_ref ( ) , second. as_ref ( ) ] ) ;
670- outer_columns_helper_multi ( when_then_exprs, columns)
671- && case
672- . expr
673- . as_ref ( )
674- . map_or ( true , |expr| outer_columns_helper ( expr, columns) )
675- && case
676- . else_expr
677- . as_ref ( )
678- . map_or ( true , |expr| outer_columns_helper ( expr, columns) )
665+ outer_columns_helper_multi ( when_then_exprs, columns) ;
666+ if let Some ( expr) = case. expr . as_ref ( ) {
667+ outer_columns_helper ( expr, columns) ;
668+ }
669+ if let Some ( expr) = case. else_expr . as_ref ( ) {
670+ outer_columns_helper ( expr, columns) ;
671+ }
672+ }
673+ Expr :: SimilarTo ( similar_to) => {
674+ outer_columns_helper ( & similar_to. expr , columns) ;
675+ outer_columns_helper ( & similar_to. pattern , columns) ;
676+ }
677+ Expr :: TryCast ( try_cast) => outer_columns_helper ( & try_cast. expr , columns) ,
678+ Expr :: GetIndexedField ( index) => outer_columns_helper ( & index. expr , columns) ,
679+ Expr :: Between ( between) => {
680+ outer_columns_helper ( & between. expr , columns) ;
681+ outer_columns_helper ( & between. low , columns) ;
682+ outer_columns_helper ( & between. high , columns) ;
679683 }
680- Expr :: Column ( _) | Expr :: Literal ( _) | Expr :: Wildcard { .. } => true ,
681- _ => false ,
684+ Expr :: Not ( expr)
685+ | Expr :: IsNotFalse ( expr)
686+ | Expr :: IsFalse ( expr)
687+ | Expr :: IsTrue ( expr)
688+ | Expr :: IsNotTrue ( expr)
689+ | Expr :: IsUnknown ( expr)
690+ | Expr :: IsNotUnknown ( expr)
691+ | Expr :: IsNotNull ( expr)
692+ | Expr :: IsNull ( expr)
693+ | Expr :: Negative ( expr) => outer_columns_helper ( expr, columns) ,
694+ Expr :: Column ( _)
695+ | Expr :: Literal ( _)
696+ | Expr :: Wildcard { .. }
697+ | Expr :: ScalarVariable { .. }
698+ | Expr :: Placeholder ( _) => ( ) ,
682699 }
683700}
684701
@@ -690,14 +707,11 @@ fn outer_columns_helper(expr: &Expr, columns: &mut HashSet<Column>) -> bool {
690707/// * `exprs` - The expressions to analyze for outer-referenced columns.
691708/// * `columns` - A mutable reference to a `HashSet<Column>` where detected
692709/// columns are collected.
693- ///
694- /// Returns `true` if it can safely collect all outer-referenced columns.
695- /// Otherwise, returns `false`.
696710fn outer_columns_helper_multi < ' a > (
697- mut exprs : impl Iterator < Item = & ' a Expr > ,
711+ exprs : impl Iterator < Item = & ' a Expr > ,
698712 columns : & mut HashSet < Column > ,
699- ) -> bool {
700- exprs. all ( |e| outer_columns_helper ( e, columns) )
713+ ) {
714+ exprs. for_each ( |e| outer_columns_helper ( e, columns) ) ;
701715}
702716
703717/// Generates the required expressions (columns) that reside at `indices` of
@@ -766,13 +780,7 @@ fn indices_referred_by_expr(
766780) -> Result < Vec < usize > > {
767781 let mut cols = expr. to_columns ( ) ?;
768782 // Get outer-referenced columns:
769- if let Some ( outer_cols) = outer_columns ( expr) {
770- cols. extend ( outer_cols) ;
771- } else {
772- // Expression is not known to contain outer columns or not. Hence, do
773- // not assume anything and require all the schema indices at the input:
774- return Ok ( ( 0 ..input_schema. fields ( ) . len ( ) ) . collect ( ) ) ;
775- }
783+ cols. extend ( outer_columns ( expr) ) ;
776784 Ok ( cols
777785 . iter ( )
778786 . flat_map ( |col| input_schema. index_of_column ( col) )
@@ -978,8 +986,8 @@ mod tests {
978986 use arrow:: datatypes:: { DataType , Field , Schema } ;
979987 use datafusion_common:: { Result , TableReference } ;
980988 use datafusion_expr:: {
981- binary_expr, col, count, lit, logical_plan:: builder:: LogicalPlanBuilder ,
982- table_scan, Expr , LogicalPlan , Operator ,
989+ binary_expr, col, count, lit, logical_plan:: builder:: LogicalPlanBuilder , not ,
990+ table_scan, try_cast , Expr , Like , LogicalPlan , Operator ,
983991 } ;
984992
985993 fn assert_optimized_plan_equal ( plan : & LogicalPlan , expected : & str ) -> Result < ( ) > {
@@ -1060,4 +1068,187 @@ mod tests {
10601068 \n TableScan: ?table? projection=[]";
10611069 assert_optimized_plan_equal ( & plan, expected)
10621070 }
1071+
1072+ #[ test]
1073+ fn test_struct_field_push_down ( ) -> Result < ( ) > {
1074+ let schema = Arc :: new ( Schema :: new ( vec ! [
1075+ Field :: new( "a" , DataType :: Int64 , false ) ,
1076+ Field :: new_struct(
1077+ "s" ,
1078+ vec![
1079+ Field :: new( "x" , DataType :: Int64 , false ) ,
1080+ Field :: new( "y" , DataType :: Int64 , false ) ,
1081+ ] ,
1082+ false ,
1083+ ) ,
1084+ ] ) ) ;
1085+
1086+ let table_scan = table_scan ( TableReference :: none ( ) , & schema, None ) ?. build ( ) ?;
1087+ let plan = LogicalPlanBuilder :: from ( table_scan)
1088+ . project ( vec ! [ col( "s" ) . field( "x" ) ] ) ?
1089+ . build ( ) ?;
1090+ let expected = "Projection: (?table?.s)[x]\
1091+ \n TableScan: ?table? projection=[s]";
1092+ assert_optimized_plan_equal ( & plan, expected)
1093+ }
1094+
1095+ #[ test]
1096+ fn test_neg_push_down ( ) -> Result < ( ) > {
1097+ let table_scan = test_table_scan ( ) ?;
1098+ let plan = LogicalPlanBuilder :: from ( table_scan)
1099+ . project ( vec ! [ -col( "a" ) ] ) ?
1100+ . build ( ) ?;
1101+
1102+ let expected = "Projection: (- test.a)\
1103+ \n TableScan: test projection=[a]";
1104+ assert_optimized_plan_equal ( & plan, expected)
1105+ }
1106+
1107+ #[ test]
1108+ fn test_is_null ( ) -> Result < ( ) > {
1109+ let table_scan = test_table_scan ( ) ?;
1110+ let plan = LogicalPlanBuilder :: from ( table_scan)
1111+ . project ( vec ! [ col( "a" ) . is_null( ) ] ) ?
1112+ . build ( ) ?;
1113+
1114+ let expected = "Projection: test.a IS NULL\
1115+ \n TableScan: test projection=[a]";
1116+ assert_optimized_plan_equal ( & plan, expected)
1117+ }
1118+
1119+ #[ test]
1120+ fn test_is_not_null ( ) -> Result < ( ) > {
1121+ let table_scan = test_table_scan ( ) ?;
1122+ let plan = LogicalPlanBuilder :: from ( table_scan)
1123+ . project ( vec ! [ col( "a" ) . is_not_null( ) ] ) ?
1124+ . build ( ) ?;
1125+
1126+ let expected = "Projection: test.a IS NOT NULL\
1127+ \n TableScan: test projection=[a]";
1128+ assert_optimized_plan_equal ( & plan, expected)
1129+ }
1130+
1131+ #[ test]
1132+ fn test_is_true ( ) -> Result < ( ) > {
1133+ let table_scan = test_table_scan ( ) ?;
1134+ let plan = LogicalPlanBuilder :: from ( table_scan)
1135+ . project ( vec ! [ col( "a" ) . is_true( ) ] ) ?
1136+ . build ( ) ?;
1137+
1138+ let expected = "Projection: test.a IS TRUE\
1139+ \n TableScan: test projection=[a]";
1140+ assert_optimized_plan_equal ( & plan, expected)
1141+ }
1142+
1143+ #[ test]
1144+ fn test_is_not_true ( ) -> Result < ( ) > {
1145+ let table_scan = test_table_scan ( ) ?;
1146+ let plan = LogicalPlanBuilder :: from ( table_scan)
1147+ . project ( vec ! [ col( "a" ) . is_not_true( ) ] ) ?
1148+ . build ( ) ?;
1149+
1150+ let expected = "Projection: test.a IS NOT TRUE\
1151+ \n TableScan: test projection=[a]";
1152+ assert_optimized_plan_equal ( & plan, expected)
1153+ }
1154+
1155+ #[ test]
1156+ fn test_is_false ( ) -> Result < ( ) > {
1157+ let table_scan = test_table_scan ( ) ?;
1158+ let plan = LogicalPlanBuilder :: from ( table_scan)
1159+ . project ( vec ! [ col( "a" ) . is_false( ) ] ) ?
1160+ . build ( ) ?;
1161+
1162+ let expected = "Projection: test.a IS FALSE\
1163+ \n TableScan: test projection=[a]";
1164+ assert_optimized_plan_equal ( & plan, expected)
1165+ }
1166+
1167+ #[ test]
1168+ fn test_is_not_false ( ) -> Result < ( ) > {
1169+ let table_scan = test_table_scan ( ) ?;
1170+ let plan = LogicalPlanBuilder :: from ( table_scan)
1171+ . project ( vec ! [ col( "a" ) . is_not_false( ) ] ) ?
1172+ . build ( ) ?;
1173+
1174+ let expected = "Projection: test.a IS NOT FALSE\
1175+ \n TableScan: test projection=[a]";
1176+ assert_optimized_plan_equal ( & plan, expected)
1177+ }
1178+
1179+ #[ test]
1180+ fn test_is_unknown ( ) -> Result < ( ) > {
1181+ let table_scan = test_table_scan ( ) ?;
1182+ let plan = LogicalPlanBuilder :: from ( table_scan)
1183+ . project ( vec ! [ col( "a" ) . is_unknown( ) ] ) ?
1184+ . build ( ) ?;
1185+
1186+ let expected = "Projection: test.a IS UNKNOWN\
1187+ \n TableScan: test projection=[a]";
1188+ assert_optimized_plan_equal ( & plan, expected)
1189+ }
1190+
1191+ #[ test]
1192+ fn test_is_not_unknown ( ) -> Result < ( ) > {
1193+ let table_scan = test_table_scan ( ) ?;
1194+ let plan = LogicalPlanBuilder :: from ( table_scan)
1195+ . project ( vec ! [ col( "a" ) . is_not_unknown( ) ] ) ?
1196+ . build ( ) ?;
1197+
1198+ let expected = "Projection: test.a IS NOT UNKNOWN\
1199+ \n TableScan: test projection=[a]";
1200+ assert_optimized_plan_equal ( & plan, expected)
1201+ }
1202+
1203+ #[ test]
1204+ fn test_not ( ) -> Result < ( ) > {
1205+ let table_scan = test_table_scan ( ) ?;
1206+ let plan = LogicalPlanBuilder :: from ( table_scan)
1207+ . project ( vec ! [ not( col( "a" ) ) ] ) ?
1208+ . build ( ) ?;
1209+
1210+ let expected = "Projection: NOT test.a\
1211+ \n TableScan: test projection=[a]";
1212+ assert_optimized_plan_equal ( & plan, expected)
1213+ }
1214+
1215+ #[ test]
1216+ fn test_try_cast ( ) -> Result < ( ) > {
1217+ let table_scan = test_table_scan ( ) ?;
1218+ let plan = LogicalPlanBuilder :: from ( table_scan)
1219+ . project ( vec ! [ try_cast( col( "a" ) , DataType :: Float64 ) ] ) ?
1220+ . build ( ) ?;
1221+
1222+ let expected = "Projection: TRY_CAST(test.a AS Float64)\
1223+ \n TableScan: test projection=[a]";
1224+ assert_optimized_plan_equal ( & plan, expected)
1225+ }
1226+
1227+ #[ test]
1228+ fn test_similar_to ( ) -> Result < ( ) > {
1229+ let table_scan = test_table_scan ( ) ?;
1230+ let expr = Box :: new ( col ( "a" ) ) ;
1231+ let pattern = Box :: new ( lit ( "[0-9]" ) ) ;
1232+ let similar_to_expr =
1233+ Expr :: SimilarTo ( Like :: new ( false , expr, pattern, None , false ) ) ;
1234+ let plan = LogicalPlanBuilder :: from ( table_scan)
1235+ . project ( vec ! [ similar_to_expr] ) ?
1236+ . build ( ) ?;
1237+
1238+ let expected = "Projection: test.a SIMILAR TO Utf8(\" [0-9]\" )\
1239+ \n TableScan: test projection=[a]";
1240+ assert_optimized_plan_equal ( & plan, expected)
1241+ }
1242+
1243+ #[ test]
1244+ fn test_between ( ) -> Result < ( ) > {
1245+ let table_scan = test_table_scan ( ) ?;
1246+ let plan = LogicalPlanBuilder :: from ( table_scan)
1247+ . project ( vec ! [ col( "a" ) . between( lit( 1 ) , lit( 3 ) ) ] ) ?
1248+ . build ( ) ?;
1249+
1250+ let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\
1251+ \n TableScan: test projection=[a]";
1252+ assert_optimized_plan_equal ( & plan, expected)
1253+ }
10631254}
0 commit comments