1717
1818//! Physical expression schema rewriting utilities 
1919
20- use  std:: sync:: Arc ; 
2120use  std:: cmp:: Ordering ; 
21+ use  std:: sync:: Arc ; 
2222
2323use  arrow:: compute:: can_cast_types; 
2424use  arrow:: datatypes:: { 
@@ -230,7 +230,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
230230            left. as_any ( ) . downcast_ref :: < CastExpr > ( ) , 
231231            right. as_any ( ) . downcast_ref :: < Literal > ( ) , 
232232        )  { 
233-             if  let  Some ( optimized)  = self . unwrap_cast_with_literal ( cast_expr,  literal,  * op) ? { 
233+             if  let  Some ( optimized)  =
234+                 self . unwrap_cast_with_literal ( cast_expr,  literal,  * op) ?
235+             { 
234236                return  Ok ( Some ( Arc :: new ( BinaryExpr :: new ( 
235237                    optimized. 0 , 
236238                    * op, 
@@ -244,7 +246,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
244246            left. as_any ( ) . downcast_ref :: < Literal > ( ) , 
245247            right. as_any ( ) . downcast_ref :: < CastExpr > ( ) , 
246248        )  { 
247-             if  let  Some ( optimized)  = self . unwrap_cast_with_literal ( cast_expr,  literal,  * op) ? { 
249+             if  let  Some ( optimized)  =
250+                 self . unwrap_cast_with_literal ( cast_expr,  literal,  * op) ?
251+             { 
248252                return  Ok ( Some ( Arc :: new ( BinaryExpr :: new ( 
249253                    optimized. 1 , 
250254                    * op, 
@@ -265,32 +269,36 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
265269    )  -> Result < Option < ( Arc < dyn  PhysicalExpr > ,  Arc < dyn  PhysicalExpr > ) > >  { 
266270        // Get the inner expression (what's being cast) 
267271        let  inner_expr = cast_expr. expr ( ) ; 
268-          
272+ 
269273        // Handle the case where inner expression might be another cast (due to schema rewriting) 
270274        // This can happen when the schema rewriter adds a cast to a column, and then we have 
271275        // an original cast on top of that. 
272-         let  ( final_inner_expr,  column)  = if  let  Some ( inner_cast)  = inner_expr. as_any ( ) . downcast_ref :: < CastExpr > ( )  { 
273-             // We have a nested cast, check if the inner cast's expression is a column 
274-             let  inner_inner_expr = inner_cast. expr ( ) ; 
275-             if  let  Some ( col)  = inner_inner_expr. as_any ( ) . downcast_ref :: < Column > ( )  { 
276-                 ( inner_inner_expr,  col) 
276+         let  ( final_inner_expr,  column)  =
277+             if  let  Some ( inner_cast)  = inner_expr. as_any ( ) . downcast_ref :: < CastExpr > ( )  { 
278+                 // We have a nested cast, check if the inner cast's expression is a column 
279+                 let  inner_inner_expr = inner_cast. expr ( ) ; 
280+                 if  let  Some ( col)  = inner_inner_expr. as_any ( ) . downcast_ref :: < Column > ( )  { 
281+                     ( inner_inner_expr,  col) 
282+                 }  else  { 
283+                     return  Ok ( None ) ; 
284+                 } 
285+             }  else  if  let  Some ( col)  = inner_expr. as_any ( ) . downcast_ref :: < Column > ( )  { 
286+                 ( inner_expr,  col) 
277287            }  else  { 
278288                return  Ok ( None ) ; 
279-             } 
280-         }  else  if  let  Some ( col)  = inner_expr. as_any ( ) . downcast_ref :: < Column > ( )  { 
281-             ( inner_expr,  col) 
282-         }  else  { 
283-             return  Ok ( None ) ; 
284-         } ; 
289+             } ; 
285290
286291        // Get the column's data type from the physical schema 
287-         let  column_data_type = match  self . physical_file_schema . field_with_name ( column. name ( ) )  { 
288-             Ok ( field)  => field. data_type ( ) , 
289-             Err ( _)  => return  Ok ( None ) ,  // Column not found, can't optimize 
290-         } ; 
292+         let  column_data_type =
293+             match  self . physical_file_schema . field_with_name ( column. name ( ) )  { 
294+                 Ok ( field)  => field. data_type ( ) , 
295+                 Err ( _)  => return  Ok ( None ) ,  // Column not found, can't optimize 
296+             } ; 
291297
292298        // Try to cast the literal to the column's data type 
293-         if  let  Some ( casted_literal)  = try_cast_literal_to_type ( literal. value ( ) ,  column_data_type,  op)  { 
299+         if  let  Some ( casted_literal)  =
300+             try_cast_literal_to_type ( literal. value ( ) ,  column_data_type,  op) 
301+         { 
294302            return  Ok ( Some ( ( 
295303                Arc :: clone ( final_inner_expr) , 
296304                expressions:: lit ( casted_literal) , 
@@ -323,7 +331,6 @@ fn cast_literal_to_type_with_op(
323331    target_type :  & DataType , 
324332    op :  Operator , 
325333)  -> Option < ScalarValue >  { 
326-     
327334    match  ( op,  lit_value)  { 
328335        ( 
329336            Operator :: Eq  | Operator :: NotEq , 
@@ -754,22 +761,27 @@ mod tests {
754761        let  column_expr = Arc :: new ( Column :: new ( "a" ,  0 ) ) ; 
755762        let  cast_expr = Arc :: new ( CastExpr :: new ( column_expr,  DataType :: Int64 ,  None ) ) ; 
756763        let  literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ; 
757-         let  binary_expr = Arc :: new ( BinaryExpr :: new ( 
758-             cast_expr, 
759-             Operator :: Eq , 
760-             literal_expr, 
761-         ) ) ; 
764+         let  binary_expr =
765+             Arc :: new ( BinaryExpr :: new ( cast_expr,  Operator :: Eq ,  literal_expr) ) ; 
762766
763767        let  result = rewriter. rewrite ( binary_expr. clone ( )  as  Arc < dyn  PhysicalExpr > ) ?; 
764768
765769        // The result should be a binary expression with the cast unwrapped 
766770        let  result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ; 
767-          
771+ 
768772        // Left side should be the original column (no cast) 
769-         assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ; 
770-         
773+         assert ! ( result_binary
774+             . left( ) 
775+             . as_any( ) 
776+             . downcast_ref:: <Column >( ) 
777+             . is_some( ) ) ; 
778+ 
771779        // Right side should be a literal with the value cast to Int32 
772-         let  right_literal = result_binary. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ; 
780+         let  right_literal = result_binary
781+             . right ( ) 
782+             . as_any ( ) 
783+             . downcast_ref :: < Literal > ( ) 
784+             . unwrap ( ) ; 
773785        assert_eq ! ( * right_literal. value( ) ,  ScalarValue :: Int32 ( Some ( 123 ) ) ) ; 
774786
775787        Ok ( ( ) ) 
@@ -787,23 +799,28 @@ mod tests {
787799        let  literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ; 
788800        let  column_expr = Arc :: new ( Column :: new ( "a" ,  0 ) ) ; 
789801        let  cast_expr = Arc :: new ( CastExpr :: new ( column_expr,  DataType :: Int64 ,  None ) ) ; 
790-         let  binary_expr = Arc :: new ( BinaryExpr :: new ( 
791-             literal_expr, 
792-             Operator :: Eq , 
793-             cast_expr, 
794-         ) ) ; 
802+         let  binary_expr =
803+             Arc :: new ( BinaryExpr :: new ( literal_expr,  Operator :: Eq ,  cast_expr) ) ; 
795804
796805        let  result = rewriter. rewrite ( binary_expr) ?; 
797806
798807        // The result should be a binary expression with the cast unwrapped 
799808        let  result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ; 
800-          
809+ 
801810        // Left side should be a literal with the value cast to Int32 
802-         let  left_literal = result_binary. left ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ; 
811+         let  left_literal = result_binary
812+             . left ( ) 
813+             . as_any ( ) 
814+             . downcast_ref :: < Literal > ( ) 
815+             . unwrap ( ) ; 
803816        assert_eq ! ( * left_literal. value( ) ,  ScalarValue :: Int32 ( Some ( 123 ) ) ) ; 
804-          
817+ 
805818        // Right side should be the original column (no cast) 
806-         assert ! ( result_binary. right( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ; 
819+         assert ! ( result_binary
820+             . right( ) 
821+             . as_any( ) 
822+             . downcast_ref:: <Column >( ) 
823+             . is_some( ) ) ; 
807824
808825        Ok ( ( ) ) 
809826    } 
@@ -820,22 +837,27 @@ mod tests {
820837        let  column_expr = Arc :: new ( Column :: new ( "a" ,  0 ) ) ; 
821838        let  cast_expr = Arc :: new ( CastExpr :: new ( column_expr,  DataType :: Utf8 ,  None ) ) ; 
822839        let  literal_expr = expressions:: lit ( ScalarValue :: Utf8 ( Some ( "123" . to_string ( ) ) ) ) ; 
823-         let  binary_expr = Arc :: new ( BinaryExpr :: new ( 
824-             cast_expr, 
825-             Operator :: Eq , 
826-             literal_expr, 
827-         ) ) ; 
840+         let  binary_expr =
841+             Arc :: new ( BinaryExpr :: new ( cast_expr,  Operator :: Eq ,  literal_expr) ) ; 
828842
829843        let  result = rewriter. rewrite ( binary_expr) ?; 
830844
831845        // The result should be a binary expression with the cast unwrapped 
832846        let  result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ; 
833-          
847+ 
834848        // Left side should be the original column (no cast) 
835-         assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ; 
836-         
849+         assert ! ( result_binary
850+             . left( ) 
851+             . as_any( ) 
852+             . downcast_ref:: <Column >( ) 
853+             . is_some( ) ) ; 
854+ 
837855        // Right side should be a literal with the value cast to Int32 
838-         let  right_literal = result_binary. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ; 
856+         let  right_literal = result_binary
857+             . right ( ) 
858+             . as_any ( ) 
859+             . downcast_ref :: < Literal > ( ) 
860+             . unwrap ( ) ; 
839861        assert_eq ! ( * right_literal. value( ) ,  ScalarValue :: Int32 ( Some ( 123 ) ) ) ; 
840862
841863        Ok ( ( ) ) 
@@ -844,7 +866,8 @@ mod tests {
844866    #[ test]  
845867    fn  test_no_unwrap_cast_optimization_when_not_applicable ( )  -> Result < ( ) >  { 
846868        // Test case where optimization should not apply - unsupported cast 
847-         let  physical_schema = Schema :: new ( vec ! [ Field :: new( "a" ,  DataType :: Float32 ,  false ) ] ) ; 
869+         let  physical_schema =
870+             Schema :: new ( vec ! [ Field :: new( "a" ,  DataType :: Float32 ,  false ) ] ) ; 
848871        let  logical_schema = Schema :: new ( vec ! [ Field :: new( "a" ,  DataType :: Int64 ,  false ) ] ) ; 
849872
850873        let  rewriter = PhysicalExprSchemaRewriter :: new ( & physical_schema,  & logical_schema) ; 
@@ -854,18 +877,19 @@ mod tests {
854877        let  column_expr = Arc :: new ( Column :: new ( "a" ,  0 ) ) ; 
855878        let  cast_expr = Arc :: new ( CastExpr :: new ( column_expr,  DataType :: Int64 ,  None ) ) ; 
856879        let  literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ; 
857-         let  binary_expr = Arc :: new ( BinaryExpr :: new ( 
858-             cast_expr, 
859-             Operator :: Eq , 
860-             literal_expr, 
861-         ) ) ; 
880+         let  binary_expr =
881+             Arc :: new ( BinaryExpr :: new ( cast_expr,  Operator :: Eq ,  literal_expr) ) ; 
862882
863883        let  result = rewriter. rewrite ( binary_expr) ?; 
864884
865885        // The result should still be a binary expression with a cast on the left side 
866886        // since Float32 is not in our supported types for unwrap cast optimization 
867887        let  result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ; 
868-         assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <CastExpr >( ) . is_some( ) ) ; 
888+         assert ! ( result_binary
889+             . left( ) 
890+             . as_any( ) 
891+             . downcast_ref:: <CastExpr >( ) 
892+             . is_some( ) ) ; 
869893
870894        Ok ( ( ) ) 
871895    } 
0 commit comments