Skip to content

Commit 73f5e24

Browse files
committed
Initial commit for window ffi integration
1 parent 5f55a59 commit 73f5e24

File tree

7 files changed

+237
-23
lines changed

7 files changed

+237
-23
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pyarrow as pa
21+
from datafusion import SessionContext, col, udwf
22+
from datafusion_ffi_example import MyRankUDF
23+
24+
25+
def setup_context_with_table():
26+
ctx = SessionContext()
27+
28+
# Pick numbers here so we get the same value in both groups
29+
# since we cannot be certain of the output order of batches
30+
batch = pa.RecordBatch.from_arrays(
31+
[
32+
pa.array([40, 10, 30, 20], type=pa.int64()),
33+
],
34+
names=["a"],
35+
)
36+
ctx.register_record_batches("test_table", [[batch]])
37+
return ctx
38+
39+
40+
def test_ffi_window_register():
41+
ctx = setup_context_with_table()
42+
my_udwf = udwf(MyRankUDF())
43+
ctx.register_udwf(my_udwf)
44+
45+
result = ctx.sql(
46+
"select a, my_custom_rank() over (order by a) from test_table"
47+
).collect()
48+
assert len(result) == 1
49+
assert result[0].num_columns == 2
50+
51+
results = [
52+
(result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4)
53+
]
54+
results.sort()
55+
56+
expected = [
57+
(10, 1),
58+
(20, 2),
59+
(30, 3),
60+
(40, 4),
61+
]
62+
assert results == expected
63+
64+
65+
def test_ffi_window_call_directly():
66+
ctx = setup_context_with_table()
67+
my_udwf = udwf(MyRankUDF())
68+
69+
result = (
70+
ctx.table("test_table")
71+
.select(col("a"), my_udwf().order_by(col("a")).build())
72+
.collect()
73+
)
74+
75+
assert len(result) == 1
76+
assert result[0].num_columns == 2
77+
78+
results = [
79+
(result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4)
80+
]
81+
results.sort()
82+
83+
expected = [
84+
(10, 1),
85+
(20, 2),
86+
(30, 3),
87+
(40, 4),
88+
]
89+
assert results == expected

examples/datafusion-ffi-example/src/aggregate_udf.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow_array::{Array, BooleanArray};
19-
use arrow_schema::{DataType, FieldRef};
20-
use datafusion::common::ScalarValue;
18+
use arrow_schema::DataType;
2119
use datafusion::error::Result as DataFusionResult;
2220
use datafusion::functions_aggregate::sum::Sum;
2321
use datafusion::logical_expr::function::AccumulatorArgs;
24-
use datafusion::logical_expr::{
25-
Accumulator, AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, ScalarUDF,
26-
ScalarUDFImpl, Signature, TypeSignature, Volatility,
27-
};
22+
use datafusion::logical_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature};
2823
use datafusion_ffi::udaf::FFI_AggregateUDF;
29-
use datafusion_ffi::udf::FFI_ScalarUDF;
3024
use pyo3::types::PyCapsule;
3125
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
3226
use std::any::Any;
@@ -35,7 +29,6 @@ use std::sync::Arc;
3529
#[pyclass(name = "MySumUDF", module = "datafusion_ffi_example", subclass)]
3630
#[derive(Debug, Clone)]
3731
pub(crate) struct MySumUDF {
38-
signature: Signature,
3932
inner: Arc<Sum>,
4033
}
4134

@@ -44,10 +37,6 @@ impl MySumUDF {
4437
#[new]
4538
fn new() -> Self {
4639
Self {
47-
signature: Signature::new(
48-
TypeSignature::Exact(vec![DataType::Int64]),
49-
Volatility::Immutable,
50-
),
5140
inner: Arc::new(Sum::new()),
5241
}
5342
}

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,21 @@ use crate::aggregate_udf::MySumUDF;
1919
use crate::scalar_udf::IsNullUDF;
2020
use crate::table_function::MyTableFunction;
2121
use crate::table_provider::MyTableProvider;
22+
use crate::window_udf::MyRankUDF;
2223
use pyo3::prelude::*;
2324

2425
pub(crate) mod aggregate_udf;
2526
pub(crate) mod scalar_udf;
2627
pub(crate) mod table_function;
2728
pub(crate) mod table_provider;
29+
pub(crate) mod window_udf;
2830

2931
#[pymodule]
3032
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
3133
m.add_class::<MyTableProvider>()?;
3234
m.add_class::<MyTableFunction>()?;
3335
m.add_class::<IsNullUDF>()?;
3436
m.add_class::<MySumUDF>()?;
37+
m.add_class::<MyRankUDF>()?;
3538
Ok(())
3639
}

