-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR] CSE pass : Restrict the equivalence to be decided by a normal form - avoids comparison of terms #11574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…n function, and using this normalization function to compare terms, avoiding O(n²) comparisons.
…emantic extension for CSE Pass apache#10544
…nd the second for treating redundant expression by decreasing order of their sizes). Instead, does only one sort, with their sizes, and when equal, with their frequencies. If the frequencies are the same too, uses the syntactical order for the deterministic aspect - as before
…quivalence relation, which is cutomizable
… it, otherwise even if we later sort it the harm is done as the canonical representant chosen might have been different
…THe first one ensures that the canonical represantants chosen are always the same, and the second (done with a custom comparison function), that we always introduce orthogonal possibilities in the same order
…world (Python side)
…al test instead of printing the hashes and relying on the human to verify them
|
The only item failing seems to be a flaky test to me as it's unrelated with the changes introduced by this PR: In the test file test_custom_datatypes.py, the function test_myfloat() gives: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead. I reported the issue there: Everything else seems ok to me. |
tkonolige
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FranckQC thanks for your hard work on this PR!
@mbs-octoml could you also review?
@tqchen In #10544 you had some concerns about using arith::Analyzer in this pass. Do you still have those concerns with this pr? The analyzer is not used by default and is only being used once for each expression.
…ify_equiv_terms is false, as suggested by the user https://github.com/tkonolige
|
Hopefully, this time everything should be resolved. To summarize, we should now have the best of the two worlds with this particular implementation, since
Despite all of that, this pass still won't be cheap at compile time, for sure (especially for programs with a lot of things to common out in cascade). But I think it should be acceptable. When I wrote the pass, I focused more on trying to not miss opportunities for commonings, and on the correctness on the pass, rather that on making the pass as cheap as possible at compile time. I guess it's often a tradeoff. I hope that's ok for most users. Many thanks for having helped to improve the pass everyone, I appreciate it! Franck |
mbs-octoml
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Normalization is certainly the most general approach and has the benefit you can see what's going on.
However if I were going to do it I'd build that into the hash function directly to avoid the need to repeatably construct new sub-terms on the off chance we have a table hit. You can use debruijn indexes for the vars, encode op argument order only when non-commutative, and so on. Food for thought.
AndrewZhaoLuo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like these folks have it covered B)
spelling
|
Let's get this one merged? :) |
|
Thanks @FranckQC! And @mbs-octoml and @AndrewZhaoLuo for reviewing. |
…orm - avoids comparison of terms (apache#11574) The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test. However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**. That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them. **This PR essentially does the following:** - When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**. That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc). It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs. - Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms. As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that. To run it ON, simply do: `with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):` before calling `build()` - When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in apache#10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc. - Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in apache#10544 but which did not make it through. - Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from apache#10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
This PR addresses the issue described in #11423 .
Here is some context :
The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where the notion of being equivalent was customizable, and no assumption was made about it. That means that the implementation of the equivalence test function
EquivalentTerms()- which was at the moment just calling the syntactical equality testEqualTerms()- could be replaced later by a cleverer equality test.However, having such a generic way of comparing elements meant that in the function
SyntacticToSemanticComputations(), where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), the only way was to compare each pair.That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them.
This PR essentially does the following:
When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them.
That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc).
It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs.
Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms.
As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that.
To run it ON, simply do:
with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):before calling
build()When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses
arith::Analyzer::Simplifyas noted by @yuanfz98 on his PR [TIR] Semantic extension for CSE Pass #10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc.Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in [TIR] Semantic extension for CSE Pass #10544 but which did not make it through.
Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) from @AndrewZhaoLuo demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from [TIR] CSE-TIR Pass - More deterministic behavior #10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
Many thanks!