Skip to content

Commit 0fa97ad

Browse files
committed
cleanup
Signed-off-by: jayzhan211 <[email protected]>
1 parent d623d2a commit 0fa97ad

File tree

7 files changed

+41
-59
lines changed

7 files changed

+41
-59
lines changed

datafusion/common/src/scalar.rs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -758,40 +758,45 @@ impl ScalarValue {
758758
Some(v) => {
759759
let array = PrimitiveArray::<T>::new(vec![v].into(), None)
760760
.with_data_type(d.clone());
761-
let res = Self::try_from_array(&array, 0).unwrap();
762-
// let res2 = ScalarValue::List(Arc::new(array));
763-
// thread 'tokio-runtime-worker' panicked at 'assertion failed: `(left == right)`
764-
// left: `Int64(0)`,
765-
// right: `List([PrimitiveArray<Int64>
766-
// [
767-
// 0,
768-
// ]])`', /Users/jayzhan/arrow-datafusion/datafusion/common/src/scalar.rs:763:17
769-
// assert_eq!(res, res2);
770-
res
761+
Self::try_from_array(&array, 0).unwrap()
771762
}
772763
}
773764
}
774765

766+
// ListArray compatible version of new_primitive
775767
pub fn new_primitives<T: ArrowPrimitiveType>(
776768
values: Vec<Option<T::Native>>,
777769
d: &DataType,
778770
) -> Self {
779-
if values.len() == 0 {
771+
if values.is_empty() {
780772
return d.try_into().unwrap();
781773
}
774+
775+
// We need to convert it to ScalarValue::Primitive (Int64(0)) instead of ScalarValue::List (List([PrimitiveArray<Int64> [0]]))
782776
if values.len() == 1 {
783777
return Self::new_primitive::<T>(values[0], d);
784778
}
785-
let mut vals = vec![];
779+
780+
let mut array = Vec::with_capacity(values.len());
781+
let mut nulls = Vec::with_capacity(values.len());
782+
786783
for a in values {
787784
match a {
788-
None => return d.try_into().unwrap(),
789-
Some(v) => vals.push(v),
785+
Some(v) => {
786+
array.push(v);
787+
nulls.push(true);
788+
}
789+
None => {
790+
array.push(T::Native::default());
791+
nulls.push(false);
792+
}
790793
}
791794
}
792-
let array = PrimitiveArray::<T>::new(vals.into(), None).with_data_type(d.clone());
793795

794-
ScalarValue::List(Arc::new(array))
796+
let arr = PrimitiveArray::<T>::new(array.into(), Some(NullBuffer::from(nulls)))
797+
.with_data_type(d.clone());
798+
799+
ScalarValue::List(Arc::new(arr))
795800
}
796801

797802
/// Create a decimal Scalar from value/precision and scale.

datafusion/expr/src/built_in_function.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,7 @@ impl BuiltinScalarFunction {
532532
Ok(data_type)
533533
}
534534
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
535-
BuiltinScalarFunction::ArrayAggregate => {
536-
// TODO: Fix this
537-
Ok(Int64)
538-
}
535+
BuiltinScalarFunction::ArrayAggregate => unimplemented!("ArrayAggregate is based on Aggreation function, so no return value for it."),
539536
BuiltinScalarFunction::ArrayConcat => {
540537
let mut expr_type = Null;
541538
let mut max_dims = 0;

datafusion/physical-expr/src/aggregate/sum.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,10 @@ impl<T: ArrowNumericType> Accumulator for SumAccumulator<T> {
233233
}
234234

235235
fn evaluate(&self) -> Result<ScalarValue> {
236-
if !self.sum.is_empty() {
237-
let arr = ScalarValue::new_primitives::<T>(self.sum.clone(), &self.data_type);
238-
Ok(arr)
239-
} else {
240-
Ok(ScalarValue::new_primitives::<T>(vec![], &self.data_type))
241-
}
236+
Ok(ScalarValue::new_primitives::<T>(
237+
self.sum.clone(),
238+
&self.data_type,
239+
))
242240
}
243241

244242
fn size(&self) -> usize {

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,8 +620,7 @@ enum ScalarFunction {
620620
ArrayEmpty = 115;
621621
ArrayPopBack = 116;
622622
StringToArray = 117;
623-
ArraySum = 118;
624-
ArrayAggregate = 119;
623+
ArrayAggregate = 118;
625624
}
626625

627626
message ScalarFunctionNode {

datafusion/proto/src/generated/pbjson.rs

Lines changed: 0 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 1 addition & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/sql/src/expr/function.rs

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
19-
use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result};
19+
use datafusion_common::{
20+
internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue,
21+
};
2022
use datafusion_expr;
2123
use datafusion_expr::expr::{ScalarFunction, ScalarUDF};
2224
use datafusion_expr::function::suggest_valid_function;
@@ -62,33 +64,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
6264
// Translate array_aggregate to aggregate function with array argument.
6365
if fun == BuiltinScalarFunction::ArrayAggregate {
6466
let fun = match &args[1] {
65-
Expr::Literal(datafusion_common::ScalarValue::Utf8(Some(name))) => {
66-
match name.as_str() {
67-
"sum" => AggregateFunction::Sum,
68-
"count" => AggregateFunction::Count,
69-
"min" => AggregateFunction::Min,
70-
"max" => AggregateFunction::Max,
71-
"avg" => AggregateFunction::Avg,
72-
_ => {
73-
return Err(DataFusionError::NotImplemented(format!(
74-
"Aggregate function {name} is not implemented"
75-
)))
76-
}
67+
Expr::Literal(ScalarValue::Utf8(Some(name))) => match name.as_str() {
68+
"sum" => AggregateFunction::Sum,
69+
_ => {
70+
return not_impl_err!(
71+
"Aggregate function {name} is not implemented"
72+
)
7773
}
78-
}
79-
_ => {
80-
return Err(DataFusionError::Internal(format!(
81-
"Aggregate function name is not a string"
82-
)))
83-
}
74+
},
75+
_ => return internal_err!("Aggregate function name is not a string"),
8476
};
85-
println!("args: {:?}", args);
8677
let args = vec![args[0].to_owned()];
87-
let e = Expr::AggregateFunction(expr::AggregateFunction::new(
78+
return Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
8879
fun, args, false, None, None,
89-
));
90-
println!("e: {:?}", e);
91-
return Ok(e);
80+
)));
9281
}
9382

9483
return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)));

0 commit comments

Comments
 (0)