From 6b63dccea0325a525c4bcfcae6f7eb283ad42d6b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 9 Jun 2024 17:51:18 -0700 Subject: [PATCH 1/4] feat: Add method to add analyzer rules to SessionContext Signed-off-by: Kevin Su --- datafusion/core/src/execution/context/mod.rs | 10 ++++++++++ datafusion/core/src/execution/session_state.rs | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e247263964cd..61482fad4b63 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -75,6 +75,7 @@ use url::Url; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; +use datafusion_optimizer::AnalyzerRule; mod avro; mod csv; @@ -331,6 +332,15 @@ impl SessionContext { self } + /// Adds an analyzer rule to the `SessionState` in the current `SessionContext`. + pub fn add_analyzer_rule( + self, + analyzer_rule: Arc, + ) -> Self { + self.state.write().add_analyzer_rule(analyzer_rule); + self + } + /// Registers an [`ObjectStore`] to be used with a specific URL prefix. /// /// See [`RuntimeEnv::register_object_store`] for more details. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index fed101bd239b..357d4be70031 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -387,9 +387,9 @@ impl SessionState { /// Add `analyzer_rule` to the end of the list of /// [`AnalyzerRule`]s used to rewrite queries. pub fn add_analyzer_rule( - mut self, + &mut self, analyzer_rule: Arc, - ) -> Self { + ) -> &Self { self.analyzer.rules.push(analyzer_rule); self } From e13471939c57c66f5b8b2ffd296fc66e5b5d6bb1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 12 Jun 2024 11:10:57 -0700 Subject: [PATCH 2/4] Add a test Signed-off-by: Kevin Su --- .../tests/user_defined/user_defined_plan.rs | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 07622e48afaf..7a7eaa8b8dac 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -92,6 +92,9 @@ use datafusion::{ }; use async_trait::async_trait; +use datafusion_common::config::ConfigOptions; +use datafusion_optimizer::analyzer::inline_table_scan::InlineTableScan; +use datafusion_optimizer::AnalyzerRule; use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches @@ -246,10 +249,12 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime) + let mut state = SessionState::new_with_config_rt(config, runtime) .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - SessionContext::new_with_state(state) + state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + let ctx = SessionContext::new_with_state(state); + ctx.add_analyzer_rule(Arc::new(MyAnalyzerRule {})) } // ------ The implementation of the TopK code follows ----- @@ -619,3 +624,15 @@ impl RecordBatchStream for TopKReader { self.input.schema() } } + +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} From 62c5ab30c3a55e5663dcf33393cf4a619a5daf0c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 12 Jun 2024 11:15:10 -0700 Subject: [PATCH 3/4] Add analyze_plan Signed-off-by: Kevin Su --- .../tests/user_defined/user_defined_plan.rs | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 7a7eaa8b8dac..f2e47f28422c 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -93,6 +93,9 @@ use datafusion::{ use async_trait::async_trait; use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::ScalarValue; +use datafusion_expr::Filter; use datafusion_optimizer::analyzer::inline_table_scan::InlineTableScan; use datafusion_optimizer::AnalyzerRule; use futures::{Stream, StreamExt}; @@ -636,3 +639,37 @@ impl AnalyzerRule for MyAnalyzerRule { "my_analyzer_rule" } } + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(|plan| { + Ok(match plan { + LogicalPlan::Filter(filter) => { + let predicate = Self::analyze_expr(filter.predicate.clone())?; + Transformed::yes(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input, + )?)) + } + _ => Transformed::no(plan), + }) + }) + .data() + } + + fn analyze_expr(expr: Expr) -> Result { + expr.transform(|expr| { + // closure is invoked for all sub expressions + Ok(match expr { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(expr), + }) + }) + .data() + } +} From 925f2c6a9796f102b25ed39d959a42bcb18ce209 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jun 2024 00:55:47 -0700 Subject: [PATCH 4/4] update test Signed-off-by: Kevin Su --- .../tests/user_defined/user_defined_plan.rs | 86 +++++++++++++------ 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index f2e47f28422c..09de6b134ea7 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -95,8 +95,7 @@ use async_trait::async_trait; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::Filter; -use datafusion_optimizer::analyzer::inline_table_scan::InlineTableScan; +use datafusion_expr::Projection; use datafusion_optimizer::AnalyzerRule; use futures::{Stream, StreamExt}; @@ -136,11 +135,13 @@ async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result Result<()> { @@ -168,6 +169,34 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re Ok(()) } +// Run the query using the specified execution context and compare it +// to the known result +async fn run_and_compare_query_with_analyzer_rule( + mut ctx: SessionContext, + description: &str, +) -> Result<()> { + let expected = vec![ + "+------------+--------------------------+", + "| UInt64(42) | arrow_typeof(UInt64(42)) |", + "+------------+--------------------------+", + "| 42 | UInt64 |", + "+------------+--------------------------+", + ]; + + let s = exec_sql(&mut ctx, QUERY2).await?; + let actual = s.lines().collect::>(); + + assert_eq!( + expected, + actual, + "output mismatch for {}. Expectedn\n{}Actual:\n{}", + description, + expected.join("\n"), + s + ); + Ok(()) +} + // Run the query using the specified execution context and compare it // to the known result async fn run_and_compare_query_with_auto_schemas( @@ -212,6 +241,13 @@ async fn normal_query() -> Result<()> { run_and_compare_query(ctx, "Default context").await } +#[tokio::test] +// Run the query using default planners, optimizer and custom analyzer rule +async fn normal_query_with_analyzer() -> Result<()> { + let ctx = SessionContext::new().add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await +} + #[tokio::test] // Run the query using topk optimization async fn topk_query() -> Result<()> { @@ -256,8 +292,7 @@ fn make_topk_context() -> SessionContext { .with_query_planner(Arc::new(TopKQueryPlanner {})) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); state.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); - let ctx = SessionContext::new_with_state(state); - ctx.add_analyzer_rule(Arc::new(MyAnalyzerRule {})) + SessionContext::new_with_state(state) } // ------ The implementation of the TopK code follows ----- @@ -644,11 +679,11 @@ impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { plan.transform(|plan| { Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, + LogicalPlan::Projection(projection) => { + let expr = Self::analyze_expr(projection.expr.clone())?; + Transformed::yes(LogicalPlan::Projection(Projection::try_new( + expr, + projection.input, )?)) } _ => Transformed::no(plan), @@ -657,19 +692,22 @@ impl MyAnalyzerRule { .data() } - fn analyze_expr(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), + fn analyze_expr(expr: Vec) -> Result> { + expr.into_iter() + .map(|e| { + e.transform(|e| { + Ok(match e { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(e), + }) + }) + .data() }) - }) - .data() + .collect() } }