66
77import vllm .envs as envs
88from vllm .compilation .fix_functionalization import FixFunctionalizationPass
9+ from vllm .compilation .fusion import FusionPass
910from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe , is_func
11+ from vllm .compilation .noop_elimination import NoOpEliminationPass
1012from vllm .compilation .sequence_parallelism import SequenceParallelismPass
1113from vllm .config import (CompilationConfig , DeviceConfig , ModelConfig ,
1214 PassConfig , VllmConfig )
1315from vllm .distributed import tensor_model_parallel_all_reduce
1416from vllm .distributed .parallel_state import (init_distributed_environment ,
1517 initialize_model_parallel )
1618from vllm .model_executor .layers .layernorm import RMSNorm
19+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
20+ Fp8LinearOp )
1721from vllm .platforms import current_platform
1822from vllm .utils import update_environment_variables
1923
2024from ..utils import multi_gpu_test
2125from .backend import TestBackend
2226
27+ FP8_DTYPE = current_platform .fp8_dtype ()
2328prompts = [
2429 "Hello, my name is" ,
2530 "The president of the United States is" ,
3035
3136class TestModel (torch .nn .Module ):
3237
33- def __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
38+ def __init__ (self ,
39+ hidden_size = 16 ,
40+ intermediate_size = 32 ,
41+ vllm_config : VllmConfig = None ):
3442 super ().__init__ ()
3543 self .hidden_size = hidden_size
3644 self .intermediate_size = intermediate_size
3745 self .gate_proj = torch .nn .Parameter (
3846 torch .empty ((intermediate_size , hidden_size )))
39- self .norm = RMSNorm (hidden_size , 1e-05 )
47+ self .norm = RMSNorm (intermediate_size , 1e-05 )
4048 # Initialize weights
4149 torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
4250
@@ -79,32 +87,138 @@ def ops_in_model(self):
7987 return [torch .ops ._C .fused_add_rms_norm .default ]
8088
8189
90+ class TestQuantModel (torch .nn .Module ):
91+
92+ def __init__ (self ,
93+ hidden_size = 16 ,
94+ intermediate_size = 32 ,
95+ vllm_config : VllmConfig = None ):
96+ super ().__init__ ()
97+ self .hidden_size = hidden_size
98+ self .intermediate_size = intermediate_size
99+ self .vllm_config = vllm_config
100+ self .gate_proj = torch .nn .Parameter (torch .empty (
101+ (intermediate_size , hidden_size )),
102+ requires_grad = False )
103+ self .norm = RMSNorm (intermediate_size , 1e-05 )
104+ # Initialize weights
105+ torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
106+
107+ self .fp8_linear = Fp8LinearOp (cutlass_fp8_supported = True ,
108+ use_per_token_if_dynamic = False )
109+
110+ self .scale = torch .rand (1 , dtype = torch .float32 )
111+ # Create a weight that is compatible with torch._scaled_mm,
112+ # which expects a column-major layout.
113+ self .w = torch .rand (hidden_size ,
114+ intermediate_size ).to (dtype = FP8_DTYPE ).t ()
115+ self .wscale = torch .rand (1 , dtype = torch .float32 )
116+
117+ def forward (self , hidden_states , residual ):
118+ """
119+ Forward pass implementing the operations in the FX graph
120+
121+ Args:
122+ hidden_states: Input tensor
123+ residual: Residual tensor from previous layer
124+
125+ Returns:
126+ Tuple containing the output tensor
127+ """
128+ # Reshape input
129+ view = hidden_states .reshape (- 1 , self .hidden_size )
130+
131+ #matrix multiplication
132+ permute = self .gate_proj .permute (1 , 0 )
133+ mm = torch .mm (view , permute )
134+
135+ # Tensor parallel all-reduce
136+ all_reduce = tensor_model_parallel_all_reduce (mm )
137+
138+ # layer normalization
139+ norm_output , residual_output = self .norm (all_reduce , residual )
140+
141+ # for static input quantization
142+ # self.fp8_linear is initialized with use_per_token_if_dynamic=False
143+ fp8_linear_result = self .fp8_linear .apply (norm_output ,
144+ self .w ,
145+ self .wscale ,
146+ input_scale = self .scale .to (
147+ norm_output .device ))
148+
149+ return fp8_linear_result , residual_output
150+
151+ def ops_in_model_before (self ):
152+ ops_to_remove = [torch .ops .vllm .all_reduce .default
153+ ] # Always removed by SP
154+ # The following are only removed if fusion happens
155+ if self .vllm_config and self .vllm_config .compilation_config \
156+ .pass_config .enable_fusion :
157+ ops_to_remove .extend ([
158+ torch .ops ._C .fused_add_rms_norm .default ,
159+ torch .ops ._C .static_scaled_fp8_quant .default ,
160+ ])
161+ return ops_to_remove
162+
163+ def ops_in_model_after (self ):
164+ ops_to_add = [
165+ torch .ops .vllm .reduce_scatter .default ,
166+ torch .ops .vllm .all_gather .default
167+ ]
168+ # The following is only added if fusion happens
169+ if self .vllm_config and self .vllm_config .compilation_config \
170+ .pass_config .enable_fusion :
171+ ops_to_add .append (
172+ torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default )
173+ return ops_to_add
174+
175+ def ops_in_model (self ):
176+ if self .vllm_config and self .vllm_config .compilation_config \
177+ .pass_config .enable_fusion :
178+ # If fusion happens, the fused op is the one
179+ # we check for (de)functionalization
180+ return [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
181+ ] # noqa: E501
182+ else :
183+ # If no fusion, the original ops are checked
184+ return [
185+ torch .ops ._C .fused_add_rms_norm .default ,
186+ # TODO functionalization pass does not handle this yet
187+ # torch.ops._C.static_scaled_fp8_quant.default,
188+ ]
189+
190+
82191@multi_gpu_test (num_gpus = 2 )
192+ @pytest .mark .parametrize ("test_model_cls" , [TestModel , TestQuantModel ])
83193@pytest .mark .parametrize ("batch_size" , [8 ])
84194@pytest .mark .parametrize ("seq_len" , [16 ])
85195@pytest .mark .parametrize ("hidden_size" , [16 ])
86196@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
197+ @pytest .mark .parametrize ("enable_fusion" , [True , False ])
87198@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE not in ["cuda" ],
88199 reason = "Only test on CUDA" )
89- def test_sequence_parallelism_pass (batch_size : int , seq_len : int ,
90- hidden_size : int , dtype : torch .dtype ):
200+ def test_sequence_parallelism_pass (test_model_cls : type [torch .nn .Module ],
201+ batch_size : int , seq_len : int ,
202+ hidden_size : int , dtype : torch .dtype ,
203+ enable_fusion : bool ):
91204 num_processes = 2
92205
93206 def run_torch_spawn (fn , nprocs ):
94207 # need to use torch.mp.spawn otherwise will have problems with
95208 # torch.distributed and cuda
96209 torch .multiprocessing .spawn (fn ,
97- args = (num_processes , batch_size , seq_len ,
98- hidden_size , dtype ),
210+ args = (num_processes , test_model_cls ,
211+ batch_size , seq_len , hidden_size ,
212+ dtype , enable_fusion ),
99213 nprocs = nprocs )
100214
101215 run_torch_spawn (sequence_parallelism_pass_on_test_model , num_processes )
102216
103217
104- def sequence_parallelism_pass_on_test_model (local_rank : int , world_size : int ,
105- batch_size : int , seq_len : int ,
106- hidden_size : int ,
107- dtype : torch .dtype ):
218+ def sequence_parallelism_pass_on_test_model (
219+ local_rank : int , world_size : int ,
220+ test_model_cls : type [ torch . nn . Module ], batch_size : int , seq_len : int ,
221+ hidden_size : int , dtype : torch .dtype , enable_fusion : bool ):
108222 current_platform .seed_everything (0 )
109223
110224 device = torch .device (f"cuda:{ local_rank } " )
@@ -127,26 +241,39 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
127241 # configure vllm config for SequenceParallelismPass
128242 vllm_config = VllmConfig ()
129243 vllm_config .compilation_config = CompilationConfig (pass_config = PassConfig (
130- enable_sequence_parallelism = True ))
244+ enable_sequence_parallelism = True ,
245+ enable_fusion = enable_fusion ,
246+ enable_noop = True )) # NoOp needed for fusion
131247 vllm_config .device_config = DeviceConfig (device = torch .device ("cuda" ))
132248
133249 # this is a fake model name to construct the model config
134250 # in the vllm_config, it's not really used.
135- model = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
136- vllm_config .model_config = ModelConfig (model = model ,
251+ model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
252+ vllm_config .model_config = ModelConfig (model = model_name ,
137253 task = "auto" ,
138- tokenizer = model ,
254+ tokenizer = model_name ,
139255 tokenizer_mode = "auto" ,
140256 trust_remote_code = True ,
141257 dtype = dtype ,
142258 seed = 42 )
143259
144260 sequence_parallelism_pass = SequenceParallelismPass (vllm_config )
145- backend_no_func = TestBackend ( sequence_parallelism_pass )
261+ noop_pass = NoOpEliminationPass ( vllm_config )
146262 func_pass = FixFunctionalizationPass (vllm_config )
147- backend_func = TestBackend (sequence_parallelism_pass , func_pass )
148263
149- model = TestModel (hidden_size , hidden_size * 2 )
264+ passes_for_backend = [noop_pass , sequence_parallelism_pass ]
265+
266+ if enable_fusion :
267+ fusion_pass = FusionPass .instance (vllm_config )
268+ passes_for_backend .append (fusion_pass )
269+
270+ backend_no_func = TestBackend (* passes_for_backend )
271+ backend_func = TestBackend (* passes_for_backend , func_pass )
272+
273+ model = test_model_cls (hidden_size ,
274+ hidden_size * 2 ,
275+ vllm_config = vllm_config )
276+
150277 hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
151278 dtype = dtype )
152279 residual = torch .randn ((batch_size * seq_len , hidden_size ), dtype = dtype )
0 commit comments