diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 8a483db8c4d6..648a281832e1 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -17,7 +17,7 @@ use async_recursion::async_recursion; use datafusion::arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; use datafusion::common::{ not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, @@ -29,12 +29,13 @@ use url::Url; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, BinaryExpr, Case, EmptyRelation, Expr, - LogicalPlan, Operator, ScalarUDF, Values, + aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, + EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, + Values, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion::prelude::JoinType; @@ -225,6 +226,7 @@ pub async fn from_substrait_plan( None => not_impl_err!("Cannot parse empty extension"), }) .collect::>>()?; + // Parse relations match plan.relations.len() { 1 => { @@ -234,7 +236,29 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, rel, &function_extension).await?) }, plan_rel::RelType::Root(root) => { - Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.equivalent_names_and_types(plan.schema()) { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?)) + } } }, None => plan_err!("Cannot parse plan relation: None") @@ -284,6 +308,105 @@ pub fn extract_projection( } } +fn rename_expressions( + exprs: impl IntoIterator, + input_schema: &DFSchema, + new_schema: DFSchemaRef, +) -> Result> { + exprs + .into_iter() + .zip(new_schema.fields()) + .map(|(old_expr, new_field)| { + if &old_expr.get_type(input_schema)? == new_field.data_type() { + // Alias column if needed + old_expr.alias_if_changed(new_field.name().into()) + } else { + // Use Cast to rename inner struct fields + alias column if needed + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + .alias_if_changed(new_field.name().into()) + } + }) + .collect() +} + +fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec, +) -> Result { + fn rename_inner_fields( + dtype: &DataType, + dfs_names: &Vec, + name_idx: &mut usize, + ) -> Result { + match dtype { + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|f| { + let name = next_struct_field_name(0, dfs_names, name_idx)?; + Ok((**f).to_owned().with_name(name).with_data_type( + rename_inner_fields(f.data_type(), dfs_names, name_idx)?, + )) + }) + .collect::>()?; + Ok(DataType::Struct(fields)) + } + DataType::List(inner) => Ok(DataType::List(FieldRef::new( + (**inner).to_owned().with_data_type(rename_inner_fields( + inner.data_type(), + dfs_names, + name_idx, + )?), + ))), + DataType::LargeList(inner) => Ok(DataType::LargeList(FieldRef::new( + (**inner).to_owned().with_data_type(rename_inner_fields( + inner.data_type(), + dfs_names, + name_idx, + )?), + ))), + _ => Ok(dtype.to_owned()), + } + } + + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec) = schema + .iter() + .map(|(q, f)| { + let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; + Ok(( + q.cloned(), + (**f) + .to_owned() + .with_name(name) + .with_data_type(rename_inner_fields( + f.data_type(), + dfs_names, + &mut name_idx, + )?), + )) + }) + .collect::>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + Ok(Arc::new(DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + )?)) +} + /// Convert Substrait Rel to DataFusion DataFrame #[async_recursion] pub async fn from_substrait_rel( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6c8be4aa9b12..88dc894eccd2 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -115,7 +115,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result<()> { roundtrip("SELECT * FROM data").await } +#[tokio::test] +async fn select_with_alias() -> Result<()> { + roundtrip("SELECT a AS aliased_a FROM data").await +} + #[tokio::test] async fn select_with_filter() -> Result<()> { roundtrip("SELECT * FROM data WHERE a > 1").await @@ -367,9 +372,9 @@ async fn implicit_cast() -> Result<()> { async fn aggregate_case() -> Result<()> { assert_expected_plan( "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", - "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\ + "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\ \n TableScan: data projection=[a]", - false // NULL vs Int64(NULL) + true ) .await } @@ -589,32 +594,23 @@ async fn roundtrip_union_all() -> Result<()> { #[tokio::test] async fn simple_intersect() -> Result<()> { + // Substrait treats both COUNT(*) and COUNT(1) the same assert_expected_plan( "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ \n Projection: \ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ \n TableScan: data2 projection=[a]", - false // COUNT(*) vs COUNT(Int64(1)) + true ) .await } #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { - assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", - false // COUNT(*) vs COUNT(Int64(1)) - ) - .await + roundtrip("SELECT COUNT(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await } #[tokio::test] @@ -694,20 +690,14 @@ async fn all_type_literal() -> Result<()> { #[tokio::test] async fn roundtrip_literal_list() -> Result<()> { - assert_expected_plan( - "SELECT [[1,2,3], [], NULL, [NULL]] FROM data", - "Projection: List([[1, 2, 3], [], , []])\ - \n TableScan: data projection=[]", - false, // "List(..)" vs "make_array(..)" - ) - .await + roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await } #[tokio::test] async fn roundtrip_literal_struct() -> Result<()> { assert_expected_plan( "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", - "Projection: Struct({c0:1,c1:true,c2:})\ + "Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\ \n TableScan: data projection=[]", false, // "Struct(..)" vs "struct(..)" ) @@ -980,12 +970,13 @@ async fn assert_expected_plan( println!("{proto:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(expected_plan_str, &plan2str); - if assert_schema { assert_eq!(plan.schema(), plan2.schema()); } + + let plan2str = format!("{plan2:?}"); + assert_eq!(expected_plan_str, &plan2str); + Ok(()) }