File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -77,7 +77,7 @@ def forward_native(
7777 input_dtype = x .dtype
7878 x = x * nn .functional .silu (gate .to (torch .float32 ))
7979 if not self .use_rms_norm :
80- return x
80+ return x . to ( input_dtype )
8181
8282 if self .n_groups == 1 :
8383 if self .tp_size > 1 :
@@ -117,9 +117,11 @@ def forward_cuda(
117117 x : torch .Tensor ,
118118 gate : torch .Tensor ,
119119 ) -> Union [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
120-
120+ input_dtype = x . dtype
121121 if not self .use_rms_norm :
122- return x * nn .functional .silu (gate .to (torch .float32 ))
122+ # Keep gate in float32 for numerical stability during silu
123+ return x * nn .functional .silu (gate .to (
124+ torch .float32 )).to (input_dtype )
123125
124126 if self .tp_size > 1 or self .n_groups != 1 :
125127 return self .forward_native (x , gate )
Original file line number Diff line number Diff line change @@ -453,7 +453,6 @@ def forward(
453453 attn_metadata = get_forward_context ().attn_metadata
454454 mamba2_metadata = prepare_mamba2_metadata (
455455 chunk_size = self .config .mamba_chunk_size ,
456- input_ids = input_ids ,
457456 attn_metadata = attn_metadata ,
458457 )
459458 if get_pp_group ().is_first_rank :
You can’t perform that action at this time.
0 commit comments