Skip to content

Commit a08e62b

Browse files
committed
add a test and fix consuming bound types
1 parent 80c7bb3 commit a08e62b

File tree

5 files changed

+206
-17
lines changed

5 files changed

+206
-17
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use datafusion::arrow::datatypes::{
2020
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
2121
};
2222
use datafusion::common::{
23-
not_impl_datafusion_err, not_impl_err, substrait_datafusion_err, substrait_err,
24-
DFSchema, DFSchemaRef,
23+
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err,
24+
substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
2525
};
2626
use substrait::proto::expression::literal::IntervalDayToSecond;
2727
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
@@ -57,7 +57,7 @@ use substrait::proto::{
5757
reference_segment::ReferenceType::StructField,
5858
window_function::bound as SubstraitBound,
5959
window_function::bound::Kind as BoundKind, window_function::Bound,
60-
MaskExpression, RexType,
60+
window_function::BoundsType, MaskExpression, RexType,
6161
},
6262
extensions::simple_extension_declaration::MappingType,
6363
function_argument::ArgType,
@@ -71,7 +71,6 @@ use substrait::proto::{
7171
use substrait::proto::{FunctionArgument, SortField};
7272

7373
use datafusion::arrow::array::GenericListArray;
74-
use datafusion::common::plan_err;
7574
use datafusion::common::scalar::ScalarStructBuilder;
7675
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
7776
use std::collections::HashMap;
@@ -1188,15 +1187,24 @@ pub async fn from_substrait_rex(
11881187
let order_by =
11891188
from_substrait_sorts(ctx, &window.sorts, input_schema, extensions)
11901189
.await?;
1191-
// Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units
1192-
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
1193-
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
1194-
// TODO: Consider the cases where window frame is specified in query and is different from default
1195-
let units = if order_by.is_empty() {
1196-
WindowFrameUnits::Rows
1197-
} else {
1198-
WindowFrameUnits::Range
1199-
};
1190+
1191+
let bound_units =
1192+
match BoundsType::try_from(window.bounds_type).map_err(|e| {
1193+
plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
1194+
})? {
1195+
BoundsType::Rows => WindowFrameUnits::Rows,
1196+
BoundsType::Range => WindowFrameUnits::Range,
1197+
BoundsType::Unspecified => {
1198+
// If the plan does not specify the bounds type, then we use a simple logic to determine the units
1199+
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
1200+
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
1201+
if order_by.is_empty() {
1202+
WindowFrameUnits::Rows
1203+
} else {
1204+
WindowFrameUnits::Range
1205+
}
1206+
}
1207+
};
12001208
Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
12011209
fun,
12021210
args: from_substrait_func_args(
@@ -1215,7 +1223,7 @@ pub async fn from_substrait_rex(
12151223
.await?,
12161224
order_by,
12171225
window_frame: datafusion::logical_expr::WindowFrame::new_bounds(
1218-
units,
1226+
bound_units,
12191227
from_substrait_bound(&window.lower_bound, true)?,
12201228
from_substrait_bound(&window.upper_bound, false)?,
12211229
),

datafusion/substrait/tests/cases/logical_plans.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mod tests {
2828
use substrait::proto::Plan;
2929

3030
#[tokio::test]
31-
async fn function_compound_signature() -> Result<()> {
31+
async fn scalar_function_compound_signature() -> Result<()> {
3232
// DataFusion currently produces Substrait that refers to functions only by their name.
3333
// However, the Substrait spec requires that functions be identified by their compound signature.
3434
// This test confirms that DataFusion is able to consume plans following the spec, even though
@@ -39,7 +39,7 @@ mod tests {
3939

4040
// File generated with substrait-java's Isthmus:
4141
// ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)"
42-
let proto = read_json("tests/testdata/select_not_bool.substrait.json");
42+
let proto = read_json("tests/testdata/test_plans/select_not_bool.substrait.json");
4343

4444
let plan = from_substrait_plan(&ctx, &proto).await?;
4545

@@ -51,13 +51,41 @@ mod tests {
5151
Ok(())
5252
}
5353

54+
// Aggregate function compound signature is tested through TPCH plans
55+
56+
#[tokio::test]
57+
async fn window_function_compound_signature() -> Result<()> {
58+
// DataFusion currently produces Substrait that refers to functions only by their name.
59+
// However, the Substrait spec requires that functions be identified by their compound signature.
60+
// This test confirms that DataFusion is able to consume plans following the spec, even though
61+
// we don't yet produce such plans.
62+
// Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests.
63+
64+
let ctx = create_context().await?;
65+
66+
// File generated with substrait-java's Isthmus:
67+
// ./isthmus-cli/build/graal/isthmus "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" -c "create table data (d int, part int, ord int)"
68+
let proto = read_json("tests/testdata/test_plans/select_window.substrait.json");
69+
70+
let plan = from_substrait_plan(&ctx, &proto).await?;
71+
72+
assert_eq!(
73+
format!("{:?}", plan),
74+
"Projection: sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
75+
\n WindowAggr: windowExpr=[[sum(DATA.a) PARTITION BY [DATA.b] ORDER BY [DATA.c ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
76+
\n TableScan: DATA projection=[a, b, c, d, e, f]"
77+
);
78+
Ok(())
79+
}
80+
5481
#[tokio::test]
5582
async fn non_nullable_lists() -> Result<()> {
5683
// DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable.
5784
// That's because implementing the non-nullability consistently is non-trivial.
5885
// This test confirms that reading a plan with non-nullable lists works as expected.
5986
let ctx = create_context().await?;
60-
let proto = read_json("tests/testdata/non_nullable_lists.substrait.json");
87+
let proto =
88+
read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json");
6189

6290
let plan = from_substrait_plan(&ctx, &proto).await?;
6391

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"extensionUriAnchor": 1,
5+
"uri": "/functions_arithmetic.yaml"
6+
}
7+
],
8+
"extensions": [
9+
{
10+
"extensionFunction": {
11+
"extensionUriReference": 1,
12+
"functionAnchor": 0,
13+
"name": "sum:i32"
14+
}
15+
}
16+
],
17+
"relations": [
18+
{
19+
"root": {
20+
"input": {
21+
"project": {
22+
"common": {
23+
"emit": {
24+
"outputMapping": [
25+
3
26+
]
27+
}
28+
},
29+
"input": {
30+
"read": {
31+
"common": {
32+
"direct": {
33+
}
34+
},
35+
"baseSchema": {
36+
"names": [
37+
"D",
38+
"PART",
39+
"ORD"
40+
],
41+
"struct": {
42+
"types": [
43+
{
44+
"i32": {
45+
"typeVariationReference": 0,
46+
"nullability": "NULLABILITY_NULLABLE"
47+
}
48+
},
49+
{
50+
"i32": {
51+
"typeVariationReference": 0,
52+
"nullability": "NULLABILITY_NULLABLE"
53+
}
54+
},
55+
{
56+
"i32": {
57+
"typeVariationReference": 0,
58+
"nullability": "NULLABILITY_NULLABLE"
59+
}
60+
}
61+
],
62+
"typeVariationReference": 0,
63+
"nullability": "NULLABILITY_REQUIRED"
64+
}
65+
},
66+
"namedTable": {
67+
"names": [
68+
"DATA"
69+
]
70+
}
71+
}
72+
},
73+
"expressions": [
74+
{
75+
"windowFunction": {
76+
"functionReference": 0,
77+
"partitions": [
78+
{
79+
"selection": {
80+
"directReference": {
81+
"structField": {
82+
"field": 1
83+
}
84+
},
85+
"rootReference": {
86+
}
87+
}
88+
}
89+
],
90+
"sorts": [
91+
{
92+
"expr": {
93+
"selection": {
94+
"directReference": {
95+
"structField": {
96+
"field": 2
97+
}
98+
},
99+
"rootReference": {
100+
}
101+
}
102+
},
103+
"direction": "SORT_DIRECTION_ASC_NULLS_LAST"
104+
}
105+
],
106+
"upperBound": {
107+
"unbounded": {
108+
}
109+
},
110+
"lowerBound": {
111+
"preceding": {
112+
"offset": "1"
113+
}
114+
},
115+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
116+
"outputType": {
117+
"i32": {
118+
"typeVariationReference": 0,
119+
"nullability": "NULLABILITY_NULLABLE"
120+
}
121+
},
122+
"args": [],
123+
"arguments": [
124+
{
125+
"value": {
126+
"selection": {
127+
"directReference": {
128+
"structField": {
129+
"field": 0
130+
}
131+
},
132+
"rootReference": {
133+
}
134+
}
135+
}
136+
}
137+
],
138+
"invocation": "AGGREGATION_INVOCATION_ALL",
139+
"options": [],
140+
"boundsType": "BOUNDS_TYPE_ROWS"
141+
}
142+
}
143+
]
144+
}
145+
},
146+
"names": [
147+
"LEAD_EXPR"
148+
]
149+
}
150+
}
151+
],
152+
"expectedTypeUrls": []
153+
}

0 commit comments

Comments
 (0)