Skip to content

Commit 188d098

Browse files
committed
update: implement forward for HiddenAct
1 parent 9581226 commit 188d098

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

backends/candle/src/layers/linear.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ pub enum HiddenAct {
1111
Swiglu,
1212
}
1313

14+
impl HiddenAct {
15+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
16+
match self {
17+
Self::Gelu => x.gelu(),
18+
Self::Relu => x.relu(),
19+
Self::Swiglu => candle_nn::ops::swiglu(x),
20+
}
21+
}
22+
}
23+
1424
#[derive(Debug)]
1525
pub struct Linear {
1626
weight: Tensor,

backends/candle/src/models/nomic.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,12 @@ impl NomicBertGatedMLP {
108108
let intermediate_size = config.n_inner;
109109
let activation = config.activation_function.clone();
110110

111-
let fc11_weight = vb.pp("fc11").get((intermediate_size, config.n_embd), "weight")?;
112-
let fc12_weight = vb.pp("fc12").get((intermediate_size, config.n_embd), "weight")?;
111+
let fc11_weight = vb
112+
.pp("fc11")
113+
.get((intermediate_size, config.n_embd), "weight")?;
114+
let fc12_weight = vb
115+
.pp("fc12")
116+
.get((intermediate_size, config.n_embd), "weight")?;
113117
let fc1_weight = Tensor::cat(&[fc11_weight, fc12_weight], 0)?;
114118

115119
let fc1_bias = if config.mlp_fc1_bias {
@@ -149,11 +153,7 @@ impl NomicBertGatedMLP {
149153
let y = hidden_states.narrow(2, 0, self.intermediate_size)?;
150154
let gate = hidden_states.narrow(2, self.intermediate_size, self.intermediate_size)?;
151155

152-
let activated_gate = match self.activation {
153-
HiddenAct::Gelu => gate.gelu()?,
154-
HiddenAct::Swiglu => gate.silu()?,
155-
_ => panic!(),
156-
};
156+
let activated_gate = self.activation.forward(&gate)?;
157157
let y = y.mul(&activated_gate)?;
158158

159159
self.fc2.forward(&y)
@@ -284,12 +284,7 @@ impl NomicExpertMLP {
284284
let expert_w2 = self.w2.narrow(0, expert_idx, 1)?.squeeze(0)?;
285285

286286
let hidden_states = hidden_states.broadcast_matmul(&expert_w1)?;
287-
288-
let hidden_states = match self.activation {
289-
HiddenAct::Gelu => hidden_states.gelu()?,
290-
HiddenAct::Swiglu => hidden_states.silu()?,
291-
_ => panic!(),
292-
};
287+
let hidden_states = self.activation.forward(&hidden_states)?;
293288

294289
hidden_states.broadcast_matmul(&expert_w2)
295290
}

0 commit comments

Comments
 (0)