1818//! Optimizer rule to replace nested unions to single union.
1919use crate :: { OptimizerConfig , OptimizerRule } ;
2020use datafusion_common:: Result ;
21- use datafusion_expr:: {
22- builder:: project_with_column_index,
23- expr_rewriter:: coerce_plan_expr_for_schema,
24- logical_plan:: { LogicalPlan , Projection , Union } ,
25- } ;
21+ use datafusion_expr:: logical_plan:: { LogicalPlan , Union } ;
2622
2723use crate :: optimizer:: ApplyOrder ;
24+ use datafusion_expr:: expr_rewriter:: coerce_plan_expr_for_schema;
2825use std:: sync:: Arc ;
2926
3027#[ derive( Default ) ]
@@ -38,6 +35,8 @@ impl EliminateNestedUnion {
3835 }
3936}
4037
38+ pub fn get_union_schema ( ) { }
39+
4140impl OptimizerRule for EliminateNestedUnion {
4241 fn try_optimize (
4342 & self ,
@@ -48,32 +47,24 @@ impl OptimizerRule for EliminateNestedUnion {
4847 LogicalPlan :: Union ( union) => {
4948 let Union { inputs, schema } = union;
5049
51- let union_schema = schema. clone ( ) ;
52-
5350 let inputs = inputs
5451 . into_iter ( )
55- . flat_map ( |plan| match Arc :: as_ref ( plan) {
56- LogicalPlan :: Union ( Union { inputs, .. } ) => inputs. clone ( ) ,
57- _ => vec ! [ Arc :: clone( plan) ] ,
58- } )
59- . map ( |plan| {
60- let plan = coerce_plan_expr_for_schema ( & plan, & union_schema) ?;
61- match plan {
62- LogicalPlan :: Projection ( Projection {
63- expr, input, ..
64- } ) => Ok ( Arc :: new ( project_with_column_index (
65- expr,
66- input,
67- union_schema. clone ( ) ,
68- ) ?) ) ,
69- _ => Ok ( Arc :: new ( plan) ) ,
70- }
52+ . flat_map ( |plan| match plan. as_ref ( ) {
53+ LogicalPlan :: Union ( Union { inputs, schema } ) => inputs
54+ . into_iter ( )
55+ . map ( |plan| {
56+ Arc :: new (
57+ coerce_plan_expr_for_schema ( plan, schema) . unwrap ( ) ,
58+ )
59+ } )
60+ . collect :: < Vec < _ > > ( ) ,
61+ _ => vec ! [ plan. clone( ) ] ,
7162 } )
72- . collect :: < Result < Vec < _ > > > ( ) ? ;
63+ . collect :: < Vec < _ > > ( ) ;
7364
7465 Ok ( Some ( LogicalPlan :: Union ( Union {
7566 inputs,
76- schema : union_schema ,
67+ schema : schema . clone ( ) ,
7768 } ) ) )
7869 }
7970 _ => Ok ( None ) ,
@@ -94,13 +85,13 @@ mod tests {
9485 use super :: * ;
9586 use crate :: test:: * ;
9687 use arrow:: datatypes:: { DataType , Field , Schema } ;
97- use datafusion_expr:: logical_plan:: table_scan;
88+ use datafusion_expr:: { col , logical_plan:: table_scan} ;
9889
9990 fn schema ( ) -> Schema {
10091 Schema :: new ( vec ! [
10192 Field :: new( "id" , DataType :: Int32 , false ) ,
10293 Field :: new( "key" , DataType :: Utf8 , false ) ,
103- Field :: new( "value" , DataType :: Int32 , false ) ,
94+ Field :: new( "value" , DataType :: Float64 , false ) ,
10495 ] )
10596 }
10697
@@ -143,4 +134,81 @@ mod tests {
143134 \n TableScan: table";
144135 assert_optimized_plan_equal ( & plan, expected)
145136 }
137+
138+ // We don't need to use project_with_column_index in logical optimizer,
139+ // after LogicalPlanBuilder::union, we already have all equal expression aliases
140+ #[ test]
141+ fn eliminate_nested_union_with_projection ( ) -> Result < ( ) > {
142+ let plan_builder = table_scan ( Some ( "table" ) , & schema ( ) , None ) ?;
143+
144+ let plan = plan_builder
145+ . clone ( )
146+ . union (
147+ plan_builder
148+ . clone ( )
149+ . project ( vec ! [ col( "id" ) . alias( "table_id" ) , col( "key" ) , col( "value" ) ] ) ?
150+ . build ( ) ?,
151+ ) ?
152+ . union (
153+ plan_builder
154+ . clone ( )
155+ . project ( vec ! [ col( "id" ) . alias( "_id" ) , col( "key" ) , col( "value" ) ] ) ?
156+ . build ( ) ?,
157+ ) ?
158+ . build ( ) ?;
159+
160+ let expected = "Union\
161+ \n TableScan: table\
162+ \n Projection: table.id AS id, table.key, table.value\
163+ \n TableScan: table\
164+ \n Projection: table.id AS id, table.key, table.value\
165+ \n TableScan: table";
166+ assert_optimized_plan_equal ( & plan, expected)
167+ }
168+
169+ #[ test]
170+ fn eliminate_nested_union_with_type_cast_projection ( ) -> Result < ( ) > {
171+ let table_1 = table_scan (
172+ Some ( "table_1" ) ,
173+ & Schema :: new ( vec ! [
174+ Field :: new( "id" , DataType :: Int64 , false ) ,
175+ Field :: new( "key" , DataType :: Utf8 , false ) ,
176+ Field :: new( "value" , DataType :: Float64 , false ) ,
177+ ] ) ,
178+ None ,
179+ ) ?;
180+
181+ let table_2 = table_scan (
182+ Some ( "table_1" ) ,
183+ & Schema :: new ( vec ! [
184+ Field :: new( "id" , DataType :: Int32 , false ) ,
185+ Field :: new( "key" , DataType :: Utf8 , false ) ,
186+ Field :: new( "value" , DataType :: Float32 , false ) ,
187+ ] ) ,
188+ None ,
189+ ) ?;
190+
191+ let table_3 = table_scan (
192+ Some ( "table_1" ) ,
193+ & Schema :: new ( vec ! [
194+ Field :: new( "id" , DataType :: Int16 , false ) ,
195+ Field :: new( "key" , DataType :: Utf8 , false ) ,
196+ Field :: new( "value" , DataType :: Float32 , false ) ,
197+ ] ) ,
198+ None ,
199+ ) ?;
200+
201+ let plan = table_1
202+ . union ( table_2. build ( ) ?) ?
203+ . union ( table_3. build ( ) ?) ?
204+ . build ( ) ?;
205+
206+ let expected = "Union\
207+ \n TableScan: table_1\
208+ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
209+ \n TableScan: table_1\
210+ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\
211+ \n TableScan: table_1";
212+ assert_optimized_plan_equal ( & plan, expected)
213+ }
146214}
0 commit comments