Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions bigframes/core/rewrite/select_pullup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
if not isinstance(child, nodes.SelectionNode):
return node

# case where selection must be kept in place to prevent ambiguity
if set(child.child.ids) & set(node.defined_variables):
return node

# Schema-preserving nodes
if isinstance(
node,
Expand Down Expand Up @@ -157,9 +161,78 @@ def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
return node


def pull_up_selects_under_join(node: nodes.JoinNode) -> nodes.JoinNode:
# Can in theory pull up selects here, but it is a bit dangerous, in particular or self-joins, when there are more transforms to do.
# TODO: Safely pull up selects above join
def pull_up_selects_under_join(node: nodes.JoinNode) -> nodes.BigFrameNode:
if isinstance(node.left_child, nodes.SelectionNode) and isinstance(
node.right_child, nodes.SelectionNode
):
conflicts = set(node.left_child.child.ids) & set(node.right_child.child.ids)
if not conflicts:
lmap = {id: ref.id for ref, id in node.left_child.input_output_pairs}
rmap = {id: ref.id for ref, id in node.right_child.input_output_pairs}
new_join = nodes.JoinNode(
node.left_child.child,
node.right_child.child,
conditions=tuple(
(lref.remap_column_refs(lmap), rref.remap_column_refs(rmap))
for lref, rref in node.conditions
),
type=node.type,
propogate_order=node.propogate_order,
)
new_select = nodes.SelectionNode(
new_join,
(
*node.left_child.input_output_pairs,
*node.right_child.input_output_pairs,
),
)
return new_select
elif isinstance(node.left_child, nodes.SelectionNode):
conflicts = set(node.left_child.child.ids) & set(node.right_child.ids)
if not conflicts:
lmap = {id: ref.id for ref, id in node.left_child.input_output_pairs}
new_join = nodes.JoinNode(
node.left_child.child,
node.right_child,
conditions=tuple(
(lref.remap_column_refs(lmap), rref)
for lref, rref in node.conditions
),
type=node.type,
propogate_order=node.propogate_order,
)
new_select = nodes.SelectionNode(
new_join,
(
*node.left_child.input_output_pairs,
*(nodes.AliasedRef.identity(id) for id in node.right_child.ids),
),
)
return new_select

elif isinstance(node.right_child, nodes.SelectionNode):
conflicts = set(node.right_child.child.ids) & set(node.left_child.ids)
if not conflicts:
rmap = {id: ref.id for ref, id in node.right_child.input_output_pairs}
new_join = nodes.JoinNode(
node.left_child,
node.right_child.child,
conditions=tuple(
(lref, rref.remap_column_refs(rmap))
for lref, rref in node.conditions
),
type=node.type,
propogate_order=node.propogate_order,
)
new_select = nodes.SelectionNode(
new_join,
(
*(nodes.AliasedRef.identity(id) for id in node.left_child.ids),
*node.right_child.input_output_pairs,
),
)
return new_select

return node


Expand Down