diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 87ed1b8f4140..fbaa402e703c 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -27,7 +27,7 @@ use sqlparser::{ use datafusion_common::Result; -use super::{utils::date_part_to_sql, Unparser}; +use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; /// `Dialect` to use for Unparsing /// @@ -80,6 +80,11 @@ pub trait Dialect: Send + Sync { DateFieldExtractStyle::DatePart } + /// The character length extraction style to use: `CharacterLengthStyle` + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::CharacterLength + } + /// The SQL type to use for Arrow Int64 unparsing /// Most dialects use BigInt, but some, like MySQL, require SIGNED fn int64_cast_dtype(&self) -> ast::DataType { @@ -176,6 +181,17 @@ pub enum DateFieldExtractStyle { Strftime, } +/// `CharacterLengthStyle` to use for unparsing +/// +/// Different DBMSs uses different names for function calculating the number of characters in the string +/// `Length` style uses length(x) +/// `SQLStandard` style uses character_length(x) +#[derive(Clone, Copy, PartialEq)] +pub enum CharacterLengthStyle { + Length, + CharacterLength, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -271,6 +287,35 @@ impl PostgreSqlDialect { } } +pub struct DuckDBDialect {} + +impl Dialect for DuckDBDialect { + fn identifier_quote_style(&self, _: &str) -> Option { + Some('"') + } + + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "character_length" { + return character_length_to_sql( + unparser, + self.character_length_style(), + args, + ); + } + + Ok(None) + } +} + pub struct MySqlDialect {} impl Dialect for MySqlDialect { @@ -347,6 +392,10 @@ impl Dialect for SqliteDialect { ast::DataType::Text } + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + fn supports_column_alias_in_table_alias(&self) -> bool { false } @@ -357,11 +406,15 @@ impl Dialect for SqliteDialect { func_name: &str, args: &[Expr], ) -> Result> { - if func_name == "date_part" { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + match func_name { + "date_part" => { + date_part_to_sql(unparser, self.date_field_extract_style(), args) + } + "character_length" => { + character_length_to_sql(unparser, self.character_length_style(), args) + } + _ => Ok(None), } - - Ok(None) } } @@ -374,6 +427,7 @@ pub struct CustomDialect { utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, + character_length_style: CharacterLengthStyle, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -395,6 +449,7 @@ impl Default for CustomDialect { utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, + character_length_style: CharacterLengthStyle::CharacterLength, int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -454,6 +509,10 @@ impl Dialect for CustomDialect { self.date_field_extract_style } + fn character_length_style(&self) -> CharacterLengthStyle { + self.character_length_style + } + fn int64_cast_dtype(&self) -> ast::DataType { self.int64_cast_dtype.clone() } @@ -488,11 +547,15 @@ impl Dialect for CustomDialect { func_name: &str, args: &[Expr], ) -> Result> { - if func_name == "date_part" { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + match func_name { + "date_part" => { + date_part_to_sql(unparser, self.date_field_extract_style(), args) + } + "character_length" => { + character_length_to_sql(unparser, self.character_length_style(), args) + } + _ => Ok(None), } - - Ok(None) } fn requires_derived_table_alias(&self) -> bool { @@ -527,6 +590,7 @@ pub struct CustomDialectBuilder { utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, + character_length_style: CharacterLengthStyle, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -554,6 +618,7 @@ impl CustomDialectBuilder { utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, + character_length_style: CharacterLengthStyle::CharacterLength, int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -578,6 +643,7 @@ impl CustomDialectBuilder { utf8_cast_dtype: self.utf8_cast_dtype, large_utf8_cast_dtype: self.large_utf8_cast_dtype, date_field_extract_style: self.date_field_extract_style, + character_length_style: self.character_length_style, int64_cast_dtype: self.int64_cast_dtype, int32_cast_dtype: self.int32_cast_dtype, timestamp_cast_dtype: self.timestamp_cast_dtype, @@ -620,6 +686,15 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific character_length_style listed in `CharacterLengthStyle` + pub fn with_character_length_style( + mut self, + character_length_style: CharacterLengthStyle, + ) -> Self { + self.character_length_style = character_length_style; + self + } + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { self.float64_ast_dtype = float64_ast_dtype; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8f6ffa51f76a..d09bd6e8b90c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1488,8 +1488,8 @@ mod tests { use datafusion_functions_window::row_number::row_number_udwf; use crate::unparser::dialect::{ - CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, - PostgreSqlDialect, + CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, + Dialect, PostgreSqlDialect, }; use super::*; @@ -2007,6 +2007,33 @@ mod tests { Ok(()) } + #[test] + fn test_character_length_scalar_to_expr() { + let tests = [ + (CharacterLengthStyle::Length, "length(x)"), + (CharacterLengthStyle::CharacterLength, "character_length(x)"), + ]; + + for (style, expected) in tests { + let dialect = CustomDialectBuilder::new() + .with_character_length_style(style) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = ScalarUDF::new_from_impl( + datafusion_functions::unicode::character_length::CharacterLengthFunc::new( + ), + ) + .call(vec![col("x")]); + + let ast = unparser.expr_to_sql(&expr).expect("to be unparsed"); + + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + #[test] fn test_interval_scalar_to_expr() { let tests = [ diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 284956cef195..d0f80da83d63 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -28,7 +28,10 @@ use datafusion_expr::{ }; use sqlparser::ast; -use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; +use super::{ + dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, + rewrite::TableAliasRewriter, Unparser, +}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -445,3 +448,19 @@ pub(crate) fn date_part_to_sql( Ok(None) } + +pub(crate) fn character_length_to_sql( + unparser: &Unparser, + style: CharacterLengthStyle, + character_length_args: &[Expr], +) -> Result> { + let func_name = match style { + CharacterLengthStyle::CharacterLength => "character_length", + CharacterLengthStyle::Length => "length", + }; + + Ok(Some(unparser.scalar_function_to_sql( + func_name, + character_length_args, + )?)) +}