Skip to content

Commit a8f1f8a

Browse files
authored
Extended datatypes & signatures support for NULLIF function (#4737)
* extended nullif datatypes & signatures support * sqllogictests & type inheritance
1 parent 8ec511e commit a8f1f8a

File tree

3 files changed

+207
-44
lines changed

3 files changed

+207
-44
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
statement ok
19+
CREATE TABLE test(
20+
int_field INT,
21+
bool_field BOOLEAN,
22+
text_field TEXT,
23+
more_ints INT
24+
) as VALUES
25+
(1, true, 'abc', 2),
26+
(2, false, 'def', 2),
27+
(3, NULL, 'ghij', 3),
28+
(NULL, NULL, NULL, 4),
29+
(4, false, 'zxc', 5),
30+
(NULL, true, NULL, 6)
31+
;
32+
33+
# Arrays tests
34+
query T
35+
SELECT NULLIF(int_field, 2) FROM test;
36+
----
37+
1
38+
NULL
39+
3
40+
NULL
41+
4
42+
NULL
43+
44+
query T
45+
SELECT NULLIF(bool_field, false) FROM test;
46+
----
47+
true
48+
NULL
49+
NULL
50+
NULL
51+
NULL
52+
true
53+
54+
query T
55+
SELECT NULLIF(text_field, 'zxc') FROM test;
56+
----
57+
abc
58+
def
59+
ghij
60+
NULL
61+
NULL
62+
NULL
63+
64+
query T
65+
SELECT NULLIF(int_field, more_ints) FROM test;
66+
----
67+
1
68+
NULL
69+
NULL
70+
NULL
71+
4
72+
NULL
73+
74+
query T
75+
SELECT NULLIF(3, int_field) FROM test;
76+
----
77+
3
78+
3
79+
NULL
80+
3
81+
3
82+
3
83+
84+
# Scalar values tests
85+
query T
86+
SELECT NULLIF(1, 1);
87+
----
88+
NULL
89+
90+
query T
91+
SELECT NULLIF(1, 3);
92+
----
93+
1
94+
95+
query T
96+
SELECT NULLIF(NULL, NULL);
97+
----
98+
NULL

datafusion/expr/src/nullif.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@ pub static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
3232
DataType::Int64,
3333
DataType::Float32,
3434
DataType::Float64,
35+
DataType::Utf8,
36+
DataType::LargeUtf8,
3537
];

datafusion/physical-expr/src/expressions/nullif.rs

Lines changed: 107 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,52 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
2018
use arrow::array::Array;
21-
use arrow::array::*;
2219
use arrow::compute::eq_dyn;
2320
use arrow::compute::nullif::nullif;
24-
use arrow::datatypes::DataType;
25-
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result};
21+
use datafusion_common::{cast::as_boolean_array, DataFusionError, Result, ScalarValue};
2622
use datafusion_expr::ColumnarValue;
2723

2824
use super::binary::array_eq_scalar;
2925

30-
/// Invoke a compute kernel on a primitive array and a Boolean Array
31-
macro_rules! compute_bool_array_op {
32-
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
33-
let ll = $LEFT
34-
.as_any()
35-
.downcast_ref::<$DT>()
36-
.expect("compute_op failed to downcast array");
37-
let rr = as_boolean_array($RIGHT).expect("compute_op failed to downcast array");
38-
Ok(Arc::new($OP(&ll, &rr)?) as ArrayRef)
39-
}};
40-
}
41-
42-
/// Binary op between primitive and boolean arrays
43-
macro_rules! primitive_bool_array_op {
44-
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
45-
match $LEFT.data_type() {
46-
DataType::Int8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int8Array),
47-
DataType::Int16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int16Array),
48-
DataType::Int32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int32Array),
49-
DataType::Int64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Int64Array),
50-
DataType::UInt8 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt8Array),
51-
DataType::UInt16 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt16Array),
52-
DataType::UInt32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt32Array),
53-
DataType::UInt64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, UInt64Array),
54-
DataType::Float32 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float32Array),
55-
DataType::Float64 => compute_bool_array_op!($LEFT, $RIGHT, $OP, Float64Array),
56-
other => Err(DataFusionError::Internal(format!(
57-
"Unsupported data type {:?} for NULLIF/primitive/boolean operator",
58-
other
59-
))),
60-
}
61-
}};
62-
}
63-
6426
/// Implements NULLIF(expr1, expr2)
6527
/// Args: 0 - left expr is any array
6628
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
@@ -79,7 +41,7 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
7941
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
8042
let cond_array = array_eq_scalar(lhs, rhs)?;
8143

