Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Update the relax.transform.EliminateCommonSubexpr pass to handle R.match_cast bindings, where the argument of the R.match_cast has also been de-duplicated.

Update the `relax.transform.EliminateCommonSubexpr` pass to handle
`R.match_cast` bindings, where the argument of the `R.match_cast` has
also been de-duplicated.
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The redesign of the CSE pass seems generally sound and behaves more nicely than my previous approach of overriding the top-level VisitExpr and listing exceptions. I listed a couple of questions and concerns resulting from the substantial rewrites.

* within the output struct info. As a result, it would erroneously
* de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and
* `R.match_cast(A, R.Tensor([p,q]))`, even though they define
* different symbolic variables.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it still do the wrong thing even if you set map_free_vars to false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, though for a different reason. Even if map_free_vars is false, SEqualReducer::DefEqual can override it to true for some IR nodes. This is the function that is used for MatchCastNode::struct_info, in order to support StructuralEqual in cases where two match cast nodes are defining different variables, but should compare as structural equal.

Effectively, the MatchCastNode::SEqualReduce is implemented to allow TIR variable definitions, which is the opposite of what we need for the de-duplication.

<< ". The duplicate binding of this value to " << binding->var
<< " will be replaced with a trivial binding, "
<< "and occurrences of " << binding->var << " will be replaced with " << it->second;
output_binding = VarBinding(binding->var, it->second);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we still need a binding if we will be remapping the uses of the var? Not a big deal since CanonicalizeBindings and DeadCodeElimination would clean it up.

Copy link
Contributor Author

@Lunderberg Lunderberg Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to handle cases where the variable remapping would not be allowed, but the variable binding is still valid. If we the de-duplication results in remapping a non-dataflow variable to a dataflow variable, that remapping is only valid within the dataflow block. Outside of that dataflow block, the variable remaps are reset, and any remaining usages would still need to be definition.

I've added the following unit test to test this behavior.

def test_no_replacement_across_dataflow_boundary():
    @I.ir_module
    class Before:
        @R.function
        def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                A = R.add(x, y)
                # B has the same value as A, and so instances of B can be replaced with A.
                B = R.add(x, y)
                C = R.multiply(A, B)

                # However, B is exposed for use outside of the
                # DataflowBlock, while A is not.  Therefore, any
                # additional uses of `B` must NOT be replaced with
                # A.
                R.output(B, C)

            # In addition, because `A` is only valid within the
            # dataflow block, the `R.add(x,y)` cannot be de-duplicated
            # as another usage of `A`.
            D = R.add(x, y)
            return (B, C, D)

    @I.ir_module
    class Expected:
        @R.function
        def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
            with R.dataflow():
                A = R.add(x, y)
                B = A
                C = R.multiply(A, A)
                R.output(B, C)

            D = R.add(x, y)
            return (B, C, D)

    verify(Before, Expected)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, writing that test made me realize a limitation in the implementation. The non-dataflow var D = R.add(x,y) can't be de-duplicated as the dataflow var A, but it still could be de-duplicated as the non-dataflow var B.

I've updated the implementation to have std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_, tracking all variables that could be de-duplicated as each other. That way, when exiting a DataflowBlock, the expr_replacements_ could still track non-dataflow variables that could be used for de-duplication.

@I.ir_module
class Expected:
    @R.function
    def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
        with R.dataflow():
            A = R.add(x, y)
            B = A
            C = R.multiply(A, A)
            R.output(B, C)

        D = B
        return (B, C, B)

This way, such a function would only require CSE -> Canonicalize, whereas before it would require CSE -> Canonicalize -> CSE.

builder_->EmitNormalized(GetRef<MatchCast>(binding));
Expr VisitExpr_(const IfNode* op) override {
Expr cond = VisitExpr(op->cond);
Expr true_branch = VisitWithCleanScope(op->true_branch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think visiting with a clean scope is the correct approach, since there are potential correct replacements that could come from the outer scope. I think the right approach would be to cache the current state and restore it after visiting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I've updated to have propagate bindings from before a if/else branch into the body of either branch. This functionality has three new unit tests, to validate (1) bindings before a branch are de-duped inside the branch, (2) bindings within a branch are not de-duped after the branch, and (3) bindings within a branch are not de-duped into its sibling branch.

// copy of the mutator, to avoid replacing a child-scope
// expression with a parent-scope binding, or vice versa.
if (expr_replacements_.size() || var_remap_.size()) {
return VisitWithCleanScope(GetRef<Expr>(op));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure using a clean scope in all cases is necessarily, since an inner function could capture vars from the outer scope and there might be legitimate substitutions that are possible. That said, this may not necessarily be desirable behavior since it could result in bigger closures (capturing more vars than you might want). We would also have to make sure we don't capture any DataflowVars, since that's not permitted. Any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. From a lowering standpoint, my main concern would be to avoid introducing a captured variable. If a user has an inner function without closure variables, it could be surprising if CSE prevents it from being hoisted out of the local function.

(Selfishly, I have a few plans to simplify LambdaLift when I have the time. Currently, LambdaLift handles lifting of inner functions regardless of whether there are closure variables, which is a bit tricky to track. I'd like to split that out into a HoistClosureVariablesToParams and HoistInnerFunctions passes. The first would provide explicit arguments for any closure variables, and the second would hoist the inner function out to the IRModule. If CSE is applied in-between those two passes, I'd like to avoid having it re-introduce closure variables.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're right that capturing more vars might be an unwelcome surprising behavior.

Comment on lines +91 to +95
"""CSE is only applied at variable bindings
To remain consistent with the behavior of the normalizer, tuples
are kept as-is, even if they contain repeated sub-tuples.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation, probably a good choice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. With how much discussion there was around the normalization with respect to R.call_tir, I figured it's best to match as much as possible to the normalization.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing my concerns.

@Lunderberg Lunderberg merged commit 864fd5c into apache:main Feb 23, 2024
@Lunderberg Lunderberg deleted the transform_check_match_cast_in_cse branch February 23, 2024 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants