-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr #16599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transform] De-duplicate MatchCast nodes in EliminateCommonSubexpr #16599
Conversation
Update the `relax.transform.EliminateCommonSubexpr` pass to handle `R.match_cast` bindings, where the argument of the `R.match_cast` has also been de-duplicated.
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The redesign of the CSE pass seems generally sound and behaves more nicely than my previous approach of overriding the top-level VisitExpr and listing exceptions. I listed a couple of questions and concerns resulting from the substantial rewrites.
| * within the output struct info. As a result, it would erroneously | ||
| * de-duplicate `R.match_cast(A, R.Tensor([m,n]))` and | ||
| * `R.match_cast(A, R.Tensor([p,q]))`, even though they define | ||
| * different symbolic variables. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it still do the wrong thing even if you set map_free_vars to false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe so, though for a different reason. Even if map_free_vars is false, SEqualReducer::DefEqual can override it to true for some IR nodes. This is the function that is used for MatchCastNode::struct_info, in order to support StructuralEqual in cases where two match cast nodes are defining different variables, but should compare as structural equal.
Effectively, the MatchCastNode::SEqualReduce is implemented to allow TIR variable definitions, which is the opposite of what we need for the de-duplication.
| << ". The duplicate binding of this value to " << binding->var | ||
| << " will be replaced with a trivial binding, " | ||
| << "and occurrences of " << binding->var << " will be replaced with " << it->second; | ||
| output_binding = VarBinding(binding->var, it->second); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we still need a binding if we will be remapping the uses of the var? Not a big deal since CanonicalizeBindings and DeadCodeElimination would clean it up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to handle cases where the variable remapping would not be allowed, but the variable binding is still valid. If we the de-duplication results in remapping a non-dataflow variable to a dataflow variable, that remapping is only valid within the dataflow block. Outside of that dataflow block, the variable remaps are reset, and any remaining usages would still need to be definition.
I've added the following unit test to test this behavior.
def test_no_replacement_across_dataflow_boundary():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
with R.dataflow():
A = R.add(x, y)
# B has the same value as A, and so instances of B can be replaced with A.
B = R.add(x, y)
C = R.multiply(A, B)
# However, B is exposed for use outside of the
# DataflowBlock, while A is not. Therefore, any
# additional uses of `B` must NOT be replaced with
# A.
R.output(B, C)
# In addition, because `A` is only valid within the
# dataflow block, the `R.add(x,y)` cannot be de-duplicated
# as another usage of `A`.
D = R.add(x, y)
return (B, C, D)
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
with R.dataflow():
A = R.add(x, y)
B = A
C = R.multiply(A, A)
R.output(B, C)
D = R.add(x, y)
return (B, C, D)
verify(Before, Expected)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, writing that test made me realize a limitation in the implementation. The non-dataflow var D = R.add(x,y) can't be de-duplicated as the dataflow var A, but it still could be de-duplicated as the non-dataflow var B.
I've updated the implementation to have std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_, tracking all variables that could be de-duplicated as each other. That way, when exiting a DataflowBlock, the expr_replacements_ could still track non-dataflow variables that could be used for de-duplication.
@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
with R.dataflow():
A = R.add(x, y)
B = A
C = R.multiply(A, A)
R.output(B, C)
D = B
return (B, C, B)This way, such a function would only require CSE -> Canonicalize, whereas before it would require CSE -> Canonicalize -> CSE.
| builder_->EmitNormalized(GetRef<MatchCast>(binding)); | ||
| Expr VisitExpr_(const IfNode* op) override { | ||
| Expr cond = VisitExpr(op->cond); | ||
| Expr true_branch = VisitWithCleanScope(op->true_branch); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think visiting with a clean scope is the correct approach, since there are potential correct replacements that could come from the outer scope. I think the right approach would be to cache the current state and restore it after visiting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call. I've updated to have propagate bindings from before a if/else branch into the body of either branch. This functionality has three new unit tests, to validate (1) bindings before a branch are de-duped inside the branch, (2) bindings within a branch are not de-duped after the branch, and (3) bindings within a branch are not de-duped into its sibling branch.
| // copy of the mutator, to avoid replacing a child-scope | ||
| // expression with a parent-scope binding, or vice versa. | ||
| if (expr_replacements_.size() || var_remap_.size()) { | ||
| return VisitWithCleanScope(GetRef<Expr>(op)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure using a clean scope in all cases is necessarily, since an inner function could capture vars from the outer scope and there might be legitimate substitutions that are possible. That said, this may not necessarily be desirable behavior since it could result in bigger closures (capturing more vars than you might want). We would also have to make sure we don't capture any DataflowVars, since that's not permitted. Any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. From a lowering standpoint, my main concern would be to avoid introducing a captured variable. If a user has an inner function without closure variables, it could be surprising if CSE prevents it from being hoisted out of the local function.
(Selfishly, I have a few plans to simplify LambdaLift when I have the time. Currently, LambdaLift handles lifting of inner functions regardless of whether there are closure variables, which is a bit tricky to track. I'd like to split that out into a HoistClosureVariablesToParams and HoistInnerFunctions passes. The first would provide explicit arguments for any closure variables, and the second would hoist the inner function out to the IRModule. If CSE is applied in-between those two passes, I'd like to avoid having it re-introduce closure variables.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think you're right that capturing more vars might be an unwelcome surprising behavior.
| """CSE is only applied at variable bindings | ||
| To remain consistent with the behavior of the normalizer, tuples | ||
| are kept as-is, even if they contain repeated sub-tuples. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good observation, probably a good choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. With how much discussion there was around the normalization with respect to R.call_tir, I figured it's best to match as much as possible to the normalization.
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing my concerns.
Update the
relax.transform.EliminateCommonSubexprpass to handleR.match_castbindings, where the argument of theR.match_casthas also been de-duplicated.