diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 76f5d6c7..4aaee0cd 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -64,7 +64,12 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
///
public LLamaAttentionType attention_type;
-
+
+ ///
+ /// When to enable Flash Attention
+ ///
+ public LLamaAttentionType flash_attn_type;
+
///
/// RoPE base frequency, 0 = from model
///
diff --git a/LLama/Native/LLamaFlashAttnType.cs b/LLama/Native/LLamaFlashAttnType.cs
new file mode 100644
index 00000000..116fbb29
--- /dev/null
+++ b/LLama/Native/LLamaFlashAttnType.cs
@@ -0,0 +1,23 @@
+namespace LLama.Native;
+
+///
+///
+///
+/// llama_flash_attn_type
+public enum LLamaFlashAttnType
+{
+ ///
+ ///
+ ///
+ Auto = -1,
+
+ ///
+ ///
+ ///
+ Disable = 0,
+
+ ///
+ ///
+ ///
+ Enabled = 1,
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index 705f8032..70db7986 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -202,6 +202,11 @@ public enum LLamaFtype
///
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,
+
///
/// File type was not specified
///
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index acb02485..20a9c99d 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -99,7 +99,17 @@ public bool check_tensors
readonly get => Convert.ToBoolean(_check_tensors);
set => _check_tensors = Convert.ToSByte(value);
}
- private sbyte _check_tensors;
+ private sbyte _check_tensors;
+
+ ///
+ /// use extra buffer types (used for weight repacking)
+ ///
+ public bool use_extra_bufts
+ {
+ readonly get => Convert.ToBoolean(_use_extra_bufts);
+ set => _use_extra_bufts = Convert.ToSByte(value);
+ }
+ private sbyte _use_extra_bufts;
///
/// Create a LLamaModelParams with default values
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index e26619b2..915b08be 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -294,7 +294,7 @@ static SafeLLamaContextHandle()
/// Get the exact size needed to copy the state of a single sequence
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId);
+ private static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seqId, uint llama_state_seq_flags);
///
/// Copy the state of a single sequence into the specified buffer
@@ -303,9 +303,10 @@ static SafeLLamaContextHandle()
///
///
///
+ ///
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId);
+ private static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, nuint size, LLamaSeqId seqId, uint llama_state_seq_flags);
///
/// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -314,12 +315,13 @@ static SafeLLamaContextHandle()
///
///
///
+ ///
///
/// - Positive: Ok
/// - Zero: Failed to load
///
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId);
+ private static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, nuint size, LLamaSeqId destSeqId, uint llama_state_seq_flags);
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern LLamaPerfContextTimings llama_perf_context(SafeLLamaContextHandle ctx);
@@ -680,7 +682,7 @@ public nuint GetStateSize()
///
public nuint GetStateSize(LLamaSeqId sequence)
{
- return llama_state_seq_get_size(this, sequence);
+ return llama_state_seq_get_size(this, sequence, 0u);
}
///
@@ -712,7 +714,7 @@ public unsafe nuint GetState(byte* dest, nuint size, LLamaSeqId sequence)
if (size < required)
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");
- return llama_state_seq_get_data(this, dest, size, sequence);
+ return llama_state_seq_get_data(this, dest, size, sequence, 0u);
}
///
@@ -735,7 +737,7 @@ public unsafe nuint SetState(byte* src, nuint size)
/// Number of bytes read from the src pointer
public unsafe nuint SetState(byte* src, nuint size, LLamaSeqId sequence)
{
- return llama_state_seq_set_data(this, src, size, sequence);
+ return llama_state_seq_set_data(this, src, size, sequence, 1u);
}
#endregion