diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index 4cccd13c5..2b4087720 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -74,7 +74,7 @@ public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView v } /// - /// Allocate a new llama_kv_cache_view_free + /// Allocate a new KV cache view which can be used to inspect the KV cache /// /// /// The maximum number of sequences visible in this view per cell @@ -102,24 +102,6 @@ public void Update() NativeApi.llama_kv_cache_view_update(_ctx, ref _view); } - /// - /// Count the number of used cells in the KV cache - /// - /// - public int CountCells() - { - return NativeApi.llama_get_kv_cache_used_cells(_ctx); - } - - /// - /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times - /// - /// - public int CountTokens() - { - return NativeApi.llama_get_kv_cache_token_count(_ctx); - } - /// /// Get the raw KV cache view /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index d90d46d5b..91e82c85d 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -289,6 +289,35 @@ public void SetThreads(uint threads, uint threadsBatch) } #region KV Cache Management + /// + /// Get a new KV cache view that can be used to debug the KV cache + /// + /// + /// + public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4) + { + return LLamaKvCacheViewSafeHandle.Allocate(this, maxSequences); + } + + /// + /// Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them) + /// + /// + public int KvCacheCountCells() + { + return NativeApi.llama_get_kv_cache_used_cells(this); + } + + /// + /// Returns the number of tokens in the KV cache (slow, use only for debug) + /// If a KV cell has multiple sequences assigned to it, it will be counted multiple times + /// + /// + public int KvCacheCountTokens() + { + return NativeApi.llama_get_kv_cache_token_count(this); + } + /// /// Clear the KV cache /// @@ -344,6 +373,21 @@ public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLama { NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta); } + + /// + /// Integer division of the positions by factor of `d > 1` + /// If the KV cache is RoPEd, the KV data is updated accordingly + /// p0 < 0 : [0, p1] + /// p1 < 0 : [p0, inf) + /// + /// + /// + /// + /// + public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor) + { + NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor); + } #endregion } }