@@ -108,8 +108,12 @@ impl NomicBertGatedMLP {
108108 let intermediate_size = config. n_inner ;
109109 let activation = config. activation_function . clone ( ) ;
110110
111- let fc11_weight = vb. pp ( "fc11" ) . get ( ( intermediate_size, config. n_embd ) , "weight" ) ?;
112- let fc12_weight = vb. pp ( "fc12" ) . get ( ( intermediate_size, config. n_embd ) , "weight" ) ?;
111+ let fc11_weight = vb
112+ . pp ( "fc11" )
113+ . get ( ( intermediate_size, config. n_embd ) , "weight" ) ?;
114+ let fc12_weight = vb
115+ . pp ( "fc12" )
116+ . get ( ( intermediate_size, config. n_embd ) , "weight" ) ?;
113117 let fc1_weight = Tensor :: cat ( & [ fc11_weight, fc12_weight] , 0 ) ?;
114118
115119 let fc1_bias = if config. mlp_fc1_bias {
@@ -149,11 +153,7 @@ impl NomicBertGatedMLP {
149153 let y = hidden_states. narrow ( 2 , 0 , self . intermediate_size ) ?;
150154 let gate = hidden_states. narrow ( 2 , self . intermediate_size , self . intermediate_size ) ?;
151155
152- let activated_gate = match self . activation {
153- HiddenAct :: Gelu => gate. gelu ( ) ?,
154- HiddenAct :: Swiglu => gate. silu ( ) ?,
155- _ => panic ! ( ) ,
156- } ;
156+ let activated_gate = self . activation . forward ( & gate) ?;
157157 let y = y. mul ( & activated_gate) ?;
158158
159159 self . fc2 . forward ( & y)
@@ -284,12 +284,7 @@ impl NomicExpertMLP {
284284 let expert_w2 = self . w2 . narrow ( 0 , expert_idx, 1 ) ?. squeeze ( 0 ) ?;
285285
286286 let hidden_states = hidden_states. broadcast_matmul ( & expert_w1) ?;
287-
288- let hidden_states = match self . activation {
289- HiddenAct :: Gelu => hidden_states. gelu ( ) ?,
290- HiddenAct :: Swiglu => hidden_states. silu ( ) ?,
291- _ => panic ! ( ) ,
292- } ;
287+ let hidden_states = self . activation . forward ( & hidden_states) ?;
293288
294289 hidden_states. broadcast_matmul ( & expert_w2)
295290 }
0 commit comments