Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView v
}

/// <summary>
/// Allocate a new llama_kv_cache_view_free
/// Allocate a new KV cache view which can be used to inspect the KV cache
/// </summary>
/// <param name="ctx"></param>
/// <param name="maxSequences">The maximum number of sequences visible in this view per cell</param>
Expand Down Expand Up @@ -102,24 +102,6 @@ public void Update()
NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
}

/// <summary>
/// Count the number of used cells in the KV cache
/// </summary>
/// <returns></returns>
public int CountCells()
{
return NativeApi.llama_get_kv_cache_used_cells(_ctx);
}

/// <summary>
/// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be counted multiple times
/// </summary>
/// <returns></returns>
public int CountTokens()
{
return NativeApi.llama_get_kv_cache_token_count(_ctx);
}

/// <summary>
/// Get the raw KV cache view
/// </summary>
Expand Down
44 changes: 44 additions & 0 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,35 @@ public void SetThreads(uint threads, uint threadsBatch)
}

#region KV Cache Management
/// <summary>
/// Get a new KV cache view that can be used to debug the KV cache
/// </summary>
/// <param name="maxSequences"></param>
/// <returns></returns>
public LLamaKvCacheViewSafeHandle KvCacheGetDebugView(int maxSequences = 4)
{
return LLamaKvCacheViewSafeHandle.Allocate(this, maxSequences);
}

/// <summary>
/// Count the number of used cells in the KV cache (i.e. have at least one sequence assigned to them)
/// </summary>
/// <returns></returns>
public int KvCacheCountCells()
{
return NativeApi.llama_get_kv_cache_used_cells(this);
}

/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
/// </summary>
/// <returns></returns>
public int KvCacheCountTokens()
{
return NativeApi.llama_get_kv_cache_token_count(this);
}

/// <summary>
/// Clear the KV cache
/// </summary>
Expand Down Expand Up @@ -344,6 +373,21 @@ public void KvCacheSequenceShift(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLama
{
NativeApi.llama_kv_cache_seq_shift(this, seq, p0, p1, delta);
}

/// <summary>
/// Integer division of the positions by factor of `d > 1`
/// If the KV cache is RoPEd, the KV data is updated accordingly
/// p0 &lt; 0 : [0, p1]
/// p1 &lt; 0 : [p0, inf)
/// </summary>
/// <param name="seq"></param>
/// <param name="p0"></param>
/// <param name="p1"></param>
/// <param name="divisor"></param>
public void KvCacheSequenceDivide(LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int divisor)
{
NativeApi.llama_kv_cache_seq_div(this, seq, p0, p1, divisor);
}
#endregion
}
}