examples/datafusion-ffi-example/src/scalar_udf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use arrow_array::{Array, BooleanArray};
19-
use arrow_schema::{DataType, FieldRef};
19+
use arrow_schema::DataType;
2020
use datafusion::common::ScalarValue;
2121
use datafusion::error::Result as DataFusionResult;
2222
use datafusion::logical_expr::{
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow_schema::{DataType, FieldRef};
19+
use datafusion::error::Result as DataFusionResult;
20+
use datafusion::functions_window::rank::rank_udwf;
21+
use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs};
22+
use datafusion::logical_expr::{PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl};
23+
use datafusion_ffi::udwf::FFI_WindowUDF;
24+
use pyo3::types::PyCapsule;
25+
use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
26+
use std::any::Any;
27+
use std::sync::Arc;
28+
29+
#[pyclass(name = "MyRankUDF", module = "datafusion_ffi_example", subclass)]
30+
#[derive(Debug, Clone)]
31+
pub(crate) struct MyRankUDF {
32+
inner: Arc<WindowUDF>,
33+
}
34+
35+
#[pymethods]
36+
impl MyRankUDF {
37+
#[new]
38+
fn new() -> Self {
39+
Self { inner: rank_udwf() }
40+
}
41+
42+
fn __datafusion_window_udf__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
43+
let name = cr"datafusion_window_udf".into();
44+
45+
let func = Arc::new(WindowUDF::from(self.clone()));
46+
let provider = FFI_WindowUDF::from(func);
47+
48+
PyCapsule::new(py, provider, Some(name))
49+
}
50+
}
51+
52+
impl WindowUDFImpl for MyRankUDF {
53+
fn as_any(&self) -> &dyn Any {
54+
self
55+
}
56+
57+
fn name(&self) -> &str {
58+
"my_custom_rank"
59+
}
60+
61+
fn signature(&self) -> &Signature {
62+
self.inner.signature()
63+
}
64+
65+
fn partition_evaluator(
66+
&self,
67+
partition_evaluator_args: PartitionEvaluatorArgs,
68+
) -> DataFusionResult<Box<dyn PartitionEvaluator>> {
69+
self.inner
70+
.inner()
71+
.partition_evaluator(partition_evaluator_args)
72+
}
73+
74+
fn field(&self, field_args: WindowUDFFieldArgs) -> DataFusionResult<FieldRef> {
75+
self.inner.inner().field(field_args)
76+
}
77+
78+
fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult<Vec<DataType>> {
79+
self.inner.coerce_types(arg_types)
80+
}
81+
}

