@@ -1862,9 +1862,6 @@ static bool llama_kv_cache_init(
18621862 if (model.arch == LLM_ARCH_MAMBA) {
18631863 // only one slot is needed for Mamba
18641864 n_ctx = 1 ;
1865- // it's probably best to keep as much precision as possible for the states
1866- ktype = GGML_TYPE_F32;
1867- vtype = GGML_TYPE_F32;
18681865 }
18691866
18701867 cache.has_shift = false ;
@@ -4179,7 +4176,7 @@ static bool llm_load_tensors(
41794176 } break ;
41804177 case LLM_ARCH_MAMBA:
41814178 {
4182- const int64_t d_conv = hparams.n_embd_head_k ;
4179+ const int64_t d_conv = hparams.n_embd_head_k + 1 ;
41834180 const int64_t d_state = hparams.n_embd_head_v ;
41844181 const int64_t d_inner = hparams.n_head ;
41854182 // FIXME: ceiling instead of floor
@@ -6917,28 +6914,27 @@ struct llm_build_context {
69176914 struct ggml_cgraph * build_mamba () {
69186915 struct ggml_cgraph * gf = ggml_new_graph_custom (ctx0, LLAMA_MAX_NODES, false );
69196916
6920- const bool use_conv = batch.n_tokens > 1 ;
6921- GGML_ASSERT (use_conv == false ); // TODO: implement
6917+ const int32_t n_tok = batch.n_tokens ;
69226918
69236919 // hopefully the compiler does constant folding
69246920 const int64_t d_model = n_embd;
69256921 const int64_t d_inner = n_head;
69266922 GGML_ASSERT (2 * d_model == d_inner);
6927- const int64_t d_conv = n_embd_head_k;
6923+ const int64_t d_conv = n_embd_head_k + 1 ;
69286924 const int64_t d_state = n_embd_head_v;
69296925 const int64_t dt_rank = d_model / 16 ;
69306926
69316927 struct ggml_tensor * cur;
69326928 struct ggml_tensor * inpL;
69336929
6934- // NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
6935- // {n_embd, batch}
6930+ // {n_embd, n_tok}
69366931 inpL = llm_build_inp_embd (ctx0, hparams, batch, model.tok_embd , lctx.inp_tokens , lctx.inp_embd , cb);
69376932 cb (inpL, " inp_embd" , -1 );
69386933
69396934 for (int il = 0 ; il < n_layer; ++il) {
69406935 // (ab)using the kv cache to store the state
6941- ggml_tensor * conv_state = ggml_reshape_2d (ctx0, kv_self.k_l [il], d_conv, d_inner);
6936+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
6937+ ggml_tensor * conv_state = ggml_reshape_2d (ctx0, kv_self.k_l [il], d_conv - 1 , d_inner);
69426938 ggml_tensor * ssm_state = ggml_reshape_2d (ctx0, kv_self.v_l [il], d_state, d_inner);
69436939
69446940 // norm
@@ -6947,33 +6943,43 @@ struct llm_build_context {
69476943 LLM_NORM_RMS, cb, il);
69486944 cb (cur, " attn_norm" , il);
69496945
6950- // {n_embd, 2*d_inner} * {n_embd, batch } = {2*d_inner, batch }
6946+ // {n_embd, 2*d_inner} * {n_embd, n_tok } => {2*d_inner, n_tok }
69516947 struct ggml_tensor * xz = ggml_mul_mat (ctx0, model.layers [il].ssm_in , cur);
69526948 // split the above in two
6953- // assuming it's contiguous
6954- // {d_inner, batch}
6949+ // => {d_inner, n_tok}
69556950 struct ggml_tensor * x = ggml_view_2d (ctx0, xz, d_inner, xz->ne [1 ], xz->nb [1 ], 0 );
69566951 struct ggml_tensor * z = ggml_view_2d (ctx0, xz, d_inner, xz->ne [1 ], xz->nb [1 ], ggml_element_size (xz)*d_inner);
69576952
6958- cur = x;
6959-
69606953 // conv
69616954 {
6962- // shift conv state left
6963- conv_state = ggml_set_2d (ctx0, conv_state, ggml_view_2d (ctx0, conv_state, (d_conv - 1 ), d_inner, conv_state->nb [1 ], ggml_element_size (conv_state)*1 ), conv_state->nb [1 ], 0 );
6964-
6965- // update last column
6966- // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967- conv_state = ggml_set_2d (ctx0, conv_state, ggml_cont (ctx0, ggml_transpose (ctx0, x)), conv_state->nb [1 ], ggml_element_size (conv_state)*(d_conv - 1 ));
6968-
6969- ggml_build_forward_expand (gf, ggml_cpy (ctx0, conv_state, ggml_view_tensor (ctx0, kv_self.k_l [il])));
6970-
6971- // rearrange and sum
6972- // no need to rearrange the conv_state, since it's already in the right shape
6973- // => {1, d_inner}
6974- x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_state, model.layers [il].ssm_conv1d ));
6975- // => {d_inner, 1}
6976- x = ggml_transpose (ctx0, x);
6955+ // concat last (d_conv - 1) columns of conv_state, and x
6956+
6957+ // The following tensor is too big in order to avoid an assertion error when making an overlapping view.
6958+ // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
6959+ // This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}
6960+ // which is around (d_conv-1) times as small as its current size.
6961+ struct ggml_tensor * conv_x = ggml_new_tensor_1d (ctx0, conv_state->type , d_conv*d_inner*n_tok);
6962+ const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size (conv_x);
6963+
6964+ conv_x = ggml_set_2d (ctx0, conv_x, conv_state, conv_x_nb1, 0 );
6965+ // unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float)
6966+ conv_x = ggml_set_2d (ctx0, conv_x, ggml_cont (ctx0, ggml_transpose (ctx0, x)), conv_x_nb1, (d_conv - 1 )*ggml_element_size (conv_x));
6967+
6968+ // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
6969+ ggml_build_forward_expand (gf,
6970+ ggml_cpy (ctx0,
6971+ ggml_view_2d (ctx0, conv_x, d_conv - 1 , d_inner, conv_x_nb1, n_tok*ggml_element_size (conv_x)),
6972+ ggml_view_tensor (ctx0, kv_self.k_l [il])));
6973+
6974+ // prepare convolution for all tokens in the batch with a self-overlapping view
6975+ // {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
6976+ conv_x = ggml_view_3d (ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1 )*d_inner*ggml_element_size (conv_x), 0 );
6977+
6978+ // perform convolution
6979+ // => {1, d_inner, n_tok}
6980+ x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_x, model.layers [il].ssm_conv1d ));
6981+ // => {d_inner, n_tok, 1}
6982+ x = ggml_permute (ctx0, x, 2 , 0 , 1 , 3 );
69776983
69786984 // bias
69796985 x = ggml_add (ctx0, x, model.layers [il].ssm_conv1d_b );
@@ -6983,23 +6989,24 @@ struct llm_build_context {
69836989
69846990 // ssm
69856991 {
6986- // {2*n_embd, batch} * {2*n_embd , dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
6987- struct ggml_tensor * x_db = ggml_mul_mat (ctx0, x, model.layers [il].ssm_x );
6988- // FIXME: handle batches of more than 1 token
6989- struct ggml_tensor * dt = ggml_view_1d (ctx0, x_db, dt_rank, 0 );
6990- struct ggml_tensor * B = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size (x_db)*dt_rank);
6991- struct ggml_tensor * C = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size (x_db)*(dt_rank+d_state));
6992-
6993- // {dt_rank} * {dt_rank, d_inner } = {1, d_inner }
6994- dt = ggml_mul_mat (ctx0, dt, model.layers [il].ssm_dt );
6995- dt = ggml_add (ctx0, dt, ggml_transpose (ctx0, model.layers [il].ssm_dt_b ) );
6992+ // {d_inner , dt_rank + 2*d_state} * {d_inner, n_tok} => { dt_rank + 2*d_state, n_tok }
6993+ struct ggml_tensor * x_db = ggml_mul_mat (ctx0, model.layers [il].ssm_x , x );
6994+ // split
6995+ struct ggml_tensor * dt = ggml_view_2d (ctx0, x_db, dt_rank, x_db-> ne [ 1 ], x_db-> nb [ 1 ] , 0 );
6996+ struct ggml_tensor * B = ggml_view_2d (ctx0, x_db, d_state, x_db-> ne [ 1 ], x_db-> nb [ 1 ] , ggml_element_size (x_db)*dt_rank);
6997+ struct ggml_tensor * C = ggml_view_2d (ctx0, x_db, d_state, x_db-> ne [ 1 ], x_db-> nb [ 1 ] , ggml_element_size (x_db)*(dt_rank+d_state));
6998+
6999+ // {dt_rank, d_inner } * {dt_rank, n_tok } => {d_inner, n_tok }
7000+ dt = ggml_mul_mat (ctx0, model.layers [il].ssm_dt , dt );
7001+ dt = ggml_add (ctx0, dt, model.layers [il].ssm_dt_b );
69967002 dt = ggml_soft_plus (ctx0, dt);
69977003
7004+ // FIXME: support batches with more than 1 token
69987005 // => {d_state, d_inner}
6999- struct ggml_tensor * dA = ggml_exp (ctx0, ggml_mul (ctx0, model.layers [il].ssm_a , dt ));
7006+ struct ggml_tensor * dA = ggml_exp (ctx0, ggml_mul (ctx0, model.layers [il].ssm_a , ggml_transpose (ctx0, dt) ));
70007007
70017008 // => {d_state, d_inner}
7002- struct ggml_tensor * dB = ggml_out_prod (ctx0, B, ggml_transpose (ctx0, dt) );
7009+ struct ggml_tensor * dB = ggml_out_prod (ctx0, B, dt );
70037010
70047011 // => {d_state, d_inner}
70057012 cur = ggml_mul (ctx0, dB, ggml_transpose (ctx0, x));
@@ -7014,7 +7021,7 @@ struct llm_build_context {
70147021 y = ggml_add (ctx0, y, ggml_mul (ctx0, model.layers [il].ssm_d , x));
70157022 y = ggml_mul (ctx0, y, ggml_silu (ctx0, z));
70167023
7017- // {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
7024+ // {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
70187025 cur = ggml_mul_mat (ctx0, model.layers [il].ssm_out , y);
70197026 }
70207027
@@ -10722,8 +10729,15 @@ struct llama_context * llama_new_context_with_model(
1072210729 ctx->rng = std::mt19937 (params.seed );
1072310730 ctx->logits_all = params.logits_all ;
1072410731
10725- const ggml_type type_k = params.type_k ;
10726- const ggml_type type_v = params.type_v ;
10732+ ggml_type type_k = params.type_k ;
10733+ ggml_type type_v = params.type_v ;
10734+
10735+ // Mamba (mis)uses the KV cache to store its states
10736+ if (model->arch == LLM_ARCH_MAMBA) {
10737+ // it's probably best to keep as much precision as possible for the states
10738+ type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
10739+ type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
10740+ }
1072710741
1072810742 GGML_ASSERT (hparams.n_embd_head_k % ggml_blck_size (type_k) == 0 );
1072910743 GGML_ASSERT (hparams.n_embd_head_v % ggml_blck_size (type_v) == 0 );
0 commit comments