-
Notifications
You must be signed in to change notification settings - Fork 347
promote blocksparse from prototype, make it faster #1734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1734
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 PendingAs of commit dd500a4 with merge base 79ac44e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
might be good to consider getting the changes from #1690 in here since you are making a major API change, it will save you a migration in the future. |
ah yes that's a good idea. I'll open a subsequent PR and update all of sparsity APIs. |
This PR promotes block sparsity from prototype in torchao. Chiefly, it ports over the triton addmm blocksparse kernels from core, and makes several performance improvements to them. All of the numbers reported below are for an H100, with blocksize=64 and sparsity_level=0.9. The default dense baseline is 134 tok/s 1) Adds padding support to the triton kernel for dense matrices with dimension < 16, like those we run into during decoding. (214 -> 218 tok/s) 2) Changes the default [num_stages](triton-lang/triton#512) parameter from 1 to 4. This has a large effect on performance, and it seemed like the default kernel autotuning either does not modify or deems this parameter to be unimportant for some reason. (218 -> 263 tok/s). 3) Adds an env_var, BSR_AUTOTUNE, that users can use if they want to do kernel autotuning on top of the default parameters. (263 -> 266 tok/s) This seems to matter more for bs=n compute bound workloads, where I see a reduction from 0.3855 to 0.3745s on bs=8192 prefill (roughly 3%) So in total we are seeing a **1.985x** speedup 🚀 I've also updated the documentation to not reference prototype - planning on updating the diagram in a subsequent PR. ### Testing I added a new test case for the padding inputs and moved the test file out of prototype. ``` python test/sparsity/test_sparse_api.py ```
This PR promotes block sparsity from prototype in torchao.
Chiefly, it ports over the triton addmm blocksparse kernels from core, and makes several performance improvements to them.
All of the numbers reported below are for an H100, with blocksize=64 and sparsity_level=0.9. The default dense baseline is 134 tok/s
So in total we are seeing a 1.985x speedup 🚀
I've also updated the documentation to not reference prototype - planning on updating the diagram in a subsequent PR.
Testing
I added a new test case for the padding inputs and moved the test file out of prototype.
Benchmarking