@@ -23,9 +23,10 @@ use arrow::datatypes::Schema;
2323use crate :: execution:: context:: ExecutionConfig ;
2424use crate :: logical_plan:: JoinType ;
2525use crate :: physical_plan:: cross_join:: CrossJoinExec ;
26+ use crate :: physical_plan:: expressions:: Column ;
2627use crate :: physical_plan:: hash_join:: HashJoinExec ;
2728use crate :: physical_plan:: projection:: ProjectionExec ;
28- use crate :: physical_plan:: { expressions , ExecutionPlan , PhysicalExpr } ;
29+ use crate :: physical_plan:: { ExecutionPlan , PhysicalExpr } ;
2930
3031use super :: optimizer:: PhysicalOptimizerRule ;
3132use super :: utils:: optimize_children;
@@ -84,15 +85,14 @@ fn swap_reverting_projection(
8485) -> Vec < ( Arc < dyn PhysicalExpr > , String ) > {
8586 let right_cols = right_schema. fields ( ) . iter ( ) . enumerate ( ) . map ( |( i, f) | {
8687 (
87- Arc :: new ( expressions :: Column :: new ( f. name ( ) , i) ) as Arc < dyn PhysicalExpr > ,
88+ Arc :: new ( Column :: new ( f. name ( ) , i) ) as Arc < dyn PhysicalExpr > ,
8889 f. name ( ) . to_owned ( ) ,
8990 )
9091 } ) ;
9192 let right_len = right_cols. len ( ) ;
9293 let left_cols = left_schema. fields ( ) . iter ( ) . enumerate ( ) . map ( |( i, f) | {
9394 (
94- Arc :: new ( expressions:: Column :: new ( f. name ( ) , right_len + i) )
95- as Arc < dyn PhysicalExpr > ,
95+ Arc :: new ( Column :: new ( f. name ( ) , right_len + i) ) as Arc < dyn PhysicalExpr > ,
9696 f. name ( ) . to_owned ( ) ,
9797 )
9898 } ) ;
@@ -153,11 +153,107 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
153153
154154#[ cfg( test) ]
155155mod tests {
156+ use crate :: {
157+ physical_plan:: { hash_join:: PartitionMode , Statistics } ,
158+ test:: exec:: StatisticsExec ,
159+ } ;
160+
156161 use super :: * ;
157162 use std:: sync:: Arc ;
158163
159164 use arrow:: datatypes:: { DataType , Field , Schema } ;
160165
166+ fn create_big_and_small ( ) -> ( Arc < dyn ExecutionPlan > , Arc < dyn ExecutionPlan > ) {
167+ let big = Arc :: new ( StatisticsExec :: new (
168+ Statistics {
169+ num_rows : Some ( 100000 ) ,
170+ ..Default :: default ( )
171+ } ,
172+ Schema :: new ( vec ! [ Field :: new( "big_col" , DataType :: Int32 , false ) ] ) ,
173+ ) ) ;
174+
175+ let small = Arc :: new ( StatisticsExec :: new (
176+ Statistics {
177+ num_rows : Some ( 10 ) ,
178+ ..Default :: default ( )
179+ } ,
180+ Schema :: new ( vec ! [ Field :: new( "small_col" , DataType :: Int32 , false ) ] ) ,
181+ ) ) ;
182+ ( big, small)
183+ }
184+
185+ #[ tokio:: test]
186+ async fn test_join_with_swap ( ) {
187+ let ( big, small) = create_big_and_small ( ) ;
188+
189+ let join = HashJoinExec :: try_new (
190+ Arc :: clone ( & big) ,
191+ Arc :: clone ( & small) ,
192+ vec ! [ (
193+ Column :: new_with_schema( "big_col" , & big. schema( ) ) . unwrap( ) ,
194+ Column :: new_with_schema( "small_col" , & small. schema( ) ) . unwrap( ) ,
195+ ) ] ,
196+ & JoinType :: Left ,
197+ PartitionMode :: CollectLeft ,
198+ )
199+ . unwrap ( ) ;
200+
201+ let optimized_join = HashBuildProbeOrder :: new ( )
202+ . optimize ( Arc :: new ( join) , & ExecutionConfig :: new ( ) )
203+ . unwrap ( ) ;
204+
205+ let swapping_projection = optimized_join
206+ . as_any ( )
207+ . downcast_ref :: < ProjectionExec > ( )
208+ . expect ( "A proj is required to swap columns back to their original order" ) ;
209+
210+ assert_eq ! ( swapping_projection. expr( ) . len( ) , 2 ) ;
211+ let ( col, name) = & swapping_projection. expr ( ) [ 0 ] ;
212+ assert_eq ! ( name, "big_col" ) ;
213+ assert_col_expr ( col, "big_col" , 1 ) ;
214+ let ( col, name) = & swapping_projection. expr ( ) [ 1 ] ;
215+ assert_eq ! ( name, "small_col" ) ;
216+ assert_col_expr ( col, "small_col" , 0 ) ;
217+
218+ let swapped_join = swapping_projection
219+ . input ( )
220+ . as_any ( )
221+ . downcast_ref :: < HashJoinExec > ( )
222+ . expect ( "The type of the plan should not be changed" ) ;
223+
224+ assert_eq ! ( swapped_join. left( ) . statistics( ) . num_rows, Some ( 10 ) ) ;
225+ assert_eq ! ( swapped_join. right( ) . statistics( ) . num_rows, Some ( 100000 ) ) ;
226+ }
227+
228+ #[ tokio:: test]
229+ async fn test_join_no_swap ( ) {
230+ let ( big, small) = create_big_and_small ( ) ;
231+
232+ let join = HashJoinExec :: try_new (
233+ Arc :: clone ( & small) ,
234+ Arc :: clone ( & big) ,
235+ vec ! [ (
236+ Column :: new_with_schema( "small_col" , & small. schema( ) ) . unwrap( ) ,
237+ Column :: new_with_schema( "big_col" , & big. schema( ) ) . unwrap( ) ,
238+ ) ] ,
239+ & JoinType :: Left ,
240+ PartitionMode :: CollectLeft ,
241+ )
242+ . unwrap ( ) ;
243+
244+ let optimized_join = HashBuildProbeOrder :: new ( )
245+ . optimize ( Arc :: new ( join) , & ExecutionConfig :: new ( ) )
246+ . unwrap ( ) ;
247+
248+ let swapped_join = optimized_join
249+ . as_any ( )
250+ . downcast_ref :: < HashJoinExec > ( )
251+ . expect ( "The type of the plan should not be changed" ) ;
252+
253+ assert_eq ! ( swapped_join. left( ) . statistics( ) . num_rows, Some ( 10 ) ) ;
254+ assert_eq ! ( swapped_join. right( ) . statistics( ) . num_rows, Some ( 100000 ) ) ;
255+ }
256+
161257 #[ tokio:: test]
162258 async fn test_swap_reverting_projection ( ) {
163259 let left_schema = Schema :: new ( vec ! [
@@ -187,7 +283,7 @@ mod tests {
187283 fn assert_col_expr ( expr : & Arc < dyn PhysicalExpr > , name : & str , index : usize ) {
188284 let col = expr
189285 . as_any ( )
190- . downcast_ref :: < expressions :: Column > ( )
286+ . downcast_ref :: < Column > ( )
191287 . expect ( "Projection items should be Column expression" ) ;
192288 assert_eq ! ( col. name( ) , name) ;
193289 assert_eq ! ( col. index( ) , index) ;
0 commit comments