-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[V1] v1 engine + full CUDA graph support for PLaMo2 #23998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant and well-executed effort to enable v1 engine and full CUDA graph support for the Plamo2 model architecture. The changes are comprehensive, including updates to the model implementation, test configurations, and documentation. The refactoring of Plamo2MambaMixer
to use CustomOp
and the v1-style state management is particularly well done. I've identified one critical bug in the handling of state_indices_tensor
that could lead to an out-of-bounds error in mixed-batch scenarios. Addressing this issue should make the implementation robust.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work - I have a few tiny comments.
Do you have any lm_eval results comparing V0 to V1? Just so we are confident re: correctness? Nevermind, I see you included it in the PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - @nopperl can you please fix the DCO issue?
e857f4b
to
8060211
Compare
8060211
to
bdfbd8b
Compare
Signed-off-by: Hemmi Shinichi <[email protected]>
Signed-off-by: Hemmi Shinichi <[email protected]>
Signed-off-by: Hemmi Shinichi <[email protected]>
Signed-off-by: nopperl <[email protected]>
Signed-off-by: nopperl <[email protected]>
Signed-off-by: nopperl <[email protected]>
Signed-off-by: nopperl <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]> Signed-off-by: nopperl <[email protected]>
07e358d
to
ad8a800
Compare
@nopperl I tried to run Command
I also printed the kv_cache_spec before the error
|
@cyang49 the issue with PLaMo2.1 is fixed now! |
Signed-off-by: Hemmi Shinichi <[email protected]> Signed-off-by: nopperl <[email protected]> Co-authored-by: Hemmi Shinichi <[email protected]> Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: Hemmi Shinichi <[email protected]> Signed-off-by: nopperl <[email protected]> Co-authored-by: Hemmi Shinichi <[email protected]> Co-authored-by: Thomas Parnell <[email protected]>
Purpose
This PR follows the great work by @heheda12345 and @tdoublep to finally enable the v1 engine and full CUDA graphs (decode only) for the Plamo2 model architecture.
It also incorporates other recent improvements to
MambaMixer2
(such as #21075 and #18218).Additionally, this PR fixes #22999 and consequently re-enables Plamo2 in the hybrid model unit tests. I also enabled Plamo2 for the full CUDA graph unit tests.
Note: To support (piecewise) CUDA graphs and
torch.compile
, thePlamo2MambaMixer
is added toCompilationConfig._attention_ops
by default. I think it's OK to do since the diff is minimal.This PR is potentially in conflict with #21467 (cc @cyang49).
Fixes #23956.
Incorporates #23520 to fix #22999.
Benchmarks
Latency
Latency improves quite a bit.
This PR (v1 engine,
CUDAGraphMode.FULL_AND_PIECEWISE
):main
(v0 engine, piecewise):Throughput
Throughput does not significantly change.
This PR (v1 engine,
CUDAGraphMode.FULL_AND_PIECEWISE
):main
(v0 engine, piecewise):Output quality
No output degradation.
This PR (v1 engine):
main
(v0 engine, piecewise):Test Result
pytest -s -v tests/models/language/generation/test_hybrid.py
passed.Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.