Skip to content

Commit 04e8820

Browse files
committed
support tpch_1 consumer_producer_test
1 parent c012e9c commit 04e8820

File tree

6 files changed

+968
-3
lines changed

6 files changed

+968
-3
lines changed

datafusion/substrait/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ object_store = { workspace = true }
4141
pbjson-types = "0.6"
4242
prost = "0.12"
4343
substrait = { version = "0.34.0", features = ["serde"] }
44+
url = { workspace = true }
4445

4546
[dev-dependencies]
4647
serde_json = "1.0"

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use datafusion::arrow::datatypes::{
2222
use datafusion::common::{
2323
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
2424
};
25+
use substrait::proto::expression::literal::IntervalDayToSecond;
26+
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
27+
use url::Url;
2528

2629
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
2730
use datafusion::execution::FunctionRegistry;
@@ -408,7 +411,6 @@ pub async fn from_substrait_rel(
408411
};
409412
aggr_expr.push(agg_func?.as_ref().clone());
410413
}
411-
412414
input.aggregate(group_expr, aggr_expr)?.build()
413415
} else {
414416
not_impl_err!("Aggregate without an input is not valid")
@@ -569,7 +571,80 @@ pub async fn from_substrait_rel(
569571

570572
Ok(LogicalPlan::Values(Values { schema, values }))
571573
}
572-
_ => not_impl_err!("Only NamedTable and VirtualTable reads are supported"),
574+
Some(ReadType::LocalFiles(lf)) => {
575+
fn extract_filename(name: &str) -> Option<String> {
576+
let corrected_url =
577+
if name.starts_with("file://") && !name.starts_with("file:///") {
578+
name.replacen("file://", "file:///", 1)
579+
} else {
580+
name.to_string()
581+
};
582+
583+
Url::parse(&corrected_url).ok().and_then(|url| {
584+
let path = url.path();
585+
std::path::Path::new(path)
586+
.file_name()
587+
.map(|filename| filename.to_string_lossy().to_string())
588+
})
589+
}
590+
591+
// we could use the file name to check the original table provider
592+
// TODO: currently does not support multiple local files
593+
let filename: Option<String> =
594+
lf.items.first().and_then(|x| match x.path_type.as_ref() {
595+
Some(UriFile(name)) => extract_filename(name),
596+
_ => None,
597+
});
598+
599+
if lf.items.len() > 1 || !filename.is_some() {
600+
return not_impl_err!(
601+
"Only NamedTable and VirtualTable reads are supported"
602+
);
603+
}
604+
let name = filename.unwrap();
605+
// directly use unwrap here since we could determine it is a valid one
606+
let table_reference = TableReference::Bare { table: name.into() };
607+
let t = ctx.table(table_reference).await?;
608+
let t = t.into_optimized_plan()?;
609+
match &read.projection {
610+
Some(MaskExpression { select, .. }) => match &select.as_ref() {
611+
Some(projection) => {
612+
let column_indices: Vec<usize> = projection
613+
.struct_items
614+
.iter()
615+
.map(|item| item.field as usize)
616+
.collect();
617+
match &t {
618+
LogicalPlan::TableScan(scan) => {
619+
let fields = column_indices
620+
.iter()
621+
.map(|i| {
622+
scan.projected_schema.qualified_field(*i)
623+
})
624+
.map(|(qualifier, field)| {
625+
(qualifier.cloned(), Arc::new(field.clone()))
626+
})
627+
.collect();
628+
let mut scan = scan.clone();
629+
scan.projection = Some(column_indices);
630+
scan.projected_schema =
631+
DFSchemaRef::new(DFSchema::new_with_metadata(
632+
fields,
633+
HashMap::new(),
634+
)?);
635+
Ok(LogicalPlan::TableScan(scan))
636+
}
637+
_ => plan_err!("unexpected plan for table"),
638+
}
639+
}
640+
_ => Ok(t),
641+
},
642+
_ => Ok(t),
643+
}
644+
}
645+
_ => {
646+
not_impl_err!("Only NamedTable and VirtualTable reads are supported")
647+
}
573648
},
574649
Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) {
575650
Ok(set_op) => match set_op {
@@ -810,14 +885,21 @@ pub async fn from_substrait_agg_func(
810885
f.function_reference
811886
);
812887
};
813-
888+
let function_name = function_name.split(':').next().unwrap_or(function_name);
814889
// try udaf first, then built-in aggr fn.
815890
if let Ok(fun) = ctx.udaf(function_name) {
816891
Ok(Arc::new(Expr::AggregateFunction(
817892
expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None),
818893
)))
819894
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
820895
{
896+
match &fun {
897+
// deal with situation that count(*) got no arguments
898+
aggregate_function::AggregateFunction::Count if args.is_empty() => {
899+
args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
900+
}
901+
_ => {}
902+
}
821903
Ok(Arc::new(Expr::AggregateFunction(
822904
expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None),
823905
)))
@@ -1253,6 +1335,8 @@ fn from_substrait_type(
12531335
r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type(
12541336
s, dfs_names, name_idx,
12551337
)?)),
1338+
r#type::Kind::Varchar(_) => Ok(DataType::Utf8),
1339+
r#type::Kind::FixedChar(c) => Ok(DataType::Utf8),
12561340
_ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
12571341
},
12581342
_ => not_impl_err!("`None` Substrait kind is not supported"),
@@ -1541,6 +1625,13 @@ fn from_substrait_literal(
15411625
Some(LiteralType::Null(ntype)) => {
15421626
from_substrait_null(ntype, dfs_names, name_idx)?
15431627
}
1628+
Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond {
1629+
days,
1630+
seconds,
1631+
microseconds,
1632+
})) => {
1633+
ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000))
1634+
}
15441635
Some(LiteralType::UserDefined(user_defined)) => {
15451636
match user_defined.type_reference {
15461637
INTERVAL_YEAR_MONTH_TYPE_REF => {

datafusion/substrait/tests/cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ mod logical_plans;
1919
mod roundtrip_logical_plan;
2020
mod roundtrip_physical_plan;
2121
mod serialize;
22+
mod tpch;
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! tests contains in <https://github.com/substrait-io/consumer-testing/tree/main/substrait_consumer/tests/integration/queries/tpch_substrait_plans>
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use datafusion::common::Result;
23+
use datafusion::execution::options::ParquetReadOptions;
24+
use datafusion::prelude::SessionContext;
25+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
26+
use std::fs::File;
27+
use std::io::BufReader;
28+
use substrait::proto::Plan;
29+
30+
#[tokio::test]
31+
async fn tpch_test_1() -> Result<()> {
32+
let ctx = create_context().await?;
33+
let path = "tests/testdata/query_1.json";
34+
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
35+
File::open(path).expect("file not found"),
36+
))
37+
.expect("failed to parse json");
38+
39+
let plan = from_substrait_plan(&ctx, &proto).await?;
40+
41+
assert_eq!(
42+
format!("{:?}", plan),
43+
"Sort: FILENAME_PLACEHOLDER_0.l_returnflag ASC NULLS LAST, FILENAME_PLACEHOLDER_0.l_linestatus ASC NULLS LAST\n \
44+
Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus]], aggr=[[SUM(FILENAME_PLACEHOLDER_0.l_quantity), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount), SUM(FILENAME_PLACEHOLDER_0.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_0.l_discount * Int32(1) + FILENAME_PLACEHOLDER_0.l_tax), AVG(FILENAME_PLACEHOLDER_0.l_quantity), AVG(FILENAME_PLACEHOLDER_0.l_extendedprice), AVG(FILENAME_PLACEHOLDER_0.l_discount), COUNT(Int64(1))]]\n \
45+
Projection: FILENAME_PLACEHOLDER_0.l_returnflag, FILENAME_PLACEHOLDER_0.l_linestatus, FILENAME_PLACEHOLDER_0.l_quantity, FILENAME_PLACEHOLDER_0.l_extendedprice, FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount), FILENAME_PLACEHOLDER_0.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_0.l_discount) * (CAST(Int32(1) AS Decimal128(19, 0)) + FILENAME_PLACEHOLDER_0.l_tax), FILENAME_PLACEHOLDER_0.l_discount\n \
46+
Filter: FILENAME_PLACEHOLDER_0.l_shipdate <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 120, milliseconds: 0 }\")\n \
47+
TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]"
48+
);
49+
Ok(())
50+
}
51+
52+
async fn create_context() -> datafusion::common::Result<SessionContext> {
53+
let ctx = SessionContext::new();
54+
ctx.register_parquet(
55+
"FILENAME_PLACEHOLDER_0",
56+
"tests/testdata/tpch/lineitem.parquet",
57+
ParquetReadOptions::default(),
58+
)
59+
.await?;
60+
Ok(ctx)
61+
}
62+
}

0 commit comments

Comments
 (0)