@@ -23,6 +23,7 @@ use arrow::{
2323 row:: { RowConverter , Rows , SortField } ,
2424} ;
2525use datafusion_expr:: { ColumnarValue , Operator } ;
26+ use parking_lot:: RwLock ;
2627use std:: mem:: size_of;
2728use std:: { cmp:: Ordering , collections:: BinaryHeap , sync:: Arc } ;
2829
@@ -121,13 +122,36 @@ pub struct TopK {
121122 /// Common sort prefix between the input and the sort expressions to allow early exit optimization
122123 common_sort_prefix : Arc < [ PhysicalSortExpr ] > ,
123124 /// Filter matching the state of the `TopK` heap used for dynamic filter pushdown
124- filter : Option < Arc < DynamicFilterPhysicalExpr > > ,
125+ filter : TopKDynamicFilters ,
125126 /// If true, indicates that all rows of subsequent batches are guaranteed
126127 /// to be greater (by byte order, after row conversion) than the top K,
127128 /// which means the top K won't change and the computation can be finished early.
128129 pub ( crate ) finished : bool ,
129130}
130131
132+ #[ derive( Debug , Clone ) ]
133+ pub struct TopKDynamicFilters {
134+ /// The current *global* threshold for the dynamic filter.
135+ /// This is shared across all partitions and is updated by any of them.
136+ thresholds : Arc < RwLock < Option < Vec < ScalarValue > > > > ,
137+ /// The expression used to evaluate the dynamic filter
138+ expr : Arc < DynamicFilterPhysicalExpr > ,
139+ }
140+
141+ impl TopKDynamicFilters {
142+ /// Create a new `TopKDynamicFilters` with the given expression
143+ pub fn new ( expr : Arc < DynamicFilterPhysicalExpr > ) -> Self {
144+ Self {
145+ thresholds : Arc :: new ( RwLock :: new ( None ) ) ,
146+ expr,
147+ }
148+ }
149+
150+ pub fn expr ( & self ) -> Arc < DynamicFilterPhysicalExpr > {
151+ Arc :: clone ( & self . expr )
152+ }
153+ }
154+
131155// Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter
132156const ESTIMATED_BYTES_PER_ROW : usize = 20 ;
133157
@@ -160,7 +184,7 @@ impl TopK {
160184 batch_size : usize ,
161185 runtime : Arc < RuntimeEnv > ,
162186 metrics : & ExecutionPlanMetricsSet ,
163- filter : Option < Arc < DynamicFilterPhysicalExpr > > ,
187+ filter : TopKDynamicFilters ,
164188 ) -> Result < Self > {
165189 let reservation = MemoryConsumer :: new ( format ! ( "TopK[{partition_id}]" ) )
166190 . register ( & runtime. memory_pool ) ;
@@ -214,41 +238,39 @@ impl TopK {
214238
215239 let mut selected_rows = None ;
216240
217- if let Some ( filter) = self . filter . as_ref ( ) {
218- // If a filter is provided, update it with the new rows
219- let filter = filter. current ( ) ?;
220- let filtered = filter. evaluate ( & batch) ?;
221- let num_rows = batch. num_rows ( ) ;
222- let array = filtered. into_array ( num_rows) ?;
223- let mut filter = array. as_boolean ( ) . clone ( ) ;
224- let true_count = filter. true_count ( ) ;
225- if true_count == 0 {
226- // nothing to filter, so no need to update
227- return Ok ( ( ) ) ;
241+ // If a filter is provided, update it with the new rows
242+ let filter = self . filter . expr . current ( ) ?;
243+ let filtered = filter. evaluate ( & batch) ?;
244+ let num_rows = batch. num_rows ( ) ;
245+ let array = filtered. into_array ( num_rows) ?;
246+ let mut filter = array. as_boolean ( ) . clone ( ) ;
247+ let true_count = filter. true_count ( ) ;
248+ if true_count == 0 {
249+ // nothing to filter, so no need to update
250+ return Ok ( ( ) ) ;
251+ }
252+ // only update the keys / rows if the filter does not match all rows
253+ if true_count < num_rows {
254+ // Indices in `set_indices` should be correct if filter contains nulls
255+ // So we prepare the filter here. Note this is also done in the `FilterBuilder`
256+ // so there is no overhead to do this here.
257+ if filter. nulls ( ) . is_some ( ) {
258+ filter = prep_null_mask_filter ( & filter) ;
228259 }
229- // only update the keys / rows if the filter does not match all rows
230- if true_count < num_rows {
231- // Indices in `set_indices` should be correct if filter contains nulls
232- // So we prepare the filter here. Note this is also done in the `FilterBuilder`
233- // so there is no overhead to do this here.
234- if filter. nulls ( ) . is_some ( ) {
235- filter = prep_null_mask_filter ( & filter) ;
236- }
237260
238- let filter_predicate = FilterBuilder :: new ( & filter) ;
239- let filter_predicate = if sort_keys. len ( ) > 1 {
240- // Optimize filter when it has multiple sort keys
241- filter_predicate. optimize ( ) . build ( )
242- } else {
243- filter_predicate. build ( )
244- } ;
245- selected_rows = Some ( filter) ;
246- sort_keys = sort_keys
247- . iter ( )
248- . map ( |key| filter_predicate. filter ( key) . map_err ( |x| x. into ( ) ) )
249- . collect :: < Result < Vec < _ > > > ( ) ?;
250- }
251- } ;
261+ let filter_predicate = FilterBuilder :: new ( & filter) ;
262+ let filter_predicate = if sort_keys. len ( ) > 1 {
263+ // Optimize filter when it has multiple sort keys
264+ filter_predicate. optimize ( ) . build ( )
265+ } else {
266+ filter_predicate. build ( )
267+ } ;
268+ selected_rows = Some ( filter) ;
269+ sort_keys = sort_keys
270+ . iter ( )
271+ . map ( |key| filter_predicate. filter ( key) . map_err ( |x| x. into ( ) ) )
272+ . collect :: < Result < Vec < _ > > > ( ) ?;
273+ }
252274 // reuse existing `Rows` to avoid reallocations
253275 let rows = & mut self . scratch_rows ;
254276 rows. clear ( ) ;
@@ -319,13 +341,88 @@ impl TopK {
319341 /// (a > 2 OR (a = 2 AND b < 3))
320342 /// ```
321343 fn update_filter ( & mut self ) -> Result < ( ) > {
322- let Some ( filter) = & self . filter else {
323- return Ok ( ( ) ) ;
324- } ;
325344 let Some ( thresholds) = self . heap . get_threshold_values ( & self . expr ) ? else {
326345 return Ok ( ( ) ) ;
327346 } ;
328347
348+ // Are the new thresholds more selective than our existing ones?
349+ let should_update = {
350+ if let Some ( current) = self . filter . thresholds . write ( ) . as_mut ( ) {
351+ assert ! ( current. len( ) == thresholds. len( ) ) ;
352+ // Check if new thresholds are more selective than current ones
353+ let mut more_selective = false ;
354+ for ( ( current_value, new_value) , sort_expr) in
355+ current. iter ( ) . zip ( thresholds. iter ( ) ) . zip ( self . expr . iter ( ) )
356+ {
357+ // Handle null cases
358+ let ( current_is_null, new_is_null) =
359+ ( current_value. is_null ( ) , new_value. is_null ( ) ) ;
360+
361+ match ( current_is_null, new_is_null) {
362+ ( true , true ) => {
363+ // Both null, continue checking next values
364+ }
365+ ( true , false ) => {
366+ // Current is null, new is not null
367+ // For nulls_first: null < non-null, so new value is less selective
368+ // For nulls_last: null > non-null, so new value is more selective
369+ more_selective = !sort_expr. options . nulls_first ;
370+ break ;
371+ }
372+ ( false , true ) => {
373+ // Current is not null, new is null
374+ // For nulls_first: non-null > null, so new value is more selective
375+ // For nulls_last: non-null < null, so new value is less selective
376+ more_selective = sort_expr. options . nulls_first ;
377+ break ;
378+ }
379+ ( false , false ) => {
380+ // Neither is null, compare values
381+ match current_value. partial_cmp ( new_value) {
382+ Some ( ordering) => {
383+ match ordering {
384+ Ordering :: Equal => {
385+ // Continue checking next values
386+ }
387+ Ordering :: Less => {
388+ // For descending sort: new > current means more selective
389+ // For ascending sort: new > current means less selective
390+ more_selective = sort_expr. options . descending ;
391+ break ;
392+ }
393+ Ordering :: Greater => {
394+ // For descending sort: new < current means less selective
395+ // For ascending sort: new < current means more selective
396+ more_selective =
397+ !sort_expr. options . descending ;
398+ break ;
399+ }
400+ }
401+ }
402+ None => {
403+ // If values can't be compared, don't update
404+ more_selective = false ;
405+ break ;
406+ }
407+ }
408+ }
409+ }
410+ }
411+ // If the new thresholds are more selective, update the current ones
412+ if more_selective {
413+ * current = thresholds. clone ( ) ;
414+ }
415+ more_selective
416+ } else {
417+ // No current thresholds, so update with the new ones
418+ true
419+ }
420+ } ;
421+
422+ if !should_update {
423+ return Ok ( ( ) ) ;
424+ }
425+
329426 // Create filter expressions for each threshold
330427 let mut filters: Vec < Arc < dyn PhysicalExpr > > =
331428 Vec :: with_capacity ( thresholds. len ( ) ) ;
@@ -405,7 +502,7 @@ impl TopK {
405502
406503 if let Some ( predicate) = dynamic_predicate {
407504 if !predicate. eq ( & lit ( true ) ) {
408- filter. update ( predicate) ?;
505+ self . filter . expr . update ( predicate) ?;
409506 }
410507 }
411508
@@ -1053,7 +1150,10 @@ mod tests {
10531150 2 ,
10541151 runtime,
10551152 & metrics,
1056- None ,
1153+ TopKDynamicFilters :: new ( Arc :: new ( DynamicFilterPhysicalExpr :: new (
1154+ vec ! [ ] ,
1155+ lit ( true ) ,
1156+ ) ) ) ,
10571157 ) ?;
10581158
10591159 // Create the first batch with two columns:
0 commit comments