1616// under the License.
1717
1818//! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available.
19- use std:: collections:: HashSet ;
2019use std:: sync:: Arc ;
2120
2221use crate :: { utils, OptimizerConfig , OptimizerRule } ;
2322
23+ use crate :: join_key_set:: JoinKeySet ;
2424use datafusion_common:: { plan_err, Result } ;
2525use datafusion_expr:: expr:: { BinaryExpr , Expr } ;
2626use datafusion_expr:: logical_plan:: {
@@ -55,7 +55,7 @@ impl OptimizerRule for EliminateCrossJoin {
5555 plan : & LogicalPlan ,
5656 config : & dyn OptimizerConfig ,
5757 ) -> Result < Option < LogicalPlan > > {
58- let mut possible_join_keys: Vec < ( Expr , Expr ) > = vec ! [ ] ;
58+ let mut possible_join_keys = JoinKeySet :: new ( ) ;
5959 let mut all_inputs: Vec < LogicalPlan > = vec ! [ ] ;
6060 let parent_predicate = match plan {
6161 LogicalPlan :: Filter ( filter) => {
@@ -76,7 +76,7 @@ impl OptimizerRule for EliminateCrossJoin {
7676 extract_possible_join_keys (
7777 & filter. predicate ,
7878 & mut possible_join_keys,
79- ) ? ;
79+ ) ;
8080 Some ( & filter. predicate )
8181 }
8282 _ => {
@@ -101,7 +101,7 @@ impl OptimizerRule for EliminateCrossJoin {
101101 } ;
102102
103103 // Join keys are handled locally:
104- let mut all_join_keys = HashSet :: < ( Expr , Expr ) > :: new ( ) ;
104+ let mut all_join_keys = JoinKeySet :: new ( ) ;
105105 let mut left = all_inputs. remove ( 0 ) ;
106106 while !all_inputs. is_empty ( ) {
107107 left = find_inner_join (
@@ -131,7 +131,7 @@ impl OptimizerRule for EliminateCrossJoin {
131131 . map ( |f| Some ( LogicalPlan :: Filter ( f) ) )
132132 } else {
133133 // Remove join expressions from filter:
134- match remove_join_expressions ( predicate, & all_join_keys) ? {
134+ match remove_join_expressions ( predicate. clone ( ) , & all_join_keys) {
135135 Some ( filter_expr) => Filter :: try_new ( filter_expr, Arc :: new ( left) )
136136 . map ( |f| Some ( LogicalPlan :: Filter ( f) ) ) ,
137137 _ => Ok ( Some ( left) ) ,
@@ -150,7 +150,7 @@ impl OptimizerRule for EliminateCrossJoin {
150150/// Returns a boolean indicating whether the flattening was successful.
151151fn try_flatten_join_inputs (
152152 plan : & LogicalPlan ,
153- possible_join_keys : & mut Vec < ( Expr , Expr ) > ,
153+ possible_join_keys : & mut JoinKeySet ,
154154 all_inputs : & mut Vec < LogicalPlan > ,
155155) -> Result < bool > {
156156 let children = match plan {
@@ -160,7 +160,7 @@ fn try_flatten_join_inputs(
160160 // issue: https://github.com/apache/datafusion/issues/4844
161161 return Ok ( false ) ;
162162 }
163- possible_join_keys. extend ( join. on . clone ( ) ) ;
163+ possible_join_keys. insert_all ( join. on . iter ( ) ) ;
164164 vec ! [ & join. left, & join. right]
165165 }
166166 LogicalPlan :: CrossJoin ( join) => {
@@ -204,8 +204,8 @@ fn try_flatten_join_inputs(
204204fn find_inner_join (
205205 left_input : & LogicalPlan ,
206206 rights : & mut Vec < LogicalPlan > ,
207- possible_join_keys : & [ ( Expr , Expr ) ] ,
208- all_join_keys : & mut HashSet < ( Expr , Expr ) > ,
207+ possible_join_keys : & JoinKeySet ,
208+ all_join_keys : & mut JoinKeySet ,
209209) -> Result < LogicalPlan > {
210210 for ( i, right_input) in rights. iter ( ) . enumerate ( ) {
211211 let mut join_keys = vec ! [ ] ;
@@ -228,7 +228,7 @@ fn find_inner_join(
228228
229229 // Found one or more matching join keys
230230 if !join_keys. is_empty ( ) {
231- all_join_keys. extend ( join_keys. clone ( ) ) ;
231+ all_join_keys. insert_all ( join_keys. iter ( ) ) ;
232232 let right_input = rights. remove ( i) ;
233233 let join_schema = Arc :: new ( build_join_schema (
234234 left_input. schema ( ) ,
@@ -265,90 +265,67 @@ fn find_inner_join(
265265 } ) )
266266}
267267
268- fn intersect (
269- accum : & mut Vec < ( Expr , Expr ) > ,
270- vec1 : & [ ( Expr , Expr ) ] ,
271- vec2 : & [ ( Expr , Expr ) ] ,
272- ) {
273- if !( vec1. is_empty ( ) || vec2. is_empty ( ) ) {
274- for x1 in vec1. iter ( ) {
275- for x2 in vec2. iter ( ) {
276- if x1. 0 == x2. 0 && x1. 1 == x2. 1 || x1. 1 == x2. 0 && x1. 0 == x2. 1 {
277- accum. push ( ( x1. 0 . clone ( ) , x1. 1 . clone ( ) ) ) ;
278- }
279- }
280- }
281- }
282- }
283-
284268/// Extract join keys from a WHERE clause
285- fn extract_possible_join_keys ( expr : & Expr , accum : & mut Vec < ( Expr , Expr ) > ) -> Result < ( ) > {
269+ fn extract_possible_join_keys ( expr : & Expr , join_keys : & mut JoinKeySet ) {
286270 if let Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) = expr {
287271 match op {
288272 Operator :: Eq => {
289- // Ensure that we don't add the same Join keys multiple times
290- if !( accum. contains ( & ( * left. clone ( ) , * right. clone ( ) ) )
291- || accum. contains ( & ( * right. clone ( ) , * left. clone ( ) ) ) )
292- {
293- accum. push ( ( * left. clone ( ) , * right. clone ( ) ) ) ;
294- }
273+ // insert handles ensuring we don't add the same Join keys multiple times
274+ join_keys. insert ( left, right) ;
295275 }
296276 Operator :: And => {
297- extract_possible_join_keys ( left, accum ) ? ;
298- extract_possible_join_keys ( right, accum ) ?
277+ extract_possible_join_keys ( left, join_keys ) ;
278+ extract_possible_join_keys ( right, join_keys )
299279 }
300280 // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
301281 Operator :: Or => {
302- let mut left_join_keys = vec ! [ ] ;
303- let mut right_join_keys = vec ! [ ] ;
282+ let mut left_join_keys = JoinKeySet :: new ( ) ;
283+ let mut right_join_keys = JoinKeySet :: new ( ) ;
304284
305- extract_possible_join_keys ( left, & mut left_join_keys) ? ;
306- extract_possible_join_keys ( right, & mut right_join_keys) ? ;
285+ extract_possible_join_keys ( left, & mut left_join_keys) ;
286+ extract_possible_join_keys ( right, & mut right_join_keys) ;
307287
308- intersect ( accum , & left_join_keys, & right_join_keys)
288+ join_keys . insert_intersection ( left_join_keys, right_join_keys)
309289 }
310290 _ => ( ) ,
311291 } ;
312292 }
313- Ok ( ( ) )
314293}
315294
316295/// Remove join expressions from a filter expression
317- /// Returns Some() when there are few remaining predicates in filter_expr
318- /// Returns None otherwise
319- fn remove_join_expressions (
320- expr : & Expr ,
321- join_keys : & HashSet < ( Expr , Expr ) > ,
322- ) -> Result < Option < Expr > > {
296+ ///
297+ /// # Returns
298+ /// * `Some()` when there are few remaining predicates in filter_expr
299+ /// * `None` otherwise
300+ fn remove_join_expressions ( expr : Expr , join_keys : & JoinKeySet ) -> Option < Expr > {
323301 match expr {
324- Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) => {
325- match op {
326- Operator :: Eq => {
327- if join_keys. contains ( & ( * left. clone ( ) , * right. clone ( ) ) )
328- || join_keys. contains ( & ( * right. clone ( ) , * left. clone ( ) ) )
329- {
330- Ok ( None )
331- } else {
332- Ok ( Some ( expr. clone ( ) ) )
333- }
334- }
335- // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
336- Operator :: And | Operator :: Or => {
337- let l = remove_join_expressions ( left, join_keys) ?;
338- let r = remove_join_expressions ( right, join_keys) ?;
339- match ( l, r) {
340- ( Some ( ll) , Some ( rr) ) => Ok ( Some ( Expr :: BinaryExpr (
341- BinaryExpr :: new ( Box :: new ( ll) , * op, Box :: new ( rr) ) ,
342- ) ) ) ,
343- ( Some ( ll) , _) => Ok ( Some ( ll) ) ,
344- ( _, Some ( rr) ) => Ok ( Some ( rr) ) ,
345- _ => Ok ( None ) ,
346- }
347- }
348- _ => Ok ( Some ( expr. clone ( ) ) ) ,
302+ Expr :: BinaryExpr ( BinaryExpr {
303+ left,
304+ op : Operator :: Eq ,
305+ right,
306+ } ) if join_keys. contains ( & left, & right) => {
307+ // was a join key, so remove it
308+ None
309+ }
310+ // Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
311+ Expr :: BinaryExpr ( BinaryExpr { left, op, right } )
312+ if matches ! ( op, Operator :: And | Operator :: Or ) =>
313+ {
314+ let l = remove_join_expressions ( * left, join_keys) ;
315+ let r = remove_join_expressions ( * right, join_keys) ;
316+ match ( l, r) {
317+ ( Some ( ll) , Some ( rr) ) => Some ( Expr :: BinaryExpr ( BinaryExpr :: new (
318+ Box :: new ( ll) ,
319+ op,
320+ Box :: new ( rr) ,
321+ ) ) ) ,
322+ ( Some ( ll) , _) => Some ( ll) ,
323+ ( _, Some ( rr) ) => Some ( rr) ,
324+ _ => None ,
349325 }
350326 }
351- _ => Ok ( Some ( expr. clone ( ) ) ) ,
327+
328+ _ => Some ( expr) ,
352329 }
353330}
354331
0 commit comments