-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Unity] Infer struct info for relax.op.split on dynamic-sized index #16355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Infer struct info for relax.op.split on dynamic-sized index #16355
Conversation
81084c4 to
b15daba
Compare
b15daba to
ed7c05d
Compare
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a clear improvement and the changes to the code are straightforward.
| # Spliting a dynamic axis at specific indices is allowed. The | ||
| # algebraic form here isn't the cleanest, primarily because the | ||
| # test case doesn't know that `n` is a shape variable. When | ||
| # occurring in a relax function, `n` would be marked with | ||
| # `analyzer_.MarkGlobalNonNegValue`, which would make the shapes | ||
| # simplify to `[(2,16), (3,16), (n,16)]`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to test that directly? Maybe the BlockBuilder can expose its analyzer for testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note the typo in the first word :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and I've updated to fix the typo, and to test the simplified expressions directly. Rather than exposing the block builder's analyzer, this test case now calls bb.begin_scope. That way, failures resulting from changes to the BlockBuilderImpl::BeginScope implementation would also be caught.
| auto it = shape_var_map.find(shape_var); | ||
| if (it == shape_var_map.end()) { | ||
| shape_var_map.Set(shape_var, shape_expr); | ||
| analyzer_.MarkGlobalNonNegValue(shape_var); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth commenting that indicating this results in much simpler symbolic shapes in many cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, and I've added a comment on it.
Prior to this commit, the
relax.op.splitdid not provide a known shape if the split was performed over a dynamic-size axis. This commit updates the shape inference to provide correct shapes in this case. A test case is also added forCombineParallelMatmulto show the intended usage of this feature, to ensure that the split outputs from a combined matmul still have correct shape information.