88from vllm .compilation .fusion import RMSNormQuantFusionPass
99from vllm .compilation .noop_elimination import NoOpEliminationPass
1010from vllm .compilation .post_cleanup import PostCleanupPass
11- from vllm .config import (CompilationConfig , CompilationLevel , ModelConfig ,
12- PassConfig , VllmConfig )
11+ from vllm .config import (
12+ CompilationConfig ,
13+ CompilationLevel ,
14+ ModelConfig ,
15+ PassConfig ,
16+ VllmConfig ,
17+ )
1318from vllm .model_executor .layers .layernorm import RMSNorm
1419from vllm .model_executor .layers .quantization .utils .quant_utils import (
15- GroupShape , QuantKey , ScaleDesc )
20+ GroupShape ,
21+ QuantKey ,
22+ ScaleDesc ,
23+ )
1624from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
17- Fp8LinearOp , cutlass_fp8_supported , maybe_create_device_identity )
25+ Fp8LinearOp ,
26+ cutlass_fp8_supported ,
27+ maybe_create_device_identity ,
28+ )
1829from vllm .platforms import current_platform
1930
2031from ..utils import override_cutlass_fp8_supported
2435
2536
2637class TestModel (torch .nn .Module ):
27-
28- def __init__ (self , hidden_size : int , eps : float , static : bool ,
29- cuda_force_torch : bool , * args , ** kwargs ):
38+ def __init__ (
39+ self ,
40+ hidden_size : int ,
41+ eps : float ,
42+ static : bool ,
43+ cuda_force_torch : bool ,
44+ * args ,
45+ ** kwargs ,
46+ ):
3047 super ().__init__ (* args , ** kwargs )
3148 self .cuda_force_torch = cuda_force_torch
3249 self .norm = [RMSNorm (hidden_size , eps ) for _ in range (4 )]
@@ -57,30 +74,27 @@ def forward(self, x):
5774 x = resid = torch .relu (x )
5875 y = self .norm [0 ](x )
5976
60- x2 = self .fp8_linear .apply (y ,
61- self .w [0 ],
62- self .wscale [0 ],
63- input_scale = self .scale [0 ])
77+ x2 = self .fp8_linear .apply (
78+ y , self .w [0 ], self .wscale [0 ], input_scale = self .scale [0 ]
79+ )
6480 # make sure resid is used for replacement to work
6581 y2 , resid = self .norm [1 ](x2 , resid )
6682
67- x3 = self .fp8_linear .apply (y2 ,
68- self .w [1 ],
69- self .wscale [1 ],
70- input_scale = self .scale [1 ])
83+ x3 = self .fp8_linear .apply (
84+ y2 , self .w [1 ], self .wscale [1 ], input_scale = self .scale [1 ]
85+ )
7186
7287 y3 , resid = self .norm [2 ](x3 , resid ) # use resid here
7388
74- x4 = self .fp8_linear .apply (y3 ,
75- self .w [2 ],
76- self .wscale [2 ],
77- input_scale = self .scale [2 ])
89+ x4 = self .fp8_linear .apply (
90+ y3 , self .w [2 ], self .wscale [2 ], input_scale = self .scale [2 ]
91+ )
7892
7993 y4 , resid = self .norm [3 ](x4 , resid ) # use resid here
8094 return y4
8195
8296
83- @pytest .mark .parametrize ("dtype" , [torch .float16 ]) #, torch.bfloat16])
97+ @pytest .mark .parametrize ("dtype" , [torch .float16 ]) # , torch.bfloat16])
8498@pytest .mark .parametrize ("hidden_size" , [64 ])
8599@pytest .mark .parametrize ("num_tokens" , [257 ])
86100@pytest .mark .parametrize ("eps" , [1e-5 , 1e-6 ])
@@ -89,13 +103,22 @@ def forward(self, x):
89103@pytest .mark .parametrize ("enable_quant_fp8" , [True , False ])
90104# cuda_force_torch used to test torch code path on platforms that
91105# cutlass_fp8_supported() == True.
92- @pytest .mark .parametrize ("cuda_force_torch" ,
93- [True , False ] if cutlass_fp8_supported () else [True ])
94- @pytest .mark .skipif (not current_platform .is_cuda_alike (),
95- reason = "Only test on CUDA and ROCm" )
96- def test_fusion_rmsnorm_quant (dtype , hidden_size , num_tokens , eps , static ,
97- enable_rms_norm , enable_quant_fp8 ,
98- cuda_force_torch ):
106+ @pytest .mark .parametrize (
107+ "cuda_force_torch" , [True , False ] if cutlass_fp8_supported () else [True ]
108+ )
109+ @pytest .mark .skipif (
110+ not current_platform .is_cuda_alike (), reason = "Only test on CUDA and ROCm"
111+ )
112+ def test_fusion_rmsnorm_quant (
113+ dtype ,
114+ hidden_size ,
115+ num_tokens ,
116+ eps ,
117+ static ,
118+ enable_rms_norm ,
119+ enable_quant_fp8 ,
120+ cuda_force_torch ,
121+ ):
99122 torch .set_default_device ("cuda" )
100123 torch .set_default_dtype (dtype )
101124 torch .manual_seed (1 )
0 commit comments