Skip to content

Commit 095277c

Browse files
committed
Simplify matcher utils by using RMSNorm.forward_static
Signed-off-by: Luka Govedič <[email protected]>
1 parent c3264d8 commit 095277c

File tree

2 files changed

+8
-33
lines changed

2 files changed

+8
-33
lines changed

vllm/compilation/matcher_utils.py

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ def inputs(self) -> list[torch.Tensor]:
6565
class MatcherRMSNorm(MatcherCustomOp):
6666
def __init__(self, epsilon: float, enabled: Optional[bool] = None):
6767
if enabled is None:
68-
# TODO either pass config to enabled or set it globally
69-
# (global during pass init seems reasonable)
7068
enabled = RMSNorm.enabled()
7169

7270
super().__init__(enabled)
@@ -83,7 +81,6 @@ def forward_custom(
8381
self,
8482
input: torch.Tensor,
8583
weight: torch.Tensor,
86-
residual: Optional[torch.Tensor] = None,
8784
) -> torch.Tensor:
8885
result = torch.empty_like(input)
8986
_, result = auto_functionalized(
@@ -100,28 +97,15 @@ def forward_native(
10097
self,
10198
input: torch.Tensor,
10299
weight: torch.Tensor,
103-
residual: Optional[torch.Tensor] = None,
104100
) -> torch.Tensor:
105-
x = input.to(torch.float32)
106-
if residual is not None:
107-
x = x + residual
108-
residual = x.to(self.model_dtype)
109-
110-
variance = x.pow(2).mean(dim=-1, keepdim=True)
111-
112-
x = x * torch.rsqrt(variance + self.epsilon)
113-
x = x.to(self.model_dtype)
114-
if weight is not None:
115-
x = x * weight
116-
117-
return x if residual is None else (x, residual)
101+
return RMSNorm.forward_static(
102+
input, self.epsilon, input.size(-1), self.model_dtype, weight
103+
)
118104

119105

120106
class MatcherFusedAddRMSNorm(MatcherCustomOp):
121107
def __init__(self, epsilon: float, enabled: Optional[bool] = None):
122108
if enabled is None:
123-
# TODO either pass config to enabled or set it globally
124-
# (global during pass init seems reasonable)
125109
enabled = RMSNorm.enabled()
126110

127111
super().__init__(enabled)
@@ -157,19 +141,9 @@ def forward_native(
157141
weight: torch.Tensor,
158142
residual: torch.Tensor,
159143
) -> tuple[torch.Tensor, torch.Tensor]:
160-
x = input.to(torch.float32)
161-
if residual is not None:
162-
x = x + residual
163-
residual = x.to(self.model_dtype)
164-
165-
variance = x.pow(2).mean(dim=-1, keepdim=True)
166-
167-
x = x * torch.rsqrt(variance + self.epsilon)
168-
x = x.to(self.model_dtype)
169-
if weight is not None:
170-
x = x * weight
171-
172-
return x if residual is None else (x, residual)
144+
return RMSNorm.forward_static(
145+
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
146+
)
173147

174148

175149
class MatcherQuant:

vllm/model_executor/layers/layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,12 @@ def forward_static(
187187
x: torch.Tensor,
188188
variance_epsilon: float,
189189
hidden_size: int,
190+
orig_dtype: torch.dtype,
190191
weight: Optional[torch.Tensor] = None,
191192
residual: Optional[torch.Tensor] = None,
192193
variance_size_override: Optional[int] = None,
193194
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
194195
"""PyTorch-native implementation equivalent to forward()."""
195-
orig_dtype = x.dtype
196196
x = x.to(torch.float32)
197197
if residual is not None:
198198
# residual promoted f16->f32 automatically,
@@ -239,6 +239,7 @@ def forward_native(
239239
x,
240240
self.variance_epsilon,
241241
self.hidden_size,
242+
x.dtype,
242243
self.weight.data if self.has_weight else None,
243244
residual,
244245
self.variance_size_override,

0 commit comments

Comments
 (0)