Skip to content

Commit 20bd6f4

Browse files
dhiaEddineRhaiemyounesbelkadailyasch2JingweiZuo
authored
[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)
Signed-off-by: dhia.rhaiem <[email protected]> Co-authored-by: younesbelkada <[email protected]> Co-authored-by: Ilyas Chahed <[email protected]> Co-authored-by: Jingwei Zuo <[email protected]>
1 parent 1f07954 commit 20bd6f4

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def forward_native(
7777
input_dtype = x.dtype
7878
x = x * nn.functional.silu(gate.to(torch.float32))
7979
if not self.use_rms_norm:
80-
return x
80+
return x.to(input_dtype)
8181

8282
if self.n_groups == 1:
8383
if self.tp_size > 1:
@@ -117,9 +117,11 @@ def forward_cuda(
117117
x: torch.Tensor,
118118
gate: torch.Tensor,
119119
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
120-
120+
input_dtype = x.dtype
121121
if not self.use_rms_norm:
122-
return x * nn.functional.silu(gate.to(torch.float32))
122+
# Keep gate in float32 for numerical stability during silu
123+
return x * nn.functional.silu(gate.to(
124+
torch.float32)).to(input_dtype)
123125

124126
if self.tp_size > 1 or self.n_groups != 1:
125127
return self.forward_native(x, gate)

vllm/model_executor/models/falcon_h1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,6 @@ def forward(
453453
attn_metadata = get_forward_context().attn_metadata
454454
mamba2_metadata = prepare_mamba2_metadata(
455455
chunk_size=self.config.mamba_chunk_size,
456-
input_ids=input_ids,
457456
attn_metadata=attn_metadata,
458457
)
459458
if get_pp_group().is_first_rank:

0 commit comments

Comments
 (0)