2424from vllm .platforms import current_platform
2525
2626from .inductor_pass import enable_fake_mode
27+ from .matcher_utils import MatcherQuant , MatcherRMSNorm
2728from .vllm_inductor_pass import VllmInductorPass , VllmPatternMatcherPass
2829
2930logger = init_logger (__name__ )
@@ -99,6 +100,9 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey):
99100 assert key in FUSED_OPS , f"unsupported fused rmsnorm+quant op for { key } "
100101 self .FUSED_OP = FUSED_OPS [key ]
101102
103+ self .rmsnorm_matcher = MatcherRMSNorm (epsilon )
104+ self .quant_matcher = MatcherQuant (key .quant )
105+
102106
103107class RMSNormStaticQuantPattern (RMSNormQuantPattern ):
104108 def __init__ (self , epsilon : float , quant_dtype : torch .dtype , symmetric = True ):
@@ -113,25 +117,8 @@ def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
113117 def register (self , pm_pass : PatternMatcherPass ):
114118 # Cannot use methods, as the self argument affects tracing
115119 def pattern (input : torch .Tensor , weight : torch .Tensor , scale : torch .Tensor ):
116- result_rms = torch .empty_like (input )
117- # TODO: why does empty_like produce a permute but
118- # empty via shape doesn't?
119- result = torch .empty (
120- input .shape , device = input .device , dtype = self .quant_dtype
121- )
122- at1 = auto_functionalized (
123- RMS_OP ,
124- result = result_rms ,
125- input = input ,
126- weight = weight ,
127- epsilon = self .epsilon ,
128- )
129- at2 = auto_functionalized (
130- self .QUANT_OP , result = result , input = at1 [1 ], scale = scale
131- )
132-
133- # result
134- return at2 [1 ]
120+ result_rms = self .rmsnorm_matcher (input , weight )
121+ return self .quant_matcher (result_rms , scale )
135122
136123 def replacement (input : torch .Tensor , weight : torch .Tensor , scale : torch .Tensor ):
137124 result = torch .empty_like (input , dtype = self .quant_dtype )
@@ -173,22 +160,10 @@ def pattern(
173160 weight : torch .Tensor ,
174161 scale : torch .Tensor ,
175162 ):
176- result = torch .empty (
177- input .shape , device = input .device , dtype = self .quant_dtype
178- )
179- at = auto_functionalized (
180- RMS_ADD_OP ,
181- input = input ,
182- residual = residual ,
183- weight = weight ,
184- epsilon = self .epsilon ,
185- )
186- at1 = auto_functionalized (
187- self .QUANT_OP , result = result , input = at [1 ], scale = scale
188- )
163+ result_rms , residual = self .rmsnorm_matcher (input , weight , residual )
164+ result = self .quant_matcher (result_rms , scale )
189165
190- # result, residual
191- return at1 [1 ], at [2 ]
166+ return result , residual
192167
193168 def replacement (
194169 input : torch .Tensor ,
@@ -242,27 +217,14 @@ def __init__(
242217 super ().__init__ (epsilon , key )
243218
244219 def register (self , pm_pass : PatternMatcherPass ):
245- def pattern (input : torch .Tensor , weight : torch .Tensor , scale : torch .Tensor ):
246- result_rms = torch .empty_like (input )
247- result = torch .empty (
248- input .shape , device = input .device , dtype = self .quant_dtype
249- )
250- at1 = auto_functionalized (
251- RMS_OP ,
252- result = result_rms ,
253- input = input ,
254- weight = weight ,
255- epsilon = self .epsilon ,
256- )
257- at2 = auto_functionalized (
258- self .QUANT_OP , result = result , input = at1 [1 ], scale = scale , scale_ub = None
259- )
260-
220+ def pattern (input : torch .Tensor , weight : torch .Tensor ):
221+ result_rms = self .rmsnorm_matcher (input , weight )
261222 # result, scale
262- return at2 [ 1 ], at2 [ 2 ]
223+ return self . quant_matcher ( result_rms )
263224
264- def replacement (input : torch .Tensor , weight : torch .Tensor , scale : torch . Tensor ):
225+ def replacement (input : torch .Tensor , weight : torch .Tensor ):
265226 result = torch .empty_like (input , dtype = self .quant_dtype )
227+ scale = self .quant_matcher .make_scale (input )
266228 at = auto_functionalized (
267229 self .FUSED_OP ,
268230 result = result ,
@@ -280,7 +242,6 @@ def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
280242 inputs = [
281243 empty_bf16 (5 , 4 ), # input
282244 empty_bf16 (1 , 5 ), # weight
283- empty_fp32 (1 , 1 ), # scale
284245 ]
285246
286247 pm .register_replacement (
@@ -308,36 +269,17 @@ def __init__(
308269 super ().__init__ (epsilon , key )
309270
310271 def register (self , pm_pass : PatternMatcherPass ):
311- def pattern (
312- input : torch .Tensor ,
313- residual : torch .Tensor ,
314- weight : torch .Tensor ,
315- scale : torch .Tensor ,
316- ):
317- result = torch .empty (
318- input .shape , device = input .device , dtype = self .quant_dtype
319- )
320- at = auto_functionalized (
321- RMS_ADD_OP ,
322- input = input ,
323- residual = residual ,
324- weight = weight ,
325- epsilon = self .epsilon ,
326- )
327- at1 = auto_functionalized (
328- self .QUANT_OP , result = result , input = at [1 ], scale = scale , scale_ub = None
329- )
272+ def pattern (input : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor ):
273+ result_rms , residual = self .rmsnorm_matcher (input , weight , residual )
274+ result , scale = self .quant_matcher (result_rms )
330275
331- # result, residual, scale
332- return at1 [1 ], at [2 ], at1 [2 ]
276+ return result , residual , scale
333277
334278 def replacement (
335- input : torch .Tensor ,
336- residual : torch .Tensor ,
337- weight : torch .Tensor ,
338- scale : torch .Tensor ,
279+ input : torch .Tensor , residual : torch .Tensor , weight : torch .Tensor
339280 ):
340281 result = torch .empty_like (input , dtype = self .quant_dtype )
282+ scale = self .quant_matcher .make_scale (input )
341283 at = auto_functionalized (
342284 self .FUSED_OP ,
343285 result = result ,
@@ -356,7 +298,6 @@ def replacement(
356298 empty_bf16 (5 , 4 ), # input
357299 empty_bf16 (5 , 4 ), # residual
358300 empty_bf16 (1 , 5 ), # weight
359- empty_fp32 (1 , 1 ), # scale
360301 ]
361302
362303 pm .register_replacement (
0 commit comments