Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f93fcf4
Enable attention output FP8 fusion for V1 attention backends
gshtras Jun 17, 2025
9417465
reformat
gshtras Jun 17, 2025
f7ac2b8
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Jun 17, 2025
845e40c
No longer disabling noop and fusion in V1 graph mode
gshtras Jun 18, 2025
f3eb6cb
Restrict triton attention fusion to per tensor scaling
gshtras Jun 18, 2025
a7e64d7
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Jun 18, 2025
01e6130
Fix the fusion pattern to account fot the output tensot ro be initial…
gshtras Jun 19, 2025
61f7551
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Jun 25, 2025
f9e195a
V1 unit tests
gshtras Jun 30, 2025
0c94a0d
Using the updated test backend
gshtras Jun 30, 2025
e528023
TEMP fixed tests (except output)
ProExpertProg Jun 27, 2025
a2c2bc2
Cleanup and attempt to use an inverse scale
gshtras Jul 1, 2025
4646027
Another inverted scale for the V1 split attn prefill path
gshtras Jul 1, 2025
97759d3
Using empty tensors with matching dtype
gshtras Jul 1, 2025
3e1f552
Merge remote-tracking branch 'upstream/main' into attention_fusion_v1
gshtras Jul 8, 2025
4ebb561
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Jul 14, 2025
15ba786
Fix the new test. Rename parameters
gshtras Jul 14, 2025
76e7a01
Remove deprecated parameter
gshtras Jul 14, 2025
8310c81
Add quant custom op in the test
gshtras Jul 14, 2025
a394548
Rename parameter names at call sites
gshtras Jul 14, 2025
fb8c9f6
I will always press save after pressing reformat
gshtras Jul 14, 2025
2b467a4
Remove memory restrictions
gshtras Jul 14, 2025
96d201f
linter
gshtras Jul 14, 2025
b52e5f0
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Jul 16, 2025
62a1451
Skip attn fusion test on CUDA. Fusion is not applicable to the CUDA F…
gshtras Jul 16, 2025
ff5e1b8
Try to force triton attention on cuda in the test
gshtras Jul 16, 2025
16078b7
Format
gshtras Jul 16, 2025
6df13c2
Remove debug leftovers
gshtras Jul 28, 2025
7430120
Try out the process per test decorator
gshtras Jul 28, 2025
3418837
Syntax
gshtras Jul 28, 2025
d148c55
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Jul 29, 2025
3bde9ad
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Aug 6, 2025
0e8b47f
Remove the print
gshtras Aug 6, 2025
ea3e55a
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Aug 6, 2025
becbd2f
Trying to figure out why the test passes on ROCm, but OOMs on CUDA
gshtras Aug 6, 2025
61c34ae
Spawn in itself doesn't seem to work on CUDA
gshtras Aug 7, 2025
8b42595
Trying to fit into the memory constrains of the CI CUDA machines
gshtras Aug 8, 2025
3f309e8
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Sep 2, 2025
c99f236
Fixed pattern and unit tests
gshtras Sep 3, 2025
c03d469
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Sep 3, 2025
6c46e87
Add dtype to the v0 test
gshtras Sep 4, 2025
4422fcf
Add float16 to test. Gate it on cuda-alike platforms. Refactor common…
gshtras Sep 5, 2025
d33bc75
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Sep 5, 2025
f14f017
Less explicit model variables name
gshtras Sep 8, 2025
998421c
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Sep 8, 2025
8a1333e
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Sep 9, 2025
747d55c
Merge remote-tracking branch 'origin/main' into attention_fusion_v1
gshtras Sep 9, 2025
cb94169
cleanup test checks (#663)
ProExpertProg Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 87 additions & 42 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="V0 attn quant fusion only on ROCm")
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
Expand All @@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))

llm = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
gpu_memory_utilization=0.5,
max_model_len=2048)

sampling_params = SamplingParams(temperature=0.0,
Expand All @@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(compilation_config=compile_config)
vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))

# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
Expand All @@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
llm2 = LLM(model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.9,
gpu_memory_utilization=0.5,
max_model_len=2048)

# check support
Expand Down Expand Up @@ -171,6 +178,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
cache_config=vllm_config.cache_config,
prefix="model.layers.0.self_attn.attn",
)
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
Comment on lines +181 to +182
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Where would ROCm actually do this? Because I think it might break the Blackwell FI?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using on-cpu tensors in the reshape_and_cache kernel causes a crash. In production the default device is set to CUDA before the tensors are created, but not in the test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just set the default device at the start of the test then?


self.block_size = 16

Expand All @@ -188,7 +197,7 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
device=self.device,
)

def build_attn_metadata(self, batch_size: int):
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
"""Initialize attention metadata."""

# Create common attn metadata
Expand All @@ -205,18 +214,26 @@ def build_attn_metadata(self, batch_size: int):
num_blocks = batch_size * max_blocks

# Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
# - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
if current_platform.is_rocm():
# k/v as 1st dimention
if use_hnd:
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
else:
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
else:
# k/v as 2nd dimention
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
self.attn.kv_cache = [kv_cache]

# Build attn metadata
Expand Down Expand Up @@ -296,28 +313,51 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
out_dtype=attn_output.dtype)


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
if current_platform.is_cuda():
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)]
HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
TestAttentionFp8StaticQuantPatternModel)]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
HEADS = []


@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_name, model_class",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.parametrize("batch_size",
[7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
current_platform.is_cuda() else [_Backend.ROCM_FLASH])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to test the triton backend as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is tested with the split_attention parameter. The 2 approaches share the same backend class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this actually dispatches to the triton backend. We should cleanup the attention backend selection logic on rocm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I just remembered: we could test the Triton backend on CUDA as well, it would run in CI automatically which would be nice. Could you add the triton backend to the list of cuda backends?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add Triton backend here as a follow-up

@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)),
reason="Only test on SM100(Blackwell)")
@pytest.mark.skipif(current_platform.is_cuda()
and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, monkeypatch, dist_init):
backend: _Backend, split_attention: bool,
monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass"""

monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")

device = torch.device("cuda:0")
torch.manual_seed(42)
Expand All @@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
model_config=ModelConfig(
model=model_name,
max_model_len=2048,
dtype=dtype,
),
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
Expand Down Expand Up @@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,

forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size)
batch_size, use_hnd=split_attention)

# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
Expand All @@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
model_fused = model_fused.to(device)

forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)

# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
Expand All @@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v)

# After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
if backend == _Backend.FLASHINFER:
# With the Flashinfer backend after the 1st round of the forward
# pass, output quant scale should be loaded into the attn layer's
# _o_scale_float, the 2nd round should reuse the loaded
# _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None

torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)

# Check attn fusion support
quant_key = model_class.quant_key
Expand Down Expand Up @@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501

# Check that results are closed
# Check that results are close
torch.testing.assert_close(result_unfused,
result_fused_1,
atol=1e-2,
rtol=1e-2)
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)
18 changes: 16 additions & 2 deletions vllm/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from .prefix_prefill import context_attention_fwd

float8_info = torch.finfo(current_platform.fp8_dtype())


@triton.jit
def cdiv_fn(x, y):
Expand All @@ -34,6 +36,7 @@ def kernel_paged_attention_2d(
scale, # float32
k_scale, # float32
v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int
Expand All @@ -60,7 +63,9 @@ def kernel_paged_attention_2d(
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
):
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)

Expand Down Expand Up @@ -204,6 +209,9 @@ def kernel_paged_attention_2d(

# epilogue
acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)

output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1)
Expand Down Expand Up @@ -234,6 +242,7 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
output_scale=None,
# Optional tensor for sinks
sinks=None,
):
Expand Down Expand Up @@ -266,6 +275,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks,
)

Expand Down Expand Up @@ -316,7 +326,7 @@ def chunked_prefill_paged_decode(
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
Expand Down Expand Up @@ -345,6 +355,7 @@ def chunked_prefill_paged_decode(
kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale,
v_scale=v_scale,
fp8_out_scale=output_scale,
)
else:
kernel_paged_attention_2d[(
Expand All @@ -362,6 +373,8 @@ def chunked_prefill_paged_decode(
scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
out_scale_inv=1.0 /
output_scale if output_scale is not None else 1.0,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded,
Expand All @@ -388,4 +401,5 @@ def chunked_prefill_paged_decode(
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
USE_FP8=output_scale is not None,
)
14 changes: 13 additions & 1 deletion vllm/attention/ops/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())


# Here's an example autotuner config for this kernel. This config does provide
Expand Down Expand Up @@ -43,6 +44,7 @@ def _fwd_kernel(Q,
sm_scale,
k_scale,
v_scale,
out_scale_inv,
B_Start_Loc,
B_Seqlen,
x: tl.constexpr,
Expand Down Expand Up @@ -82,8 +84,11 @@ def _fwd_kernel(Q,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0):
MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):

cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
Expand Down Expand Up @@ -284,6 +289,9 @@ def _fwd_kernel(Q,
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
tl.store(out_ptrs,
acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
Expand Down Expand Up @@ -743,6 +751,7 @@ def context_attention_fwd(q,
sliding_window=None,
sm_scale=None,
skip_decode=False,
fp8_out_scale=None,
sinks=None):

q_dtype_is_f32 = q.dtype is torch.float32
Expand Down Expand Up @@ -793,6 +802,7 @@ def context_attention_fwd(q,

if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
Expand Down Expand Up @@ -870,6 +880,7 @@ def context_attention_fwd(q,
sm_scale,
k_scale,
v_scale,
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
b_start_loc,
b_seq_len,
k_cache.shape[4],
Expand Down Expand Up @@ -905,6 +916,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
BLOCK_M=128,
BLOCK_N=64,
num_unroll_cache=4,
Expand Down
Loading