2222 multi_gpu_test )
2323from .backend import TestBackend
2424
25+ FP8_DTYPE = current_platform .fp8_dtype ()
26+
2527prompts = [
2628 "Hello, my name is" ,
2729 "The president of the United States is" ,
3234
3335class TestMMRSModel (torch .nn .Module ):
3436
35- def __init__ (self , hidden_size = 16 ):
37+ def __init__ (self , hidden_size = 16 , dtype = torch . float16 ):
3638 super ().__init__ ()
3739 self .hidden_size = hidden_size
40+ self .dtype = dtype
3841 self .gate_proj = torch .nn .Parameter (torch .empty (
3942 (self .hidden_size * 2 , hidden_size )),
4043 requires_grad = False )
@@ -64,9 +67,10 @@ def ops_in_model_after(self):
6467
6568class TestAGMMModel (torch .nn .Module ):
6669
67- def __init__ (self , hidden_size = 16 ):
70+ def __init__ (self , hidden_size = 16 , dtype = torch . float16 ):
6871 super ().__init__ ()
6972 self .hidden_size = hidden_size
73+ self .dtype = dtype
7074 self .weight = torch .nn .Parameter (torch .empty (
7175 (hidden_size , hidden_size )),
7276 requires_grad = False )
@@ -91,8 +95,125 @@ def ops_in_model_after(self):
9195 return [torch .ops .symm_mem .fused_all_gather_matmul .default ]
9296
9397
98+ class _BaseScaledMMModel (torch .nn .Module ):
99+
100+ def __init__ (self , hidden_size = 16 , dtype = torch .float16 ):
101+ super ().__init__ ()
102+ self .hidden_size = hidden_size
103+ self .dtype = dtype
104+ self .weight = torch .empty ([hidden_size , hidden_size ], dtype = FP8_DTYPE )\
105+ .contiguous ().transpose (0 , 1 )
106+
107+ # Initialize scale_b for _scaled_mm.
108+ self .scale_b = torch .ones (1 , self .hidden_size , dtype = torch .float32 )
109+
110+
111+ class TestScaledMMRSModel (_BaseScaledMMModel ):
112+
113+ def forward (self , input : torch .Tensor ):
114+ """
115+ Forward pass implementing the scaled_mm + reduce scatter in the FX graph
116+
117+ """
118+ fp8_input = input .to (FP8_DTYPE )
119+ scale_a = torch .ones (input .shape [0 ], 1 , dtype = torch .float32 )
120+ scaled_mm = torch ._scaled_mm (fp8_input ,
121+ self .weight ,
122+ scale_a = scale_a ,
123+ scale_b = self .scale_b ,
124+ out_dtype = self .dtype )
125+ reduce_scatter = tensor_model_parallel_reduce_scatter (scaled_mm , dim = 0 )
126+ return reduce_scatter
127+
128+ def ops_in_model_before (self ):
129+ return [torch .ops .vllm .reduce_scatter .default ]
130+
131+ def ops_in_model_after (self ):
132+ return [torch .ops .symm_mem .fused_scaled_matmul_reduce_scatter .default ]
133+
134+
135+ class TestAGScaledMMModel (_BaseScaledMMModel ):
136+
137+ def forward (self , input : torch .Tensor ):
138+ """
139+ Forward pass implementing the all gather + scaled_mm in the FX graph
140+ """
141+ # Reshape input
142+ fp8_input = input .to (FP8_DTYPE )
143+ all_gather = tensor_model_parallel_all_gather (fp8_input , dim = 0 )
144+
145+ scale_a = torch .ones (all_gather .shape [0 ], 1 , dtype = torch .float32 )
146+ scaled_mm = torch ._scaled_mm (all_gather ,
147+ self .weight ,
148+ scale_a = scale_a ,
149+ scale_b = self .scale_b ,
150+ out_dtype = self .dtype )
151+ return scaled_mm
152+
153+ def ops_in_model_before (self ):
154+ return [torch .ops .vllm .all_gather .default ]
155+
156+ def ops_in_model_after (self ):
157+ return [torch .ops .symm_mem .fused_all_gather_scaled_matmul .default ]
158+
159+
160+ class TestCutlassScaledMMRSModel (_BaseScaledMMModel ):
161+
162+ def forward (self , input : torch .Tensor ):
163+ """
164+ Forward pass implementing the cutlass_scaled_mm + reduce scatter
165+ in the FX graph
166+
167+ """
168+ fp8_input = input .to (FP8_DTYPE )
169+ scale_a = torch .ones (input .shape [0 ], 1 , dtype = torch .float32 )
170+ mm_out = torch .empty ((fp8_input .shape [0 ], self .weight .shape [1 ]),
171+ dtype = self .dtype ,
172+ device = input .device )
173+ torch .ops ._C .cutlass_scaled_mm (mm_out , fp8_input , self .weight , scale_a ,
174+ self .scale_b , None )
175+ reduce_scatter = tensor_model_parallel_reduce_scatter (mm_out , dim = 0 )
176+ return reduce_scatter
177+
178+ def ops_in_model_before (self ):
179+ return [torch .ops .vllm .reduce_scatter .default ]
180+
181+ def ops_in_model_after (self ):
182+ return [torch .ops .symm_mem .fused_scaled_matmul_reduce_scatter .default ]
183+
184+
185+ class TestAGCutlassScaledMMModel (_BaseScaledMMModel ):
186+
187+ def forward (self , input : torch .Tensor ):
188+ """
189+ Forward pass implementing the all gather + cutlass_scaled_mm
190+ in the FX graph
191+ """
192+ # Reshape input
193+ fp8_input = input .to (FP8_DTYPE )
194+ all_gather = tensor_model_parallel_all_gather (fp8_input , dim = 0 )
195+
196+ scale_a = torch .ones (all_gather .shape [0 ], 1 , dtype = torch .float32 )
197+
198+ mm_out = torch .empty ((all_gather .shape [0 ], self .weight .shape [1 ]),
199+ dtype = self .dtype ,
200+ device = all_gather .device )
201+ torch .ops ._C .cutlass_scaled_mm (mm_out , all_gather , self .weight ,
202+ scale_a , self .scale_b , None )
203+ return mm_out
204+
205+ def ops_in_model_before (self ):
206+ return [torch .ops .vllm .all_gather .default ]
207+
208+ def ops_in_model_after (self ):
209+ return [torch .ops .symm_mem .fused_all_gather_scaled_matmul .default ]
210+
211+
94212@multi_gpu_test (num_gpus = 2 )
95- @pytest .mark .parametrize ("test_model" , [TestMMRSModel , TestAGMMModel ])
213+ @pytest .mark .parametrize ("test_model" , [
214+ TestMMRSModel , TestAGMMModel , TestScaledMMRSModel , TestAGScaledMMModel ,
215+ TestCutlassScaledMMRSModel , TestAGCutlassScaledMMModel
216+ ])
96217@pytest .mark .parametrize ("batch_size" , [8 ])
97218@pytest .mark .parametrize ("seq_len" , [16 ])
98219@pytest .mark .parametrize ("hidden_size" , [16 ])
@@ -101,6 +222,14 @@ def ops_in_model_after(self):
101222 reason = "Only test on CUDA" )
102223def test_async_tp_pass_replace (test_model : str , batch_size : int , seq_len : int ,
103224 hidden_size : int , dtype : torch .dtype ):
225+ if test_model in (TestScaledMMRSModel , TestAGScaledMMModel ,
226+ TestCutlassScaledMMRSModel ,
227+ TestAGCutlassScaledMMModel ) and dtype == torch .float16 :
228+ pytest .skip (
229+ "Only bf16 high precision output types are supported for " \
230+ "per-token (row-wise) scaling"
231+ )
232+
104233 num_processes = 2
105234
106235 def run_torch_spawn (fn , nprocs ):
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
155284 async_tp_pass = AsyncTPPass (vllm_config )
156285 backend = TestBackend (async_tp_pass )
157286
158- model = test_model_cls (hidden_size )
287+ model = test_model_cls (hidden_size ,
288+ dtype ) # Pass dtype to model constructor
159289
160290 hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
161291 dtype = dtype ,
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
174304
175305
176306@create_new_process_for_each_test ()
177- @pytest .mark .parametrize ("model_id" , ["meta-llama/Llama-3.2-1B-Instruct" ])
307+ @pytest .mark .parametrize ("model_id" , [
308+ "meta-llama/Llama-3.2-1B-Instruct" ,
309+ "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
310+ ])
178311@pytest .mark .parametrize ("tp_size" , [2 ])
179312@pytest .mark .parametrize ("async_tp_enabled" , [True ])
180313@pytest .mark .parametrize ("distributed_backend" , ["mp" ])
0 commit comments