From 3d290cefd3a6e520d6729a4f024969a7e0688ca9 Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 8 Sep 2025 21:07:35 +0900 Subject: [PATCH 1/8] build(deps): candle-moe --- Cargo.lock | 13 +++++++++++++ Cargo.toml | 9 +++++---- backends/candle/Cargo.toml | 3 ++- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1e02ee05..66531470 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -595,6 +595,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "candle-moe" +version = "0.0.1" +source = "git+https://github.com/kozistr/candle-moe?rev=063e5dc2b905e83e009781c6bce35eb8b144265f#063e5dc2b905e83e009781c6bce35eb8b144265f" +dependencies = [ + "anyhow", + "bindgen_cuda", + "candle-core", + "cudarc", + "half", +] + [[package]] name = "candle-nn" version = "0.8.4" @@ -4502,6 +4514,7 @@ dependencies = [ "candle-flash-attn", "candle-flash-attn-v1", "candle-layer-norm", + "candle-moe", "candle-nn", "candle-rotary", "candle-transformers", diff --git a/Cargo.toml b/Cargo.toml index dbdb14e0..ce105980 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,20 +42,21 @@ serde_json = "1.0" thiserror = "1.0" rand = "0.9" serial_test = "2.0.0" -cudarc = { version = "0.13" , features =["cuda-12020"], default-features = false} +cudarc = { version = "0.13" , features = ["cuda-12020"], default-features = false } intel-mkl-src = { version = "0.8"} candle = { version = "0.8", package = "candle-core" } -candle-nn = { version = "0.8" } +candle-nn = { version = "0.8" } candle-transformers = { version = "0.8" } candle-flash-attn = { version = "0.8" } -candle-cublaslt= { version = "0.0.1" } +candle-cublaslt = { version = "0.0.1" } candle-layer-norm = { version = "0.0.1" } candle-rotary = { version = "0.0.1" } candle-flash-attn-v1 = { version = "0.0.1" } +candle-moe = { git = "https://github.com/kozistr/candle-moe", rev = "063e5dc2b905e83e009781c6bce35eb8b144265f" } half = { version = "2.3.1", features = ["num-traits"] } [patch.crates-io] -cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9"} +cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9" } candle = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-core" } candle-nn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-nn" } candle-transformers = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-transformers" } diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index a2c3e469..58659d63 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -17,6 +17,7 @@ candle-flash-attn-v1 = { workspace = true, optional = true } candle-cublaslt = { workspace = true, optional = true } candle-layer-norm = { workspace = true, optional = true } candle-rotary = { workspace = true, optional = true } +candle-moe = { workspace = true, optional = true } nohash-hasher = { workspace = true } text-embeddings-backend-core = { path = "../core" } tracing = { workspace = true } @@ -41,6 +42,6 @@ anyhow = { version = "1", features = ["backtrace"] } accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] metal = ["candle/metal", "candle-nn/metal"] mkl = ["dep:intel-mkl-src", "candle/_mkl"] -cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"] +cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary", "dep:candle-moe"] flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"] flash-attn = ["dep:candle-flash-attn", "cuda"] From ed96f76fd7de13938ddd84de25b1ca15752756c7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 8 Sep 2025 21:07:57 +0900 Subject: [PATCH 2/8] feature: fused moe kernel --- backends/candle/src/models/flash_nomic.rs | 116 +++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 32cd31b6..d73ae972 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,8 +1,9 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; -use crate::models::nomic::{NomicBertEmbeddings, NomicMLP}; +use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP, NomicBertMLP}; use crate::models::{Model, NomicConfig}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_moe::{apply_topk_softmax_inplace, FusedMoeForward}; use candle_nn::VarBuilder; use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -100,6 +101,119 @@ impl NomicAttention { } } +struct NomicFusedMoELayer { + layer: Linear, + gate_weight: Tensor, + up_weight: Tensor, + bias: Tensor, + fused_moe: FusedMoeForward, + top_k: usize, + + span: tracing::Span, +} + +impl NomicFusedMoELayer { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let hidden_size = config.n_embd; + let ffn_hidden_size = config.n_inner; + let moe_num_experts = config.num_experts.unwrap(); + let top_k = config.moe_top_k.unwrap(); + let activation = config.activation_function.clone(); + + let layer_weight = vb + .pp("router.layer") + .get((moe_num_experts, hidden_size), "weight")?; + let layer = Linear::new(layer_weight, None, None); + + let gate_weight = vb + .pp("experts.mlp.w1") + .get((moe_num_experts * ffn_hidden_size, hidden_size), "w1")? + .reshape((moe_num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))?; + let up_weight = vb + .pp("experts.mlp.w2") + .get((moe_num_experts * ffn_hidden_size, hidden_size), "w2")? + .reshape((moe_num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))?; + let bias = vb.pp("experts.bias").get((hidden_size,), "bias")?; + + let fused_moe = FusedMoeForward::new(moe_num_experts, top_k, candle_moe::Activation::Silu); + + Ok(Self { + layer, + gate_weight, + up_weight, + bias, + fused_moe, + top_k, + span: tracing::span!(tracing::Level::TRACE, "fused_moe"), + }) + } + + fn forward_router(&self, hidden_states: &Tensor) -> Result<(Tensor, Tensor)> { + let device = hidden_states.device(); + + let weights = hidden_states.reshape(((), hidden_states.dim(D::Minus1)?))?; + let weights = self.layer.forward(&weights)?.to_dtype(DType::F32)?; + + let (seq_len, _) = weights.shape().dims2()?; + + let topk_weight = Tensor::zeros((seq_len, self.top_k), DType::F32, device)?; + let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; + let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; + + apply_topk_softmax_inplace(weights, &topk_weight, &topk_indices, &token_expert_indices)?; + + Ok((topk_weight, topk_indices)) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let (scores, indices) = self.forward_router(hidden_states)?; + + let out = self.fused_moe.forward( + &hidden_states, + &self.gate_weight, + &self.up_weight, + None, + &scores, + &indices, + 1_u32, + )?; + + out.broadcast_add(&self.bias) + } +} + +pub enum NomicMLP { + MoE(NomicFusedMoELayer), + GatedMLP(NomicBertGatedMLP), + Mlp(NomicBertMLP), +} + +impl NomicMLP { + pub fn load(vb: VarBuilder, index: usize, config: &NomicConfig) -> Result { + let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); + + if use_moe { + Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + } else if config.activation_function == HiddenAct::Gelu { + Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) + } else { + Ok(Self::GatedMLP(NomicBertGatedMLP::load(vb, config)?)) + } + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + match self { + Self::MoE(layer) => layer.forward(hidden_states), + Self::GatedMLP(layer) => layer.forward(hidden_states), + Self::Mlp(layer) => layer.forward(hidden_states), + } + } +} + struct NomicBertBlock { attention: NomicAttention, mlp: NomicMLP, From e5053507ae116dbb8ac8d8e68af409f45c1a9baa Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 8 Sep 2025 14:22:50 +0000 Subject: [PATCH 3/8] update: nomic --- Cargo.lock | 2 +- Cargo.toml | 2 +- backends/candle/src/models/flash_nomic.rs | 17 ++++++++++++----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 66531470..bb19c884 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -598,7 +598,7 @@ dependencies = [ [[package]] name = "candle-moe" version = "0.0.1" -source = "git+https://github.com/kozistr/candle-moe?rev=063e5dc2b905e83e009781c6bce35eb8b144265f#063e5dc2b905e83e009781c6bce35eb8b144265f" +source = "git+https://github.com/kozistr/candle-moe?rev=990ac1f42248dd441c51c9b5bcb73c5b77c03f99#990ac1f42248dd441c51c9b5bcb73c5b77c03f99" dependencies = [ "anyhow", "bindgen_cuda", diff --git a/Cargo.toml b/Cargo.toml index ce105980..c661b9d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ candle-cublaslt = { version = "0.0.1" } candle-layer-norm = { version = "0.0.1" } candle-rotary = { version = "0.0.1" } candle-flash-attn-v1 = { version = "0.0.1" } -candle-moe = { git = "https://github.com/kozistr/candle-moe", rev = "063e5dc2b905e83e009781c6bce35eb8b144265f" } +candle-moe = { git = "https://github.com/kozistr/candle-moe", rev = "990ac1f42248dd441c51c9b5bcb73c5b77c03f99" } half = { version = "2.3.1", features = ["num-traits"] } [patch.crates-io] diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index d73ae972..74cd76df 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; +use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, LayerNorm, Linear}; use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP, NomicBertMLP}; use crate::models::{Model, NomicConfig}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; @@ -101,7 +101,7 @@ impl NomicAttention { } } -struct NomicFusedMoELayer { +pub struct NomicFusedMoELayer { layer: Linear, gate_weight: Tensor, up_weight: Tensor, @@ -137,7 +137,14 @@ impl NomicFusedMoELayer { .permute((0, 2, 1))?; let bias = vb.pp("experts.bias").get((hidden_size,), "bias")?; - let fused_moe = FusedMoeForward::new(moe_num_experts, top_k, candle_moe::Activation::Silu); + let moe_act = match activation { + HiddenAct::Silu => candle_moe::Activation::Silu, + HiddenAct::Gelu => candle_moe::Activation::Gelu, + HiddenAct::Relu => candle_moe::Activation::Relu, + _ => candle::bail!("not supported activation type"), + }; + + let fused_moe = FusedMoeForward::new(moe_num_experts, top_k, moe_act); Ok(Self { layer, @@ -162,7 +169,7 @@ impl NomicFusedMoELayer { let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; - apply_topk_softmax_inplace(weights, &topk_weight, &topk_indices, &token_expert_indices)?; + apply_topk_softmax_inplace(&weights, &topk_weight, &topk_indices, &token_expert_indices)?; Ok((topk_weight, topk_indices)) } @@ -197,7 +204,7 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + Ok(Self::MoE(NomicFusedMoELayer::load(vb, config)?)) } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { From dd9be801958f402f31a9952f80b3701fc5a07441 Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 8 Sep 2025 14:28:42 +0000 Subject: [PATCH 4/8] fix: weight name --- backends/candle/src/models/flash_nomic.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 74cd76df..5c4db4f6 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -126,16 +126,16 @@ impl NomicFusedMoELayer { let layer = Linear::new(layer_weight, None, None); let gate_weight = vb - .pp("experts.mlp.w1") + .pp("experts.mlp") .get((moe_num_experts * ffn_hidden_size, hidden_size), "w1")? .reshape((moe_num_experts, ffn_hidden_size, hidden_size))? .permute((0, 2, 1))?; let up_weight = vb - .pp("experts.mlp.w2") + .pp("experts.mlp") .get((moe_num_experts * ffn_hidden_size, hidden_size), "w2")? .reshape((moe_num_experts, ffn_hidden_size, hidden_size))? .permute((0, 2, 1))?; - let bias = vb.pp("experts.bias").get((hidden_size,), "bias")?; + let bias = vb.pp("experts").get((hidden_size,), "bias")?; let moe_act = match activation { HiddenAct::Silu => candle_moe::Activation::Silu, From 7aa38f719f9b213b4d5bd944dcce6d5f843a3dff Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 8 Sep 2025 15:12:10 +0000 Subject: [PATCH 5/8] update: fused moe --- backends/candle/src/models/flash_nomic.rs | 34 +++++++++++++---------- backends/candle/src/models/nomic.rs | 16 ++++++++--- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 5c4db4f6..c027cf55 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -108,33 +108,32 @@ pub struct NomicFusedMoELayer { bias: Tensor, fused_moe: FusedMoeForward, top_k: usize, + idx: usize, span: tracing::Span, } impl NomicFusedMoELayer { - pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + pub fn load(vb: VarBuilder, config: &NomicConfig, idx: usize) -> Result { let hidden_size = config.n_embd; let ffn_hidden_size = config.n_inner; - let moe_num_experts = config.num_experts.unwrap(); + let num_experts = config.num_experts.unwrap(); let top_k = config.moe_top_k.unwrap(); let activation = config.activation_function.clone(); let layer_weight = vb .pp("router.layer") - .get((moe_num_experts, hidden_size), "weight")?; + .get((num_experts, hidden_size), "weight")?; let layer = Linear::new(layer_weight, None, None); let gate_weight = vb .pp("experts.mlp") - .get((moe_num_experts * ffn_hidden_size, hidden_size), "w1")? - .reshape((moe_num_experts, ffn_hidden_size, hidden_size))? - .permute((0, 2, 1))?; + .get((num_experts * ffn_hidden_size, hidden_size), "w1")? + .reshape((num_experts, hidden_size, ffn_hidden_size))?; let up_weight = vb .pp("experts.mlp") - .get((moe_num_experts * ffn_hidden_size, hidden_size), "w2")? - .reshape((moe_num_experts, ffn_hidden_size, hidden_size))? - .permute((0, 2, 1))?; + .get((num_experts * ffn_hidden_size, hidden_size), "w2")? + .reshape((num_experts, hidden_size, ffn_hidden_size))?; let bias = vb.pp("experts").get((hidden_size,), "bias")?; let moe_act = match activation { @@ -144,7 +143,7 @@ impl NomicFusedMoELayer { _ => candle::bail!("not supported activation type"), }; - let fused_moe = FusedMoeForward::new(moe_num_experts, top_k, moe_act); + let fused_moe = FusedMoeForward::new(num_experts, top_k, moe_act); Ok(Self { layer, @@ -153,7 +152,8 @@ impl NomicFusedMoELayer { bias, fused_moe, top_k, - span: tracing::span!(tracing::Level::TRACE, "fused_moe"), + idx, + span: tracing::span!(tracing::Level::TRACE, "moe"), }) } @@ -180,7 +180,7 @@ impl NomicFusedMoELayer { let (scores, indices) = self.forward_router(hidden_states)?; let out = self.fused_moe.forward( - &hidden_states, + hidden_states, &self.gate_weight, &self.up_weight, None, @@ -189,7 +189,13 @@ impl NomicFusedMoELayer { 1_u32, )?; - out.broadcast_add(&self.bias) + let out = out.broadcast_add(&self.bias)?; + + if self.idx == 1 { + println!("MoE: {:}", out); + } + + Ok(out) } } @@ -204,7 +210,7 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicFusedMoELayer::load(vb, config)?)) + Ok(Self::MoE(NomicFusedMoELayer::load(vb, config, index)?)) } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 8748db38..3425537e 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -366,18 +366,20 @@ impl NomicExperts { pub struct NomicMoELayer { router: NomicRouter, experts: NomicExperts, + idx: usize, span: tracing::Span, } impl NomicMoELayer { - pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + pub fn load(vb: VarBuilder, config: &NomicConfig, idx: usize) -> Result { let router = NomicRouter::load(vb.pp("router"), config)?; let experts = NomicExperts::load(vb.pp("experts"), config)?; Ok(Self { router, experts, + idx, span: tracing::span!(tracing::Level::TRACE, "moe"), }) } @@ -387,8 +389,14 @@ impl NomicMoELayer { let (top_weights, top_experts) = self.router.forward(hidden_states)?; - self.experts - .forward(hidden_states, &top_weights, &top_experts) + let out = self.experts + .forward(hidden_states, &top_weights, &top_experts)?; + + if self.idx == 1 { + println!("MoE: {:}", out); + } + + Ok(out) } } @@ -403,7 +411,7 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + Ok(Self::MoE(NomicMoELayer::load(vb, config, index)?)) } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { From d00a1a1027c512785be6002dd139a0df8c273aab Mon Sep 17 00:00:00 2001 From: kozistr Date: Mon, 8 Sep 2025 15:50:21 +0000 Subject: [PATCH 6/8] fix: experts mlp --- backends/candle/src/models/flash_nomic.rs | 22 +++++++++------------- backends/candle/src/models/nomic.rs | 16 ++++------------ 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index c027cf55..af3e2cfe 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -108,13 +108,12 @@ pub struct NomicFusedMoELayer { bias: Tensor, fused_moe: FusedMoeForward, top_k: usize, - idx: usize, span: tracing::Span, } impl NomicFusedMoELayer { - pub fn load(vb: VarBuilder, config: &NomicConfig, idx: usize) -> Result { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { let hidden_size = config.n_embd; let ffn_hidden_size = config.n_inner; let num_experts = config.num_experts.unwrap(); @@ -129,11 +128,15 @@ impl NomicFusedMoELayer { let gate_weight = vb .pp("experts.mlp") .get((num_experts * ffn_hidden_size, hidden_size), "w1")? - .reshape((num_experts, hidden_size, ffn_hidden_size))?; + .reshape((num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))? + .contiguous()?; let up_weight = vb .pp("experts.mlp") .get((num_experts * ffn_hidden_size, hidden_size), "w2")? - .reshape((num_experts, hidden_size, ffn_hidden_size))?; + .reshape((num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))? + .contiguous()?; let bias = vb.pp("experts").get((hidden_size,), "bias")?; let moe_act = match activation { @@ -152,7 +155,6 @@ impl NomicFusedMoELayer { bias, fused_moe, top_k, - idx, span: tracing::span!(tracing::Level::TRACE, "moe"), }) } @@ -189,13 +191,7 @@ impl NomicFusedMoELayer { 1_u32, )?; - let out = out.broadcast_add(&self.bias)?; - - if self.idx == 1 { - println!("MoE: {:}", out); - } - - Ok(out) + out.broadcast_add(&self.bias) } } @@ -210,7 +206,7 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicFusedMoELayer::load(vb, config, index)?)) + Ok(Self::MoE(NomicFusedMoELayer::load(vb, config)?)) } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 3425537e..8748db38 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -366,20 +366,18 @@ impl NomicExperts { pub struct NomicMoELayer { router: NomicRouter, experts: NomicExperts, - idx: usize, span: tracing::Span, } impl NomicMoELayer { - pub fn load(vb: VarBuilder, config: &NomicConfig, idx: usize) -> Result { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { let router = NomicRouter::load(vb.pp("router"), config)?; let experts = NomicExperts::load(vb.pp("experts"), config)?; Ok(Self { router, experts, - idx, span: tracing::span!(tracing::Level::TRACE, "moe"), }) } @@ -389,14 +387,8 @@ impl NomicMoELayer { let (top_weights, top_experts) = self.router.forward(hidden_states)?; - let out = self.experts - .forward(hidden_states, &top_weights, &top_experts)?; - - if self.idx == 1 { - println!("MoE: {:}", out); - } - - Ok(out) + self.experts + .forward(hidden_states, &top_weights, &top_experts) } } @@ -411,7 +403,7 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicMoELayer::load(vb, config, index)?)) + Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { From da9064faea0715361fdf80ad5e92b594bcc0dea1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 9 Sep 2025 22:06:31 +0900 Subject: [PATCH 7/8] refactor: fused moe to nomic --- backends/candle/src/models/flash_nomic.rs | 125 +-------------- backends/candle/src/models/nomic.rs | 178 +++++++++++++++++++++- 2 files changed, 178 insertions(+), 125 deletions(-) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index af3e2cfe..93c8528f 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,9 +1,8 @@ use crate::flash_attn::flash_attn_varlen; use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, LayerNorm, Linear}; -use crate::models::nomic::{NomicBertEmbeddings, NomicBertGatedMLP, NomicBertMLP}; +use crate::models::nomic::{NomicBertEmbeddings, NomicMLP}; use crate::models::{Model, NomicConfig}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_moe::{apply_topk_softmax_inplace, FusedMoeForward}; use candle_nn::VarBuilder; use candle_rotary::apply_rotary_inplace; use text_embeddings_backend_core::{Batch, ModelType, Pool}; @@ -101,128 +100,6 @@ impl NomicAttention { } } -pub struct NomicFusedMoELayer { - layer: Linear, - gate_weight: Tensor, - up_weight: Tensor, - bias: Tensor, - fused_moe: FusedMoeForward, - top_k: usize, - - span: tracing::Span, -} - -impl NomicFusedMoELayer { - pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { - let hidden_size = config.n_embd; - let ffn_hidden_size = config.n_inner; - let num_experts = config.num_experts.unwrap(); - let top_k = config.moe_top_k.unwrap(); - let activation = config.activation_function.clone(); - - let layer_weight = vb - .pp("router.layer") - .get((num_experts, hidden_size), "weight")?; - let layer = Linear::new(layer_weight, None, None); - - let gate_weight = vb - .pp("experts.mlp") - .get((num_experts * ffn_hidden_size, hidden_size), "w1")? - .reshape((num_experts, ffn_hidden_size, hidden_size))? - .permute((0, 2, 1))? - .contiguous()?; - let up_weight = vb - .pp("experts.mlp") - .get((num_experts * ffn_hidden_size, hidden_size), "w2")? - .reshape((num_experts, ffn_hidden_size, hidden_size))? - .permute((0, 2, 1))? - .contiguous()?; - let bias = vb.pp("experts").get((hidden_size,), "bias")?; - - let moe_act = match activation { - HiddenAct::Silu => candle_moe::Activation::Silu, - HiddenAct::Gelu => candle_moe::Activation::Gelu, - HiddenAct::Relu => candle_moe::Activation::Relu, - _ => candle::bail!("not supported activation type"), - }; - - let fused_moe = FusedMoeForward::new(num_experts, top_k, moe_act); - - Ok(Self { - layer, - gate_weight, - up_weight, - bias, - fused_moe, - top_k, - span: tracing::span!(tracing::Level::TRACE, "moe"), - }) - } - - fn forward_router(&self, hidden_states: &Tensor) -> Result<(Tensor, Tensor)> { - let device = hidden_states.device(); - - let weights = hidden_states.reshape(((), hidden_states.dim(D::Minus1)?))?; - let weights = self.layer.forward(&weights)?.to_dtype(DType::F32)?; - - let (seq_len, _) = weights.shape().dims2()?; - - let topk_weight = Tensor::zeros((seq_len, self.top_k), DType::F32, device)?; - let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; - let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; - - apply_topk_softmax_inplace(&weights, &topk_weight, &topk_indices, &token_expert_indices)?; - - Ok((topk_weight, topk_indices)) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let _enter = self.span.enter(); - - let (scores, indices) = self.forward_router(hidden_states)?; - - let out = self.fused_moe.forward( - hidden_states, - &self.gate_weight, - &self.up_weight, - None, - &scores, - &indices, - 1_u32, - )?; - - out.broadcast_add(&self.bias) - } -} - -pub enum NomicMLP { - MoE(NomicFusedMoELayer), - GatedMLP(NomicBertGatedMLP), - Mlp(NomicBertMLP), -} - -impl NomicMLP { - pub fn load(vb: VarBuilder, index: usize, config: &NomicConfig) -> Result { - let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); - - if use_moe { - Ok(Self::MoE(NomicFusedMoELayer::load(vb, config)?)) - } else if config.activation_function == HiddenAct::Gelu { - Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) - } else { - Ok(Self::GatedMLP(NomicBertGatedMLP::load(vb, config)?)) - } - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - match self { - Self::MoE(layer) => layer.forward(hidden_states), - Self::GatedMLP(layer) => layer.forward(hidden_states), - Self::Mlp(layer) => layer.forward(hidden_states), - } - } -} - struct NomicBertBlock { attention: NomicAttention, mlp: NomicMLP, diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 8748db38..ab01350d 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -239,6 +239,50 @@ impl NomicRouter { } } +#[cfg(feature = "cuda")] +pub struct NomicFusedRouter { + layer: Linear, + top_k: usize, + + span: tracing::Span, +} + +#[cfg(feature = "cuda")] +impl NomicFusedRouter { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let num_experts = config.num_experts.unwrap(); + let top_k = config.moe_top_k.unwrap(); + + let layer_weight = vb.pp("layer").get((num_experts, config.n_embd), "weight")?; + let layer = Linear::new(layer_weight, None, None); + + Ok(Self { + layer, + top_k, + span: tracing::span!(tracing::Level::TRACE, "router"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result<(Tensor, Tensor)> { + use candle_moe::apply_topk_softmax_inplace; + + let _enter = self.span.enter(); + + let weights = hidden_states.reshape(((), hidden_states.dim(D::Minus1)?))?; + let weights = self.layer.forward(&weights)?.to_dtype(DType::F32)?; + + let (seq_len, _) = weights.shape().dims2()?; + + let topk_weight = Tensor::zeros((seq_len, self.top_k), DType::F32, device)?; + let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; + let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; + + apply_topk_softmax_inplace(&weights, &topk_weight, &topk_indices, &token_expert_indices)?; + + Ok((topk_weight, topk_indices)) + } +} + pub struct NomicExpertMLP { w1: Tensor, w2: Tensor, @@ -363,6 +407,96 @@ impl NomicExperts { } } +#[cfg(feature = "cuda")] +pub struct NomicFusedExperts { + num_experts: usize, + mlp: NomicExpertMLP, + bias: Tensor, + + span: tracing::Span, +} + +#[cfg(feature = "cuda")] +impl NomicFusedExperts { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let hidden_size = config.n_embd; + let ffn_hidden_size = config.n_inner; + let num_experts = config.num_experts.unwrap(); + let top_k = config.moe_top_k.unwrap(); + let activation = config.activation_function.clone(); + + let gate_weight = vb + .pp("mlp") + .get((num_experts * ffn_hidden_size, hidden_size), "w1")? + .reshape((num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))? + .contiguous()?; + let up_weight = vb + .pp("mlp") + .get((num_experts * ffn_hidden_size, hidden_size), "w2")? + .reshape((num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))? + .contiguous()?; + + let bias = vb.get((config.n_embd,), "bias")?; + + use candle_moe::{Activation, FusedMoeForward}; + + let moe_act = match activation { + HiddenAct::Silu => candle_moe::Activation::Silu, + HiddenAct::Gelu => candle_moe::Activation::Gelu, + HiddenAct::Relu => candle_moe::Activation::Relu, + _ => candle::bail!("not supported activation type"), + }; + + let fused_moe = FusedMoeForward::new(num_experts, top_k, moe_act); + + Ok(Self { + gate_weight, + up_weight, + bias, + fused_moe, + span: tracing::span!(tracing::Level::TRACE, "experts"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + top_weights: &Tensor, + top_experts: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let dims = hidden_states.dims(); + let ndim = dims.len(); + + let (bs, seq_len, hidden_size) = match ndim { + 3 => (dims[0], dims[1], dims[2]), + 2 => (1, dims[0], dims[1]), + _ => unreachable!(), + }; + + let hidden_states = hidden_states.reshape(((), hidden_size))?; + + let mut out = self.fused_moe.forward( + hidden_states, + &self.gate_weight, + &self.up_weight, + None, // Nomic MoE doesn't need down projection + &top_weights, + &top_experts, + 1_u32, // Nomic MoE + )?; + + if ndim == 3 { + out = out.reshape((bs, seq_len, hidden_size))?; + } + + out.broadcast_add(&self.bias) + } +} + pub struct NomicMoELayer { router: NomicRouter, experts: NomicExperts, @@ -392,8 +526,41 @@ impl NomicMoELayer { } } +#[cfg(feature = "cuda")] +pub struct NomicFusedMoELayer { + router: NomicFusedRouter, + experts: NomicFusedExperts, + + span: tracing::Span, +} + +#[cfg(feature = "cuda")] +impl NomicFusedMoELayer { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let router = NomicFusedRouter::load(vb.pp("router"), config)?; + let experts = NomicFusedExperts::load(vb.pp("experts"), config)?; + + Ok(Self { + router, + experts, + span: tracing::span!(tracing::Level::TRACE, "moe"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let (top_weights, top_experts) = self.router.forward(hidden_states)?; + + self.experts + .forward(hidden_states, &top_weights, &top_experts) + } +} + pub enum NomicMLP { MoE(NomicMoELayer), + #[cfg(feature = "cuda")] + FusedMoE(NomicFusedMoELayer), GatedMLP(NomicBertGatedMLP), Mlp(NomicBertMLP), } @@ -403,7 +570,14 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + #[cfg(feature = "cuda")] + { + Ok(Self::FusedMoE(NomicFusedMoELayer::load(vb, config)?)) + } + #[cfg(not(feature = "cuda"))] + { + Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + } } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { @@ -414,6 +588,8 @@ impl NomicMLP { pub fn forward(&self, hidden_states: &Tensor) -> Result { match self { Self::MoE(layer) => layer.forward(hidden_states), + #[cfg(feature = "cuda")] + Self::FusedMoE(layer) => layer.forward(hidden_states), Self::GatedMLP(layer) => layer.forward(hidden_states), Self::Mlp(layer) => layer.forward(hidden_states), } From 4eeb1b7af9ec18c09bf6455d9e2711717810a953 Mon Sep 17 00:00:00 2001 From: kozistr Date: Tue, 9 Sep 2025 14:03:00 +0000 Subject: [PATCH 8/8] update: fused moe --- backends/candle/src/models/flash_nomic.rs | 2 +- backends/candle/src/models/nomic.rs | 26 ++++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/backends/candle/src/models/flash_nomic.rs b/backends/candle/src/models/flash_nomic.rs index 93c8528f..32cd31b6 100644 --- a/backends/candle/src/models/flash_nomic.rs +++ b/backends/candle/src/models/flash_nomic.rs @@ -1,5 +1,5 @@ use crate::flash_attn::flash_attn_varlen; -use crate::layers::{get_cos_sin, get_inv_freqs, HiddenAct, LayerNorm, Linear}; +use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear}; use crate::models::nomic::{NomicBertEmbeddings, NomicMLP}; use crate::models::{Model, NomicConfig}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index ab01350d..4956eede 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -3,6 +3,8 @@ use crate::layers::{ }; use crate::models::Model; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +#[cfg(feature = "cuda")] +use candle_moe; use candle_nn::{Embedding, VarBuilder}; use candle_transformers::models::deepseek2::{BincountOp, NonZeroOp, TopKLastDimOp, TopKOutput}; use serde::Deserialize; @@ -264,10 +266,10 @@ impl NomicFusedRouter { } pub fn forward(&self, hidden_states: &Tensor) -> Result<(Tensor, Tensor)> { - use candle_moe::apply_topk_softmax_inplace; - let _enter = self.span.enter(); + let device = hidden_states.device(); + let weights = hidden_states.reshape(((), hidden_states.dim(D::Minus1)?))?; let weights = self.layer.forward(&weights)?.to_dtype(DType::F32)?; @@ -277,7 +279,12 @@ impl NomicFusedRouter { let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; - apply_topk_softmax_inplace(&weights, &topk_weight, &topk_indices, &token_expert_indices)?; + candle_moe::apply_topk_softmax_inplace( + &weights, + &topk_weight, + &topk_indices, + &token_expert_indices, + )?; Ok((topk_weight, topk_indices)) } @@ -409,9 +416,10 @@ impl NomicExperts { #[cfg(feature = "cuda")] pub struct NomicFusedExperts { - num_experts: usize, - mlp: NomicExpertMLP, + gate_weight: Tensor, + up_weight: Tensor, bias: Tensor, + fused_moe: candle_moe::FusedMoeForward, span: tracing::Span, } @@ -440,8 +448,6 @@ impl NomicFusedExperts { let bias = vb.get((config.n_embd,), "bias")?; - use candle_moe::{Activation, FusedMoeForward}; - let moe_act = match activation { HiddenAct::Silu => candle_moe::Activation::Silu, HiddenAct::Gelu => candle_moe::Activation::Gelu, @@ -449,7 +455,7 @@ impl NomicFusedExperts { _ => candle::bail!("not supported activation type"), }; - let fused_moe = FusedMoeForward::new(num_experts, top_k, moe_act); + let fused_moe = candle_moe::FusedMoeForward::new(num_experts, top_k, moe_act); Ok(Self { gate_weight, @@ -480,10 +486,10 @@ impl NomicFusedExperts { let hidden_states = hidden_states.reshape(((), hidden_size))?; let mut out = self.fused_moe.forward( - hidden_states, + &hidden_states, &self.gate_weight, &self.up_weight, - None, // Nomic MoE doesn't need down projection + None, &top_weights, &top_experts, 1_u32, // Nomic MoE