@@ -167,6 +167,8 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
167167 llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168168 state_base = kv->get_base ()->init_full ();
169169 state_swa = kv->get_swa ()->init_full ();
170+
171+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
170172}
171173
172174llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
@@ -176,22 +178,7 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
176178 state_base = kv->get_base ()->init_update (lctx, optimize);
177179 state_swa = kv->get_swa ()->init_update (lctx, optimize);
178180
179- // TODO: this is very ugly - how to make it simpler?
180- // the llama_memory_status enum is not very well designed
181- if (state_base->get_status () != LLAMA_MEMORY_STATUS_SUCCESS && state_base->get_status () != LLAMA_MEMORY_STATUS_NO_UPDATE) {
182- status = state_base->get_status ();
183- return ;
184- }
185-
186- if (state_swa->get_status () != LLAMA_MEMORY_STATUS_SUCCESS && state_swa->get_status () != LLAMA_MEMORY_STATUS_NO_UPDATE) {
187- status = state_swa->get_status ();
188- return ;
189- }
190-
191- if (state_base->get_status () == LLAMA_MEMORY_STATUS_NO_UPDATE && state_swa->get_status () == LLAMA_MEMORY_STATUS_NO_UPDATE) {
192- status = LLAMA_MEMORY_STATUS_NO_UPDATE;
193- return ;
194- }
181+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
195182}
196183
197184llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
@@ -200,13 +187,15 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200187 std::vector<uint32_t > heads_base,
201188 std::vector<uint32_t > heads_swa,
202189 std::vector<llama_ubatch> ubatches)
203- : status(LLAMA_MEMORY_STATUS_SUCCESS),
204- sbatch(std::move(sbatch)),
205- ubatches(std::move(ubatches)) {
206- // note: here we copy the ubatches. not sure if this is ideal
207- state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
208- state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
209- }
190+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
191+ sbatch(std::move(sbatch)),
192+ ubatches(std::move(ubatches)) {
193+ // note: here we copy the ubatches. not sure if this is ideal
194+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
195+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
196+
197+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
198+ }
210199
211200llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state () = default ;
212201
@@ -246,6 +235,7 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
246235
247236const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch () const {
248237 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
238+
249239 return ubatches[i_next];
250240}
251241
0 commit comments