55import torch
66
77import vllm .envs as envs
8- from vllm import LLM , SamplingParams
98from vllm .compilation .activation_quant_fusion import ActivationQuantFusionPass
109from vllm .compilation .fix_functionalization import FixFunctionalizationPass
11- from vllm .compilation .fusion import FUSED_OPS , RMSNormQuantFusionPass
10+ from vllm .compilation .fusion import RMSNormQuantFusionPass
1211from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe , is_func
1312from vllm .compilation .noop_elimination import NoOpEliminationPass
1413from vllm .compilation .post_cleanup import PostCleanupPass
1514from vllm .config import CompilationConfig , PassConfig , VllmConfig
15+ from vllm .model_executor .layers .activation import SiluAndMul
16+ from vllm .model_executor .layers .layernorm import RMSNorm
1617from vllm .model_executor .layers .quantization .utils .quant_utils import (
17- QuantKey , kFp8DynamicTokenSym , kFp8StaticTensorSym )
18+ GroupShape )
19+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
20+ Fp8LinearOp )
21+ from vllm .model_executor .layers .rotary_embedding import get_rope
22+ from vllm .platforms import current_platform
1823
1924from .backend import TestBackend
2025
21- OPS_IN_MODEL = [
22- torch .ops ._C .rotary_embedding .default ,
23- torch .ops ._C .fused_add_rms_norm .default ,
24- ]
26+ TEST_FP8 = current_platform .supports_fp8 ()
27+ FP8_DTYPE = current_platform .fp8_dtype ()
28+
29+
30+ class TestSiluMul (torch .nn .Module ):
31+
32+ def __init__ (self , hidden_size : int = 128 ):
33+ super ().__init__ ()
34+ self .silu_and_mul = SiluAndMul ()
35+ self .wscale = torch .rand (1 , dtype = torch .float32 )
36+ self .scale = torch .rand (1 , dtype = torch .float32 )
37+
38+ if TEST_FP8 :
39+ self .w = torch .rand (hidden_size ,
40+ hidden_size ).to (dtype = FP8_DTYPE ).t ()
41+ self .fp8_linear = Fp8LinearOp (
42+ act_quant_static = True ,
43+ act_quant_group_shape = GroupShape .PER_TENSOR ,
44+ )
45+
46+ def forward (self , x ):
47+ y = self .silu_and_mul (x )
48+ if TEST_FP8 :
49+ x2 = self .fp8_linear .apply (y ,
50+ self .w ,
51+ self .wscale ,
52+ input_scale = self .wscale )
53+ return x2
54+ else :
55+ return y
56+
57+ def example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
58+ dtype = torch .float16 if TEST_FP8 else torch .float32
59+ return (torch .rand (num_tokens , hidden_size * 2 , dtype = dtype ), )
60+
61+ def ops_in_model (self , do_fusion ):
62+ if TEST_FP8 and do_fusion :
63+ return [torch .ops ._C .silu_and_mul_quant .default ]
64+ else :
65+ return [torch .ops ._C .silu_and_mul .default ]
66+
67+ def ops_not_in_model (self ):
68+ return []
69+
70+
71+ class TestFusedAddRMSNorm (torch .nn .Module ):
72+
73+ def __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
74+ super ().__init__ ()
75+ self .hidden_size = hidden_size
76+ self .intermediate_size = intermediate_size
77+
78+ dtype = torch .float16 if TEST_FP8 else torch .float32
79+
80+ self .gate_proj = torch .nn .Parameter (
81+ torch .empty ((intermediate_size , hidden_size ), dtype = dtype ))
82+ self .norm = RMSNorm (intermediate_size , 1e-05 )
83+ self .norm .weight = torch .nn .Parameter (
84+ torch .ones (intermediate_size , dtype = dtype ))
85+
86+ torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
87+
88+ if TEST_FP8 :
89+ self .fp8_linear = Fp8LinearOp (act_quant_static = True )
90+
91+ self .scale = torch .rand (1 , dtype = torch .float32 )
92+ self .w = torch .rand (hidden_size ,
93+ intermediate_size ).to (dtype = FP8_DTYPE ).t ()
94+ self .wscale = torch .rand (1 , dtype = torch .float32 )
95+
96+ def forward (self , hidden_states , residual ):
97+ # Reshape input
98+ view = hidden_states .reshape (- 1 , self .hidden_size )
99+
100+ # matrix multiplication
101+ permute = self .gate_proj .permute (1 , 0 )
102+ mm = torch .mm (view , permute )
103+
104+ # layer normalization
105+ norm_output , residual_output = self .norm (mm , residual )
106+
107+ if TEST_FP8 :
108+ # scaled_mm with static input quantization
109+ fp8_linear_result = self .fp8_linear .apply (
110+ norm_output ,
111+ self .w ,
112+ self .wscale ,
113+ input_scale = self .scale .to (norm_output .device ),
114+ )
115+
116+ return fp8_linear_result , residual_output
117+
118+ else :
119+ return norm_output , residual_output
120+
121+ def example_inputs (self , batch_size = 8 , hidden_size = 16 , seq_len = 16 ):
122+ dtype = torch .float16 if TEST_FP8 else torch .float32
123+ hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
124+ dtype = dtype )
125+ residual = torch .randn ((batch_size * seq_len , hidden_size ),
126+ dtype = dtype )
127+ return (hidden_states , residual )
25128
26- RMS_OP = torch .ops ._C .rms_norm .default
129+ def ops_in_model (self , do_fusion ):
130+ if TEST_FP8 and do_fusion :
131+ return [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
132+ else :
133+ return [torch .ops ._C .fused_add_rms_norm .default ]
27134
28- RMS_QUANT_OPS = {
29- "static_fp8" : [
30- torch .ops ._C .rms_norm_static_fp8_quant .default ,
31- torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
32- ],
33- }
135+ def ops_not_in_model (self ):
136+ return []
34137
35- SILU_MUL_OP = torch .ops ._C .silu_and_mul .default
36138
37- SILU_MUL_QUANT_OP = torch .ops ._C .silu_and_mul_quant .default
38- prompts = [
39- "Hello, my name is" ,
40- "The president of the United States is" ,
41- "The capital of France is" ,
42- "The future of AI is" ,
139+ class TestRotaryEmbedding (torch .nn .Module ):
140+
141+ def __init__ (self ,
142+ head_dim = 64 ,
143+ rotary_dim = None ,
144+ max_position = 2048 ,
145+ base = 10000 ):
146+ super ().__init__ ()
147+ self .head_dim = head_dim
148+ self .rotary_dim = rotary_dim or head_dim
149+
150+ self .rotary_emb = get_rope (
151+ self .head_dim ,
152+ rotary_dim = self .rotary_dim ,
153+ max_position = max_position ,
154+ base = base ,
155+ )
156+
157+ def forward (self , positions , q , k ):
158+ q_rotated , k_rotated = self .rotary_emb (positions , q , k )
159+ return q_rotated , k_rotated
160+
161+ def example_inputs (self , num_tokens = 32 , head_dim = 64 ):
162+ dtype = torch .float16
163+ positions = torch .arange (num_tokens , dtype = torch .long )
164+ q = torch .randn (num_tokens , head_dim , dtype = dtype )
165+ k = torch .randn (num_tokens , head_dim , dtype = dtype )
166+ return (positions , q , k )
167+
168+ def ops_in_model (self , do_fusion ):
169+ return [torch .ops ._C .rotary_embedding .default ]
170+
171+ def ops_not_in_model (self ):
172+ return []
173+
174+
175+ class TestRotaryEmbeddingSliceScatter (torch .nn .Module ):
176+
177+ def __init__ (self ,
178+ head_dim = 64 ,
179+ num_heads = 4 ,
180+ max_position = 2048 ,
181+ base = 10000 ):
182+ super ().__init__ ()
183+ self .head_dim = head_dim
184+ self .num_heads = num_heads
185+ self .hidden_size = head_dim * num_heads
186+
187+ self .qkv_proj = torch .nn .Linear (self .hidden_size ,
188+ self .hidden_size * 3 ,
189+ bias = False ,
190+ dtype = torch .float16 )
191+
192+ self .rotary_emb = get_rope (
193+ self .head_dim ,
194+ rotary_dim = self .head_dim ,
195+ max_position = max_position ,
196+ base = base ,
197+ )
198+
199+ def forward (self , positions , hidden_states ):
200+ # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
201+ # -> slice_scatter -> split_with_sizes
202+
203+ qkv = self .qkv_proj (hidden_states )
204+ split_sizes = [self .hidden_size , self .hidden_size , self .hidden_size ]
205+ q , k , v = torch .split (qkv , split_sizes , dim = - 1 )
206+
207+ q_rotated , k_rotated = self .rotary_emb (positions , q , k )
208+
209+ qkv_updated = torch .cat ([q_rotated , k_rotated , v ], dim = - 1 )
210+ return qkv_updated
211+
212+ def example_inputs (self , num_tokens = 32 , head_dim = 64 , num_heads = 4 ):
213+ dtype = torch .float16
214+ hidden_size = head_dim * num_heads
215+ positions = torch .arange (num_tokens , dtype = torch .long )
216+ hidden_states = torch .randn (num_tokens , hidden_size , dtype = dtype )
217+ return (positions , hidden_states )
218+
219+ def ops_in_model (self , do_fusion ):
220+ return [torch .ops ._C .rotary_embedding .default ]
221+
222+ def ops_not_in_model (self ):
223+ return [torch .ops .aten .slice_scatter .default ]
224+
225+
226+ MODELS = [
227+ TestSiluMul ,
228+ TestFusedAddRMSNorm ,
229+ TestRotaryEmbedding ,
230+ TestRotaryEmbeddingSliceScatter ,
43231]
44232
45233
46- @pytest .mark .parametrize (
47- "model, quant_key" ,
48- [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" , kFp8StaticTensorSym ),
49- ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e" ,
50- kFp8DynamicTokenSym )])
234+ @pytest .mark .parametrize ("model_class" , MODELS )
51235@pytest .mark .parametrize ("do_fusion" , [True , False ])
52236@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
53237 reason = "Only test on CUDA" )
54- def test_fix_functionalization (model : str , quant_key : QuantKey ,
55- do_fusion : bool ):
238+ def test_fix_functionalization (model_class : torch .nn .Module , do_fusion : bool ):
56239 torch .set_default_device ("cuda" )
57240
58241 vllm_config = VllmConfig ()
@@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
63246 cleanup_pass = PostCleanupPass (vllm_config )
64247 act_quant_fusion_pass = ActivationQuantFusionPass (vllm_config )
65248
66- passes = [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass
67- ] if do_fusion else [noop_pass , cleanup_pass ]
249+ passes = ( [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass ]
250+ if do_fusion else [noop_pass , cleanup_pass ])
68251 func_pass = FixFunctionalizationPass (vllm_config )
252+
69253 backend_func = TestBackend (* passes , func_pass )
70254 backend_no_func = TestBackend (* passes )
71255
72- # instantiate a full engine and manually compile the model 2x
73- # (with and without FixFunctionalizationPass)
74- llm = LLM (model = model , enforce_eager = True )
75- model_runner = llm .llm_engine .model_executor .driver_worker .model_runner
76- orig_model = model_runner .model
77- # TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
78- # Can only do that by using the decorator but then we'd have to instantiate
79- # 2 LLM instances.
80-
81- sampling_params = SamplingParams (temperature = 0.0 , top_p = 1.0 )
82- model_runner .model = torch .compile (orig_model ,
83- fullgraph = True ,
84- backend = backend_func )
85- gen_func = llm .generate (prompts , sampling_params )
86-
87- model_runner .model = torch .compile (orig_model ,
88- fullgraph = True ,
89- backend = backend_no_func )
90-
91- gen_no_func = llm .generate (prompts , sampling_params )
92-
93- for output_func , output_no_func in zip (gen_func , gen_no_func ):
94- assert output_func .outputs [0 ].text == output_no_func .outputs [0 ].text
95-
96- # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
97- # and replaced by fused quantized ops in RMS_QUANT_OPS.
98- rms_ops = [FUSED_OPS [(quant_key , True )], FUSED_OPS [(quant_key , False )]
99- ] if do_fusion else [RMS_OP ]
100- silu_mul_ops = [SILU_MUL_QUANT_OP ] if do_fusion and \
101- quant_key == kFp8StaticTensorSym else [
102- SILU_MUL_OP
103- ]
104-
105- ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
106-
107- for op in ops :
256+ model = model_class ()
257+ torch .compile (model , backend = backend_func )(* model .example_inputs ())
258+ torch .compile (model , backend = backend_no_func )(* model .example_inputs ())
259+
260+ # check if the functionalization pass is applied
261+ for op in model .ops_in_model (do_fusion ):
108262 find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
109- assert find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,
110- op ) is None # noqa: E501
263+ assert ( find_auto_fn_maybe (backend_func .graph_post_pass .nodes , op )
264+ is None ) # noqa: E501
111265
112266 # make sure the ops were all de-functionalized
113267 found = dict ()
114268 for node in backend_func .graph_post_pass .nodes :
115- for op in ops :
269+ for op in model .ops_in_model (do_fusion ):
270+ if is_func (node , op ):
271+ found [op ] = True
272+ for op in model .ops_not_in_model ():
116273 if is_func (node , op ):
117274 found [op ] = True
118- assert all (found [op ] for op in ops )
275+ assert all (found [op ] for op in model .ops_in_model (do_fusion ))
276+ assert all (not found .get (op ) for op in model .ops_not_in_model ())
0 commit comments