@@ -399,15 +399,18 @@ def pct_change(block: blocks.Block, periods: int = 1) -> blocks.Block:
399399 window_spec = windows .unbound ()
400400
401401 original_columns = block .value_columns
402- block , shift_columns = block .multi_apply_window_op (
403- original_columns , agg_ops .ShiftOp (periods ), window_spec = window_spec
404- )
405402 exprs = []
406- for original_col , shifted_col in zip (original_columns , shift_columns ):
407- change_expr = ops .sub_op .as_expr (original_col , shifted_col )
408- pct_change_expr = ops .div_op .as_expr (change_expr , shifted_col )
403+ for original_col in original_columns :
404+ shift_expr = agg_expressions .WindowExpression (
405+ agg_expressions .UnaryAggregation (
406+ agg_ops .ShiftOp (periods ), ex .deref (original_col )
407+ ),
408+ window_spec ,
409+ )
410+ change_expr = ops .sub_op .as_expr (original_col , shift_expr )
411+ pct_change_expr = ops .div_op .as_expr (change_expr , shift_expr )
409412 exprs .append (pct_change_expr )
410- return block .project_exprs (exprs , labels = column_labels , drop = True )
413+ return block .project_block_exprs (exprs , labels = column_labels , drop = True )
411414
412415
413416def rank (
@@ -428,16 +431,11 @@ def rank(
428431
429432 columns = columns or tuple (col for col in block .value_columns )
430433 labels = [block .col_id_to_label [id ] for id in columns ]
431- # Step 1: Calculate row numbers for each row
432- # Identify null values to be treated according to na_option param
433- rownum_col_ids = []
434- nullity_col_ids = []
434+
435+ result_exprs = []
435436 for col in columns :
436- block , nullity_col_id = block .apply_unary_op (
437- col ,
438- ops .isnull_op ,
439- )
440- nullity_col_ids .append (nullity_col_id )
437+ # Step 1: Calculate row numbers for each row
438+ # Identify null values to be treated according to na_option param
441439 window_ordering = (
442440 ordering .OrderingExpression (
443441 ex .deref (col ),
@@ -448,87 +446,66 @@ def rank(
448446 ),
449447 )
450448 # Count_op ignores nulls, so if na_option is "top" or "bottom", we instead count the nullity columns, where nulls have been mapped to bools
451- block , rownum_id = block . apply_window_op (
452- col if na_option == "keep" else nullity_col_id ,
453- agg_ops . dense_rank_op if method == "dense" else agg_ops . count_op ,
454- window_spec = windows . unbound (
455- grouping_keys = grouping_cols , ordering = window_ordering
456- )
449+ target_expr = (
450+ ex . deref ( col ) if na_option == "keep" else ops . isnull_op . as_expr ( col )
451+ )
452+ window_op = agg_ops . dense_rank_op if method == "dense" else agg_ops . count_op
453+ window_spec = (
454+ windows . unbound ( grouping_keys = grouping_cols , ordering = window_ordering )
457455 if method == "dense"
458456 else windows .rows (
459457 end = 0 , ordering = window_ordering , grouping_keys = grouping_cols
460- ),
461- skip_reproject_unsafe = (col != columns [- 1 ]),
458+ )
459+ )
460+ result_expr : ex .Expression = agg_expressions .WindowExpression (
461+ agg_expressions .UnaryAggregation (window_op , target_expr ), window_spec
462462 )
463463 if pct :
464- block , max_id = block .apply_window_op (
465- rownum_id , agg_ops .max_op , windows .unbound (grouping_keys = grouping_cols )
464+ result_expr = ops .div_op .as_expr (
465+ result_expr ,
466+ agg_expressions .WindowExpression (
467+ agg_expressions .UnaryAggregation (agg_ops .max_op , result_expr ),
468+ windows .unbound (grouping_keys = grouping_cols ),
469+ ),
466470 )
467- block , rownum_id = block .project_expr (ops .div_op .as_expr (rownum_id , max_id ))
468-
469- rownum_col_ids .append (rownum_id )
470-
471- # Step 2: Apply aggregate to groups of like input values.
472- # This step is skipped for method=='first' or 'dense'
473- if method in ["average" , "min" , "max" ]:
474- agg_op = {
475- "average" : agg_ops .mean_op ,
476- "min" : agg_ops .min_op ,
477- "max" : agg_ops .max_op ,
478- }[method ]
479- post_agg_rownum_col_ids = []
480- for i in range (len (columns )):
481- block , result_id = block .apply_window_op (
482- rownum_col_ids [i ],
483- agg_op ,
484- window_spec = windows .unbound (grouping_keys = (columns [i ], * grouping_cols )),
485- skip_reproject_unsafe = (i < (len (columns ) - 1 )),
471+ # Step 2: Apply aggregate to groups of like input values.
472+ # This step is skipped for method=='first' or 'dense'
473+ if method in ["average" , "min" , "max" ]:
474+ agg_op = {
475+ "average" : agg_ops .mean_op ,
476+ "min" : agg_ops .min_op ,
477+ "max" : agg_ops .max_op ,
478+ }[method ]
479+ result_expr = agg_expressions .WindowExpression (
480+ agg_expressions .UnaryAggregation (agg_op , result_expr ),
481+ windows .unbound (grouping_keys = (col , * grouping_cols )),
486482 )
487- post_agg_rownum_col_ids .append (result_id )
488- rownum_col_ids = post_agg_rownum_col_ids
489-
490- # Pandas masks all values where any grouping column is null
491- # Note: we use pd.NA instead of float('nan')
492- if grouping_cols :
493- predicate = functools .reduce (
494- ops .and_op .as_expr ,
495- [ops .notnull_op .as_expr (column_id ) for column_id in grouping_cols ],
496- )
497- block = block .project_exprs (
498- [
499- ops .where_op .as_expr (
500- ex .deref (col ),
501- predicate ,
502- ex .const (None ),
503- )
504- for col in rownum_col_ids
505- ],
506- labels = labels ,
507- )
508- rownum_col_ids = list (block .value_columns [- len (rownum_col_ids ) :])
509-
510- # Step 3: post processing: mask null values and cast to float
511- if method in ["min" , "max" , "first" , "dense" ]:
512- # Pandas rank always produces Float64, so must cast for aggregation types that produce ints
513- return (
514- block .select_columns (rownum_col_ids )
515- .multi_apply_unary_op (ops .AsTypeOp (pd .Float64Dtype ()))
516- .with_column_labels (labels )
517- )
518- if na_option == "keep" :
519- # For na_option "keep", null inputs must produce null outputs
520- exprs = []
521- for i in range (len (columns )):
522- exprs .append (
523- ops .where_op .as_expr (
524- ex .const (pd .NA , dtype = pd .Float64Dtype ()),
525- nullity_col_ids [i ],
526- rownum_col_ids [i ],
527- )
483+ # Pandas masks all values where any grouping column is null
484+ # Note: we use pd.NA instead of float('nan')
485+ if grouping_cols :
486+ predicate = functools .reduce (
487+ ops .and_op .as_expr ,
488+ [ops .notnull_op .as_expr (column_id ) for column_id in grouping_cols ],
489+ )
490+ result_expr = ops .where_op .as_expr (
491+ result_expr ,
492+ predicate ,
493+ ex .const (None ),
528494 )
529- return block .project_exprs (exprs , labels = labels , drop = True )
530495
531- return block .select_columns (rownum_col_ids ).with_column_labels (labels )
496+ # Step 3: post processing: mask null values and cast to float
497+ if method in ["min" , "max" , "first" , "dense" ]:
498+ # Pandas rank always produces Float64, so must cast for aggregation types that produce ints
499+ result_expr = ops .AsTypeOp (pd .Float64Dtype ()).as_expr (result_expr )
500+ elif na_option == "keep" :
501+ # For na_option "keep", null inputs must produce null outputs
502+ result_expr = ops .where_op .as_expr (
503+ ex .const (pd .NA , dtype = pd .Float64Dtype ()),
504+ ops .isnull_op .as_expr (col ),
505+ result_expr ,
506+ )
507+ result_exprs .append (result_expr )
508+ return block .project_block_exprs (result_exprs , labels = labels , drop = True )
532509
533510
534511def dropna (
0 commit comments