|
6 | 6 | import pytest |
7 | 7 | import torch._dynamo |
8 | 8 |
|
9 | | -from tests.compile.backend import TestBackend |
| 9 | +from tests.compile.backend import LazyInitPass, TestBackend |
10 | 10 | from tests.models.utils import check_outputs_equal |
11 | 11 | from tests.v1.attention.utils import (BatchSpec, _Backend, |
12 | 12 | create_common_attn_metadata) |
13 | 13 | from vllm import LLM, SamplingParams |
14 | 14 | from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant |
15 | | -from vllm.attention import Attention |
| 15 | +from vllm.attention import Attention, AttentionMetadata |
16 | 16 | from vllm.attention.selector import global_force_attn_backend_context_manager |
17 | 17 | from vllm.compilation.fusion import QUANT_OPS |
18 | 18 | from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass |
19 | 19 | from vllm.compilation.fx_utils import find_op_nodes |
20 | 20 | from vllm.compilation.noop_elimination import NoOpEliminationPass |
| 21 | +from vllm.compilation.post_cleanup import PostCleanupPass |
21 | 22 | from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, |
22 | 23 | ModelConfig, PassConfig, SchedulerConfig, VllmConfig, |
23 | 24 | set_current_vllm_config) |
@@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, |
104 | 105 |
|
105 | 106 | # AttnFusionPass needs attention layers to be registered in config upon init |
106 | 107 | # so we initialize it during compilation. |
107 | | - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) |
| 108 | + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) |
108 | 109 | backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) |
109 | 110 | llm2 = LLM(model, |
110 | 111 | enforce_eager=True, |
@@ -197,7 +198,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, |
197 | 198 | device=self.device, |
198 | 199 | ) |
199 | 200 |
|
200 | | - def build_attn_metadata(self, batch_size: int, use_hnd: bool): |
| 201 | + def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ |
| 202 | + -> AttentionMetadata: |
201 | 203 | """Initialize attention metadata.""" |
202 | 204 |
|
203 | 205 | # Create common attn metadata |
@@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, |
447 | 449 |
|
448 | 450 | # Create test backend with fusion passes enabled |
449 | 451 | noop_pass = NoOpEliminationPass(vllm_config) |
450 | | - attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw |
451 | | - ) |
452 | | - test_backend = TestBackend(noop_pass, attn_pass) |
| 452 | + attn_pass = LazyInitPass(AttnFusionPass, vllm_config) |
| 453 | + cleanup_pass = PostCleanupPass(vllm_config) |
| 454 | + |
| 455 | + test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) |
453 | 456 |
|
454 | 457 | # Compile model with fusion enabled |
455 | 458 | model_compiled = torch.compile(model_fused, |
@@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, |
485 | 488 | test_backend.check_before_ops([QUANT_OPS[quant_key]], |
486 | 489 | fully_replaced=True) |
487 | 490 |
|
| 491 | + # access the underlying `AttnFusionPass` on the `LazyInitPass` |
| 492 | + assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) |
| 493 | + |
488 | 494 | # Check attention ops in the graph before and after fusion |
489 | 495 | attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) |
490 | 496 | attn_nodes_post = list(find_op_nodes(ATTN_OP, |
|
0 commit comments