File tree Expand file tree Collapse file tree 2 files changed +14
-5
lines changed Expand file tree Collapse file tree 2 files changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -554,16 +554,18 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
554554
555555 bool found = true ;
556556 for (uint32_t i = 0 ; i < n_tokens; i++) {
557- // TODO: improve to accept cells that are masked by the SWA
558- // if (!cells.is_empty(head_cur + i)) {
559-
560557 const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
561558
559+ // can we use this cell? either:
560+ // - the cell is empty
561+ // - the cell is occupied only by the same sequence, and the sequence is not masked
562562 const bool can_use =
563563 cells.is_empty (head_cur + i) ||
564564 (
565- cells.seq_has (head_cur + i, seq_id) && // TODO: seq_has_only
566- is_masked_swa (cells.pos_get (head_cur + i), ubatch.seq_pos_min [seq_id])
565+ cells.pos_get (head_cur + i) <= ubatch.pos [i] && // causal mask
566+ cells.seq_has (head_cur + i, seq_id) && // sequence mask
567+ cells.seq_count (head_cur + i) == 1 &&
568+ is_masked_swa (cells.pos_get (head_cur + i), ubatch.seq_pos_min [seq_id]) // SWA mask
567569 );
568570
569571 if (!can_use) {
Original file line number Diff line number Diff line change @@ -155,6 +155,13 @@ class llama_kv_cells_unified {
155155 return false ;
156156 }
157157
158+ int seq_count (uint32_t i) const {
159+ assert (i < pos.size ());
160+ assert (pos[i] != -1 );
161+
162+ return seq[i].count ();
163+ }
164+
158165 bool seq_has (uint32_t i, llama_seq_id seq_id) const {
159166 assert (i < pos.size ());
160167 assert (seq_id >= 0 );
You can’t perform that action at this time.
0 commit comments