66#include  < bitset> 
77#include  < cassert> 
88#include  < vector> 
9+ #include  < set> 
910
1011//  meta information about KV cells that can be part of multiple sequences at the same time
1112//  TODO: add unit tests
@@ -18,8 +19,13 @@ class llama_kv_cells_unified {
1819            seq[i].reset ();
1920        }
2021
21-         used      = 0 ;
2222        has_shift = false ;
23+ 
24+         used.clear ();
25+ 
26+         for  (uint32_t  s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
27+             seq_pos[s].clear ();
28+         }
2329    }
2430
2531    void  reset_shift () {
@@ -50,7 +56,25 @@ class llama_kv_cells_unified {
5056    }
5157
5258    uint32_t  get_used () const  {
53-         return  used;
59+         return  used.size ();
60+     }
61+ 
62+     //  the index of the first cell that is used
63+     //  return 0 if no cells are used
64+     uint32_t  used_min () const  {
65+         return  used.empty () ? 0  : *used.begin ();
66+     }
67+ 
68+     //  the index of the last cell that is used + 1
69+     //  return 0 if no cells are used
70+     uint32_t  used_max_p1 () const  {
71+ #if  0 
72+         if (!seq_pos[0].empty()) printf("kv_cells: min[0] = %5d, max[0] = %5d\n", *seq_pos[0].begin(), *seq_pos[0].rbegin());
73+         if (!seq_pos[1].empty()) printf("kv_cells: min[1] = %5d, max[1] = %5d\n", *seq_pos[1].begin(), *seq_pos[1].rbegin());
74+         if (!seq_pos[2].empty()) printf("kv_cells: min[2] = %5d, max[2] = %5d\n", *seq_pos[2].begin(), *seq_pos[2].rbegin());
75+ #endif 
76+ 
77+         return  used.empty () ? 0  : *used.rbegin () + 1 ;
5478    }
5579
5680    bool  get_has_shift () const  {
@@ -69,6 +93,9 @@ class llama_kv_cells_unified {
6993        pos  [isrc] = -1 ;
7094        shift[isrc] =  0 ;
7195        seq  [isrc].reset ();
96+ 
97+         used.erase  (isrc);
98+         used.insert (idst);
7299    }
73100
74101    //  copy the state of cells [i, i + n) (used for save/restore the state of the cells)
@@ -95,16 +122,24 @@ class llama_kv_cells_unified {
95122
96123        for  (uint32_t  j = 0 ; j < other.pos .size (); ++j) {
97124            if  (pos[i + j] == -1  && other.pos [j] != -1 ) {
98-                 used++ ;
125+                 used. insert (i + j) ;
99126            }
100127
101128            if  (pos[i + j] != -1  && other.pos [j] == -1 ) {
102-                 used--;
129+                 used.erase (i + j);
130+             }
131+ 
132+             if  (pos[i + j] != -1 ) {
133+                 seq_pos_rm (i + j);
103134            }
104135
105136            pos[i + j] = other.pos [j];
106137            seq[i + j] = other.seq [j];
107138
139+             if  (pos[i + j] != -1 ) {
140+                 seq_pos_add (i + j);
141+             }
142+ 
108143            assert (shift[i + j] == 0 );
109144        }
110145    }
@@ -118,11 +153,12 @@ class llama_kv_cells_unified {
118153        assert (seq_id >= 0 );
119154
120155        seq[i].reset (seq_id);
156+         seq_pos[seq_id].erase (pos[i]);
121157
122158        if  (seq[i].none ()) {
123159            pos[i] = -1 ;
124160
125-             used-- ;
161+             used. erase (i) ;
126162
127163            return  true ;
128164        }
@@ -135,17 +171,22 @@ class llama_kv_cells_unified {
135171        assert (i < pos.size ());
136172
137173        if  (seq[i].test (seq_id)) {
174+             seq_pos_rm (i);
138175            seq[i].reset ();
176+ 
139177            seq[i].set (seq_id);
178+             seq_pos[seq_id].insert (pos[i]);
140179
141180            return  false ;
142181        }
143182
144183        if  (seq[i].any ()) {
184+             seq_pos_rm (i);
145185            seq[i].reset ();
186+ 
146187            pos[i] = -1 ;
147188
148-             used-- ;
189+             used. erase (i) ;
149190
150191            return  true ;
151192        }
@@ -169,6 +210,29 @@ class llama_kv_cells_unified {
169210        assert (!seq[i].test (seq_id));
170211
171212        seq[i].set (seq_id);
213+         seq_pos[seq_id].insert (pos[i]);
214+     }
215+ 
216+     llama_pos seq_pos_min (llama_seq_id seq_id) const  {
217+         assert (seq_id >= 0 );
218+         assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
219+ 
220+         if  (seq_pos[seq_id].empty ()) {
221+             return  -1 ;
222+         }
223+ 
224+         return  *seq_pos[seq_id].begin ();
225+     }
226+ 
227+     llama_pos seq_pos_max (llama_seq_id seq_id) const  {
228+         assert (seq_id >= 0 );
229+         assert (seq_id < LLAMA_MAX_PARALLEL_SEQUENCES);
230+ 
231+         if  (seq_pos[seq_id].empty ()) {
232+             return  -1 ;
233+         }
234+ 
235+         return  *seq_pos[seq_id].rbegin ();
172236    }
173237
174238    //  note: call only if the cell is not empty
@@ -202,7 +266,8 @@ class llama_kv_cells_unified {
202266        assert (pos[i] == -1 );
203267
204268        pos[i] = p;
205-         used++;
269+ 
270+         used.insert (i);
206271    }
207272
208273    //  pos[i] = pos[i] + d
@@ -218,10 +283,12 @@ class llama_kv_cells_unified {
218283        has_shift = true ;
219284
220285        if  (pos[i] < 0 ) {
221-             pos[i] = -1 ;
286+             seq_pos_rm (i);
287+ 
222288            seq[i].reset ();
289+             pos[i] = -1 ;
223290
224-             used-- ;
291+             used. erase (i) ;
225292
226293            return  true ;
227294        }
@@ -245,10 +312,11 @@ class llama_kv_cells_unified {
245312    }
246313
247314private: 
248-     uint32_t  used = 0 ; //  used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
249- 
250315    bool  has_shift = false ;
251316
317+     //  set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
318+     std::set<uint32_t > used;
319+ 
252320    std::vector<llama_pos> pos;
253321
254322    //  this array accumulates any applied shifts to the pos array since the last reset_shift() call
@@ -268,6 +336,32 @@ class llama_kv_cells_unified {
268336    // 
269337    std::vector<llama_pos> shift;
270338
271-     std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
272- };
339+     using  bits_t  = std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>;
273340
341+     //  the bitset seq[i] tells us which sequences are currently occupying the i-th cell
342+     std::vector<bits_t > seq;
343+ 
344+     //  the set seq_pos[s] tells us which positions are currently occupied by the s-th sequence
345+     //  this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
346+     std::set<llama_pos> seq_pos[LLAMA_MAX_PARALLEL_SEQUENCES];
347+ 
348+     //  helper functions for updating `seq_pos`, once cell at a time:
349+ 
350+     //  remove cell i
351+     void  seq_pos_rm (uint32_t  i) {
352+         for  (int  s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
353+             if  (seq[i].test (s)) {
354+                 seq_pos[s].erase (pos[i]);
355+             }
356+         }
357+     }
358+ 
359+     //  add cell i
360+     void  seq_pos_add (uint32_t  i) {
361+         for  (int  s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
362+             if  (seq[i].test (s)) {
363+                 seq_pos[s].insert (pos[i]);
364+             }
365+         }
366+     }
367+ };
0 commit comments