Skip to content

Commit 20ab469

Browse files
feat: Support mixed scalar-analytic expressions (#2239)
1 parent 956a5b0 commit 20ab469

File tree

7 files changed

+405
-94
lines changed

7 files changed

+405
-94
lines changed

bigframes/core/agg_expressions.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import functools
2020
import itertools
2121
import typing
22-
from typing import Callable, Mapping, TypeVar
22+
from typing import Callable, Mapping, Tuple, TypeVar
2323

2424
from bigframes import dtypes
2525
from bigframes.core import expression, window_spec
@@ -63,6 +63,10 @@ def inputs(
6363
) -> typing.Tuple[expression.Expression, ...]:
6464
...
6565

66+
@property
67+
def children(self) -> Tuple[expression.Expression, ...]:
68+
return self.inputs
69+
6670
@property
6771
def free_variables(self) -> typing.Tuple[str, ...]:
6872
return tuple(
@@ -73,6 +77,10 @@ def free_variables(self) -> typing.Tuple[str, ...]:
7377
def is_const(self) -> bool:
7478
return all(child.is_const for child in self.inputs)
7579

80+
@functools.cached_property
81+
def is_scalar_expr(self) -> bool:
82+
return False
83+
7684
@abc.abstractmethod
7785
def replace_args(self: TExpression, *arg) -> TExpression:
7886
...
@@ -176,8 +184,13 @@ def output_type(self) -> dtypes.ExpressionType:
176184
def inputs(
177185
self,
178186
) -> typing.Tuple[expression.Expression, ...]:
187+
# TODO: Maybe make the window spec itself an expression?
179188
return (self.analytic_expr, *self.window.expressions)
180189

190+
@property
191+
def children(self) -> Tuple[expression.Expression, ...]:
192+
return self.inputs
193+
181194
@property
182195
def free_variables(self) -> typing.Tuple[str, ...]:
183196
return tuple(
@@ -188,12 +201,16 @@ def free_variables(self) -> typing.Tuple[str, ...]:
188201
def is_const(self) -> bool:
189202
return all(child.is_const for child in self.inputs)
190203

204+
@functools.cached_property
205+
def is_scalar_expr(self) -> bool:
206+
return False
207+
191208
def transform_children(
192209
self: WindowExpression,
193210
t: Callable[[expression.Expression], expression.Expression],
194211
) -> WindowExpression:
195212
return WindowExpression(
196-
self.analytic_expr.transform_children(t),
213+
t(self.analytic_expr), # type: ignore
197214
self.window.transform_exprs(t),
198215
)
199216

bigframes/core/array_value.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,24 @@
1616
from dataclasses import dataclass
1717
import datetime
1818
import functools
19+
import itertools
1920
import typing
2021
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
2122

2223
import google.cloud.bigquery
2324
import pandas
2425
import pyarrow as pa
2526

26-
from bigframes.core import agg_expressions, bq_data
27+
from bigframes.core import (
28+
agg_expressions,
29+
bq_data,
30+
expression_factoring,
31+
join_def,
32+
local_data,
33+
)
2734
import bigframes.core.expression as ex
2835
import bigframes.core.guid
2936
import bigframes.core.identifiers as ids
30-
import bigframes.core.join_def as join_def
31-
import bigframes.core.local_data as local_data
3237
import bigframes.core.nodes as nodes
3338
from bigframes.core.ordering import OrderingExpression
3439
import bigframes.core.ordering as orderings
@@ -261,6 +266,23 @@ def compute_values(self, assignments: Sequence[ex.Expression]):
261266
col_ids,
262267
)
263268

269+
def compute_general_expression(self, assignments: Sequence[ex.Expression]):
270+
named_exprs = [
271+
expression_factoring.NamedExpression(expr, ids.ColumnId.unique())
272+
for expr in assignments
273+
]
274+
# TODO: Push this to rewrite later to go from block expression to planning form
275+
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
276+
fragments = tuple(
277+
itertools.chain.from_iterable(
278+
expression_factoring.fragmentize_expression(expr)
279+
for expr in named_exprs
280+
)
281+
)
282+
target_ids = tuple(named_expr.name for named_expr in named_exprs)
283+
new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids)
284+
return (ArrayValue(new_root), target_ids)
285+
264286
def project_to_id(self, expression: ex.Expression):
265287
array_val, ids = self.compute_values(
266288
[expression],

bigframes/core/block_transforms.py

Lines changed: 65 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -399,15 +399,18 @@ def pct_change(block: blocks.Block, periods: int = 1) -> blocks.Block:
399399
window_spec = windows.unbound()
400400

401401
original_columns = block.value_columns
402-
block, shift_columns = block.multi_apply_window_op(
403-
original_columns, agg_ops.ShiftOp(periods), window_spec=window_spec
404-
)
405402
exprs = []
406-
for original_col, shifted_col in zip(original_columns, shift_columns):
407-
change_expr = ops.sub_op.as_expr(original_col, shifted_col)
408-
pct_change_expr = ops.div_op.as_expr(change_expr, shifted_col)
403+
for original_col in original_columns:
404+
shift_expr = agg_expressions.WindowExpression(
405+
agg_expressions.UnaryAggregation(
406+
agg_ops.ShiftOp(periods), ex.deref(original_col)
407+
),
408+
window_spec,
409+
)
410+
change_expr = ops.sub_op.as_expr(original_col, shift_expr)
411+
pct_change_expr = ops.div_op.as_expr(change_expr, shift_expr)
409412
exprs.append(pct_change_expr)
410-
return block.project_exprs(exprs, labels=column_labels, drop=True)
413+
return block.project_block_exprs(exprs, labels=column_labels, drop=True)
411414

412415

413416
def rank(
@@ -428,16 +431,11 @@ def rank(
428431

429432
columns = columns or tuple(col for col in block.value_columns)
430433
labels = [block.col_id_to_label[id] for id in columns]
431-
# Step 1: Calculate row numbers for each row
432-
# Identify null values to be treated according to na_option param
433-
rownum_col_ids = []
434-
nullity_col_ids = []
434+
435+
result_exprs = []
435436
for col in columns:
436-
block, nullity_col_id = block.apply_unary_op(
437-
col,
438-
ops.isnull_op,
439-
)
440-
nullity_col_ids.append(nullity_col_id)
437+
# Step 1: Calculate row numbers for each row
438+
# Identify null values to be treated according to na_option param
441439
window_ordering = (
442440
ordering.OrderingExpression(
443441
ex.deref(col),
@@ -448,87 +446,66 @@ def rank(
448446
),
449447
)
450448
# Count_op ignores nulls, so if na_option is "top" or "bottom", we instead count the nullity columns, where nulls have been mapped to bools
451-
block, rownum_id = block.apply_window_op(
452-
col if na_option == "keep" else nullity_col_id,
453-
agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op,
454-
window_spec=windows.unbound(
455-
grouping_keys=grouping_cols, ordering=window_ordering
456-
)
449+
target_expr = (
450+
ex.deref(col) if na_option == "keep" else ops.isnull_op.as_expr(col)
451+
)
452+
window_op = agg_ops.dense_rank_op if method == "dense" else agg_ops.count_op
453+
window_spec = (
454+
windows.unbound(grouping_keys=grouping_cols, ordering=window_ordering)
457455
if method == "dense"
458456
else windows.rows(
459457
end=0, ordering=window_ordering, grouping_keys=grouping_cols
460-
),
461-
skip_reproject_unsafe=(col != columns[-1]),
458+
)
459+
)
460+
result_expr: ex.Expression = agg_expressions.WindowExpression(
461+
agg_expressions.UnaryAggregation(window_op, target_expr), window_spec
462462
)
463463
if pct:
464-
block, max_id = block.apply_window_op(
465-
rownum_id, agg_ops.max_op, windows.unbound(grouping_keys=grouping_cols)
464+
result_expr = ops.div_op.as_expr(
465+
result_expr,
466+
agg_expressions.WindowExpression(
467+
agg_expressions.UnaryAggregation(agg_ops.max_op, result_expr),
468+
windows.unbound(grouping_keys=grouping_cols),
469+
),
466470
)
467-
block, rownum_id = block.project_expr(ops.div_op.as_expr(rownum_id, max_id))
468-
469-
rownum_col_ids.append(rownum_id)
470-
471-
# Step 2: Apply aggregate to groups of like input values.
472-
# This step is skipped for method=='first' or 'dense'
473-
if method in ["average", "min", "max"]:
474-
agg_op = {
475-
"average": agg_ops.mean_op,
476-
"min": agg_ops.min_op,
477-
"max": agg_ops.max_op,
478-
}[method]
479-
post_agg_rownum_col_ids = []
480-
for i in range(len(columns)):
481-
block, result_id = block.apply_window_op(
482-
rownum_col_ids[i],
483-
agg_op,
484-
window_spec=windows.unbound(grouping_keys=(columns[i], *grouping_cols)),
485-
skip_reproject_unsafe=(i < (len(columns) - 1)),
471+
# Step 2: Apply aggregate to groups of like input values.
472+
# This step is skipped for method=='first' or 'dense'
473+
if method in ["average", "min", "max"]:
474+
agg_op = {
475+
"average": agg_ops.mean_op,
476+
"min": agg_ops.min_op,
477+
"max": agg_ops.max_op,
478+
}[method]
479+
result_expr = agg_expressions.WindowExpression(
480+
agg_expressions.UnaryAggregation(agg_op, result_expr),
481+
windows.unbound(grouping_keys=(col, *grouping_cols)),
486482
)
487-
post_agg_rownum_col_ids.append(result_id)
488-
rownum_col_ids = post_agg_rownum_col_ids
489-
490-
# Pandas masks all values where any grouping column is null
491-
# Note: we use pd.NA instead of float('nan')
492-
if grouping_cols:
493-
predicate = functools.reduce(
494-
ops.and_op.as_expr,
495-
[ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
496-
)
497-
block = block.project_exprs(
498-
[
499-
ops.where_op.as_expr(
500-
ex.deref(col),
501-
predicate,
502-
ex.const(None),
503-
)
504-
for col in rownum_col_ids
505-
],
506-
labels=labels,
507-
)
508-
rownum_col_ids = list(block.value_columns[-len(rownum_col_ids) :])
509-
510-
# Step 3: post processing: mask null values and cast to float
511-
if method in ["min", "max", "first", "dense"]:
512-
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints
513-
return (
514-
block.select_columns(rownum_col_ids)
515-
.multi_apply_unary_op(ops.AsTypeOp(pd.Float64Dtype()))
516-
.with_column_labels(labels)
517-
)
518-
if na_option == "keep":
519-
# For na_option "keep", null inputs must produce null outputs
520-
exprs = []
521-
for i in range(len(columns)):
522-
exprs.append(
523-
ops.where_op.as_expr(
524-
ex.const(pd.NA, dtype=pd.Float64Dtype()),
525-
nullity_col_ids[i],
526-
rownum_col_ids[i],
527-
)
483+
# Pandas masks all values where any grouping column is null
484+
# Note: we use pd.NA instead of float('nan')
485+
if grouping_cols:
486+
predicate = functools.reduce(
487+
ops.and_op.as_expr,
488+
[ops.notnull_op.as_expr(column_id) for column_id in grouping_cols],
489+
)
490+
result_expr = ops.where_op.as_expr(
491+
result_expr,
492+
predicate,
493+
ex.const(None),
528494
)
529-
return block.project_exprs(exprs, labels=labels, drop=True)
530495

531-
return block.select_columns(rownum_col_ids).with_column_labels(labels)
496+
# Step 3: post processing: mask null values and cast to float
497+
if method in ["min", "max", "first", "dense"]:
498+
# Pandas rank always produces Float64, so must cast for aggregation types that produce ints
499+
result_expr = ops.AsTypeOp(pd.Float64Dtype()).as_expr(result_expr)
500+
elif na_option == "keep":
501+
# For na_option "keep", null inputs must produce null outputs
502+
result_expr = ops.where_op.as_expr(
503+
ex.const(pd.NA, dtype=pd.Float64Dtype()),
504+
ops.isnull_op.as_expr(col),
505+
result_expr,
506+
)
507+
result_exprs.append(result_expr)
508+
return block.project_block_exprs(result_exprs, labels=labels, drop=True)
532509

533510

534511
def dropna(

bigframes/core/blocks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,27 @@ def project_exprs(
11541154
index_labels=self._index_labels,
11551155
)
11561156

1157+
# This is a new experimental version of the project_exprs that supports mixing analytic and scalar expressions
1158+
def project_block_exprs(
1159+
self,
1160+
exprs: Sequence[ex.Expression],
1161+
labels: Union[Sequence[Label], pd.Index],
1162+
drop=False,
1163+
) -> Block:
1164+
new_array, _ = self.expr.compute_general_expression(exprs)
1165+
if drop:
1166+
new_array = new_array.drop_columns(self.value_columns)
1167+
1168+
new_array.node.validate_tree()
1169+
return Block(
1170+
new_array,
1171+
index_columns=self.index_columns,
1172+
column_labels=labels
1173+
if drop
1174+
else self.column_labels.append(pd.Index(labels)),
1175+
index_labels=self._index_labels,
1176+
)
1177+
11571178
def apply_window_op(
11581179
self,
11591180
column: str,

0 commit comments

Comments
 (0)