Skip to content

Commit 78a2433

Browse files
committed
Add aggregate udf via ffi
1 parent c476e61 commit 78a2433

File tree

5 files changed

+233
-7
lines changed

5 files changed

+233
-7
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pyarrow as pa
21+
from datafusion import SessionContext, col, udaf
22+
from datafusion_ffi_example import MySumUDF
23+
24+
25+
def setup_context_with_table():
26+
ctx = SessionContext()
27+
28+
# Pick numbers here so we get the same value in both groups
29+
# since we cannot be certain of the output order of batches
30+
batch = pa.RecordBatch.from_arrays(
31+
[
32+
pa.array([1, 2, 3, None], type=pa.int64()),
33+
pa.array([1, 1, 2, 2], type=pa.int64()),
34+
],
35+
names=["a", "b"],
36+
)
37+
ctx.register_record_batches("test_table", [[batch]])
38+
return ctx
39+
40+
41+
def test_ffi_aggregate_register():
42+
ctx = setup_context_with_table()
43+
my_udaf = udaf(MySumUDF())
44+
ctx.register_udaf(my_udaf)
45+
46+
result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect()
47+
48+
assert len(result) == 2
49+
assert result[0].num_columns == 1
50+
51+
result = [r.column(0) for r in result]
52+
expected = [
53+
pa.array([3], type=pa.int64()),
54+
pa.array([3], type=pa.int64()),
55+
]
56+
57+
assert result == expected
58+
59+
60+
def test_ffi_aggregate_call_directly():
61+
ctx = setup_context_with_table()
62+
my_udaf = udaf(MySumUDF())
63+
64+
result = (
65+
ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect()
66+
)
67+
68+
assert len(result) == 2
69+
assert result[0].num_columns == 2
70+
71+
result = [r.column(1) for r in result]
72+
expected = [
73+
pa.array([3], type=pa.int64()),
74+
pa.array([3], type=pa.int64()),
75+
]
76+
77+
assert result == expected
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow_array::{Array, BooleanArray};
19+
use arrow_schema::{DataType, FieldRef};
20+
use datafusion::common::ScalarValue;
21+
use datafusion::error::Result as DataFusionResult;
22+
use datafusion::functions_aggregate::sum::Sum;
23+
use datafusion::logical_expr::function::AccumulatorArgs;
24+
use datafusion::logical_expr::{
25+
Accumulator, AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, ScalarUDF,
26+
ScalarUDFImpl, Signature, TypeSignature, Volatility,
27+
};
28+
use datafusion_ffi::udaf::FFI_AggregateUDF;
29+
use datafusion_ffi::udf::FFI_ScalarUDF;
30+
use pyo3::types::PyCapsule;
31+
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
32+
use std::any::Any;
33+
use std::sync::Arc;
34+
35+
#[pyclass(name = "MySumUDF", module = "datafusion_ffi_example", subclass)]
36+
#[derive(Debug, Clone)]
37+
pub(crate) struct MySumUDF {
38+
signature: Signature,
39+
inner: Arc<Sum>,
40+
}
41+
42+
#[pymethods]
43+
impl MySumUDF {
44+
#[new]
45+
fn new() -> Self {
46+
Self {
47+
signature: Signature::new(
48+
TypeSignature::Exact(vec![DataType::Int64]),
49+
Volatility::Immutable,
50+
),
51+
inner: Arc::new(Sum::new()),
52+
}
53+
}
54+
55+
fn __datafusion_aggregate_udf__<'py>(
56+
&self,
57+
py: Python<'py>,
58+
) -> PyResult<Bound<'py, PyCapsule>> {
59+
let name = cr"datafusion_aggregate_udf".into();
60+
61+
let func = Arc::new(AggregateUDF::from(self.clone()));
62+
let provider = FFI_AggregateUDF::from(func);
63+
64+
PyCapsule::new(py, provider, Some(name))
65+
}
66+
}
67+
68+
impl AggregateUDFImpl for MySumUDF {
69+
fn as_any(&self) -> &dyn Any {
70+
self
71+
}
72+
73+
fn name(&self) -> &str {
74+
"my_custom_sum"
75+
}
76+
77+
fn signature(&self) -> &Signature {
78+
self.inner.signature()
79+
}
80+
81+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
82+
self.inner.return_type(arg_types)
83+
}
84+
85+
fn accumulator(&self, acc_args: AccumulatorArgs) -> DataFusionResult<Box<dyn Accumulator>> {
86+
self.inner.accumulator(acc_args)
87+
}
88+
89+
fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult<Vec<DataType>> {
90+
self.inner.coerce_types(arg_types)
91+
}
92+
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
// under the License.
1717

1818
use crate::catalog_provider::MyCatalogProvider;
19+
use crate::aggregate_udf::MySumUDF;
1920
use crate::scalar_udf::IsNullUDF;
2021
use crate::table_function::MyTableFunction;
2122
use crate::table_provider::MyTableProvider;
2223
use pyo3::prelude::*;
2324

