@@ -65,8 +65,6 @@ def inputs(self) -> list[torch.Tensor]:
6565class  MatcherRMSNorm (MatcherCustomOp ):
6666    def  __init__ (self , epsilon : float , enabled : Optional [bool ] =  None ):
6767        if  enabled  is  None :
68-             # TODO either pass config to enabled or set it globally 
69-             #  (global during pass init seems reasonable) 
7068            enabled  =  RMSNorm .enabled ()
7169
7270        super ().__init__ (enabled )
@@ -83,7 +81,6 @@ def forward_custom(
8381        self ,
8482        input : torch .Tensor ,
8583        weight : torch .Tensor ,
86-         residual : Optional [torch .Tensor ] =  None ,
8784    ) ->  torch .Tensor :
8885        result  =  torch .empty_like (input )
8986        _ , result  =  auto_functionalized (
@@ -100,28 +97,15 @@ def forward_native(
10097        self ,
10198        input : torch .Tensor ,
10299        weight : torch .Tensor ,
103-         residual : Optional [torch .Tensor ] =  None ,
104100    ) ->  torch .Tensor :
105-         x  =  input .to (torch .float32 )
106-         if  residual  is  not None :
107-             x  =  x  +  residual 
108-             residual  =  x .to (self .model_dtype )
109- 
110-         variance  =  x .pow (2 ).mean (dim = - 1 , keepdim = True )
111- 
112-         x  =  x  *  torch .rsqrt (variance  +  self .epsilon )
113-         x  =  x .to (self .model_dtype )
114-         if  weight  is  not None :
115-             x  =  x  *  weight 
116- 
117-         return  x  if  residual  is  None  else  (x , residual )
101+         return  RMSNorm .forward_static (
102+             input , self .epsilon , input .size (- 1 ), self .model_dtype , weight 
103+         )
118104
119105
120106class  MatcherFusedAddRMSNorm (MatcherCustomOp ):
121107    def  __init__ (self , epsilon : float , enabled : Optional [bool ] =  None ):
122108        if  enabled  is  None :
123-             # TODO either pass config to enabled or set it globally 
124-             #  (global during pass init seems reasonable) 
125109            enabled  =  RMSNorm .enabled ()
126110
127111        super ().__init__ (enabled )
@@ -157,19 +141,9 @@ def forward_native(
157141        weight : torch .Tensor ,
158142        residual : torch .Tensor ,
159143    ) ->  tuple [torch .Tensor , torch .Tensor ]:
160-         x  =  input .to (torch .float32 )
161-         if  residual  is  not None :
162-             x  =  x  +  residual 
163-             residual  =  x .to (self .model_dtype )
164- 
165-         variance  =  x .pow (2 ).mean (dim = - 1 , keepdim = True )
166- 
167-         x  =  x  *  torch .rsqrt (variance  +  self .epsilon )
168-         x  =  x .to (self .model_dtype )
169-         if  weight  is  not None :
170-             x  =  x  *  weight 
171- 
172-         return  x  if  residual  is  None  else  (x , residual )
144+         return  RMSNorm .forward_static (
145+             input , self .epsilon , input .size (- 1 ), self .model_dtype , weight , residual 
146+         )
173147
174148
175149class  MatcherQuant :
0 commit comments