- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.8k
[compile] Enable sequence parallelism matching w/o custom ops enabled #27126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: angelayi <[email protected]>
c1efc65    to
    ed10d76      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking this on! Could you just add me as a co-author on one of the commits?
| """Base helper for RMSNorm and RMSNorm + Quantization functionalization.""" | ||
| def get_first_out_wrapper(fn): | ||
| @functools.wraps(fn) | ||
| def wrapper(*args): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work? I thought that during tracing the pattern matching tracer will think that args is a single parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes! updated the test to assert the number of all_reduce/all_gather ops in the graph!
Signed-off-by: angelayi <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
ed10d76    to
    5d66118      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cascade812 could you take a look at this please?
| 
 Sure! | 
Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: angelayi <[email protected]>
Purpose
Based on #24604, modified sequence-parallelism pass to do custom op matching w/o needing to enable the custom op
Test Plan
pytest -sv tests/compile/test_sequence_parallelism.pyPerformance numbers
I did some benchmarking with the command on H100 w/o flashinfer
while varying
"pass_config": {"enable_async_tp": true, "enable_sequence_parallelism": true}vs."pass_config": {"enable_async_tp": false, "enable_sequence_parallelism": false}"custom_ops":["+quant_fp8", "+rms_norm"]vs."custom_ops":[]