Skip to content

Commit 44db49c

Browse files
Blizzaraalamb
authored andcommitted
Support consuming Substrait with compound signature function names (apache#10653)
* Support consuming Substrait with compound signature function names Substrait 0.32.0+ requires functions to be specified using compound names, which include the function name as well as the arguments it takes. We don't necessarily need that information while consuming the plans, but we need to support those compound names. * Add a test for using "not:bool" * clippy fixes * Add license to new file * Apply suggestions from code review Co-authored-by: Andrew Lamb <[email protected]> * remove prost-types dep as it's replaced by pbjson-types --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent eed27f0 commit 44db49c

File tree

6 files changed

+182
-10
lines changed

6 files changed

+182
-10
lines changed

datafusion/substrait/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ chrono = { workspace = true }
3737
datafusion = { workspace = true, default-features = true }
3838
itertools = { workspace = true }
3939
object_store = { workspace = true }
40+
pbjson-types = "0.6"
4041
prost = "0.12"
41-
prost-types = "0.12"
42-
substrait = "0.34.0"
42+
substrait = { version = "0.34.0", features = ["serde"] }
4343

4444
[dev-dependencies]
45+
serde_json = "1.0"
4546
tokio = { workspace = true }
4647

4748
[features]

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ fn scalar_function_type_from_str(
124124
name: &str,
125125
) -> Result<ScalarFunctionType> {
126126
let s = ctx.state();
127+
let name = match name.rsplit_once(':') {
128+
// Since 0.32.0, Substrait requires the function names to be in a compound format
129+
// https://substrait.io/extensions/#function-signature-compound-names
130+
// for example, `add:i8_i8`.
131+
// On the consumer side, we don't really care about the signature though, just the name.
132+
Some((name, _)) => name,
133+
None => name,
134+
};
135+
127136
if let Some(func) = s.scalar_functions().get(name) {
128137
return Ok(ScalarFunctionType::Udf(func.to_owned()));
129138
}
@@ -1525,7 +1534,7 @@ fn from_substrait_literal(
15251534
return substrait_err!("Interval year month value is empty");
15261535
};
15271536
let value_slice: [u8; 4] =
1528-
raw_val.value.clone().try_into().map_err(|_| {
1537+
(*raw_val.value).try_into().map_err(|_| {
15291538
substrait_datafusion_err!(
15301539
"Failed to parse interval year month value"
15311540
)
@@ -1537,7 +1546,7 @@ fn from_substrait_literal(
15371546
return substrait_err!("Interval day time value is empty");
15381547
};
15391548
let value_slice: [u8; 8] =
1540-
raw_val.value.clone().try_into().map_err(|_| {
1549+
(*raw_val.value).try_into().map_err(|_| {
15411550
substrait_datafusion_err!(
15421551
"Failed to parse interval day time value"
15431552
)
@@ -1549,7 +1558,7 @@ fn from_substrait_literal(
15491558
return substrait_err!("Interval month day nano value is empty");
15501559
};
15511560
let value_slice: [u8; 16] =
1552-
raw_val.value.clone().try_into().map_err(|_| {
1561+
(*raw_val.value).try_into().map_err(|_| {
15531562
substrait_datafusion_err!(
15541563
"Failed to parse interval month day nano value"
15551564
)

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use datafusion::logical_expr::expr::{
4545
};
4646
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
4747
use datafusion::prelude::Expr;
48-
use prost_types::Any as ProtoAny;
48+
use pbjson_types::Any as ProtoAny;
4949
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
5050
use substrait::proto::expression::literal::user_defined::Val;
5151
use substrait::proto::expression::literal::UserDefined;
@@ -547,7 +547,7 @@ pub fn to_substrait_rel(
547547
.serialize_logical_plan(extension_plan.node.as_ref())?;
548548
let detail = ProtoAny {
549549
type_url: extension_plan.node.name().to_string(),
550-
value: extension_bytes,
550+
value: extension_bytes.into(),
551551
};
552552
let mut inputs_rel = extension_plan
553553
.node
@@ -1919,7 +1919,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
19191919
}],
19201920
val: Some(Val::Value(ProtoAny {
19211921
type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(),
1922-
value: bytes.to_vec(),
1922+
value: bytes.to_vec().into(),
19231923
})),
19241924
}),
19251925
INTERVAL_YEAR_MONTH_TYPE_REF,
@@ -1942,7 +1942,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
19421942
type_parameters: vec![i64_param.clone(), i64_param],
19431943
val: Some(Val::Value(ProtoAny {
19441944
type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(),
1945-
value: bytes.to_vec(),
1945+
value: bytes.to_vec().into(),
19461946
})),
19471947
}),
19481948
INTERVAL_MONTH_DAY_NANO_TYPE_REF,
@@ -1965,7 +1965,7 @@ fn to_substrait_literal(value: &ScalarValue) -> Result<Literal> {
19651965
}],
19661966
val: Some(Val::Value(ProtoAny {
19671967
type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(),
1968-
value: bytes.to_vec(),
1968+
value: bytes.to_vec().into(),
19691969
})),
19701970
}),
19711971
INTERVAL_DAY_TIME_TYPE_REF,
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Tests for reading substrait plans produced by other systems
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use datafusion::common::Result;
23+
use datafusion::prelude::{CsvReadOptions, SessionContext};
24+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
25+
use std::fs::File;
26+
use std::io::BufReader;
27+
use substrait::proto::Plan;
28+
29+
#[tokio::test]
30+
async fn function_compound_signature() -> Result<()> {
31+
// DataFusion currently produces Substrait that refers to functions only by their name.
32+
// However, the Substrait spec requires that functions be identified by their compound signature.
33+
// This test confirms that DataFusion is able to consume plans following the spec, even though
34+
// we don't yet produce such plans.
35+
// Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests.
36+
37+
let ctx = create_context().await?;
38+
39+
// File generated with substrait-java's Isthmus:
40+
// ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)"
41+
let path = "tests/testdata/select_not_bool.substrait.json";
42+
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
43+
File::open(path).expect("file not found"),
44+
))
45+
.expect("failed to parse json");
46+
47+
let plan = from_substrait_plan(&ctx, &proto).await?;
48+
49+
assert_eq!(
50+
format!("{:?}", plan),
51+
"Projection: NOT DATA.a\
52+
\n TableScan: DATA projection=[a, b, c, d, e, f]"
53+
);
54+
Ok(())
55+
}
56+
57+
async fn create_context() -> datafusion::common::Result<SessionContext> {
58+
let ctx = SessionContext::new();
59+
ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new())
60+
.await?;
61+
Ok(ctx)
62+
}
63+
}

datafusion/substrait/tests/cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod logical_plans;
1819
mod roundtrip_logical_plan;
1920
mod roundtrip_physical_plan;
2021
mod serialize;
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"extensionUriAnchor": 1,
5+
"uri": "/functions_boolean.yaml"
6+
}
7+
],
8+
"extensions": [
9+
{
10+
"extensionFunction": {
11+
"extensionUriReference": 1,
12+
"functionAnchor": 0,
13+
"name": "not:bool"
14+
}
15+
}
16+
],
17+
"relations": [
18+
{
19+
"root": {
20+
"input": {
21+
"project": {
22+
"common": {
23+
"emit": {
24+
"outputMapping": [
25+
1
26+
]
27+
}
28+
},
29+
"input": {
30+
"read": {
31+
"common": {
32+
"direct": {
33+
}
34+
},
35+
"baseSchema": {
36+
"names": [
37+
"D"
38+
],
39+
"struct": {
40+
"types": [
41+
{
42+
"bool": {
43+
"typeVariationReference": 0,
44+
"nullability": "NULLABILITY_NULLABLE"
45+
}
46+
}
47+
],
48+
"typeVariationReference": 0,
49+
"nullability": "NULLABILITY_REQUIRED"
50+
}
51+
},
52+
"namedTable": {
53+
"names": [
54+
"DATA"
55+
]
56+
}
57+
}
58+
},
59+
"expressions": [
60+
{
61+
"scalarFunction": {
62+
"functionReference": 0,
63+
"args": [],
64+
"outputType": {
65+
"bool": {
66+
"typeVariationReference": 0,
67+
"nullability": "NULLABILITY_NULLABLE"
68+
}
69+
},
70+
"arguments": [
71+
{
72+
"value": {
73+
"selection": {
74+
"directReference": {
75+
"structField": {
76+
"field": 0
77+
}
78+
},
79+
"rootReference": {
80+
}
81+
}
82+
}
83+
}
84+
],
85+
"options": []
86+
}
87+
}
88+
]
89+
}
90+
},
91+
"names": [
92+
"EXPR$0"
93+
]
94+
}
95+
}
96+
],
97+
"expectedTypeUrls": []
98+
}

0 commit comments

Comments
 (0)