82-
let array = primitive_bool_array_op!(lhs, &cond_array, nullif)?;
44+
let array = nullif(lhs, as_boolean_array(&cond_array)?)?;
8345

8446
Ok(ColumnarValue::Array(array))
8547
}
@@ -88,17 +50,34 @@ pub fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
8850
let cond_array = eq_dyn(lhs, rhs)?;
8951

9052
// Now, invoke nullif on the result
91-
let array = primitive_bool_array_op!(lhs, &cond_array, nullif)?;
53+
let array = nullif(lhs, as_boolean_array(&cond_array)?)?;
54+
Ok(ColumnarValue::Array(array))
55+
}
56+
(ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
57+
// Similar to Array-Array case, except of ScalarValue -> Array cast
58+
let lhs = lhs.to_array_of_size(rhs.len());
59+
let cond_array = eq_dyn(&lhs, rhs)?;
60+
61+
let array = nullif(&lhs, as_boolean_array(&cond_array)?)?;
9262
Ok(ColumnarValue::Array(array))
9363
}
94-
_ => Err(DataFusionError::NotImplemented(
95-
"nullif does not support a literal as first argument".to_string(),
96-
)),
64+
(ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => {
65+
let val: ScalarValue = match lhs.eq(rhs) {
66+
true => lhs.get_datatype().try_into()?,
67+
false => lhs.clone(),
68+
};
69+
70+
Ok(ColumnarValue::Scalar(val))
71+
}
9772
}
9873
}
9974

10075
#[cfg(test)]
10176
mod tests {
77+
use std::sync::Arc;
78+
79+
use arrow::array::*;
80+
10281
use super::*;
10382
use datafusion_common::{Result, ScalarValue};
10483

@@ -162,4 +141,88 @@ mod tests {
162141
assert_eq!(expected.as_ref(), result.as_ref());
163142
Ok(())
164143
}
144+
145+
#[test]
146+
fn nullif_boolean() -> Result<()> {
147+
let a = BooleanArray::from(vec![Some(true), Some(false), None]);
148+
let a = ColumnarValue::Array(Arc::new(a));
149+
150+
let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
151+
152+
let result = nullif_func(&[a, lit_array])?;
153+
let result = result.into_array(0);
154+
155+
let expected =
156+
Arc::new(BooleanArray::from(vec![Some(true), None, None])) as ArrayRef;
157+
158+
assert_eq!(expected.as_ref(), result.as_ref());
159+
Ok(())
160+
}
161+
162+
#[test]
163+
fn nullif_string() -> Result<()> {
164+
let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]);
165+
let a = ColumnarValue::Array(Arc::new(a));
166+
167+
let lit_array = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bar".to_string())));
168+
169+
let result = nullif_func(&[a, lit_array])?;
170+
let result = result.into_array(0);
171+
172+
let expected = Arc::new(StringArray::from(vec![
173+
Some("foo"),
174+
None,
175+
None,
176+
Some("baz"),
177+
])) as ArrayRef;
178+
179+
assert_eq!(expected.as_ref(), result.as_ref());
180+
Ok(())
181+
}
182+
183+
#[test]
184+
fn nullif_literal_first() -> Result<()> {
185+
let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);
186+
let a = ColumnarValue::Array(Arc::new(a));
187+
188+
let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
189+
190+
let result = nullif_func(&[lit_array, a])?;
191+
let result = result.into_array(0);
192+
193+
let expected = Arc::new(Int32Array::from(vec![
194+
Some(2),
195+
None,
196+
Some(2),
197+
Some(2),
198+
Some(2),
199+
Some(2),
200+
])) as ArrayRef;
201+
assert_eq!(expected.as_ref(), result.as_ref());
202+
Ok(())
203+
}
204+
205+
#[test]
206+
fn nullif_scalar() -> Result<()> {
207+
let a_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
208+
let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
209+
210+
let result_eq = nullif_func(&[a_eq, b_eq])?;
211+
let result_eq = result_eq.into_array(1);
212+
213+
let expected_eq = Arc::new(Int32Array::from(vec![None])) as ArrayRef;
214+
215+
assert_eq!(expected_eq.as_ref(), result_eq.as_ref());
216+
217+
let a_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
218+
let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));
219+
220+
let result_neq = nullif_func(&[a_neq, b_neq])?;
221+
let result_neq = result_neq.into_array(1);
222+
223+
let expected_neq = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef;
224+
assert_eq!(expected_neq.as_ref(), result_neq.as_ref());
225+
226+
Ok(())
227+
}
165228
}

0 commit comments

Comments
 (0)