Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ use url::Url;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_optimizer::AnalyzerRule;

mod avro;
mod csv;
Expand Down Expand Up @@ -331,6 +332,15 @@ impl SessionContext {
self
}

/// Adds an analyzer rule to the `SessionState` in the current `SessionContext`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have an examples, or doc test for the this method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree -- we are tracking adding an example for how to use custom analyzer rules in #10855, so perhaps we can add the example as part of that ticket (i think @goldmedal said he may have some time to work on that eventually)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I have a WIP PR to improve these examples -- I hope to get it up for review sometime this weekend

pub fn add_analyzer_rule(
self,
analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
) -> Self {
self.state.write().add_analyzer_rule(analyzer_rule);
self
}

/// Registers an [`ObjectStore`] to be used with a specific URL prefix.
///
/// See [`RuntimeEnv::register_object_store`] for more details.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,9 @@ impl SessionState {
/// Add `analyzer_rule` to the end of the list of
/// [`AnalyzerRule`]s used to rewrite queries.
pub fn add_analyzer_rule(
mut self,
&mut self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think technically this is an API change as now the api takes a mut reference rather than self

However, I think the change is good as now add_analyzer_rule looks more like a standard mutation style api (that takes &mut self) rather than a builder style (self)

What do you think about adding an api to make things consistent? (we could do this as a separate PR)

    pub fn with_analyzer_rule(
      mut self, 
        analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
  ) -> Self {
..
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also update add_optimizer_rule add_physical_optimizer_rule? (mut self -> &mut self)

Also, do we need to add with_optimizer_rule and with_physical_optimizer_rule to make it consistent?
If so, I can do it in a separate PR.

analyzer_rule: Arc<dyn AnalyzerRule + Send + Sync>,
) -> Self {
) -> &Self {
self.analyzer.rules.push(analyzer_rule);
self
}
Expand Down
99 changes: 95 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ use datafusion::{
};

use async_trait::async_trait;
use datafusion_common::tree_node::Transformed;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::Projection;
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use futures::{Stream, StreamExt};

/// Execute the specified sql and return the resulting record batches
Expand Down Expand Up @@ -132,11 +136,13 @@ async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result<SessionC
Ok(ctx)
}

const QUERY1: &str = "SELECT * FROM sales limit 3";

const QUERY: &str =
"SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3";

const QUERY1: &str = "SELECT * FROM sales limit 3";

const QUERY2: &str = "SELECT 42, arrow_typeof(42)";

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Result<()> {
Expand Down Expand Up @@ -164,6 +170,34 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re
Ok(())
}

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query_with_analyzer_rule(
mut ctx: SessionContext,
description: &str,
) -> Result<()> {
let expected = vec![
"+------------+--------------------------+",
"| UInt64(42) | arrow_typeof(UInt64(42)) |",
"+------------+--------------------------+",
"| 42 | UInt64 |",
"+------------+--------------------------+",
];

let s = exec_sql(&mut ctx, QUERY2).await?;
let actual = s.lines().collect::<Vec<_>>();

assert_eq!(
expected,
actual,
"output mismatch for {}. Expectedn\n{}Actual:\n{}",
description,
expected.join("\n"),
s
);
Ok(())
}

// Run the query using the specified execution context and compare it
// to the known result
async fn run_and_compare_query_with_auto_schemas(
Expand Down Expand Up @@ -208,6 +242,13 @@ async fn normal_query() -> Result<()> {
run_and_compare_query(ctx, "Default context").await
}

#[tokio::test]
// Run the query using default planners, optimizer and custom analyzer rule
async fn normal_query_with_analyzer() -> Result<()> {
let ctx = SessionContext::new().add_analyzer_rule(Arc::new(MyAnalyzerRule {}));
run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await
}

#[tokio::test]
// Run the query using topk optimization
async fn topk_query() -> Result<()> {
Expand Down Expand Up @@ -248,9 +289,10 @@ async fn topk_plan() -> Result<()> {
fn make_topk_context() -> SessionContext {
let config = SessionConfig::new().with_target_partitions(48);
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionState::new_with_config_rt(config, runtime)
let mut state = SessionState::new_with_config_rt(config, runtime)
.with_query_planner(Arc::new(TopKQueryPlanner {}))
.add_optimizer_rule(Arc::new(TopKOptimizerRule {}));
state.add_analyzer_rule(Arc::new(MyAnalyzerRule {}));
SessionContext::new_with_state(state)
}

Expand Down Expand Up @@ -633,3 +675,52 @@ impl RecordBatchStream for TopKReader {
self.input.schema()
}
}

struct MyAnalyzerRule {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add test that exercises these APIs?

For example, perhaps a test that does something like

select 42, arrow_typeof(42)

Which I think this code will print out 42 and UInt?


impl AnalyzerRule for MyAnalyzerRule {
fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result<LogicalPlan> {
Self::analyze_plan(plan)
}

fn name(&self) -> &str {
"my_analyzer_rule"
}
}

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(|plan| {
Ok(match plan {
LogicalPlan::Projection(projection) => {
let expr = Self::analyze_expr(projection.expr.clone())?;
Transformed::yes(LogicalPlan::Projection(Projection::try_new(
expr,
projection.input,
)?))
}
_ => Transformed::no(plan),
})
})
.data()
}

fn analyze_expr(expr: Vec<Expr>) -> Result<Vec<Expr>> {
expr.into_iter()
.map(|e| {
e.transform(|e| {
Ok(match e {
Expr::Literal(ScalarValue::Int64(i)) => {
// transform to UInt64
Transformed::yes(Expr::Literal(ScalarValue::UInt64(
i.map(|i| i as u64),
)))
}
_ => Transformed::no(e),
})
})
.data()
})
.collect()
}
}