-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR] Ramp and Broadcast lanes fixed to int32 dtype #16795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Ramp and Broadcast lanes fixed to int32 dtype #16795
Conversation
bcc0e1d to
0675ed4
Compare
|
In this particular case (deep equality), i think type do matter, so it would be great instead to fix the cases that would depend on the relaxed behavior. I know we had some i64/i32 issues, and general rule of thumb now is to try to be explicit as much as possible and that helps to reduce errors |
Oh okay, thanks for the feedback @tqchen. The cases we started seeing was that some expressions were not getting simplified properly after I realized that the simplification was not happening because lanes between broadcast and RampNode in my case had different types, so a couple other solutions I thought would apply here is to either fix the RampNode constructor to stick to some fixed dtype for lanes (something like int32/int16, since But if dtype does matter, then should we update the |
|
ah oK, i think in this case we should try to come up with a rule for lanes. In this particular case i think it was related to SVE changes, so also cc @ekalda I think having a fixed dtype probably makes sense then we handle cast for related cases (keep i32 and i64 iterators in mind |
|
Yeah I think fixing the dtype is a good idea, it would hopefully avoid this kind of problems in the future as well. Out of interest, what were the mismatching dtypes of the two compared |
Thanks @ekalda. I'll update the PR to fix the dtypes in RampNode (and perhaps the broadcast node as well). The dtypes in my case were The RampNode seems to get the int64 lanes because the all the iterators in our case is by default int64, but the broadcast seems to be inserted during the evaluation of AddNode in op.cc here |
When Ramp and Broadcast nodes are created with fixed length lanes, they're fixed to int32 dtype since DLDataType always supports only uint16 lanes.
0675ed4 to
888bbf0
Compare
|
I've updated the PR to fix the dtype of lanes as we discussed. Let me know if this looks good. |
ekalda
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thank you for the fix @quic-sanirudh! Hopefully we won't see these kind of silent simplification failures in the future now 🤞
|
Thanks for the review @ekalda |
|
@quic-sanirudh can you add a test case about i64 iterator type with Ramp? I know fixing to i32 likely will resolve issues when most programs are i32 iterator, we should make sure it is i64 compatible |
| tvm.tir.Ramp(x, 1, 4), | ||
| ), | ||
| TestCase( | ||
| tvm.tir.Broadcast(x, tvm.tir.IntImm(dtype="int64", value=4)) + tvm.tir.Ramp(0, 1, 4), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually meant that the value being broadcasted and ramp base being i64. but lanes remains i32, this would be a more common case in our setting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay, sorry for the confusion. I'll update the test cases to check that as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen I've updated the test case, let me know if this looks good.
|
@tvm-bot rerun |
* [TIR] Ramp and Broadcast lanes fixed to int32 dtype When Ramp and Broadcast nodes are created with fixed length lanes, they're fixed to int32 dtype since DLDataType always supports only uint16 lanes. * Add test cases for int64 type lanes * Update test case with int64 iterators
When Ramp and Broadcast nodes are created with fixed length lanes,
they're fixed to int32 dtype since DLDataType always supports only
uint16 lanes.