Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
/// </summary>
public LLamaAttentionType attention_type;


/// <summary>
/// When to enable Flash Attention
/// </summary>
public LLamaAttentionType flash_attn_type;

/// <summary>
/// RoPE base frequency, 0 = from model
/// </summary>
Expand Down
23 changes: 23 additions & 0 deletions LLama/Native/LLamaFlashAttnType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_flash_attn_type</remarks>
public enum LLamaFlashAttnType
{
/// <summary>
///
/// </summary>
Auto = -1,

/// <summary>
///
/// </summary>
Disable = 0,

/// <summary>
///
/// </summary>
Enabled = 1,
}
5 changes: 5 additions & 0 deletions LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ public enum LLamaFtype
/// </summary>
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
12 changes: 11 additions & 1 deletion LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,17 @@ public bool check_tensors
readonly get => Convert.ToBoolean(_check_tensors);
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;
private sbyte _check_tensors;

/// <summary>
/// use extra buffer types (used for weight repacking)
/// </summary>
public bool use_extra_bufts
{
readonly get => Convert.ToBoolean(_use_extra_bufts);
set => _use_extra_bufts = Convert.ToSByte(value);
}
private sbyte _use_extra_bufts;

/// <summary>
/// Create a LLamaModelParams with default values
Expand Down
14 changes: 8 additions & 6 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ static SafeLLamaContextHandle()
/// Get the exact size needed to copy the state of a single sequence
/// </summary>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId);
private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId, uint llama_state_seq_flags);

/// <summary>
/// Copy the state of a single sequence into the specified buffer
Expand All @@ -303,9 +303,10 @@ static SafeLLamaContextHandle()
/// <param name="dst"></param>
/// <param name="size"></param>
/// <param name="seqId"></param>
/// <param name="llama_state_seq_flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId);
private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId, uint llama_state_seq_flags);

/// <summary>
/// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
Expand All @@ -314,12 +315,13 @@ static SafeLLamaContextHandle()
/// <param name="src"></param>
/// <param name="size"></param>
/// <param name="destSeqId"></param>
/// <param name="llama_state_seq_flags"></param>
/// <returns>
/// - Positive: Ok
/// - Zero: Failed to load
/// </returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId);
private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId, uint llama_state_seq_flags);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern LLamaPerfContextTimings llama_perf_context(SafeLLamaContextHandle ctx);
Expand Down Expand Up @@ -680,7 +682,7 @@ public nuint GetStateSize()
/// <returns></returns>
public nuint GetStateSize(LLamaSeqId sequence)
{
return llama_state_seq_get_size(this, sequence);
return llama_state_seq_get_size(this, sequence, 0u);
}

/// <summary>
Expand Down Expand Up @@ -712,7 +714,7 @@ public unsafe nuint GetState(byte* dest, nuint size, LLamaSeqId sequence)
if (size < required)
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");

return llama_state_seq_get_data(this, dest, size, sequence);
return llama_state_seq_get_data(this, dest, size, sequence, 0u);
}

/// <summary>
Expand All @@ -735,7 +737,7 @@ public unsafe nuint SetState(byte* src, nuint size)
/// <returns>Number of bytes read from the src pointer</returns>
public unsafe nuint SetState(byte* src, nuint size, LLamaSeqId sequence)
{
return llama_state_seq_set_data(this, src, size, sequence);
return llama_state_seq_set_data(this, src, size, sequence, 1u);
}
#endregion

Expand Down