From 2b6ccd334f3cf2b48444cf238ccff7d241989874 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Tue, 28 May 2024 22:44:22 -0400 Subject: [PATCH 1/9] Without migrating tests --- datafusion/expr/src/aggregate_function.rs | 7 - .../expr/src/type_coercion/aggregates.rs | 2 +- datafusion/functions-aggregate/src/lib.rs | 2 + .../functions-aggregate/src/variance.rs | 254 ++++++++++++++++++ .../physical-expr/src/aggregate/build_in.rs | 85 +----- .../physical-expr/src/aggregate/variance.rs | 91 +++---- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../proto/src/physical_plan/to_proto.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + 13 files changed, 317 insertions(+), 142 deletions(-) create mode 100644 datafusion/functions-aggregate/src/variance.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index fb5a8db550e3..8f683cabe6d6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -49,8 +49,6 @@ pub enum AggregateFunction { ArrayAgg, /// N'th value in a group according to some ordering NthValue, - /// Variance (Sample) - Variance, /// Variance (Population) VariancePop, /// Standard Deviation (Sample) @@ -111,7 +109,6 @@ impl AggregateFunction { ApproxDistinct => "APPROX_DISTINCT", ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", - Variance => "VAR", VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", @@ -169,9 +166,7 @@ impl FromStr for AggregateFunction { "stddev" => AggregateFunction::Stddev, "stddev_pop" => AggregateFunction::StddevPop, "stddev_samp" => AggregateFunction::Stddev, - "var" => AggregateFunction::Variance, "var_pop" => AggregateFunction::VariancePop, - "var_samp" => AggregateFunction::Variance, "regr_slope" => AggregateFunction::RegrSlope, "regr_intercept" => AggregateFunction::RegrIntercept, "regr_count" => AggregateFunction::RegrCount, @@ -235,7 +230,6 @@ impl AggregateFunction { AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Ok(DataType::Boolean) } - AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), AggregateFunction::VariancePop => { variance_return_type(&coerced_data_types[0]) } @@ -315,7 +309,6 @@ impl AggregateFunction { } AggregateFunction::Avg | AggregateFunction::Sum - | AggregateFunction::Variance | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 6bd204c53c61..b7004e200d70 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -173,7 +173,7 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::Variance | AggregateFunction::VariancePop => { + AggregateFunction::VariancePop => { if !is_variance_support_arg_type(&input_types[0]) { return plan_err!( "The function {:?} does not support inputs of type {:?}.", diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index ac40a90aaec6..f9d96646ba7b 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,6 +58,7 @@ pub mod macros; pub mod covariance; pub mod first_last; pub mod median; +pub mod variance; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -80,6 +81,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { covariance::covar_samp_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), + variance::var_samp_udaf(), ]; functions.into_iter().try_for_each(|udf| { diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs new file mode 100644 index 000000000000..56d7eb2d8d29 --- /dev/null +++ b/datafusion/functions-aggregate/src/variance.rs @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CovarianceSample`]: covariance sample aggregations. + +use std::fmt::Debug; + +use arrow::{ + array::{ArrayRef, Float64Array, UInt64Array}, + compute::kernels::cast, + datatypes::{DataType, Field}, +}; + +use datafusion_common::{downcast_value, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::stats::StatsType; + +make_udaf_expr_and_func!( + VarianceSample, + var_sample, + y x, + "Computes the sample variance.", + var_samp_udaf +); + +pub struct VarianceSample { + signature: Signature, + aliases: Vec, +} + +impl Debug for VarianceSample { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("VarianceSample") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for VarianceSample { + fn default() -> Self { + Self::new() + } +} + +impl VarianceSample { + pub fn new() -> Self { + Self { + aliases: vec![String::from("var")], + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for VarianceSample { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "var" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Variance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// An accumulator to compute variance +/// The algrithm used is an online implementation and numerically stable. It is based on this paper: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. + +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: f64, + mean: f64, + count: u64, + stats_type: StatsType, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + m2: 0_f64, + mean: 0_f64, + count: 0_u64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> f64 { + self.mean + } + + pub fn get_m2(&self) -> f64 { + self.m2 + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count + 1; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + + self.count += 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count - 1; + let delta1 = self.mean - value; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = new_mean - value; + let new_m2 = self.m2 - delta1 * delta2; + + self.count -= 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let means = downcast_value!(states[1], Float64Array); + let m2s = downcast_value!(states[2], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + let new_count = self.count + c; + let new_mean = self.mean * self.count as f64 / new_count as f64 + + means.value(i) * c as f64 / new_count as f64; + let delta = self.mean - means.value(i); + let new_m2 = self.m2 + + m2s.value(i) + + delta * delta * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + Ok(ScalarValue::Float64(match self.count { + 0 => None, + 1 => { + if let StatsType::Population = self.stats_type { + Some(0.0) + } else { + None + } + } + _ => Some(self.m2 / count as f64), + })) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index e10008995463..daf28696ebcb 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -166,14 +166,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Avg, true) => { return not_impl_err!("AVG(DISTINCT) aggregations are not available"); } - (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Variance, true) => { - return not_impl_err!("VAR(DISTINCT) aggregations are not available"); - } (AggregateFunction::VariancePop, false) => Arc::new( expressions::VariancePop::new(input_phy_exprs[0].clone(), name, data_type), ), @@ -373,12 +365,13 @@ pub fn create_aggregate_expr( #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; + use expressions::{StddevPop, VariancePop}; use super::*; use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, - Max, Min, Stddev, Sum, Variance, + Max, Min, Stddev, Sum, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -749,43 +742,7 @@ mod tests { Ok(()) } - #[test] - fn test_variance_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Variance]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } + // TODO (yyin): Add back test #[test] fn test_var_pop_expr() -> Result<()> { @@ -812,8 +769,8 @@ mod tests { &input_schema, "c1", )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); + if fun == AggregateFunction::VariancePop { + assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( Field::new("c1", DataType::Float64, true), @@ -850,7 +807,7 @@ mod tests { &input_schema, "c1", )?; - if fun == AggregateFunction::Variance { + if fun == AggregateFunction::Stddev { assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( @@ -888,8 +845,8 @@ mod tests { &input_schema, "c1", )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); + if fun == AggregateFunction::StddevPop { + assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( Field::new("c1", DataType::Float64, true), @@ -1055,31 +1012,7 @@ mod tests { assert!(observed.is_err()); } - #[test] - fn test_variance_return_type() -> Result<()> { - let observed = AggregateFunction::Variance.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Int64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_variance_no_utf8() { - let observed = AggregateFunction::Variance.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } + // TODO (yyin): Add back tests to sqllogictest #[test] fn test_stddev_return_type() -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index 989041097730..c480db4946ec 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -332,6 +332,7 @@ impl Accumulator for VarianceAccumulator { } } +// TODO (yyin): Move to aggregations.slt #[cfg(test)] mod tests { use super::*; @@ -359,23 +360,23 @@ mod tests { generic_test_op!(a, DataType::Float64, VariancePop, ScalarValue::from(2_f64)) } - #[test] - fn variance_f64_3() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64)) - } - - #[test] - fn variance_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - Variance, - ScalarValue::from(0.9033333333333333_f64) - ) - } + // #[test] + // fn variance_f64_3() -> Result<()> { + // let a: ArrayRef = + // Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + // generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64)) + // } + + // #[test] + // fn variance_f64_4() -> Result<()> { + // let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); + // generic_test_op!( + // a, + // DataType::Float64, + // Variance, + // ScalarValue::from(0.9033333333333333_f64) + // ) + // } #[test] fn variance_i32() -> Result<()> { @@ -397,22 +398,22 @@ mod tests { generic_test_op!(a, DataType::Float32, VariancePop, ScalarValue::from(2_f64)) } - #[test] - fn test_variance_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + // #[test] + // fn test_variance_1_input() -> Result<()> { + // let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); + // let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + // let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); + // let agg = Arc::new(Variance::new( + // col("a", &schema)?, + // "bla".to_string(), + // DataType::Float64, + // )); + // let actual = aggregate(&batch, agg).unwrap(); + // assert_eq!(actual, ScalarValue::Float64(None)); - Ok(()) - } + // Ok(()) + // } #[test] fn variance_i32_with_nulls() -> Result<()> { @@ -431,22 +432,22 @@ mod tests { ) } - #[test] - fn variance_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + // #[test] + // fn variance_i32_all_nulls() -> Result<()> { + // let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); + // let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + // let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); + // let agg = Arc::new(Variance::new( + // col("a", &schema)?, + // "bla".to_string(), + // DataType::Float64, + // )); + // let actual = aggregate(&batch, agg).unwrap(); + // assert_eq!(actual, ScalarValue::Float64(None)); - Ok(()) - } + // Ok(()) + // } #[test] fn variance_f64_merge_1() -> Result<()> { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index cb0ae0f551f2..e20379399ea2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -473,7 +473,7 @@ enum AggregateFunction { COUNT = 4; APPROX_DISTINCT = 5; ARRAY_AGG = 6; - VARIANCE = 7; + // VARIANCE = 7; VARIANCE_POP = 8; // COVARIANCE = 9; // COVARIANCE_POP = 10; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 2edbae24294b..7c9f9f3c2a36 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -428,7 +428,6 @@ impl serde::Serialize for AggregateFunction { Self::Count => "COUNT", Self::ApproxDistinct => "APPROX_DISTINCT", Self::ArrayAgg => "ARRAY_AGG", - Self::Variance => "VARIANCE", Self::VariancePop => "VARIANCE_POP", Self::Stddev => "STDDEV", Self::StddevPop => "STDDEV_POP", @@ -471,7 +470,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "COUNT", "APPROX_DISTINCT", "ARRAY_AGG", - "VARIANCE", "VARIANCE_POP", "STDDEV", "STDDEV_POP", @@ -543,7 +541,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "COUNT" => Ok(AggregateFunction::Count), "APPROX_DISTINCT" => Ok(AggregateFunction::ApproxDistinct), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "VARIANCE" => Ok(AggregateFunction::Variance), "VARIANCE_POP" => Ok(AggregateFunction::VariancePop), "STDDEV" => Ok(AggregateFunction::Stddev), "STDDEV_POP" => Ok(AggregateFunction::StddevPop), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e9407cc65bb1..009df9f172b0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1897,7 +1897,7 @@ pub enum AggregateFunction { Count = 4, ApproxDistinct = 5, ArrayAgg = 6, - Variance = 7, + /// VARIANCE = 7; VariancePop = 8, /// COVARIANCE = 9; /// COVARIANCE_POP = 10; @@ -1940,7 +1940,6 @@ impl AggregateFunction { AggregateFunction::Count => "COUNT", AggregateFunction::ApproxDistinct => "APPROX_DISTINCT", AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Variance => "VARIANCE", AggregateFunction::VariancePop => "VARIANCE_POP", AggregateFunction::Stddev => "STDDEV", AggregateFunction::StddevPop => "STDDEV_POP", @@ -1979,7 +1978,6 @@ impl AggregateFunction { "COUNT" => Some(Self::Count), "APPROX_DISTINCT" => Some(Self::ApproxDistinct), "ARRAY_AGG" => Some(Self::ArrayAgg), - "VARIANCE" => Some(Self::Variance), "VARIANCE_POP" => Some(Self::VariancePop), "STDDEV" => Some(Self::Stddev), "STDDEV_POP" => Some(Self::StddevPop), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 905c6654cfe9..144c2fbac570 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -141,7 +141,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ApproxDistinct => Self::ApproxDistinct, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Variance => Self::Variance, protobuf::AggregateFunction::VariancePop => Self::VariancePop, protobuf::AggregateFunction::Stddev => Self::Stddev, protobuf::AggregateFunction::StddevPop => Self::StddevPop, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b0059aff615b..618d745143bf 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -112,7 +112,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ApproxDistinct => Self::ApproxDistinct, AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::Variance => Self::Variance, AggregateFunction::VariancePop => Self::VariancePop, AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, @@ -410,7 +409,6 @@ pub fn serialize_expr( AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, AggregateFunction::VariancePop => { protobuf::AggregateFunction::VariancePop } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index d0af2f8338e0..ed3a8448083a 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -29,7 +29,7 @@ use datafusion::physical_plan::expressions::{ DistinctCount, DistinctSum, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, - StringAgg, Sum, TryCastExpr, Variance, VariancePop, WindowShift, + StringAgg, Sum, TryCastExpr, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -273,8 +273,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Max } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Avg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Variance } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::VariancePop } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d2721dfafd14..91cf01453aba 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,6 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; +use datafusion::functions_aggregate::variance::var_sample; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -624,6 +625,7 @@ async fn roundtrip_expr_api() -> Result<()> { covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), median(lit(2)), + var_sample(lit(1.5), lit(2.2)), ]; // ensure expressions created with the expr api can be round tripped From e4e9e572b387549151e769ca5887e7f5f583b314 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Wed, 29 May 2024 08:05:55 -0400 Subject: [PATCH 2/9] Should fail VAR(DISTINCT) but doesn't --- .../functions-aggregate/src/variance.rs | 9 ++- .../physical-expr/src/aggregate/build_in.rs | 4 - .../physical-expr/src/aggregate/variance.rs | 75 ------------------- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/sqllogictest/Cargo.toml | 10 ++- 5 files changed, 15 insertions(+), 85 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 56d7eb2d8d29..703b21fc1dc1 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -28,7 +28,6 @@ use arrow::{ use datafusion_common::{downcast_value, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, }; @@ -65,8 +64,8 @@ impl Default for VarianceSample { impl VarianceSample { pub fn new() -> Self { Self { - aliases: vec![String::from("var")], - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + aliases: vec![String::from("var_sample")], + signature: Signature::numeric(1, Volatility::Immutable), } } } @@ -251,4 +250,8 @@ impl Accumulator for VarianceAccumulator { fn size(&self) -> usize { std::mem::size_of_val(self) } + + fn supports_retract_batch(&self) -> bool { + true + } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index daf28696ebcb..03fd49a4d1df 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -742,8 +742,6 @@ mod tests { Ok(()) } - // TODO (yyin): Add back test - #[test] fn test_var_pop_expr() -> Result<()> { let funcs = vec![AggregateFunction::VariancePop]; @@ -1012,8 +1010,6 @@ mod tests { assert!(observed.is_err()); } - // TODO (yyin): Add back tests to sqllogictest - #[test] fn test_stddev_return_type() -> Result<()> { let observed = AggregateFunction::Stddev.return_type(&[DataType::Float32])?; diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index c480db4946ec..ae49d139af55 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -35,13 +35,6 @@ use datafusion_common::downcast_value; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; -/// VAR and VAR_SAMP aggregate expression -#[derive(Debug)] -pub struct Variance { - name: String, - expr: Arc, -} - /// VAR_POP aggregate expression #[derive(Debug)] pub struct VariancePop { @@ -49,74 +42,6 @@ pub struct VariancePop { expr: Arc, } -impl Variance { - /// Create a new VARIANCE aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of variance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for Variance { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Variance { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - impl VariancePop { /// Create a new VAR_POP aggregate function pub fn new( diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 1e9644f75afe..324699af5b5c 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -60,7 +60,7 @@ pub use crate::aggregate::stddev::{Stddev, StddevPop}; pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; -pub use crate::aggregate::variance::{Variance, VariancePop}; +pub use crate::aggregate::variance::VariancePop; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index c652c8041ff1..3b1f0dfd6d89 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -40,7 +40,7 @@ bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } clap = { version = "4.4.8", features = ["derive", "env"] } -datafusion = { workspace = true, default-features = true } +datafusion = { workspace = true, default-features = true, features = ["avro"] } datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } futures = { workspace = true } @@ -60,7 +60,13 @@ tokio-postgres = { version = "0.7.7", optional = true } [features] avro = ["datafusion/avro"] -postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] +postgres = [ + "bytes", + "chrono", + "tokio-postgres", + "postgres-types", + "postgres-protocol", +] [dev-dependencies] env_logger = { workspace = true } From 758ff77bded3a9647e9602b0999a2e986834621f Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Wed, 29 May 2024 08:21:52 -0400 Subject: [PATCH 3/9] Pass all other tests. --- datafusion/functions-aggregate/src/variance.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 703b21fc1dc1..756aa2124a70 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -64,7 +64,7 @@ impl Default for VarianceSample { impl VarianceSample { pub fn new() -> Self { Self { - aliases: vec![String::from("var_sample")], + aliases: vec![String::from("var_sample"), String::from("var_samp")], signature: Signature::numeric(1, Volatility::Immutable), } } From 09fecca0b9e317a28e2dc2c6a0be8e01f655be21 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Mon, 3 Jun 2024 22:52:49 -0400 Subject: [PATCH 4/9] Return error for var(distinct) --- datafusion/functions-aggregate/src/variance.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 756aa2124a70..08d7edd0d302 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -25,7 +25,7 @@ use arrow::{ datatypes::{DataType, Field}, }; -use datafusion_common::{downcast_value, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, @@ -100,7 +100,11 @@ impl AggregateUDFImpl for VarianceSample { ]) } - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("VAR(DISTINCT) aggregations are not available") + } + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) } From d9a1562aa39f64d0c17785536b0d94527fe71b61 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Mon, 3 Jun 2024 23:13:32 -0400 Subject: [PATCH 5/9] Migrate tests --- .../physical-expr/src/aggregate/variance.rs | 53 ------------------- .../sqllogictest/test_files/aggregate.slt | 48 +++++++++++++++++ datafusion/sqllogictest/test_files/order.slt | 2 +- .../test_files/sort_merge_join.slt | 1 + datafusion/sqllogictest/test_files/unnest.slt | 4 +- 5 files changed, 52 insertions(+), 56 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index ae49d139af55..7c5e3b4a37eb 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -257,7 +257,6 @@ impl Accumulator for VarianceAccumulator { } } -// TODO (yyin): Move to aggregations.slt #[cfg(test)] mod tests { use super::*; @@ -285,24 +284,6 @@ mod tests { generic_test_op!(a, DataType::Float64, VariancePop, ScalarValue::from(2_f64)) } - // #[test] - // fn variance_f64_3() -> Result<()> { - // let a: ArrayRef = - // Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - // generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64)) - // } - - // #[test] - // fn variance_f64_4() -> Result<()> { - // let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - // generic_test_op!( - // a, - // DataType::Float64, - // Variance, - // ScalarValue::from(0.9033333333333333_f64) - // ) - // } - #[test] fn variance_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); @@ -323,23 +304,6 @@ mod tests { generic_test_op!(a, DataType::Float32, VariancePop, ScalarValue::from(2_f64)) } - // #[test] - // fn test_variance_1_input() -> Result<()> { - // let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - // let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - // let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - // let agg = Arc::new(Variance::new( - // col("a", &schema)?, - // "bla".to_string(), - // DataType::Float64, - // )); - // let actual = aggregate(&batch, agg).unwrap(); - // assert_eq!(actual, ScalarValue::Float64(None)); - - // Ok(()) - // } - #[test] fn variance_i32_with_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![ @@ -357,23 +321,6 @@ mod tests { ) } - // #[test] - // fn variance_i32_all_nulls() -> Result<()> { - // let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - // let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - // let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - // let agg = Arc::new(Variance::new( - // col("a", &schema)?, - // "bla".to_string(), - // DataType::Float64, - // )); - // let actual = aggregate(&batch, agg).unwrap(); - // assert_eq!(actual, ScalarValue::Float64(None)); - - // Ok(()) - // } - #[test] fn variance_f64_merge_1() -> Result<()> { let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 256fddd9f254..a2b9fa8b0e03 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2338,6 +2338,54 @@ select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; statement ok drop table t; +# variance_f64_1 +statement ok +create table t (c double) as values (1), (2), (3), (4), (5); + +query RT +select var(c), arrow_typeof(var(c)) from t; +---- +2.5 Float64 + +statement ok +drop table t; + +# variance_f64_2 +statement ok +create table t (c double) as values (1.1), (2), (3); + +query RT +select var(c), arrow_typeof(var(c)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + +# variance_1_input +statement ok +create table t (a double not null) as values (1); + +query RT +select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# variance_i32_all_nulls +statement ok +create table t (a int) as values (null), (null); + +query RT +select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + # simple_mean query R select mean(c1) from test diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index d7f10537d02a..2678e8cbd1ba 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -1131,4 +1131,4 @@ physical_plan 01)SortPreservingMergeExec: [c@0 ASC NULLS LAST] 02)--ProjectionExec: expr=[CAST(inc_col@0 > desc_col@1 AS Int32) as c] 03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true \ No newline at end of file +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 3a27d9693d00..0cdcfabce4bd 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -344,6 +344,7 @@ t1 as ( select 11 a, 13 b) select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) ) order by 1, 2; +---- query II select * from ( diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index bdd7e6631c16..8866cd009c32 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -65,7 +65,7 @@ select * from unnest(struct(1,2,3)); ---- 1 2 3 -## Multiple unnest expression in from clause +## Multiple unnest expression in from clause query IIII select * from unnest(struct(1,2,3)),unnest([4,5,6]); ---- @@ -446,7 +446,7 @@ query error DataFusion error: type_coercion\ncaused by\nThis feature is not impl select sum(unnest(generate_series(1,10))); ## TODO: support unnest as a child expr -query error DataFusion error: Internal error: unnest on struct can ony be applied at the root level of select expression +query error DataFusion error: Internal error: unnest on struct can ony be applied at the root level of select expression select arrow_typeof(unnest(column5)) from unnest_table; From f3aea7c4cd20a3552eaf716533278380d032b658 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Mon, 3 Jun 2024 23:50:07 -0400 Subject: [PATCH 6/9] Fix tests --- .../sqllogictest/test_files/aggregate.slt | 18 ++++++++++++++++++ .../test_files/sort_merge_join.slt | 1 - 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 91234dd314bf..56ec0342577f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2344,6 +2344,12 @@ create table t (c double) as values (1), (2), (3), (4), (5); query RT select var(c), arrow_typeof(var(c)) from t; +---- +2.5 Float64 + +statement ok +drop table t; + # aggregate stddev f64_1 statement ok create table t (c1 double) as values (1), (2); @@ -2506,6 +2512,12 @@ create table t (c double) as values (1.1), (2), (3); query RT select var(c), arrow_typeof(var(c)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + # aggregate variance f64_4 statement ok create table t (c1 double) as values (1.1), (2), (3); @@ -2536,6 +2548,12 @@ create table t (a int) as values (null), (null); query RT select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + # aggregate variance i32 statement ok create table t (c1 int) as values (1), (2), (3), (4), (5); diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index ce738c7a6f3e..1fd8b0a346da 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -518,4 +518,3 @@ set datafusion.optimizer.prefer_hash_join = true; statement ok set datafusion.execution.batch_size = 8192; - From d5d43ef14b689536e40a635e193716a45ebdb462 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Tue, 4 Jun 2024 07:48:16 -0400 Subject: [PATCH 7/9] Lint --- datafusion/functions-aggregate/src/lib.rs | 2 +- datafusion/functions-aggregate/src/variance.rs | 8 +++++--- datafusion/physical-expr/src/aggregate/build_in.rs | 2 +- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b9cc7cb5dbf0..43a3d12813fa 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -58,8 +58,8 @@ pub mod macros; pub mod covariance; pub mod first_last; pub mod median; -pub mod variance; pub mod sum; +pub mod variance; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 08d7edd0d302..51af8a49e3b9 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`CovarianceSample`]: covariance sample aggregations. +//! [`VarianceSample`]: covariance sample aggregations. use std::fmt::Debug; @@ -25,7 +25,9 @@ use arrow::{ datatypes::{DataType, Field}, }; -use datafusion_common::{downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, @@ -102,7 +104,7 @@ impl AggregateUDFImpl for VarianceSample { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR(DISTINCT) aggregations are not available") + return not_impl_err!("VAR(DISTINCT) aggregations are not available"); } Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index abeaf16d272d..07409dd1f4dc 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -365,7 +365,7 @@ mod tests { use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, - Max, Min, Stddev, Sum, Variance + Max, Min, Stddev, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 620d4d129a00..c530365a2b48 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,9 +32,9 @@ use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; -use datafusion::functions_aggregate::variance::var_sample; -use datafusion::functions_aggregate::expr_fn::{covar_pop, covar_samp, first_value}; +use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::functions_aggregate::median::median; +use datafusion::functions_aggregate::variance::var_sample; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; From 65da63cea812587829d730c2c0929d8cd3b1c5ec Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Tue, 4 Jun 2024 08:09:17 -0400 Subject: [PATCH 8/9] Fix tests --- datafusion/functions-aggregate/src/variance.rs | 2 +- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 51af8a49e3b9..b5d467d0e780 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -38,7 +38,7 @@ use datafusion_physical_expr_common::aggregate::stats::StatsType; make_udaf_expr_and_func!( VarianceSample, var_sample, - y x, + expression, "Computes the sample variance.", var_samp_udaf ); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index c530365a2b48..f279ddb11bcb 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -653,7 +653,7 @@ async fn roundtrip_expr_api() -> Result<()> { covar_pop(lit(1.5), lit(2.2)), sum(lit(1)), median(lit(2)), - var_sample(lit(1.5), lit(2.2)), + var_sample(lit(2.2)), ]; // ensure expressions created with the expr api can be round tripped From 590c04304d0445699c7d840d2894e57a9158fd27 Mon Sep 17 00:00:00 2001 From: Yue Yin Date: Tue, 4 Jun 2024 21:11:02 -0400 Subject: [PATCH 9/9] Fix use --- datafusion/functions-aggregate/src/lib.rs | 1 + datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 43a3d12813fa..ff02d25ad00b 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -75,6 +75,7 @@ pub mod expr_fn { pub use super::first_last::last_value; pub use super::median::median; pub use super::sum::sum; + pub use super::variance::var_sample; } /// Returns all default aggregate functions diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f279ddb11bcb..deae97fecc96 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,10 +31,9 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; -use datafusion::functions_aggregate::expr_fn::first_value; -use datafusion::functions_aggregate::median::median; -use datafusion::functions_aggregate::variance::var_sample; +use datafusion::functions_aggregate::expr_fn::{ + covar_pop, covar_samp, first_value, median, var_sample, +}; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions};