-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends #19767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f93fcf4
9417465
f7ac2b8
845e40c
f3eb6cb
a7e64d7
01e6130
61f7551
f9e195a
0c94a0d
e528023
a2c2bc2
4646027
97759d3
3e1f552
4ebb561
15ba786
76e7a01
8310c81
a394548
fb8c9f6
2b467a4
96d201f
b52e5f0
62a1451
ff5e1b8
16078b7
6df13c2
7430120
3418837
d148c55
3bde9ad
0e8b47f
ea3e55a
becbd2f
61c34ae
8b42595
3f309e8
c99f236
c03d469
6c46e87
4422fcf
d33bc75
f14f017
998421c
8a1333e
747d55c
cb94169
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,13 +40,12 @@ | |
| @pytest.mark.parametrize( | ||
| "model, quant_key", | ||
| [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) | ||
| @pytest.mark.parametrize( | ||
| "use_triton_fa", [True, False] if current_platform.is_rocm() else [False]) | ||
| @pytest.mark.parametrize("use_triton_fa", [True, False]) | ||
| @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") | ||
| @pytest.mark.skipif(not current_platform.is_cuda_alike(), | ||
| reason="Only test CUDA and ROCm") | ||
| def test_attention_fusion(example_prompts, monkeypatch, model: str, | ||
| quant_key: QuantKey, use_triton_fa: bool): | ||
| @pytest.mark.skipif(not current_platform.is_rocm(), | ||
| reason="V0 attn quant fusion only on ROCm") | ||
| def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, | ||
| quant_key: QuantKey, use_triton_fa: bool): | ||
| # Clean Dynamo cache to avoid reusing other test cases | ||
| # (for some reason the reset at the end is not enough) | ||
| torch._dynamo.reset() | ||
|
|
@@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, | |
| backend="tests.compile.test_fusion_attn.backend_unfused", | ||
| custom_ops=["+quant_fp8"], | ||
| ) | ||
| vllm_config = VllmConfig(compilation_config=compile_config) | ||
| vllm_config = VllmConfig(compilation_config=compile_config, | ||
| model_config=ModelConfig( | ||
| model=model, | ||
| dtype=torch.bfloat16, | ||
| )) | ||
| backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) | ||
|
|
||
| llm = LLM(model, | ||
| enforce_eager=True, | ||
| compilation_config=compile_config, | ||
| gpu_memory_utilization=0.9, | ||
| gpu_memory_utilization=0.5, | ||
| max_model_len=2048) | ||
|
|
||
| sampling_params = SamplingParams(temperature=0.0, | ||
|
|
@@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, | |
| backend="tests.compile.test_fusion_attn.backend", | ||
| custom_ops=["+quant_fp8"], | ||
| ) | ||
| vllm_config = VllmConfig(compilation_config=compile_config) | ||
| vllm_config = VllmConfig(compilation_config=compile_config, | ||
| model_config=ModelConfig( | ||
| model=model, | ||
| dtype=torch.bfloat16, | ||
| )) | ||
|
|
||
| # AttnFusionPass needs attention layers to be registered in config upon init | ||
| # so we initialize it during compilation. | ||
|
|
@@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, | |
| llm2 = LLM(model, | ||
| enforce_eager=True, | ||
| compilation_config=compile_config, | ||
| gpu_memory_utilization=0.9, | ||
| gpu_memory_utilization=0.5, | ||
| max_model_len=2048) | ||
|
|
||
| # check support | ||
|
|
@@ -171,6 +178,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, | |
| cache_config=vllm_config.cache_config, | ||
| prefix="model.layers.0.self_attn.attn", | ||
| ) | ||
| self.attn._k_scale = self.attn._k_scale.to(device) | ||
| self.attn._v_scale = self.attn._v_scale.to(device) | ||
|
|
||
| self.block_size = 16 | ||
|
|
||
|
|
@@ -188,7 +197,7 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, | |
| device=self.device, | ||
| ) | ||
|
|
||
| def build_attn_metadata(self, batch_size: int): | ||
| def build_attn_metadata(self, batch_size: int, use_hnd: bool): | ||
| """Initialize attention metadata.""" | ||
|
|
||
| # Create common attn metadata | ||
|
|
@@ -205,18 +214,26 @@ def build_attn_metadata(self, batch_size: int): | |
| num_blocks = batch_size * max_blocks | ||
|
|
||
| # Create dummy KV cache for FlashInfer TRTLLM | ||
| # - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] | ||
| # - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] | ||
| # Create kv_cache in HND layout and permute to NHD layout | ||
| # (later will be permuted back to HND layout in forward pass) | ||
| # - NHD: [num_blocks, block_size, num_kv_heads, head_size] | ||
| # - HND: [num_blocks, num_kv_heads, block_size, head_size] | ||
| kv_cache = torch.zeros(num_blocks, | ||
| 2, | ||
| self.num_kv_heads, | ||
| self.block_size, | ||
| self.head_size, | ||
| dtype=self.kv_cache_dtype, | ||
| device=self.device) | ||
| kv_cache = kv_cache.permute(0, 1, 3, 2, 4) | ||
| if current_platform.is_rocm(): | ||
| # k/v as 1st dimention | ||
| if use_hnd: | ||
| kv_cache = kv_cache.permute(1, 0, 2, 3, 4) | ||
| else: | ||
| kv_cache = kv_cache.permute(1, 0, 3, 2, 4) | ||
| else: | ||
| # k/v as 2nd dimention | ||
| # Create kv_cache in HND layout and permute to NHD layout | ||
| # (later will be permuted back to HND layout in forward pass) | ||
| kv_cache = kv_cache.permute(0, 1, 3, 2, 4) | ||
| self.attn.kv_cache = [kv_cache] | ||
|
|
||
| # Build attn metadata | ||
|
|
@@ -296,28 +313,51 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): | |
| out_dtype=attn_output.dtype) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) | ||
| if current_platform.is_cuda(): | ||
| MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", | ||
| TestAttentionFp8StaticQuantPatternModel), | ||
| ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", | ||
| TestAttentionNvfp4QuantPatternModel)] | ||
| HEADS = [(64, 8), (40, 8)] | ||
| elif current_platform.is_rocm(): | ||
| MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV", | ||
| TestAttentionFp8StaticQuantPatternModel)] | ||
| HEADS = [(32, 8), (40, 8)] | ||
| else: | ||
| MODELS = [] | ||
| HEADS = [] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
| @pytest.mark.parametrize("batch_size", [7, 256, 533]) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("model_name, model_class", | ||
| [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", | ||
| TestAttentionFp8StaticQuantPatternModel), | ||
| ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", | ||
| TestAttentionNvfp4QuantPatternModel)]) | ||
| @pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) | ||
| @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") | ||
| @pytest.mark.parametrize("batch_size", | ||
| [7, 256, 533] if current_platform.is_cuda() else [8]) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @pytest.mark.parametrize("model_name, model_class", MODELS) | ||
| @pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if | ||
| current_platform.is_cuda() else [_Backend.ROCM_FLASH]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we want to test the triton backend as well?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is tested with the split_attention parameter. The 2 approaches share the same backend class
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, this actually dispatches to the triton backend. We should cleanup the attention backend selection logic on rocm
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I just remembered: we could test the Triton backend on CUDA as well, it would run in CI automatically which would be nice. Could you add the triton backend to the list of cuda backends?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can add Triton backend here as a follow-up |
||
| @pytest.mark.parametrize( | ||
| "split_attention", | ||
| [False, True] if current_platform.is_rocm() else [False]) | ||
| @pytest.mark.skipif(not current_platform.is_cuda_alike(), | ||
| reason="Only test ROCm or CUDA") | ||
| @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") | ||
| @pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), | ||
| reason="Only test on SM100(Blackwell)") | ||
| @pytest.mark.skipif(current_platform.is_cuda() | ||
| and not current_platform.is_device_capability((10, 0)), | ||
| reason="On CUDA only test on SM100(Blackwell)") | ||
ProExpertProg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @pytest.mark.skipif(not current_platform.is_cuda_alike(), | ||
| reason="Only test ROCm or CUDA") | ||
| def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | ||
| head_size: int, batch_size: int, | ||
| dtype: torch.dtype, model_name: str, | ||
| model_class: type[AttentionQuantPatternModel], | ||
| backend: _Backend, monkeypatch, dist_init): | ||
| backend: _Backend, split_attention: bool, | ||
| monkeypatch, dist_init): | ||
| """Test AttentionStaticQuantPattern fusion pass""" | ||
|
|
||
| monkeypatch.setenv("VLLM_USE_V1", "1") | ||
| if split_attention: | ||
| monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") | ||
|
|
||
| device = torch.device("cuda:0") | ||
| torch.manual_seed(42) | ||
|
|
@@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | |
| model_config=ModelConfig( | ||
| model=model_name, | ||
| max_model_len=2048, | ||
| dtype=dtype, | ||
| ), | ||
| scheduler_config=SchedulerConfig(max_num_seqs=1024), | ||
| compilation_config=CompilationConfig( | ||
|
|
@@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | |
|
|
||
| forward_ctx = get_forward_context() | ||
| forward_ctx.attn_metadata = model_unfused.build_attn_metadata( | ||
| batch_size) | ||
| batch_size, use_hnd=split_attention) | ||
|
|
||
| # Run model directly without compilation and fusion | ||
| result_unfused = model_unfused(q, k, v) | ||
|
|
@@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | |
| model_fused = model_fused.to(device) | ||
|
|
||
| forward_ctx = get_forward_context() | ||
| forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) | ||
| forward_ctx.attn_metadata = model_fused.build_attn_metadata( | ||
| batch_size, use_hnd=split_attention) | ||
|
|
||
| # Create test backend with fusion passes enabled | ||
| noop_pass = NoOpEliminationPass(vllm_config) | ||
|
|
@@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | |
| assert model_compiled.attn._o_scale_float is None | ||
| result_fused_1 = model_compiled(q, k, v) | ||
|
|
||
| # After the 1st round of the forward pass, output quant scale should be | ||
| # loaded into the attn layer's _o_scale_float, the 2nd round should | ||
| # reuse the loaded _o_scale_float | ||
| assert model_compiled.attn._o_scale_float is not None | ||
| result_fused_2 = model_compiled(q, k, v) | ||
| assert model_compiled.attn._o_scale_float is not None | ||
| if backend == _Backend.FLASHINFER: | ||
| # With the Flashinfer backend after the 1st round of the forward | ||
| # pass, output quant scale should be loaded into the attn layer's | ||
| # _o_scale_float, the 2nd round should reuse the loaded | ||
| # _o_scale_float | ||
| assert model_compiled.attn._o_scale_float is not None | ||
| result_fused_2 = model_compiled(q, k, v) | ||
| assert model_compiled.attn._o_scale_float is not None | ||
|
|
||
| torch.testing.assert_close(result_unfused, | ||
| result_fused_2, | ||
| atol=1e-2, | ||
| rtol=1e-2) | ||
|
|
||
| # Check attn fusion support | ||
| quant_key = model_class.quant_key | ||
|
|
@@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | |
| assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ | ||
| "Attention should have output_block_scale after FP4 fusion" # noqa: E501 | ||
|
|
||
| # Check that results are closed | ||
| # Check that results are close | ||
| torch.testing.assert_close(result_unfused, | ||
| result_fused_1, | ||
| atol=1e-2, | ||
| rtol=1e-2) | ||
| torch.testing.assert_close(result_unfused, | ||
| result_fused_2, | ||
| atol=1e-2, | ||
| rtol=1e-2) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? Where would ROCm actually do this? Because I think it might break the Blackwell FI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using on-cpu tensors in the reshape_and_cache kernel causes a crash. In production the default device is set to CUDA before the tensors are created, but not in the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just set the default device at the start of the test then?