-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Unity][Transform] Implement relax.transform.ReorderTakeAfterMatmul #16315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][Transform] Implement relax.transform.ReorderTakeAfterMatmul #16315
Conversation
|
what is the usecase of this pass? Seems such reordering would increase the amount of compute |
|
This pass would be part of an optimization pipeline that could be used for batched LoRA models. In this usage, the While this individual pass does increase the total amount of compute, the overall pipeline should (dependent on model size, number of loras, and batch size) improve performance by changing from three matmuls (one large matmul of the base weights and two small matmuls with the LoRA components) to two matmuls (one large matmul of |
|
I feel for LoRA maybe the best approach we need some specialized kernels. cc @yzh119 to see if you have more suggestions based on your experience |
9994a18 to
17e1ac7
Compare
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the discussion in #16313, I don't think it would be harmful to include this pass as a tool for other transformations even if we don't plan to use it in the standard compilation immediately.
|
i agree as long as it is optional |
|
Sounds good, and thank you! I should have led with this being an optional pass to be used as needed, not intended as part of a default pass. |
17e1ac7 to
1581cf2
Compare
|
Rebased onto |
If `R.matmul(x, R.take(weights, indices))` occurs, with `R.take` selecting along the output feature dimension, it can be rearranged to `R.take(R.matmul(x, weights), indices)`.
1581cf2 to
65ae3c9
Compare
If
R.matmul(x, R.take(weights, indices))occurs, withR.takeselecting along the output feature dimension, it can be rearranged toR.take(R.matmul(x, weights), indices).