@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123123
124124 assert (heads_base.size () == heads_swa.size ());
125125
126- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
126+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
127127 this , std::move (sbatch), std::move (heads_base), std::move (heads_swa), std::move (ubatches));
128128}
129129
130130llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
131- return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this );
131+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this );
132132}
133133
134- bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
135- bool res = false ;
136-
137- res = res | kv_base->update (lctx);
138- res = res | kv_swa ->update (lctx);
139-
140- return res;
141- }
142-
143- void llama_kv_cache_unified_iswa::defrag_sched (float thold) {
144- kv_base->defrag_sched (thold);
145- kv_swa ->defrag_sched (thold);
134+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
135+ return std::make_unique<llama_kv_cache_unified_iswa_state>(this , lctx, optimize);
146136}
147137
148138bool llama_kv_cache_unified_iswa::get_can_shift () const {
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174164llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
175165
176166llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
177- llama_memory_status status,
178- llama_kv_cache_unified_iswa * kv) : status(status) {
179- state_base.reset (new llama_kv_cache_unified_state (status, kv->get_base ()));
180- state_swa .reset (new llama_kv_cache_unified_state (status, kv->get_swa ()));
167+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168+ state_base = kv->get_base ()->init_full ();
169+ state_swa = kv->get_swa ()->init_full ();
170+
171+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
172+ }
173+
174+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
175+ llama_kv_cache_unified_iswa * kv,
176+ llama_context * lctx,
177+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
178+ state_base = kv->get_base ()->init_update (lctx, optimize);
179+ state_swa = kv->get_swa ()->init_update (lctx, optimize);
180+
181+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
181182}
182183
183184llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
184- llama_memory_status status,
185185 llama_kv_cache_unified_iswa * kv,
186186 llama_sbatch sbatch,
187187 std::vector<uint32_t > heads_base,
188188 std::vector<uint32_t > heads_swa,
189189 std::vector<llama_ubatch> ubatches)
190- : status(status),
191- sbatch(std::move(sbatch)),
192- ubatches(std::move(ubatches)) {
193- // note: here we copy the ubatches. not sure if this is ideal
194- state_base.reset (new llama_kv_cache_unified_state (status, kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
195- state_swa .reset (new llama_kv_cache_unified_state (status, kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
196- }
190+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
191+ sbatch(std::move(sbatch)),
192+ ubatches(std::move(ubatches)) {
193+ // note: here we copy the ubatches. not sure if this is ideal
194+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
195+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
196+
197+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
198+ }
197199
198200llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state () = default ;
199201
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
233235
234236const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch () const {
235237 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
238+
236239 return ubatches[i_next];
237240}
238241
239242const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base () const {
240243 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
241244
242- return state_base.get ();
245+ return static_cast < const llama_kv_cache_unified_state *>( state_base.get () );
243246}
244247
245248const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa () const {
246249 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
247250
248- return state_swa.get ();
251+ return static_cast < const llama_kv_cache_unified_state *>( state_swa.get () );
249252}
0 commit comments