1717
1818//! Support the coercion rule for aggregate function.
1919
20- use crate :: arrow:: datatypes:: Schema ;
2120use crate :: error:: { DataFusionError , Result } ;
2221use crate :: physical_plan:: aggregates:: AggregateFunction ;
2322use crate :: physical_plan:: expressions:: {
@@ -26,6 +25,10 @@ use crate::physical_plan::expressions::{
2625} ;
2726use crate :: physical_plan:: functions:: { Signature , TypeSignature } ;
2827use crate :: physical_plan:: PhysicalExpr ;
28+ use crate :: {
29+ arrow:: datatypes:: Schema ,
30+ physical_plan:: expressions:: is_approx_quantile_supported_arg_type,
31+ } ;
2932use arrow:: datatypes:: DataType ;
3033use std:: ops:: Deref ;
3134use std:: sync:: Arc ;
@@ -37,24 +40,9 @@ pub(crate) fn coerce_types(
3740 input_types : & [ DataType ] ,
3841 signature : & Signature ,
3942) -> Result < Vec < DataType > > {
40- match signature. type_signature {
41- TypeSignature :: Uniform ( agg_count, _) | TypeSignature :: Any ( agg_count) => {
42- if input_types. len ( ) != agg_count {
43- return Err ( DataFusionError :: Plan ( format ! (
44- "The function {:?} expects {:?} arguments, but {:?} were provided" ,
45- agg_fun,
46- agg_count,
47- input_types. len( )
48- ) ) ) ;
49- }
50- }
51- _ => {
52- return Err ( DataFusionError :: Internal ( format ! (
53- "Aggregate functions do not support this {:?}" ,
54- signature
55- ) ) ) ;
56- }
57- } ;
43+ // Validate input_types matches (at least one of) the func signature.
44+ check_arg_count ( agg_fun, input_types, & signature. type_signature ) ?;
45+
5846 match agg_fun {
5947 AggregateFunction :: Count | AggregateFunction :: ApproxDistinct => {
6048 Ok ( input_types. to_vec ( ) )
@@ -123,7 +111,75 @@ pub(crate) fn coerce_types(
123111 }
124112 Ok ( input_types. to_vec ( ) )
125113 }
114+ AggregateFunction :: ApproxQuantile => {
115+ if !is_approx_quantile_supported_arg_type ( & input_types[ 0 ] ) {
116+ return Err ( DataFusionError :: Plan ( format ! (
117+ "The function {:?} does not support inputs of type {:?}." ,
118+ agg_fun, input_types[ 0 ]
119+ ) ) ) ;
120+ }
121+ if !matches ! ( input_types[ 1 ] , DataType :: Float64 ) {
122+ return Err ( DataFusionError :: Plan ( format ! (
123+ "The quantile argument for {:?} must be Float64, not {:?}." ,
124+ agg_fun, input_types[ 1 ]
125+ ) ) ) ;
126+ }
127+ Ok ( input_types. to_vec ( ) )
128+ }
129+ }
130+ }
131+
132+ /// Validate the length of `input_types` matches the `signature` for `agg_fun`.
133+ ///
134+ /// This method DOES NOT validate the argument types - only that (at least one,
135+ /// in the case of [`TypeSignature::OneOf`]) signature matches the desired
136+ /// number of input types.
137+ fn check_arg_count (
138+ agg_fun : & AggregateFunction ,
139+ input_types : & [ DataType ] ,
140+ signature : & TypeSignature ,
141+ ) -> Result < ( ) > {
142+ match signature {
143+ TypeSignature :: Uniform ( agg_count, _) | TypeSignature :: Any ( agg_count) => {
144+ if input_types. len ( ) != * agg_count {
145+ return Err ( DataFusionError :: Plan ( format ! (
146+ "The function {:?} expects {:?} arguments, but {:?} were provided" ,
147+ agg_fun,
148+ agg_count,
149+ input_types. len( )
150+ ) ) ) ;
151+ }
152+ }
153+ TypeSignature :: Exact ( types) => {
154+ if types. len ( ) != input_types. len ( ) {
155+ return Err ( DataFusionError :: Plan ( format ! (
156+ "The function {:?} expects {:?} arguments, but {:?} were provided" ,
157+ agg_fun,
158+ types. len( ) ,
159+ input_types. len( )
160+ ) ) ) ;
161+ }
162+ }
163+ TypeSignature :: OneOf ( variants) => {
164+ let ok = variants
165+ . iter ( )
166+ . any ( |v| check_arg_count ( agg_fun, input_types, v) . is_ok ( ) ) ;
167+ if !ok {
168+ return Err ( DataFusionError :: Plan ( format ! (
169+ "The function {:?} does not accept {:?} function arguments." ,
170+ agg_fun,
171+ input_types. len( )
172+ ) ) ) ;
173+ }
174+ }
175+ _ => {
176+ return Err ( DataFusionError :: Internal ( format ! (
177+ "Aggregate functions do not support this {:?}" ,
178+ signature
179+ ) ) ) ;
180+ }
126181 }
182+ Ok ( ( ) )
127183}
128184
129185fn get_min_max_result_type ( input_types : & [ DataType ] ) -> Result < Vec < DataType > > {
@@ -239,5 +295,25 @@ mod tests {
239295 assert_eq ! ( * input_type, result. unwrap( ) ) ;
240296 }
241297 }
298+
299+ // ApproxQuantile input types
300+ let input_types = vec ! [
301+ vec![ DataType :: Int8 , DataType :: Float64 ] ,
302+ vec![ DataType :: Int16 , DataType :: Float64 ] ,
303+ vec![ DataType :: Int32 , DataType :: Float64 ] ,
304+ vec![ DataType :: Int64 , DataType :: Float64 ] ,
305+ vec![ DataType :: UInt8 , DataType :: Float64 ] ,
306+ vec![ DataType :: UInt16 , DataType :: Float64 ] ,
307+ vec![ DataType :: UInt32 , DataType :: Float64 ] ,
308+ vec![ DataType :: UInt64 , DataType :: Float64 ] ,
309+ vec![ DataType :: Float32 , DataType :: Float64 ] ,
310+ vec![ DataType :: Float64 , DataType :: Float64 ] ,
311+ ] ;
312+ for input_type in & input_types {
313+ let signature = aggregates:: signature ( & AggregateFunction :: ApproxQuantile ) ;
314+ let result =
315+ coerce_types ( & AggregateFunction :: ApproxQuantile , input_type, & signature) ;
316+ assert_eq ! ( * input_type, result. unwrap( ) ) ;
317+ }
242318 }
243319}
0 commit comments