2425
pub(crate) mod catalog_provider;
25-
mod scalar_udf;
26+
pub(crate) mod aggregate_udf;
27+
pub(crate) mod scalar_udf;
2628
pub(crate) mod table_function;
2729
pub(crate) mod table_provider;
2830

@@ -32,5 +34,6 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
3234
m.add_class::<MyTableFunction>()?;
3335
m.add_class::<MyCatalogProvider>()?;
3436
m.add_class::<IsNullUDF>()?;
37+
m.add_class::<MySumUDF>()?;
3538
Ok(())
3639
}

python/datafusion/user_defined.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ def evaluate(self) -> pa.Scalar:
277277
"""Return the resultant value."""
278278

279279

280+
class AggregateUDFExportable(Protocol):
281+
"""Type hint for object that has __datafusion_aggregate_udf__ PyCapsule."""
282+
283+
def __datafusion_aggregate_udf__(self) -> object: ... # noqa: D105
284+
285+
280286
class AggregateUDF:
281287
"""Class for performing scalar user-defined functions (UDF).
282288
@@ -298,6 +304,9 @@ def __init__(
298304
See :py:func:`udaf` for a convenience function and argument
299305
descriptions.
300306
"""
307+
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
308+
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
309+
return
301310
self._udaf = df_internal.AggregateUDF(
302311
name,
303312
accumulator,
@@ -342,7 +351,7 @@ def udaf(
342351
) -> AggregateUDF: ...
343352

344353
@staticmethod
345-
def udaf(*args: Any, **kwargs: Any): # noqa: D417
354+
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
346355
"""Create a new User-Defined Aggregate Function (UDAF).
347356
348357
This class allows you to define an aggregate function that can be used in
@@ -399,6 +408,10 @@ def udf4() -> Summarize:
399408
Args:
400409
accum: The accumulator python function. Only needed when calling as a
401410
function. Skip this argument when using ``udaf`` as a decorator.
411+
If you have a Rust backed AggregateUDF within a PyCapsule, you can
412+
pass this parameter and ignore the rest. They will be determined directly
413+
from the underlying function. See the online documentation for more
414+
information.
402415
input_types: The data types of the arguments to ``accum``.
403416
return_type: The data type of the return value.
404417
state_type: The data types of the intermediate accumulation.
@@ -457,12 +470,32 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
457470

458471
return decorator
459472

473+
if hasattr(args[0], "__datafusion_aggregate_udf__"):
474+
return AggregateUDF.from_pycapsule(args[0])
475+
460476
if args and callable(args[0]):
461477
# Case 1: Used as a function, require the first parameter to be callable
462478
return _function(*args, **kwargs)
463479
# Case 2: Used as a decorator with parameters
464480
return _decorator(*args, **kwargs)
465481

482+
@staticmethod
483+
def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
484+
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.
485+
486+
This function will instantiate a Aggregate UDF that uses a DataFusion
487+
AggregateUDF that is exported via the FFI bindings.
488+
"""
489+
name = str(udf.__class__)
490+
return AggregateUDF(
491+
name=name,
492+
accumulator=func,
493+
input_types=None,
494+
return_type=None,
495+
state_type=None,
496+
volatility=None,
497+
)
498+
466499

467500
class WindowEvaluator:
468501
"""Evaluator class for user-defined window functions (UDWF).

src/udaf.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ use std::sync::Arc;
1919

2020
use pyo3::{prelude::*, types::PyTuple};
2121

22+
use crate::common::data_type::PyScalarValue;
23+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
24+
use crate::expr::PyExpr;
25+
use crate::utils::{parse_volatility, validate_pycapsule};
2226
use datafusion::arrow::array::{Array, ArrayRef};
2327
use datafusion::arrow::datatypes::DataType;
2428
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
@@ -27,11 +31,8 @@ use datafusion::error::{DataFusionError, Result};
2731
use datafusion::logical_expr::{
2832
create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF,
2933
};
30-
31-
use crate::common::data_type::PyScalarValue;
32-
use crate::errors::to_datafusion_err;
33-
use crate::expr::PyExpr;
34-
use crate::utils::parse_volatility;
34+
use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF};
35+
use pyo3::types::PyCapsule;
3536

3637
#[derive(Debug)]
3738
struct RustAccumulator {
@@ -183,6 +184,26 @@ impl PyAggregateUDF {
183184
Ok(Self { function })
184185
}
185186

187+
#[staticmethod]
188+
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
189+
if func.hasattr("__datafusion_aggregate_udf__")? {
190+
let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
191+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
192+
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
193+
194+
let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
195+
let udaf: ForeignAggregateUDF = udaf.try_into()?;
196+
197+
Ok(Self {
198+
function: udaf.into(),
199+
})
200+
} else {
201+
Err(crate::errors::PyDataFusionError::Common(
202+
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
203+
))
204+
}
205+
}
206+
186207
/// creates a new PyExpr with the call of the udf
187208
#[pyo3(signature = (*args))]
188209
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {

0 commit comments

Comments
 (0)