diff --git a/LLama.KernelMemory/BuilderExtensions.cs b/LLama.KernelMemory/BuilderExtensions.cs index 6ab04a8bc..0aae8e69d 100644 --- a/LLama.KernelMemory/BuilderExtensions.cs +++ b/LLama.KernelMemory/BuilderExtensions.cs @@ -77,7 +77,6 @@ public static IKernelMemoryBuilder WithLLamaSharpDefaults(this IKernelMemoryBuil SplitMode = config.SplitMode, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true }; diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index 0635015df..b5c110194 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -40,7 +40,6 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config) SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true, PoolingType = LLamaPoolingType.Mean, }; @@ -68,7 +67,6 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true, PoolingType = LLamaPoolingType.Mean, }; diff --git a/LLama.KernelMemory/LlamaSharpTextGenerator.cs b/LLama.KernelMemory/LlamaSharpTextGenerator.cs index 5c965b266..166d4ad38 100644 --- a/LLama.KernelMemory/LlamaSharpTextGenerator.cs +++ b/LLama.KernelMemory/LlamaSharpTextGenerator.cs @@ -38,7 +38,6 @@ public LlamaSharpTextGenerator(LLamaSharpConfig config) SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true }; _weights = LLamaWeights.LoadFromFile(@params); @@ -66,7 +65,6 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St SplitMode = config?.SplitMode ?? LLama.Native.GPUSplitMode.Layer, BatchSize = 512, UBatchSize = 512, - FlashAttention = true, UseMemorymap = true }; _executor = executor ?? new StatelessExecutor(_weights, @params); diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index e28b55ce0..b04ee5382 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -13,7 +13,7 @@ public LLamaContextTests() { var @params = new ModelParams(Constants.GenerativeModelPath2) { - ContextSize = 128, + ContextSize = 512, BatchSize = 8, UBatchSize = 8, SeqMax = 1, @@ -33,7 +33,7 @@ public void Dispose() [Fact] public void CheckProperties() { - Assert.Equal(128u, _context.ContextSize); + Assert.Equal(512u, _context.ContextSize); Assert.Equal(960, _context.EmbeddingSize); Assert.Equal(49152, _context.Vocab.Count); } diff --git a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs index 1d16b0481..871b6b8cd 100644 --- a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs +++ b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs @@ -30,7 +30,7 @@ public LLamaContextWithCustomLoggerTests() { var @params = new ModelParams(Constants.GenerativeModelPath2) { - ContextSize = 128, + ContextSize = 512, GpuLayerCount = Constants.CIGpuLayerCount, }; @@ -55,7 +55,7 @@ public void Dispose() [Fact] public void CheckProperties() { - Assert.Equal(128u, _context.ContextSize); + Assert.Equal(512u, _context.ContextSize); Assert.Equal(960, _context.EmbeddingSize); Assert.Equal(49152, _context.Vocab.Count); } diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs index 615a7c79e..297641df3 100644 --- a/LLama.Unittest/SamplingTests.cs +++ b/LLama.Unittest/SamplingTests.cs @@ -104,7 +104,7 @@ public void BatchedSampling() } } - // Add " repeat" and test whether next tokens will be "this phrase forever.". + // Add " repeat" and test whether next tokens will be "this phrase forever." for (int i = 0; i < 4; i++) { for (int b = 0; b < batch_count; b++) diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index f80759c8a..e376258bb 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -103,11 +103,6 @@ public interface IContextParams /// bool NoKqvOffload { get; } - /// - /// Whether to use flash attention - /// - bool FlashAttention { get; } - /// /// defragment the KV cache if holes/size > defrag_threshold, Set to <= 0 to disable (default) /// diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 89737faa7..532dc1a22 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -1,3 +1,4 @@ +using System; using LLama.Abstractions; using System.Text; using System.Text.Json.Serialization; @@ -97,10 +98,7 @@ public record ModelParams public bool NoKqvOffload { get; set; } /// - - public bool FlashAttention { get; set; } - - /// + [Obsolete] public float? DefragThreshold { get; set; } /// diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 85e40f7ad..882bf7fd3 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -49,7 +49,6 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16; result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16; result.offload_kqv = !@params.NoKqvOffload; - result.flash_attention = @params.FlashAttention; result.llama_pooling_type = @params.PoolingType; result.attention_type = @params.AttentionType; diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index 629d10447..be5d09da3 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -57,7 +57,7 @@ - 11dd5a44eb180e + 86587da diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 76f5d6c77..6dea4de47 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -64,6 +64,11 @@ public struct LLamaContextParams /// Attention type to use for embeddings /// public LLamaAttentionType attention_type; + + /// + /// when to enable Flash Attention + /// + public LLamaFlashAttentionType llama_flash_attn_type; /// /// RoPE base frequency, 0 = from model diff --git a/LLama/Native/LLamaFlashAttentionType.cs b/LLama/Native/LLamaFlashAttentionType.cs new file mode 100644 index 000000000..7138dea93 --- /dev/null +++ b/LLama/Native/LLamaFlashAttentionType.cs @@ -0,0 +1,19 @@ +namespace LLama.Native; +/// +/// flash_attn_type +/// +public enum LLamaFlashAttentionType +{ + /// + /// attention type auto + /// + LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1, + /// + /// attention disabled + /// + LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0, + /// + /// attention enabled + /// + LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1, +} \ No newline at end of file diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 705f8032e..813bad1ae 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -201,7 +201,12 @@ public enum LLamaFtype /// except 1d tensors /// LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, - + + /// + /// except 1d tensors + /// + LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, + /// /// File type was not specified /// diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index acb024852..4826c96b7 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -100,7 +100,16 @@ public bool check_tensors set => _check_tensors = Convert.ToSByte(value); } private sbyte _check_tensors; - + + /// + /// use extra buffer types (used for weight repacking) + /// + public bool use_extra_bufts + { + readonly get => Convert.ToBoolean(_use_extra_bufts); + set => _use_extra_bufts = Convert.ToSByte(value); + } + private sbyte _use_extra_bufts; /// /// Create a LLamaModelParams with default values /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index db9e928bd..0a5ad6003 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -99,7 +99,8 @@ public static void llama_empty_call() /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); + public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, + LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -111,25 +112,29 @@ public static void llama_empty_call() /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [return: MarshalAs(UnmanagedType.U1)] - public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); + public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, + LLamaToken[] tokens, ulong n_token_count); /// /// Saves the specified sequence as a file on specified filepath. Can later be loaded via /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); + public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, + LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count); /// /// Loads a sequence saved as a file via into the specified sequence /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); + public static extern unsafe nuint llama_state_seq_load_file(SafeLLamaContextHandle ctx, string filepath, + LLamaSeqId dest_seq_id, LLamaToken* tokens_out, nuint n_token_capacity, out nuint n_token_count_out); /// /// Set whether to use causal attention or not. If set to true, the model will only attend to the past tokens /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool causalAttn); + public static extern void llama_set_causal_attn(SafeLLamaContextHandle ctx, + [MarshalAs(UnmanagedType.U1)] bool causalAttn); /// /// Set whether the context outputs embeddings or not @@ -137,13 +142,15 @@ public static void llama_empty_call() /// /// If true, embeddings will be returned but logits will not [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, [MarshalAs(UnmanagedType.U1)] bool embeddings); + public static extern void llama_set_embeddings(SafeLLamaContextHandle ctx, + [MarshalAs(UnmanagedType.U1)] bool embeddings); /// /// Set abort callback /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData); + public static extern void llama_set_abort_callback(SafeLlamaModelHandle ctx, + IntPtr /* ggml_abort_callback */ abortCallback, IntPtr abortCallbackData); /// /// Get the n_seq_max for this context @@ -175,12 +182,15 @@ public static void llama_empty_call() /// A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages) /// The size of the allocated buffer /// The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. - public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length) + public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, + [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length) { return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length); - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")] - static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, + EntryPoint = "llama_chat_apply_template")] + static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, + [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length); } /// @@ -215,7 +225,8 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage* /// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') /// If true, special tokens are rendered in the output /// The length written, or if the buffer is too small a negative that indicates the length required - public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special) + public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, + Span buffer, int lstrip, bool special) { // Handle invalid tokens if ((int)llamaToken < 0) @@ -225,12 +236,14 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL { fixed (byte* bufferPtr = buffer) { - return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special); + return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, + special); } } [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")] - static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special); + static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, + byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special); } /// @@ -247,7 +260,9 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL /// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit) /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special); + internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, + LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, + [MarshalAs(UnmanagedType.U1)] bool parse_special); /// /// Convert the provided tokens into text (inverse of llama_tokenize()). @@ -261,7 +276,8 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL /// unparse_special If true, special tokens are rendered in the output. /// Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial); + internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, + byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial); /// /// Register a callback to receive llama log messages @@ -272,7 +288,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) { NativeLogConfig.llama_log_set(logCallback); } - + /// /// Allocates a batch of tokens on the heap /// Each token can be assigned up to n_seq_max sequence ids @@ -311,7 +327,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end); + public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, + int n_embd, int il_start, int il_end); /// /// Build a split GGUF final path for this chunk. @@ -324,7 +341,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// Returns the split_path length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count); + public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, + int split_count); /// /// Extract the path prefix from the split_path if and only if the split_no and split_count match. @@ -337,7 +355,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// /// Returns the split_prefix length. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count); + public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, + int split_count); //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -380,5 +399,41 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) /// Name of the buffer type [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)] public static extern IntPtr ggml_backend_buft_name(IntPtr buft); + + /// + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size, + int seq_id, uint flags); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id, + uint flags); } -} +} \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index e26619b26..f48e818b7 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -341,6 +341,47 @@ static SafeLLamaContextHandle() [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_set_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter, float scale); + /// + /// Get metadata value as a string by key name + /// + /// + /// + /// + /// + /// The length of the value string (on success) -1 otherwise + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size); + + /// + /// Get the number of metadata key value pairs + /// + /// + /// The count of meta key value pairs + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_count(IntPtr adapter); + + /// + /// Get metadata key name by index + /// + /// + /// + /// + /// + /// The length of string i.e meta key (on success) -1 otherwise + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size); + + /// + /// Get metadata key value by index + /// + /// + /// + /// + /// + /// The length of value string (on success) -1 otherwise + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern int llama_rm_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter); diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs index bad1a1974..a113e1694 100644 --- a/LLama/Native/SafeLLamaSamplerHandle.cs +++ b/LLama/Native/SafeLLamaSamplerHandle.cs @@ -616,7 +616,7 @@ static extern unsafe IntPtr llama_sampler_init_logit_bias( // This is a tricky method to work with! // It can't return a handle, because that would create a second handle to these resources. - // Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary. + // Instead , It returns the raw pointer, and that can be looked up in the _samplers dictionary. [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerChainHandle chain, int i); // ReSharper restore InconsistentNaming diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index d335a1209..196bb1763 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -80,7 +80,12 @@ public sealed class SafeLlamaModelHandle /// Returns true if the model is recurrent (like Mamba, RWKV, etc.) /// public bool IsRecurrent => llama_model_is_recurrent(this); - + + /// + /// Returns true if the model is diffusion based (like LLaDA , Dream etc ) + /// + public bool IsDiffusion => llama_model_is_diffusion(this); + /// /// Get a description of this model /// @@ -424,6 +429,10 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k [return: MarshalAs(UnmanagedType.U1)] private static extern bool llama_model_is_recurrent(SafeLlamaModelHandle model); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + [return: MarshalAs(UnmanagedType.U1)] + private static extern bool llama_model_is_diffusion(SafeLlamaModelHandle model); + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern unsafe LLamaVocabNative* llama_model_get_vocab(SafeLlamaModelHandle model); diff --git a/llama.cpp b/llama.cpp index 11dd5a44e..86587da03 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 11dd5a44eb180e1d69fac24d3852b5222d66fb7f +Subproject commit 86587da03bd78df8f4e7d8b111a0c1d2494d6ed0