Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 167 additions & 81 deletions LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
@@ -1,72 +1,58 @@
using System;
using System;
using System.Runtime.InteropServices;

namespace LLama.Native;

/// <summary>
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
/// A safe handle for a LLamaKvCacheView
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaKvCacheViewCell
public sealed class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase
{
private readonly SafeLLamaContextHandle _ctx;
private NativeLLamaKvCacheView _view;

/// <summary>
/// The position for this cell. Takes KV cache shifts into account.
/// May be negative if the cell is not populated.
/// Number of KV cache cells. This will be the same as the context size.
/// </summary>
public LLamaPos pos;
}

/// <summary>
/// An updateable view of the KV cache (llama_kv_cache_view)
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaKvCacheView
{
// Number of KV cache cells. This will be the same as the context size.
int n_cells;

// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int n_seq_max;

// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int token_count;

// Number of populated cache cells.
int used_cells;
public int CellCount => GetNativeView().n_cells;

// Maximum contiguous empty slots in the cache.
int max_contiguous;

// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int max_contiguous_idx;

// Information for an individual cell.
LLamaKvCacheViewCell* cells;
/// <summary>
/// Get the total number of tokens in the KV cache.
///
/// For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
/// </summary>
public int TokenCount => GetNativeView().token_count;

/// <summary>
/// Maximum number of sequences visible for a cell. There may be more sequences than this
/// in reality, this is simply the maximum number this view can see.
/// </summary>
public int MaxSequenceCount => GetNativeView().n_seq_max;

/// <summary>
/// Number of populated cache cells
/// </summary>
public int UsedCellCount => GetNativeView().used_cells;

// The sequences for each cell. There will be n_seq_max items per cell.
LLamaSeqId* cells_sequences;
}
/// <summary>
/// Maximum contiguous empty slots in the cache.
/// </summary>
public int MaxContiguous => GetNativeView().max_contiguous;

/// <summary>
/// A safe handle for a LLamaKvCacheView
/// </summary>
public class LLamaKvCacheViewSafeHandle
: SafeLLamaHandleBase
{
private readonly SafeLLamaContextHandle _ctx;
private LLamaKvCacheView _view;
/// <summary>
/// Index to the start of the MaxContiguous slot range. Can be negative when cache is full.
/// </summary>
public int MaxContiguousIdx => GetNativeView().max_contiguous;

/// <summary>
/// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
/// </summary>
/// <param name="ctx"></param>
/// <param name="view"></param>
public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
private LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, NativeLLamaKvCacheView view)
: base((IntPtr)1, true)
{
_ctx = ctx;
Expand All @@ -81,76 +67,176 @@ public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView v
/// <returns></returns>
public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
{
var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
return new LLamaKvCacheViewSafeHandle(ctx, result);
// Allocate the view
var view = llama_kv_cache_view_init(ctx, maxSequences);
var handle = new LLamaKvCacheViewSafeHandle(ctx, view);

// Update the view so it has valid data after allocation.
handle.Update();

return handle;
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_kv_cache_view_free(ref _view);
llama_kv_cache_view_free(ref _view);
SetHandle(IntPtr.Zero);

return true;
}

/// <summary>
/// Update this view
/// Read the current KV cache state into this view.
/// </summary>
public void Update()
{
NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
llama_kv_cache_view_update(_ctx, ref _view);
}

/// <summary>
/// Get the raw KV cache view
/// </summary>
/// <returns></returns>
public ref LLamaKvCacheView GetView()
private ref NativeLLamaKvCacheView GetNativeView()
{
if (IsClosed)
throw new ObjectDisposedException("Cannot access LLamaKvCacheViewSafeHandle after is has been disposed");

return ref _view;
}
}

public static partial class NativeApi
{
/// <summary>
/// Get the cell at the given index
/// </summary>
/// <param name="index">The index of the cell [0, CellCount)</param>
/// <returns>Data about the cell at the given index</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 &lt;= index &lt; CellCount)</exception>
public LLamaPos GetCell(int index)
{
var view = GetNativeView();

if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
if (index >= view.n_cells)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");

unsafe
{
return view.cells[index].pos;
}
}

/// <summary>
/// Get all of the sequences assigned to the cell at the given index. This will contain <see cref="MaxSequenceCount"/> entries
/// sequences even if the cell actually has more than that many sequences, allocate a new view with a larger maxSequences parameter
/// if necessary. Invalid sequences will be negative values.
/// </summary>
/// <param name="index">The index of the cell [0, CellCount)</param>
/// <returns>A span containing the sequences assigned to this cell</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if index is out of range (0 &lt;= index &lt; CellCount)</exception>
public Span<LLamaSeqId> GetCellSequences(int index)
{
var view = GetNativeView();

if (index < 0)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be >= 0");
if (index >= view.n_cells)
throw new ArgumentOutOfRangeException(nameof(index), "Cell index must be < CellCount");

unsafe
{
return new Span<LLamaSeqId>(&view.cells_sequences[index * view.n_seq_max], view.n_seq_max);
}
}

#region native API
/// <summary>
/// Create an empty KV cache view. (use only for debugging purposes)
/// </summary>
/// <param name="ctx"></param>
/// <param name="n_seq_max"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max);
/// <summary>
/// Free a KV cache view. (use only for debugging purposes)
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view);
/// <summary>
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
/// </summary>
/// <param name="ctx"></param>
/// <param name="view"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);

[StructLayout(LayoutKind.Sequential)]
private struct NativeLLamaKvCacheViewCell
{
/// <summary>
/// The position for this cell. Takes KV cache shifts into account.
/// May be negative if the cell is not populated.
/// </summary>
public LLamaPos pos;
}

/// <summary>
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
/// An updateable view of the KV cache (llama_kv_cache_view)
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);
[StructLayout(LayoutKind.Sequential)]
private unsafe struct NativeLLamaKvCacheView
{
/// <summary>
/// Number of KV cache cells. This will be the same as the context size.
/// </summary>
public int n_cells;

/// <summary>
/// Maximum number of sequences that can exist in a cell. It's not an error
/// if there are more sequences in a cell than this value, however they will
/// not be visible in the view cells_sequences.
/// </summary>
public int n_seq_max;

/// <summary>
/// Number of tokens in the cache. For example, if there are two populated
/// cells, the first with 1 sequence id in it and the second with 2 sequence
/// ids then you'll have 3 tokens.
/// </summary>
public int token_count;

/// <summary>
/// Number of populated cache cells.
/// </summary>
public int used_cells;

/// <summary>
/// Maximum contiguous empty slots in the cache.
/// </summary>
public int max_contiguous;

/// <summary>
/// Index to the start of the max_contiguous slot range. Can be negative
/// when cache is full.
/// </summary>
public int max_contiguous_idx;

/// <summary>
/// Information for an individual cell.
/// </summary>
public NativeLLamaKvCacheViewCell* cells;

/// <summary>
/// The sequences for each cell. There will be n_seq_max items per cell.
/// </summary>
public LLamaSeqId* cells_sequences;
}
#endregion
}
19 changes: 18 additions & 1 deletion LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Runtime.InteropServices;

#pragma warning disable IDE1006 // Naming Styles
Expand Down Expand Up @@ -258,6 +258,23 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
{
NativeLogConfig.llama_log_set(logCallback);
}

/// <summary>
/// Returns the number of tokens in the KV cache (slow, use only for debug)
/// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx);

/// <summary>
/// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_get_kv_cache_used_cells(SafeLLamaContextHandle ctx);

/// <summary>
/// Clear the KV cache
Expand Down