diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 76f5d6c7..4aaee0cd 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -64,7 +64,12 @@ public struct LLamaContextParams /// Attention type to use for embeddings /// public LLamaAttentionType attention_type; - + + /// + /// When to enable Flash Attention + /// + public LLamaAttentionType flash_attn_type; + /// /// RoPE base frequency, 0 = from model /// diff --git a/LLama/Native/LLamaFlashAttnType.cs b/LLama/Native/LLamaFlashAttnType.cs new file mode 100644 index 00000000..116fbb29 --- /dev/null +++ b/LLama/Native/LLamaFlashAttnType.cs @@ -0,0 +1,23 @@ +namespace LLama.Native; + +/// +/// +/// +/// llama_flash_attn_type +public enum LLamaFlashAttnType +{ + /// + /// + /// + Auto = -1, + + /// + /// + /// + Disable = 0, + + /// + /// + /// + Enabled = 1, +} \ No newline at end of file diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 705f8032..70db7986 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -202,6 +202,11 @@ public enum LLamaFtype /// LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, + /// + /// except 1d tensors + /// + LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, + /// /// File type was not specified /// diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index acb02485..20a9c99d 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -99,7 +99,17 @@ public bool check_tensors readonly get => Convert.ToBoolean(_check_tensors); set => _check_tensors = Convert.ToSByte(value); } - private sbyte _check_tensors; + private sbyte _check_tensors; + + /// + /// use extra buffer types (used for weight repacking) + /// + public bool use_extra_bufts + { + readonly get => Convert.ToBoolean(_use_extra_bufts); + set => _use_extra_bufts = Convert.ToSByte(value); + } + private sbyte _use_extra_bufts; /// /// Create a LLamaModelParams with default values diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index e26619b2..915b08be 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -294,7 +294,7 @@ static SafeLLamaContextHandle() /// Get the exact size needed to copy the state of a single sequence /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId); + private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId, uint llama_state_seq_flags); /// /// Copy the state of a single sequence into the specified buffer @@ -303,9 +303,10 @@ static SafeLLamaContextHandle() /// /// /// + /// /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId); + private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId, uint llama_state_seq_flags); /// /// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence @@ -314,12 +315,13 @@ static SafeLLamaContextHandle() /// /// /// + /// /// /// - Positive: Ok /// - Zero: Failed to load /// [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId); + private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId, uint llama_state_seq_flags); [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern LLamaPerfContextTimings llama_perf_context(SafeLLamaContextHandle ctx); @@ -680,7 +682,7 @@ public nuint GetStateSize() /// public nuint GetStateSize(LLamaSeqId sequence) { - return llama_state_seq_get_size(this, sequence); + return llama_state_seq_get_size(this, sequence, 0u); } /// @@ -712,7 +714,7 @@ public unsafe nuint GetState(byte* dest, nuint size, LLamaSeqId sequence) if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); - return llama_state_seq_get_data(this, dest, size, sequence); + return llama_state_seq_get_data(this, dest, size, sequence, 0u); } /// @@ -735,7 +737,7 @@ public unsafe nuint SetState(byte* src, nuint size) /// Number of bytes read from the src pointer public unsafe nuint SetState(byte* src, nuint size, LLamaSeqId sequence) { - return llama_state_seq_set_data(this, src, size, sequence); + return llama_state_seq_set_data(this, src, size, sequence, 1u); } #endregion