python/datafusion/user_defined.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
247247
This function will instantiate a Scalar UDF that uses a DataFusion
248248
ScalarUDF that is exported via the FFI bindings.
249249
"""
250-
name = str(udf.__class__)
250+
name = str(func.__class__)
251251
return ScalarUDF(
252252
name=name,
253253
func=func,
@@ -409,9 +409,9 @@ def udf4() -> Summarize:
409409
accum: The accumulator python function. Only needed when calling as a
410410
function. Skip this argument when using ``udaf`` as a decorator.
411411
If you have a Rust backed AggregateUDF within a PyCapsule, you can
412-
pass this parameter and ignore the rest. They will be determined directly
413-
from the underlying function. See the online documentation for more
414-
information.
412+
pass this parameter and ignore the rest. They will be determined
413+
directly from the underlying function. See the online documentation
414+
for more information.
415415
input_types: The data types of the arguments to ``accum``.
416416
return_type: The data type of the return value.
417417
state_type: The data types of the intermediate accumulation.
@@ -486,7 +486,7 @@ def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
486486
This function will instantiate a Aggregate UDF that uses a DataFusion
487487
AggregateUDF that is exported via the FFI bindings.
488488
"""
489-
name = str(udf.__class__)
489+
name = str(func.__class__)
490490
return AggregateUDF(
491491
name=name,
492492
accumulator=func,
@@ -656,6 +656,12 @@ def include_rank(self) -> bool:
656656
return False
657657

658658

659+
class WindowUDFExportable(Protocol):
660+
"""Type hint for object that has __datafusion_window_udf__ PyCapsule."""
661+
662+
def __datafusion_window_udf__(self) -> object: ... # noqa: D105
663+
664+
659665
class WindowUDF:
660666
"""Class for performing window user-defined functions (UDF).
661667
@@ -676,6 +682,9 @@ def __init__(
676682
See :py:func:`udwf` for a convenience function and argument
677683
descriptions.
678684
"""
685+
if hasattr(func, "__datafusion_window_udf__"):
686+
self._udwf = df_internal.WindowUDF.from_pycapsule(func)
687+
return
679688
self._udwf = df_internal.WindowUDF(
680689
name, func, input_types, return_type, str(volatility)
681690
)
@@ -751,7 +760,10 @@ def biased_numbers() -> BiasedNumbers:
751760
752761
Args:
753762
func: Only needed when calling as a function. Skip this argument when
754-
using ``udwf`` as a decorator.
763+
using ``udwf`` as a decorator. If you have a Rust backed WindowUDF
764+
within a PyCapsule, you can pass this parameter and ignore the rest.
765+
They will be determined directly from the underlying function. See
766+
the online documentation for more information.
755767
input_types: The data types of the arguments.
756768
return_type: The data type of the return value.
757769
volatility: See :py:class:`Volatility` for allowed values.
@@ -760,6 +772,9 @@ def biased_numbers() -> BiasedNumbers:
760772
Returns:
761773
A user-defined window function that can be used in window function calls.
762774
"""
775+
if hasattr(args[0], "__datafusion_window_udf__"):
776+
return WindowUDF.from_pycapsule(args[0])
777+
763778
if args and callable(args[0]):
764779
# Case 1: Used as a function, require the first parameter to be callable
765780
return WindowUDF._create_window_udf(*args, **kwargs)
@@ -827,6 +842,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
827842

828843
return decorator
829844

845+
@staticmethod
846+
def from_pycapsule(func: WindowUDFExportable) -> WindowUDF:
847+
"""Create a Window UDF from WindowUDF PyCapsule object.
848+
849+
This function will instantiate a Window UDF that uses a DataFusion
850+
WindowUDF that is exported via the FFI bindings.
851+
"""
852+
name = str(func.__class__)
853+
return WindowUDF(
854+
name=name,
855+
func=func,
856+
input_types=None,
857+
return_type=None,
858+
volatility=None,
859+
)
860+
830861

831862
class TableFunction:
832863
"""Class for performing user-defined table functions (UDTF).

src/udwf.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@ use pyo3::exceptions::PyValueError;
2727
use pyo3::prelude::*;
2828

2929
use crate::common::data_type::PyScalarValue;
30-
use crate::errors::to_datafusion_err;
30+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3131
use crate::expr::PyExpr;
32-
use crate::utils::parse_volatility;
32+
use crate::utils::{parse_volatility, validate_pycapsule};
3333
use datafusion::arrow::datatypes::DataType;
3434
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
3535
use datafusion::error::{DataFusionError, Result};
3636
use datafusion::logical_expr::{
3737
PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl,
3838
};
39-
use pyo3::types::{PyList, PyTuple};
39+
use datafusion_ffi::udwf::{FFI_WindowUDF, ForeignWindowUDF};
40+
use pyo3::types::{PyCapsule, PyList, PyTuple};
4041

4142
#[derive(Debug)]
4243
struct RustPartitionEvaluator {
@@ -245,6 +246,26 @@ impl PyWindowUDF {
245246
Ok(self.function.call(args).into())
246247
}
247248

249+
#[staticmethod]
250+
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
251+
if func.hasattr("__datafusion_window_udf__")? {
252+
let capsule = func.getattr("__datafusion_window_udf__")?.call0()?;
253+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
254+
validate_pycapsule(capsule, "datafusion_window_udf")?;
255+
256+
let udwf = unsafe { capsule.reference::<FFI_WindowUDF>() };
257+
let udwf: ForeignWindowUDF = udwf.try_into()?;
258+
259+
Ok(Self {
260+
function: udwf.into(),
261+
})
262+
} else {
263+
Err(crate::errors::PyDataFusionError::Common(
264+
"__datafusion_window_udf__ does not exist on WindowUDF object.".to_string(),
265+
))
266+
}
267+
}
268+
248269
fn __repr__(&self) -> PyResult<String> {
249270
Ok(format!("WindowUDF({})", self.function.name()))
250271
}

0 commit comments

Comments
 (0)