@@ -6931,15 +6931,14 @@ struct llm_build_context {
69316931 struct ggml_tensor * cur;
69326932 struct ggml_tensor * inpL;
69336933
6934+ // NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
69346935 // {n_embd, batch}
69356936 inpL = llm_build_inp_embd (ctx0, hparams, batch, model.tok_embd , lctx.inp_tokens , lctx.inp_embd , cb);
69366937 cb (inpL, " inp_embd" , -1 );
69376938
69386939 for (int il = 0 ; il < n_layer; ++il) {
69396940 // (ab)using the kv cache to store the state
6940- // NOTE: the conv_state is transposed to ease shifting it.
6941- // if you figured out a way to shift it without transposing it like this, go ahead and fix this.
6942- ggml_tensor * conv_state = kv_self.k_l [il]; // {d_inner, d_conv}
6941+ ggml_tensor * conv_state = ggml_reshape_2d (ctx0, kv_self.k_l [il], d_conv, d_inner);
69436942 ggml_tensor * ssm_state = ggml_reshape_2d (ctx0, kv_self.v_l [il], d_state, d_inner);
69446943
69456944 // norm
@@ -6948,33 +6947,32 @@ struct llm_build_context {
69486947 LLM_NORM_RMS, cb, il);
69496948 cb (cur, " attn_norm" , il);
69506949
6951- // {n_embd, batch } * {n_embd, 2*d_inner } = {batch, 2*d_inner}
6952- struct ggml_tensor * xz = ggml_mul_mat (ctx0, cur, model.layers [il].ssm_in );
6950+ // {n_embd, 2*d_inner } * {n_embd, batch } = {2*d_inner, batch }
6951+ struct ggml_tensor * xz = ggml_mul_mat (ctx0, model.layers [il].ssm_in , cur );
69536952 // split the above in two
69546953 // assuming it's contiguous
6955- // FIXME: handle batches of more than 1 token
6956- struct ggml_tensor * x = ggml_view_1d (ctx0, xz, d_inner, 0 );
6957- struct ggml_tensor * z = ggml_view_1d (ctx0, xz, d_inner, ggml_element_size (xz)*d_inner);
6954+ // {d_inner, batch}
6955+ struct ggml_tensor * x = ggml_view_2d (ctx0, xz, d_inner, xz-> ne [ 1 ], xz-> nb [ 1 ] , 0 );
6956+ struct ggml_tensor * z = ggml_view_2d (ctx0, xz, d_inner, xz-> ne [ 1 ], xz-> nb [ 1 ] , ggml_element_size (xz)*d_inner);
69586957
69596958 cur = x;
69606959
69616960 // conv
69626961 {
69636962 // shift conv state left
6964- conv_state = ggml_set_1d (ctx0, conv_state, ggml_view_1d (ctx0, conv_state, (d_conv - 1 )* d_inner, ggml_element_size (conv_state)*d_inner) , 0 );
6963+ conv_state = ggml_set_2d (ctx0, conv_state, ggml_view_2d (ctx0, conv_state, (d_conv - 1 ), d_inner, conv_state-> nb [ 1 ], ggml_element_size (conv_state)*1 ), conv_state-> nb [ 1 ] , 0 );
69656964
69666965 // update last column
6967- conv_state = ggml_set_1d (ctx0, conv_state, x, ggml_element_size (conv_state)*(d_conv - 1 )*d_inner);
6966+ // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967+ conv_state = ggml_set_2d (ctx0, conv_state, ggml_cont (ctx0, ggml_transpose (ctx0, x)), conv_state->nb [1 ], ggml_element_size (conv_state)*(d_conv - 1 ));
69686968
69696969 ggml_build_forward_expand (gf, ggml_cpy (ctx0, conv_state, ggml_view_tensor (ctx0, kv_self.k_l [il])));
69706970
69716971 // rearrange and sum
6972- conv_state = ggml_reshape_2d (ctx0, conv_state, d_inner, d_conv);
6973- // TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
6974- conv_state = ggml_cont (ctx0, ggml_transpose (ctx0, conv_state));
6975-
6976- // --> {1, d_inner}
6972+ // no need to rearrange the conv_state, since it's already in the right shape
6973+ // => {1, d_inner}
69776974 x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_state, model.layers [il].ssm_conv1d ));
6975+ // => {d_inner, 1}
69786976 x = ggml_transpose (ctx0, x);
69796977
69806978 // bias
0 commit comments