diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f6b556fc6448..044eda2c31ca 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -27,8 +27,8 @@ use datafusion::logical_expr::{ BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator, }; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, WindowFrameBound, - WindowFrameUnits, + expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + Repartition, WindowFrameBound, WindowFrameUnits, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; @@ -38,7 +38,8 @@ use datafusion::{ prelude::{Column, SessionContext}, scalar::ScalarValue, }; -use substrait::proto::expression::{Literal, ScalarFunction}; +use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -550,6 +551,45 @@ pub async fn from_substrait_rel( let plan = plan.from_template(&plan.expressions(), &inputs); Ok(LogicalPlan::Extension(Extension { node: plan })) } + Some(RelType::Exchange(exchange)) => { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = + from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); + } + Partitioning::Hash( + partition_columns, + exchange.partition_count as usize, + ) + } + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) + } + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); + } + }; + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) + } _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } @@ -725,27 +765,9 @@ pub async fn from_substrait_rex( negated: false, }))) } - Some(RexType::Selection(field_ref)) => match &field_ref.reference_type { - Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { - Some(StructField(x)) => match &x.child.as_ref() { - Some(_) => not_impl_err!( - "Direct reference StructField with child is not supported" - ), - None => { - let column = - input_schema.field(x.field as usize).qualified_column(); - Ok(Arc::new(Expr::Column(Column { - relation: column.relation, - name: column.name, - }))) - } - }, - _ => not_impl_err!( - "Direct reference with types other than StructField is not supported" - ), - }, - _ => not_impl_err!("unsupported field ref type"), - }, + Some(RexType::Selection(field_ref)) => Ok(Arc::new( + from_substrait_field_reference(field_ref, input_schema)?, + )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -1245,6 +1267,32 @@ fn from_substrait_null(null_type: &Type) -> Result { } } +fn from_substrait_field_reference( + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> Result { + match &field_ref.reference_type { + Some(DirectReference(direct)) => match &direct.reference_type.as_ref() { + Some(StructField(x)) => match &x.child.as_ref() { + Some(_) => not_impl_err!( + "Direct reference StructField with child is not supported" + ), + None => { + let column = input_schema.field(x.field as usize).qualified_column(); + Ok(Expr::Column(Column { + relation: column.relation, + name: column.name, + })) + } + }, + _ => not_impl_err!( + "Direct reference with types other than StructField is not supported" + ), + }, + _ => not_impl_err!("unsupported field ref type"), + } +} + /// Build [`Expr`] from its name and required inputs. struct BuiltinExprBuilder { expr_name: String, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c5f1278be6e0..58c98c10d963 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -19,7 +19,9 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; -use datafusion::logical_expr::{CrossJoin, Distinct, Like, WindowFrameUnits}; +use datafusion::logical_expr::{ + CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, +}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, @@ -28,8 +30,8 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::common::DFSchemaRef; use datafusion::common::{exec_err, internal_err, not_impl_err}; +use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ @@ -39,8 +41,9 @@ use datafusion::logical_expr::expr::{ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; +use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::CrossRel; +use substrait::proto::{CrossRel, ExchangeRel}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -410,6 +413,53 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Project(project_rel)), })) } + LogicalPlan::Repartition(repartition) => { + let input = + to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let partition_count = match repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(num) => num, + Partitioning::Hash(_, num) => num, + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let exchange_kind = match &repartition.partitioning_scheme { + Partitioning::RoundRobinBatch(_) => { + ExchangeKind::RoundRobin(RoundRobin::default()) + } + Partitioning::Hash(exprs, _) => { + let fields = exprs + .iter() + .map(|e| { + try_to_substrait_field_reference( + e, + repartition.input.schema(), + ) + }) + .collect::>>()?; + ExchangeKind::ScatterByFields(ScatterFields { fields }) + } + Partitioning::DistributeBy(_) => { + return not_impl_err!( + "Physical plan does not support DistributeBy partitioning" + ) + } + }; + let exchange_rel = ExchangeRel { + common: None, + input: Some(input), + exchange_kind: Some(exchange_kind), + advanced_extension: None, + partition_count: partition_count as i32, + targets: vec![], + }; + Ok(Box::new(Rel { + rel_type: Some(RelType::Exchange(Box::new(exchange_rel))), + })) + } LogicalPlan::Extension(extension_plan) => { let extension_bytes = ctx .state() @@ -1751,6 +1801,31 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { } } +/// Try to convert an [Expr] to a [FieldReference]. +/// Returns `Err` if the [Expr] is not a [Expr::Column]. +fn try_to_substrait_field_reference( + expr: &Expr, + schema: &DFSchemaRef, +) -> Result { + match expr { + Expr::Column(col) => { + let index = schema.index_of_column(col)?; + Ok(FieldReference { + reference_type: Some(ReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(reference_segment::ReferenceType::StructField( + Box::new(reference_segment::StructField { + field: index as i32, + child: None, + }), + )), + })), + root_type: None, + }) + } + _ => substrait_err!("Expect a `Column` expr, but found {expr:?}"), + } +} + fn substrait_sort_field( expr: &Expr, schema: &DFSchemaRef, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 691fba864449..fb829e6b242d 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -32,7 +32,7 @@ use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ - Extension, LogicalPlan, UserDefinedLogicalNode, Volatility, + Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -698,6 +698,40 @@ async fn roundtrip_aggregate_udf() -> Result<()> { roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await } +#[tokio::test] +async fn roundtrip_repartition_roundrobin() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::RoundRobinBatch(8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_repartition_hash() -> Result<()> { + let ctx = create_context().await?; + let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(scan_plan), + partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), + }); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + let plan2 = ctx.state().optimize(&plan2)?; + + assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + Ok(()) +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type {