1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: sync:: Arc ;
19-
2018use arrow:: array:: Array ;
21- use arrow:: array:: * ;
2219use arrow:: compute:: eq_dyn;
2320use arrow:: compute:: nullif:: nullif;
24- use arrow:: datatypes:: DataType ;
25- use datafusion_common:: { cast:: as_boolean_array, DataFusionError , Result } ;
21+ use datafusion_common:: { cast:: as_boolean_array, DataFusionError , Result , ScalarValue } ;
2622use datafusion_expr:: ColumnarValue ;
2723
2824use super :: binary:: array_eq_scalar;
2925
30- /// Invoke a compute kernel on a primitive array and a Boolean Array
31- macro_rules! compute_bool_array_op {
32- ( $LEFT: expr, $RIGHT: expr, $OP: ident, $DT: ident) => { {
33- let ll = $LEFT
34- . as_any( )
35- . downcast_ref:: <$DT>( )
36- . expect( "compute_op failed to downcast array" ) ;
37- let rr = as_boolean_array( $RIGHT) . expect( "compute_op failed to downcast array" ) ;
38- Ok ( Arc :: new( $OP( & ll, & rr) ?) as ArrayRef )
39- } } ;
40- }
41-
42- /// Binary op between primitive and boolean arrays
43- macro_rules! primitive_bool_array_op {
44- ( $LEFT: expr, $RIGHT: expr, $OP: ident) => { {
45- match $LEFT. data_type( ) {
46- DataType :: Int8 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, Int8Array ) ,
47- DataType :: Int16 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, Int16Array ) ,
48- DataType :: Int32 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, Int32Array ) ,
49- DataType :: Int64 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, Int64Array ) ,
50- DataType :: UInt8 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, UInt8Array ) ,
51- DataType :: UInt16 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, UInt16Array ) ,
52- DataType :: UInt32 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, UInt32Array ) ,
53- DataType :: UInt64 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, UInt64Array ) ,
54- DataType :: Float32 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, Float32Array ) ,
55- DataType :: Float64 => compute_bool_array_op!( $LEFT, $RIGHT, $OP, Float64Array ) ,
56- other => Err ( DataFusionError :: Internal ( format!(
57- "Unsupported data type {:?} for NULLIF/primitive/boolean operator" ,
58- other
59- ) ) ) ,
60- }
61- } } ;
62- }
63-
6426/// Implements NULLIF(expr1, expr2)
6527/// Args: 0 - left expr is any array
6628/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
@@ -79,7 +41,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
7941 ( ColumnarValue :: Array ( lhs) , ColumnarValue :: Scalar ( rhs) ) => {
8042 let cond_array = array_eq_scalar ( lhs, rhs) ?;
8143
82- let array = primitive_bool_array_op ! ( lhs, & cond_array, nullif ) ?;
44+ let array = nullif ( lhs, as_boolean_array ( & cond_array) ? ) ?;
8345
8446 Ok ( ColumnarValue :: Array ( array) )
8547 }
@@ -88,17 +50,34 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
8850 let cond_array = eq_dyn ( lhs, rhs) ?;
8951
9052 // Now, invoke nullif on the result
91- let array = primitive_bool_array_op ! ( lhs, & cond_array, nullif) ?;
53+ let array = nullif ( lhs, as_boolean_array ( & cond_array) ?) ?;
54+ Ok ( ColumnarValue :: Array ( array) )
55+ }
56+ ( ColumnarValue :: Scalar ( lhs) , ColumnarValue :: Array ( rhs) ) => {
57+ // Similar to Array-Array case, except of ScalarValue -> Array cast
58+ let lhs = lhs. to_array_of_size ( rhs. len ( ) ) ;
59+ let cond_array = eq_dyn ( & lhs, rhs) ?;
60+
61+ let array = nullif ( & lhs, as_boolean_array ( & cond_array) ?) ?;
9262 Ok ( ColumnarValue :: Array ( array) )
9363 }
94- _ => Err ( DataFusionError :: NotImplemented (
95- "nullif does not support a literal as first argument" . to_string ( ) ,
96- ) ) ,
64+ ( ColumnarValue :: Scalar ( lhs) , ColumnarValue :: Scalar ( rhs) ) => {
65+ let val: ScalarValue = match lhs. eq ( rhs) {
66+ true => lhs. get_datatype ( ) . try_into ( ) ?,
67+ false => lhs. clone ( ) ,
68+ } ;
69+
70+ Ok ( ColumnarValue :: Scalar ( val) )
71+ }
9772 }
9873}
9974
10075#[ cfg( test) ]
10176mod tests {
77+ use std:: sync:: Arc ;
78+
79+ use arrow:: array:: * ;
80+
10281 use super :: * ;
10382 use datafusion_common:: { Result , ScalarValue } ;
10483
@@ -162,4 +141,88 @@ mod tests {
162141 assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
163142 Ok ( ( ) )
164143 }
144+
145+ #[ test]
146+ fn nullif_boolean ( ) -> Result < ( ) > {
147+ let a = BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) , None ] ) ;
148+ let a = ColumnarValue :: Array ( Arc :: new ( a) ) ;
149+
150+ let lit_array = ColumnarValue :: Scalar ( ScalarValue :: Boolean ( Some ( false ) ) ) ;
151+
152+ let result = nullif_func ( & [ a, lit_array] ) ?;
153+ let result = result. into_array ( 0 ) ;
154+
155+ let expected =
156+ Arc :: new ( BooleanArray :: from ( vec ! [ Some ( true ) , None , None ] ) ) as ArrayRef ;
157+
158+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
159+ Ok ( ( ) )
160+ }
161+
162+ #[ test]
163+ fn nullif_string ( ) -> Result < ( ) > {
164+ let a = StringArray :: from ( vec ! [ Some ( "foo" ) , Some ( "bar" ) , None , Some ( "baz" ) ] ) ;
165+ let a = ColumnarValue :: Array ( Arc :: new ( a) ) ;
166+
167+ let lit_array = ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( "bar" . to_string ( ) ) ) ) ;
168+
169+ let result = nullif_func ( & [ a, lit_array] ) ?;
170+ let result = result. into_array ( 0 ) ;
171+
172+ let expected = Arc :: new ( StringArray :: from ( vec ! [
173+ Some ( "foo" ) ,
174+ None ,
175+ None ,
176+ Some ( "baz" ) ,
177+ ] ) ) as ArrayRef ;
178+
179+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
180+ Ok ( ( ) )
181+ }
182+
183+ #[ test]
184+ fn nullif_literal_first ( ) -> Result < ( ) > {
185+ let a = Int32Array :: from ( vec ! [ Some ( 1 ) , Some ( 2 ) , None , None , Some ( 3 ) , Some ( 4 ) ] ) ;
186+ let a = ColumnarValue :: Array ( Arc :: new ( a) ) ;
187+
188+ let lit_array = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 2i32 ) ) ) ;
189+
190+ let result = nullif_func ( & [ lit_array, a] ) ?;
191+ let result = result. into_array ( 0 ) ;
192+
193+ let expected = Arc :: new ( Int32Array :: from ( vec ! [
194+ Some ( 2 ) ,
195+ None ,
196+ Some ( 2 ) ,
197+ Some ( 2 ) ,
198+ Some ( 2 ) ,
199+ Some ( 2 ) ,
200+ ] ) ) as ArrayRef ;
201+ assert_eq ! ( expected. as_ref( ) , result. as_ref( ) ) ;
202+ Ok ( ( ) )
203+ }
204+
205+ #[ test]
206+ fn nullif_scalar ( ) -> Result < ( ) > {
207+ let a_eq = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 2i32 ) ) ) ;
208+ let b_eq = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 2i32 ) ) ) ;
209+
210+ let result_eq = nullif_func ( & [ a_eq, b_eq] ) ?;
211+ let result_eq = result_eq. into_array ( 1 ) ;
212+
213+ let expected_eq = Arc :: new ( Int32Array :: from ( vec ! [ None ] ) ) as ArrayRef ;
214+
215+ assert_eq ! ( expected_eq. as_ref( ) , result_eq. as_ref( ) ) ;
216+
217+ let a_neq = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 2i32 ) ) ) ;
218+ let b_neq = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 1i32 ) ) ) ;
219+
220+ let result_neq = nullif_func ( & [ a_neq, b_neq] ) ?;
221+ let result_neq = result_neq. into_array ( 1 ) ;
222+
223+ let expected_neq = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 2i32 ) ] ) ) as ArrayRef ;
224+ assert_eq ! ( expected_neq. as_ref( ) , result_neq. as_ref( ) ) ;
225+
226+ Ok ( ( ) )
227+ }
165228}
0 commit comments