-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Improve round scalar function unparsing for Postgres
#12744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,12 +18,17 @@ | |
| use std::sync::Arc; | ||
|
|
||
| use arrow_schema::TimeUnit; | ||
| use datafusion_expr::Expr; | ||
| use regex::Regex; | ||
| use sqlparser::{ | ||
| ast::{self, Ident, ObjectName, TimezoneInfo}, | ||
| ast::{self, Function, Ident, ObjectName, TimezoneInfo}, | ||
| keywords::ALL_KEYWORDS, | ||
| }; | ||
|
|
||
| use datafusion_common::Result; | ||
|
|
||
| use super::{utils::date_part_to_sql, Unparser}; | ||
|
|
||
| /// `Dialect` to use for Unparsing | ||
| /// | ||
| /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) | ||
|
|
@@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync { | |
| fn supports_column_alias_in_table_alias(&self) -> bool { | ||
| true | ||
| } | ||
|
|
||
| /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. | ||
| /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is | ||
| /// a custom implementation for the function. | ||
| fn scalar_function_to_sql_overrides( | ||
| &self, | ||
| _unparser: &Unparser, | ||
| _func_name: &str, | ||
| _args: &[Expr], | ||
| ) -> Result<Option<ast::Expr>> { | ||
| Ok(None) | ||
| } | ||
| } | ||
|
|
||
| /// `IntervalStyle` to use for unparsing | ||
|
|
@@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect { | |
| fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { | ||
| sqlparser::ast::DataType::DoublePrecision | ||
| } | ||
|
|
||
| fn scalar_function_to_sql_overrides( | ||
| &self, | ||
| unparser: &Unparser, | ||
| func_name: &str, | ||
| args: &[Expr], | ||
| ) -> Result<Option<ast::Expr>> { | ||
| if func_name == "round" { | ||
| return Ok(Some( | ||
| self.round_to_sql_enforce_numeric(unparser, func_name, args)?, | ||
| )); | ||
| } | ||
|
|
||
| Ok(None) | ||
| } | ||
| } | ||
|
|
||
| impl PostgreSqlDialect { | ||
| fn round_to_sql_enforce_numeric( | ||
| &self, | ||
| unparser: &Unparser, | ||
| func_name: &str, | ||
| args: &[Expr], | ||
| ) -> Result<ast::Expr> { | ||
| let mut args = unparser.function_args_to_sql(args)?; | ||
|
|
||
| // Enforce the first argument to be Numeric | ||
| if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) = | ||
| args.first_mut() | ||
| { | ||
| if let ast::Expr::Cast { data_type, .. } = expr { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternative approach considered was removing casting altogether, but it seems less robust as the argument can be a complex expression. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than checking after unparsing, maybe this code could check before unparsing -- as in get the type of Expr in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a good idea 👍. Unfortunately, we can't enforce |
||
| // Don't create an additional cast wrapper if we can update the existing one | ||
| *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None); | ||
| } else { | ||
| // Wrap the expression in a new cast | ||
| *expr = ast::Expr::Cast { | ||
| kind: ast::CastKind::Cast, | ||
| expr: Box::new(expr.clone()), | ||
| data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), | ||
| format: None, | ||
| }; | ||
| } | ||
| } | ||
|
|
||
| Ok(ast::Expr::Function(Function { | ||
| name: ast::ObjectName(vec![Ident { | ||
| value: func_name.to_string(), | ||
| quote_style: None, | ||
| }]), | ||
| args: ast::FunctionArguments::List(ast::FunctionArgumentList { | ||
| duplicate_treatment: None, | ||
| args, | ||
| clauses: vec![], | ||
| }), | ||
| filter: None, | ||
| null_treatment: None, | ||
| over: None, | ||
| within_group: vec![], | ||
| parameters: ast::FunctionArguments::None, | ||
| })) | ||
| } | ||
| } | ||
|
|
||
| pub struct MySqlDialect {} | ||
|
|
@@ -211,6 +289,19 @@ impl Dialect for MySqlDialect { | |
| ) -> ast::DataType { | ||
| ast::DataType::Datetime(None) | ||
| } | ||
|
|
||
| fn scalar_function_to_sql_overrides( | ||
| &self, | ||
| unparser: &Unparser, | ||
| func_name: &str, | ||
| args: &[Expr], | ||
| ) -> Result<Option<ast::Expr>> { | ||
| if func_name == "date_part" { | ||
| return date_part_to_sql(unparser, self.date_field_extract_style(), args); | ||
| } | ||
|
|
||
| Ok(None) | ||
| } | ||
| } | ||
|
|
||
| pub struct SqliteDialect {} | ||
|
|
@@ -231,6 +322,19 @@ impl Dialect for SqliteDialect { | |
| fn supports_column_alias_in_table_alias(&self) -> bool { | ||
| false | ||
| } | ||
|
|
||
| fn scalar_function_to_sql_overrides( | ||
| &self, | ||
| unparser: &Unparser, | ||
| func_name: &str, | ||
| args: &[Expr], | ||
| ) -> Result<Option<ast::Expr>> { | ||
| if func_name == "date_part" { | ||
| return date_part_to_sql(unparser, self.date_field_extract_style(), args); | ||
| } | ||
|
|
||
| Ok(None) | ||
| } | ||
| } | ||
|
|
||
| pub struct CustomDialect { | ||
|
|
@@ -339,6 +443,19 @@ impl Dialect for CustomDialect { | |
| fn supports_column_alias_in_table_alias(&self) -> bool { | ||
| self.supports_column_alias_in_table_alias | ||
| } | ||
|
|
||
| fn scalar_function_to_sql_overrides( | ||
| &self, | ||
| unparser: &Unparser, | ||
| func_name: &str, | ||
| args: &[Expr], | ||
| ) -> Result<Option<ast::Expr>> { | ||
| if func_name == "date_part" { | ||
| return date_part_to_sql(unparser, self.date_field_extract_style(), args); | ||
| } | ||
|
|
||
| Ok(None) | ||
| } | ||
| } | ||
|
|
||
| /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Could you please add documentation here explaining what this function does / is used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alamb - thank you for pointing to this, missed this part, documented the
scalar_function_to_sql_overrides