1717
1818//! Logical Expressions: [`Expr`]
1919
20- use std:: collections:: HashSet ;
20+ use std:: collections:: { HashMap , HashSet } ;
2121use std:: fmt:: { self , Display , Formatter , Write } ;
2222use std:: hash:: { Hash , Hasher } ;
2323use std:: mem;
@@ -1380,7 +1380,7 @@ impl Expr {
13801380 /// // refs contains "a" and "b"
13811381 /// assert_eq!(refs.len(), 2);
13821382 /// assert!(refs.contains(&Column::new_unqualified("a")));
1383- /// assert!(refs.contains(&Column::new_unqualified("b")));
1383+ /// assert!(refs.contains(&Column::new_unqualified("b")));
13841384 /// ```
13851385 pub fn column_refs ( & self ) -> HashSet < & Column > {
13861386 let mut using_columns = HashSet :: new ( ) ;
@@ -1401,6 +1401,41 @@ impl Expr {
14011401 . expect ( "traversal is infallable" ) ;
14021402 }
14031403
1404+ /// Return all references to columns and their occurrence counts in the expression.
1405+ ///
1406+ /// # Example
1407+ /// ```
1408+ /// # use std::collections::HashMap;
1409+ /// # use datafusion_common::Column;
1410+ /// # use datafusion_expr::col;
1411+ /// // For an expression `a + (b * a)`
1412+ /// let expr = col("a") + (col("b") * col("a"));
1413+ /// let mut refs = expr.column_refs_counts();
1414+ /// // refs contains "a" and "b"
1415+ /// assert_eq!(refs.len(), 2);
1416+ /// assert_eq!(*refs.get(&Column::new_unqualified("a")).unwrap(), 2);
1417+ /// assert_eq!(*refs.get(&Column::new_unqualified("b")).unwrap(), 1);
1418+ /// ```
1419+ pub fn column_refs_counts ( & self ) -> HashMap < & Column , usize > {
1420+ let mut map = HashMap :: new ( ) ;
1421+ self . add_column_ref_counts ( & mut map) ;
1422+ map
1423+ }
1424+
1425+ /// Adds references to all columns and their occurrence counts in the expression to
1426+ /// the map.
1427+ ///
1428+ /// See [`Self::column_refs`] for details
1429+ pub fn add_column_ref_counts < ' a > ( & ' a self , map : & mut HashMap < & ' a Column , usize > ) {
1430+ self . apply ( |expr| {
1431+ if let Expr :: Column ( col) = expr {
1432+ * map. entry ( col) . or_default ( ) += 1 ;
1433+ }
1434+ Ok ( TreeNodeRecursion :: Continue )
1435+ } )
1436+ . expect ( "traversal is infallable" ) ;
1437+ }
1438+
14041439 /// Returns true if there are any column references in this Expr
14051440 pub fn any_column_refs ( & self ) -> bool {
14061441 self . exists ( |expr| Ok ( matches ! ( expr, Expr :: Column ( _) ) ) )
0 commit comments