Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ use strum_macros::EnumIter;
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// Count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

There are still a bunch of these functions -- at some point we can probably file tickets and spread out the work / fun to migrate them over.

Copy link
Contributor Author

@jayzhan211 jayzhan211 Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I constantly file 1~2 issues for #8708 so the tickets could kept on the first page

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps that is how PRs like #10898 show up -- I can't keep up anymore. Thanks @jayzhan211 -- really nice

Count,
/// Minimum
Min,
/// Maximum
Expand Down Expand Up @@ -89,7 +87,6 @@ impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Count => "COUNT",
Min => "MIN",
Max => "MAX",
Avg => "AVG",
Expand Down Expand Up @@ -135,7 +132,6 @@ impl FromStr for AggregateFunction {
"bit_xor" => AggregateFunction::BitXor,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"count" => AggregateFunction::Count,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
Expand Down Expand Up @@ -190,7 +186,6 @@ impl AggregateFunction {
})?;

match self {
AggregateFunction::Count => Ok(DataType::Int64),
AggregateFunction::Max | AggregateFunction::Min => {
// For min and max agg function, the returned type is same as input type.
// The coerced_data_types is same with input_types.
Expand Down Expand Up @@ -249,7 +244,6 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable),
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
Expand Down
13 changes: 0 additions & 13 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2135,18 +2135,6 @@ mod test {

use super::*;

#[test]
fn test_count_return_type() -> Result<()> {
let fun = find_df_window_func("count").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);

let observed = fun.return_type(&[DataType::UInt64])?;
assert_eq!(DataType::Int64, observed);

Ok(())
}

#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
Expand Down Expand Up @@ -2250,7 +2238,6 @@ mod test {
"nth_value",
"min",
"max",
"count",
"avg",
];
for name in names {
Expand Down
26 changes: 0 additions & 26 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,6 @@ pub fn avg(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the count() aggregate function
// TODO: Remove this and use `expr_fn::count` instead
pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![expr],
false,
None,
None,
None,
))
}

/// Return a new expression with bitwise AND
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Expand Down Expand Up @@ -250,19 +237,6 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr {
))
}

/// Create an expression to represent the count(distinct) aggregate function
// TODO: Remove this and use `expr_fn::count_distinct` instead
pub fn count_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Count,
vec![expr],
true,
None,
None,
None,
))
}

/// Create an in_list expression
pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
Expr::InList(InList::new(Box::new(expr), list, negated))
Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2965,11 +2965,13 @@ mod tests {
use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet};

use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, Constraint, ScalarValue};

use crate::test::function_stub::count;

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Expand Down
86 changes: 85 additions & 1 deletion datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
use arrow::datatypes::{
DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use datafusion_common::{exec_err, Result};
use datafusion_common::{exec_err, not_impl_err, Result};

macro_rules! create_func {
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
Expand Down Expand Up @@ -69,6 +69,19 @@ pub fn sum(expr: Expr) -> Expr {
))
}

create_func!(Count, count_udaf);

pub fn count(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
count_udaf(),
vec![expr],
false,
None,
None,
None,
))
}

/// Stub `sum` used for optimizer testing
#[derive(Debug)]
pub struct Sum {
Expand Down Expand Up @@ -189,3 +202,74 @@ impl AggregateUDFImpl for Sum {
AggregateOrderSensitivity::Insensitive
}
}

/// Testing stub implementation of COUNT aggregate
pub struct Count {
signature: Signature,
aliases: Vec<String>,
}

impl std::fmt::Debug for Count {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("Count")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for Count {
fn default() -> Self {
Self::new()
}
}

impl Count {
pub fn new() -> Self {
Self {
aliases: vec!["count".to_string()],
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for Count {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"COUNT"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
not_impl_err!("no impl for stub")
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
not_impl_err!("no impl for stub")
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("no impl for stub")
}

fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::Identical
}
}
2 changes: 0 additions & 2 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ pub fn coerce_types(
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count => Ok(input_types.to_vec()),
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
Expand Down Expand Up @@ -525,7 +524,6 @@ mod tests {
// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::Count,
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ regex-syntax = "0.8.0"
[dev-dependencies]
arrow-buffer = { workspace = true }
ctor = { workspace = true }
datafusion-functions-aggregate = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 -- I thought we were trying to avoid making the optimizer depend on the function library -- can we use the sub implementation instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was fine for dev-dependencies?

If not, when can we import functions for dev-dependencies, when we should not 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 -- I can't remember anymore.

It is probably cleaner to to avoid the dependencies, but we can clean it up as a follow on PR

datafusion-sql = { workspace = true }
env_logger = { workspace = true }
42 changes: 12 additions & 30 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{
aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition,
};
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand Down Expand Up @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool {
}

fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
match aggregate_function {
matches!(aggregate_function,
AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
..
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true,
AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
),
args,
..
} if args.len() == 1 && is_wildcard(&args[0]) => true,
_ => false,
}
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]))
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
let args = &window_function.args;
match window_function.fun {
WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
) if args.len() == 1 && is_wildcard(&args[0]) => true,
matches!(window_function.fun,
WindowFunctionDefinition::AggregateUDF(ref udaf)
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) =>
{
true
}
_ => false,
}
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]))
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Expand Down Expand Up @@ -121,14 +101,16 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_expr::expr::Sort;
use datafusion_expr::test::function_stub::sum;
use datafusion_expr::{
col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame,
WindowFrameBound, WindowFrameUnits,
col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound,
WindowFrameUnits,
};
use datafusion_functions_aggregate::count::count_udaf;
use std::sync::Arc;

use datafusion_functions_aggregate::expr_fn::{count, sum};

fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_analyzed_plan_eq_display_indent(
Arc::new(CountWildcardRule::new()),
Expand Down Expand Up @@ -239,7 +221,7 @@ mod tests {

let plan = LogicalPlanBuilder::from(table_scan)
.window(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
Expand Down
10 changes: 2 additions & 8 deletions datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch(
Expr::AggregateFunction(expr::AggregateFunction {
func_def, ..
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
if matches!(fun, datafusion_expr::AggregateFunction::Count) {
Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(
0,
))))
} else {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
AggregateFunctionDefinition::BuiltIn(_fun) => {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
AggregateFunctionDefinition::UDF(fun) => {
if fun.name() == "COUNT" {
Expand Down
6 changes: 4 additions & 2 deletions datafusion/optimizer/src/eliminate_group_by_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ mod tests {
use datafusion_common::Result;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl,
Signature, TypeSignature,
col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature,
};

use datafusion_functions_aggregate::expr_fn::count;

use std::sync::Arc;

#[derive(Debug)]
Expand Down
Loading