Skip to content

Commit 7f0c71b

Browse files
Customize window frame support for dialect (apache#14288)
* Customize window frame support for dialect (#70) * Customize window frame support for dialect * fix: ignore frame only when frame implies no frame * Add comments. move the window frame determine logic to dialect method * Update datafusion/sql/src/unparser/dialect.rs * Update test case * fix --------- Co-authored-by: Phillip LeBlanc <[email protected]> * fix clippy --------- Co-authored-by: Phillip LeBlanc <[email protected]>
1 parent ecc5694 commit 7f0c71b

File tree

2 files changed

+87
-6
lines changed

2 files changed

+87
-6
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ use datafusion_expr::Expr;
2323
use regex::Regex;
2424
use sqlparser::tokenizer::Span;
2525
use sqlparser::{
26-
ast::{self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo},
26+
ast::{
27+
self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo, WindowFrameBound,
28+
},
2729
keywords::ALL_KEYWORDS,
2830
};
2931

@@ -153,6 +155,18 @@ pub trait Dialect: Send + Sync {
153155
Ok(None)
154156
}
155157

158+
/// Allows the dialect to choose to omit window frame in unparsing
159+
/// based on function name and window frame bound
160+
/// Returns false if specific function name / window frame bound indicates no window frame is needed in unparsing
161+
fn window_func_support_window_frame(
162+
&self,
163+
_func_name: &str,
164+
_start_bound: &WindowFrameBound,
165+
_end_bound: &WindowFrameBound,
166+
) -> bool {
167+
true
168+
}
169+
156170
/// Extends the dialect's default rules for unparsing scalar functions.
157171
/// This is useful for supporting application-specific UDFs or custom engine extensions.
158172
fn with_custom_scalar_overrides(
@@ -500,6 +514,7 @@ pub struct CustomDialect {
500514
supports_column_alias_in_table_alias: bool,
501515
requires_derived_table_alias: bool,
502516
division_operator: BinaryOperator,
517+
window_func_support_window_frame: bool,
503518
full_qualified_col: bool,
504519
unnest_as_table_factor: bool,
505520
}
@@ -527,6 +542,7 @@ impl Default for CustomDialect {
527542
supports_column_alias_in_table_alias: true,
528543
requires_derived_table_alias: false,
529544
division_operator: BinaryOperator::Divide,
545+
window_func_support_window_frame: true,
530546
full_qualified_col: false,
531547
unnest_as_table_factor: false,
532548
}
@@ -634,6 +650,15 @@ impl Dialect for CustomDialect {
634650
self.division_operator.clone()
635651
}
636652

653+
fn window_func_support_window_frame(
654+
&self,
655+
_func_name: &str,
656+
_start_bound: &WindowFrameBound,
657+
_end_bound: &WindowFrameBound,
658+
) -> bool {
659+
self.window_func_support_window_frame
660+
}
661+
637662
fn full_qualified_col(&self) -> bool {
638663
self.full_qualified_col
639664
}
@@ -675,6 +700,7 @@ pub struct CustomDialectBuilder {
675700
supports_column_alias_in_table_alias: bool,
676701
requires_derived_table_alias: bool,
677702
division_operator: BinaryOperator,
703+
window_func_support_window_frame: bool,
678704
full_qualified_col: bool,
679705
unnest_as_table_factor: bool,
680706
}
@@ -708,6 +734,7 @@ impl CustomDialectBuilder {
708734
supports_column_alias_in_table_alias: true,
709735
requires_derived_table_alias: false,
710736
division_operator: BinaryOperator::Divide,
737+
window_func_support_window_frame: true,
711738
full_qualified_col: false,
712739
unnest_as_table_factor: false,
713740
}
@@ -733,6 +760,7 @@ impl CustomDialectBuilder {
733760
.supports_column_alias_in_table_alias,
734761
requires_derived_table_alias: self.requires_derived_table_alias,
735762
division_operator: self.division_operator,
763+
window_func_support_window_frame: self.window_func_support_window_frame,
736764
full_qualified_col: self.full_qualified_col,
737765
unnest_as_table_factor: self.unnest_as_table_factor,
738766
}
@@ -857,6 +885,14 @@ impl CustomDialectBuilder {
857885
self
858886
}
859887

888+
pub fn with_window_func_support_window_frame(
889+
mut self,
890+
window_func_support_window_frame: bool,
891+
) -> Self {
892+
self.window_func_support_window_frame = window_func_support_window_frame;
893+
self
894+
}
895+
860896
/// Customize the dialect to allow full qualified column names
861897
pub fn with_full_qualified_col(mut self, full_qualified_col: bool) -> Self {
862898
self.full_qualified_col = full_qualified_col;

datafusion/sql/src/unparser/expr.rs

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,18 +217,29 @@ impl Unparser<'_> {
217217

218218
let start_bound = self.convert_bound(&window_frame.start_bound)?;
219219
let end_bound = self.convert_bound(&window_frame.end_bound)?;
220+
221+
let window_frame = if self.dialect.window_func_support_window_frame(
222+
func_name,
223+
&start_bound,
224+
&end_bound,
225+
) {
226+
Some(ast::WindowFrame {
227+
units,
228+
start_bound,
229+
end_bound: Some(end_bound),
230+
})
231+
} else {
232+
None
233+
};
234+
220235
let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
221236
window_name: None,
222237
partition_by: partition_by
223238
.iter()
224239
.map(|e| self.expr_to_sql_inner(e))
225240
.collect::<Result<Vec<_>>>()?,
226241
order_by,
227-
window_frame: Some(ast::WindowFrame {
228-
units,
229-
start_bound,
230-
end_bound: Option::from(end_bound),
231-
}),
242+
window_frame,
232243
}));
233244

234245
Ok(ast::Expr::Function(Function {
@@ -1632,6 +1643,7 @@ mod tests {
16321643
use datafusion_functions_aggregate::expr_fn::sum;
16331644
use datafusion_functions_nested::expr_fn::{array_element, make_array};
16341645
use datafusion_functions_nested::map::map;
1646+
use datafusion_functions_window::rank::rank_udwf;
16351647
use datafusion_functions_window::row_number::row_number_udwf;
16361648

16371649
use crate::unparser::dialect::{
@@ -2677,6 +2689,39 @@ mod tests {
26772689
Ok(())
26782690
}
26792691

2692+
#[test]
2693+
fn test_window_func_support_window_frame() -> Result<()> {
2694+
let default_dialect: Arc<dyn Dialect> =
2695+
Arc::new(CustomDialectBuilder::new().build());
2696+
2697+
let test_dialect: Arc<dyn Dialect> = Arc::new(
2698+
CustomDialectBuilder::new()
2699+
.with_window_func_support_window_frame(false)
2700+
.build(),
2701+
);
2702+
2703+
for (dialect, expected) in [
2704+
(
2705+
default_dialect,
2706+
"rank() OVER (ORDER BY a ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
2707+
),
2708+
(test_dialect, "rank() OVER (ORDER BY a ASC NULLS FIRST)"),
2709+
] {
2710+
let unparser = Unparser::new(dialect.as_ref());
2711+
let func = WindowFunctionDefinition::WindowUDF(rank_udwf());
2712+
let mut window_func = WindowFunction::new(func, vec![]);
2713+
window_func.order_by = vec![Sort::new(col("a"), true, true)];
2714+
let expr = Expr::WindowFunction(window_func);
2715+
let ast = unparser.expr_to_sql(&expr)?;
2716+
2717+
let actual = ast.to_string();
2718+
let expected = expected.to_string();
2719+
2720+
assert_eq!(actual, expected);
2721+
}
2722+
Ok(())
2723+
}
2724+
26802725
#[test]
26812726
fn test_utf8_view_to_sql() -> Result<()> {
26822727
let dialect = CustomDialectBuilder::new()

0 commit comments

Comments
 (0)