Skip to content

Conversation

@FranckQC
Copy link
Contributor

@FranckQC FranckQC commented Jun 4, 2022

This PR addresses the issue described in #11423 .
Here is some context :

The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where the notion of being equivalent was customizable, and no assumption was made about it. That means that the implementation of the equivalence test function EquivalentTerms() - which was at the moment just calling the syntactical equality test EqualTerms() - could be replaced later by a cleverer equality test.

However, having such a generic way of comparing elements meant that in the function SyntacticToSemanticComputations(), where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), the only way was to compare each pair.
That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them.

This PR essentially does the following:

  • When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them.
    That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc).
    It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs.

  • Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms.
    As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that.

To run it ON, simply do:
with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):
before calling build()

Many thanks!

FranckQC added 13 commits June 4, 2022 07:22
…n function, and using this normalization function to compare terms, avoiding O(n²) comparisons.
…nd the second for treating redundant expression by decreasing order of their sizes). Instead, does only one sort, with their sizes, and when equal, with their frequencies. If the frequencies are the same too, uses the syntactical order for the deterministic aspect - as before
… it, otherwise even if we later sort it the harm is done as the canonical representant chosen might have been different
…THe first one ensures that the canonical represantants chosen are always the same, and the second (done with a custom comparison function), that we always introduce orthogonal possibilities in the same order
…al test instead of printing the hashes and relying on the human to verify them
@FranckQC FranckQC marked this pull request as draft June 4, 2022 13:42
@FranckQC
Copy link
Contributor Author

FranckQC commented Jun 5, 2022

The only item failing seems to be a flaky test to me as it's unrelated with the changes introduced by this PR:

In the test file test_custom_datatypes.py, the function test_myfloat() gives:

UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.

I reported the issue there:
#11580

Everything else seems ok to me.
This PR is now ok to be reviewed :)

@FranckQC FranckQC marked this pull request as ready for review June 5, 2022 02:35
Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FranckQC thanks for your hard work on this PR!

@mbs-octoml could you also review?

@tqchen In #10544 you had some concerns about using arith::Analyzer in this pass. Do you still have those concerns with this pr? The analyzer is not used by default and is only being used once for each expression.

@FranckQC
Copy link
Contributor Author

FranckQC commented Jun 8, 2022

Hopefully, this time everything should be resolved.
The last build and all tests from this afternoon were successful, so fingers crossed it will stay the same with the latest commit (I had forgot the minor thing about the curly braces in the previous one from this afternoon, sorry).

To summarize, we should now have the best of the two worlds with this particular implementation, since identify_equiv_terms has now been brought to the knowledge of the function SyntacticToSemanticComputations():

  • When identify_equiv_terms is false (which is the case by default), we go straight from the hashtable to the vector, without doing any necessary work (thank you @tkonolige for the very good point!).
  • When identify_equiv_terms is true (for people ok with much longer compile time but who want to to common-out as much as possible), it will do something better than what it did before this PR, as it now takes benefit of the normal form function (which defines/implies the equivalence relation). Previously, we did not take advantage of that. And by the way, now this "identify_equiv_terms == true" mode really is usable for someone who wishes to (it wasn't fully plugged before).

Despite all of that, this pass still won't be cheap at compile time, for sure (especially for programs with a lot of things to common out in cascade). But I think it should be acceptable. When I wrote the pass, I focused more on trying to not miss opportunities for commonings, and on the correctness on the pass, rather that on making the pass as cheap as possible at compile time. I guess it's often a tradeoff. I hope that's ok for most users.

Many thanks for having helped to improve the pass everyone, I appreciate it!

Franck

Copy link
Contributor

@mbs-octoml mbs-octoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Normalization is certainly the most general approach and has the benefit you can see what's going on.

However if I were going to do it I'd build that into the hash function directly to avoid the need to repeatably construct new sub-terms on the off chance we have a table hit. You can use debruijn indexes for the vars, encode op argument order only when non-commutative, and so on. Food for thought.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like these folks have it covered B)

@FranckQC
Copy link
Contributor Author

FranckQC commented Jun 9, 2022

Let's get this one merged? :)

@tkonolige tkonolige merged commit d8678a6 into apache:main Jun 9, 2022
@tkonolige
Copy link
Contributor

Thanks @FranckQC! And @mbs-octoml and @AndrewZhaoLuo for reviewing.

Kathryn-cat pushed a commit to Kathryn-cat/tvm that referenced this pull request Jun 10, 2022
…orm - avoids comparison of terms (apache#11574)

The CSE pass had been designed for potentially allowing comparisons (and commonings) of equivalent terms (like (x+y)+z and x+(y+z)), where **the notion of being equivalent was customizable, and no assumption was made about it**. That means that the implementation of the equivalence test function `EquivalentTerms()` - which was at the moment just calling the syntactical equality test `EqualTerms()` - could be replaced later by a cleverer equality test.

However, having such a generic way of comparing elements meant that in the function `SyntacticToSemanticComputations()`, where we were going from a hashtable of syntactical entities to what I called a vector of "semantical entites" (which are just canonical forms/representants of classes of equivalence of terms), **the only way was to compare each pair**.
That resulted in a quadratic behavior of this function, but there was no way around it as in order to merge equivalent entities into their class of equivalence, we had to compare them.

**This PR essentially does the following:**

- When computing the classes of equivalences of terms (therefore transforming a ComputationTable (i.e. a hashtable) into a vector of classes of equivalence) : **instead of comparing each pair of terms, relies on a normalization procedure to obtain a normal form for each of them**.
That transforms a small part of the algorithm that was quadratic to n.logn. However, it's difficult to see improvements in practice, in particular for average sized programs, as that part was a "small" quadratic to a "big" n.logn (finding things in a hash-table, copying it to a vector, etc).
It was probably going from a complexity of ~O(((n²-n)/2) + n.logn) to a complexity of ~O(3n + n.logn), so potential gains would only be expected for very large programs.

- Completely gives the user the possibility to turn ON/OFF the semantical comparisons of terms. It is turned OFF by default (as it's quite longer to compile with it ON, unsurprisingly), which means that by default, the equivalence coincides with the (syntactical) equality of terms.
    As the pass was written with the possibility to do these additional commonings (like (x+y)+z and x+(y+z)), it was a good time to fully plug that completely, up to the Python user who can now turn that ON if he wants to. But again, it is OFF by default, so no real change on that.

To run it ON, simply do:
`with tvm.transform.PassContext(config={'tir.enable_equiv_terms_in_cse_tir':True}):`
before calling `build()`

- When this boolean is set to ON, it uses a simple implementation of the normalization function with equivalences that uses `arith::Analyzer::Simplify` as noted by in apache#10544 . Note that this is not a real normalization procedure as it is incomplete (i.e., it is not guarantee to converge to the normal form), but it is correct, and it works well with most properties : associativity of +, distributivity of * on +, etc.

- Clarifies and enhance the test base for the pass. In particular, it adds the tests that were written in apache#10544 but which did not make it through.

- Also add the test ( https://github.com/AndrewZhaoLuo/TVM-Sandbox/blob/19284ddbd6bb28af61c0c2aa8bb334c5c53731a7/tir/test_inconsistent_tir_lowering.py#L1 ) demonstrating the (older) non-deterministic lowering and put it into a proper test, as I found it useful for making sure that this does not happen again. It has been copied from apache#10663 and only slightly adapted (in particular for doing the comparison of hashes automatically instead of printing them and relying on a human to compare them).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants