diff --git a/LLama.Examples/ExampleRunner.cs b/LLama.Examples/ExampleRunner.cs
index c073cd4cd..23f07c6a1 100644
--- a/LLama.Examples/ExampleRunner.cs
+++ b/LLama.Examples/ExampleRunner.cs
@@ -15,7 +15,7 @@ public class ExampleRunner
{ "Chat Session: Automatic conversation", TalkToYourself.Run },
{ "Chat Session: Chinese characters", ChatChineseGB2312.Run },
{ "Executor: Interactive mode chat", InteractiveModeExecute.Run },
- { "Executor: Llava Interactive mode chat", LlavaInteractiveModeExecute.Run },
+ { "Executor: Mtmd Interactive mode chat", MtmdInteractiveModeExecute.Run },
{ "Executor: Instruct mode chat", InstructModeExecute.Run },
{ "Executor: Stateless mode chat", StatelessModeExecute.Run },
{ "Save and Load: chat session", SaveAndLoadSession.Run },
@@ -33,7 +33,7 @@ public class ExampleRunner
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
- { "Batched Executor: LLava", BatchedExecutorLLava.Run },
+ { "Batched Executor: Mtmd", BatchedExecutorMtmd.Run },
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
{ "Custom Sampling Pipeline", CustomSampler.Run },
diff --git a/LLama.Examples/Examples/BatchedExecutorLLava.cs b/LLama.Examples/Examples/BatchedExecutorLLava.cs
deleted file mode 100644
index a131e994e..000000000
--- a/LLama.Examples/Examples/BatchedExecutorLLava.cs
+++ /dev/null
@@ -1,91 +0,0 @@
-using System.Text;
-using LLama.Batched;
-using LLama.Common;
-using LLama.Native;
-using LLama.Sampling;
-using Spectre.Console;
-
-namespace LLama.Examples.Examples;
-
-///
-/// Demonstrates using LLava (image embeddings) with the batched executor.
-///
-public class BatchedExecutorLLava
-{
- ///
- /// How many tokens of response to generate
- ///
- public const int TokenCount = 64;
-
- public static async Task Run()
- {
- // Load model weights
- var parameters = new ModelParams(UserSettings.GetModelPath());
- using var model = await LLamaWeights.LoadFromFileAsync(parameters);
- using var llava = await LLavaWeights.LoadFromFileAsync(UserSettings.GetMMProjPath());
-
- // Decide on the prompt
- var prompt = model.Tokenize(AnsiConsole.Ask("Prompt (or ENTER for default):", "\nUSER: Provide a full description of the image.\nASSISTANT: "), true, false, Encoding.UTF8);
-
- // Get image and show it
- var image = UserSettings.GetImagePath();
- AnsiConsole.Write(new CanvasImage(image));
-
- // Create an executor with one conversation
- using var executor = new BatchedExecutor(model, parameters);
- using var conversation = executor.Create();
-
- // Embed the image
- SafeLlavaImageEmbedHandle embedding = null!;
- await AnsiConsole
- .Status()
- .StartAsync("[yellow]Embedding image with CLIP[/]", async _ =>
- {
- // ReSharper disable once AccessToDisposedClosure
- embedding = llava.CreateImageEmbeddings(await File.ReadAllBytesAsync(image));
- });
-
- // Pass in the image and run inference until the entire image has been processed
- await AnsiConsole
- .Status()
- .StartAsync("[yellow]Processing image embedding with language model[/]", async _ =>
- {
- conversation.Prompt(embedding);
- while (executor.BatchedTokenCount > 0)
- await executor.Infer();
- });
-
- // Prompt with the text prompt
- conversation.Prompt(prompt);
-
- // Run inference loop
- var decoder = new StreamingTokenDecoder(executor.Context);
- var sampler = new DefaultSamplingPipeline();
- await AnsiConsole
- .Progress()
- .StartAsync(async ctx =>
- {
- var task = ctx.AddTask("Generating Response");
- task.MaxValue = TokenCount;
-
- // Run a normal inference loop
- for (var i = 0; i < TokenCount; i++)
- {
- task.Increment(1);
-
- await executor.Infer();
-
- var token = sampler.Sample(executor.Context.NativeHandle, conversation.GetSampleIndex());
- if (token.IsEndOfGeneration(executor.Context.Vocab))
- break;
-
- decoder.Add(token);
- conversation.Prompt(token);
- }
- });
-
- // Print final result
- var str = decoder.Read();
- AnsiConsole.MarkupInterpolated($"[green]{str}[/]");
- }
-}
\ No newline at end of file
diff --git a/LLama.Examples/Examples/BatchedExecutorMtmd.cs b/LLama.Examples/Examples/BatchedExecutorMtmd.cs
new file mode 100644
index 000000000..b62f8b120
--- /dev/null
+++ b/LLama.Examples/Examples/BatchedExecutorMtmd.cs
@@ -0,0 +1,126 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using LLama.Batched;
+using LLama.Common;
+using LLama.Exceptions;
+using LLama.Native;
+using LLama.Sampling;
+using Spectre.Console;
+
+namespace LLama.Examples.Examples;
+
+///
+/// Demonstrates how to evaluate an image with MTMD helpers and continue generation by
+/// manually scheduling batches, similar to what the batched executor does internally.
+///
+public class BatchedExecutorMtmd
+{
+ ///
+ /// Number of completion tokens to generate after sending the image prompt.
+ ///
+ public const int TokenCount = 10000;
+
+ public static async Task Run()
+ {
+ // Load the base LLM and its clip/mtmd sidecar weights so the executor has everything it needs.
+ var parameters = new ModelParams(UserSettings.GetModelPath());
+ using var model = await LLamaWeights.LoadFromFileAsync(parameters);
+ var mtmdParams = MtmdContextParams.Default(); // reuse llama.cpp defaults for helper settings
+ mtmdParams.UseGpu = false;
+ var marker = mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? "";
+
+ using var mtmd = await SafeMtmdWeights.LoadFromFileAsync(UserSettings.GetMMProjPath(), model, mtmdParams); // multimodal helper weights
+
+ using var executor = new BatchedExecutor(model, parameters, mtmd); // drives batched token + chunk evaluation
+
+ // Prepend the media marker so the helper knows where to inject the encoded image tokens.
+ var defaultPrompt = "\nUSER: Provide a full description of the image.\nASSISTANT: ";
+ var promptSuffix = AnsiConsole.Ask("Prompt (or ENTER for default):", defaultPrompt);
+ var promptText = string.Concat(marker, promptSuffix);
+
+ var imagePath = UserSettings.GetImagePath();
+ AnsiConsole.Write(new CanvasImage(imagePath));
+
+ var vocab = executor.Context.NativeHandle.ModelHandle.Vocab;
+
+ // Simple low-temperature sampler keeps the demo deterministic-ish.
+ var sampler = new DefaultSamplingPipeline
+ {
+ Temperature = 0.1f
+ };
+
+ // Stream decoded text to the console as soon as tokens arrive.
+ var decoder = new StreamingTokenDecoder(executor.Context)
+ {
+ DecodeSpecialTokens = false
+ };
+
+ try
+ {
+ // Each conversation tracks its own KV cache sequence IDs.
+ var conversation = executor.Create();
+ // enqueue the image so MtmdHelper sees it
+ conversation.QueueMedia(imagePath);
+ // schedule multimodal prompt
+ conversation.Prompt(promptText, addBos: true, special: true);
+
+ Console.ForegroundColor = ConsoleColor.Yellow;
+ Console.WriteLine("Prompt queued with multimodal chunks. Generating response...\n");
+ Console.ResetColor();
+
+ var remaining = TokenCount;
+
+ // Run one decode/sampling/prompt cycle – mirrors the batched executor inner loop.
+ async Task ProcessNextAsync()
+ {
+ var decodeResult = await executor.Infer();
+ if (decodeResult == DecodeResult.NoKvSlot) // KV cache exhausted – surface to the user
+ {
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.WriteLine("Insufficient KV cache space for multimodal evaluation.");
+ Console.ResetColor();
+ return false;
+ }
+
+ if (decodeResult != DecodeResult.Ok)
+ throw new RuntimeError($"Failed to evaluate batch: {decodeResult}.");
+
+ if (!conversation.RequiresSampling) // another conversation may still be queued
+ return true;
+
+ var token = conversation.Sample(sampler); // pull logits (or -1 for mtmd chunk) and sample
+ if (token.IsEndOfGeneration(vocab))
+ return false;
+
+ decoder.Add(token);
+ var delta = decoder.Read();
+ if (!string.IsNullOrEmpty(delta))
+ Console.Write(delta);
+
+ sampler.Accept(token); // keep sampler state in sync
+ conversation.Prompt(token); // feed the accepted token back into the batch
+ remaining--;
+ return remaining > 0;
+ }
+
+ while (remaining > 0 && await ProcessNextAsync()) // continue until EOS or budget is reached
+ {
+ }
+
+ Console.WriteLine();
+ }
+ catch (IOException ex)
+ {
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.WriteLine($"Could not load media '{imagePath}': {ex.Message}");
+ Console.ResetColor();
+ }
+ catch (RuntimeError ex)
+ {
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.WriteLine($"MTMD processing failed: {ex.Message}");
+ Console.ResetColor();
+ }
+ }
+}
diff --git a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs
similarity index 59%
rename from LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
rename to LLama.Examples/Examples/MtmdInteractiveModeExecute.cs
index 8cbf58dcd..ca0de3b77 100644
--- a/LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
+++ b/LLama.Examples/Examples/MtmdInteractiveModeExecute.cs
@@ -1,3 +1,5 @@
+using System.Collections.Generic;
+using System.IO;
using System.Text.RegularExpressions;
using LLama.Common;
using Spectre.Console;
@@ -6,27 +8,32 @@
namespace LLama.Examples.Examples
{
- // This example shows how to chat with LLaVA model with both image and text as input.
+ // This example shows how to chat with Mtmd model with both image and text as input.
// It uses the interactive executor to inference.
- public class LlavaInteractiveModeExecute
+ public class MtmdInteractiveModeExecute
{
public static async Task Run()
{
string multiModalProj = UserSettings.GetMMProjPath();
string modelPath = UserSettings.GetModelPath();
string modelImage = UserSettings.GetImagePath();
- const int maxTokens = 1024;
+ const int maxTokens = 2048;
var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
var parameters = new ModelParams(modelPath);
+ var mtmdParameters = MtmdContextParams.Default();
+ mtmdParameters.UseGpu = false;
+
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
-
- // Llava Init
- using var clipModel = await LLavaWeights.LoadFromFileAsync(multiModalProj);
-
+
+ // Mtmd Init
+ using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(multiModalProj, model, mtmdParameters );
+
+ var mediaMarker = mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? "";
+
var ex = new InteractiveExecutor(context, clipModel);
Console.ForegroundColor = ConsoleColor.Yellow;
@@ -40,7 +47,7 @@ public static async Task Run()
Temperature = 0.1f
},
- AntiPrompts = new List { "\nUSER:" },
+ AntiPrompts = new List { "\nASSISTANT:" },
MaxTokens = maxTokens
};
@@ -48,30 +55,53 @@ public static async Task Run()
do
{
- // Evaluate if we have images
+ // Evaluate if we have media
//
- var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
- var imageCount = imageMatches.Count();
- var hasImages = imageCount > 0;
+ var mediaMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
+ var mediaCount = mediaMatches.Count();
+ var hasMedia = mediaCount > 0;
- if (hasImages)
+ if (hasMedia)
{
- var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
- var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();
+ var mediaPathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
+ var mediaPaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();
- List imageBytes;
+ var embeds = new List();
+ var imageList = new List();
+ var imageExtensions = new HashSet(StringComparer.OrdinalIgnoreCase)
+ {
+ ".png",
+ ".jpg",
+ ".jpeg",
+ ".bmp",
+ ".gif",
+ ".webp"
+ };
+
try
{
- imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
+ foreach (var mediaPath in mediaPaths)
+ {
+ var extension = Path.GetExtension(mediaPath);
+ if (!string.IsNullOrEmpty(extension) && imageExtensions.Contains(extension))
+ {
+ // Keep the raw image data so the caller can reuse or inspect the images later.
+ imageList.Add(File.ReadAllBytes(mediaPath));
+ }
+
+ var embed = clipModel.LoadMedia(mediaPath);
+ embeds.Add(embed);
+ }
}
catch (IOException exception)
{
Console.ForegroundColor = ConsoleColor.Red;
Console.Write(
- $"Could not load your {(imageCount == 1 ? "image" : "images")}:");
+ $"Could not load your {(mediaCount == 1 ? "media" : "medias")}:");
Console.Write($"{exception.Message}");
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("Please try again.");
+ clipModel.ClearMedia();
break;
}
@@ -81,19 +111,17 @@ public static async Task Run()
// https://github.com/ggerganov/llama.cpp/discussions/3620
ex.Context.NativeHandle.MemorySequenceRemove( LLamaSeqId.Zero, -1, -1 );
- int index = 0;
- foreach (var path in imagePathsWithCurlyBraces)
+ // Replace placeholders with media markers (one marker per image)
+ foreach (var path in mediaPathsWithCurlyBraces)
{
- // First image replace to tag " : "");
+ prompt = prompt.Replace(path, mediaMarker, StringComparison.Ordinal);
}
-
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message.");
Console.WriteLine();
- foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)) ?? Array.Empty())
+ foreach (var consoleImage in imageList.Select(image => new CanvasImage(image.ToArray())))
{
consoleImage.MaxWidth = 50;
AnsiConsole.Write(consoleImage);
@@ -108,10 +136,9 @@ public static async Task Run()
// Initialize Images in executor
//
- foreach (var image in imagePaths)
- {
- ex.Images.Add(await File.ReadAllBytesAsync(image));
- }
+ ex.Embeds.Clear();
+ foreach (var embed in embeds)
+ ex.Embeds.Add(embed);
}
Console.ForegroundColor = Color.White;
diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj
index 330e77386..6ca8c7210 100644
--- a/LLama.Examples/LLama.Examples.csproj
+++ b/LLama.Examples/LLama.Examples.csproj
@@ -9,7 +9,7 @@
true
true
- 12
+ 13
1701;1702;8604;SKEXP0001;SKEXP0050;SKEXP0052;SKEXP0003
diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs
index d501b189b..f705f1609 100644
--- a/LLama.Unittest/Constants.cs
+++ b/LLama.Unittest/Constants.cs
@@ -9,9 +9,9 @@ internal static class Constants
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf";
- public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
- public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";
- public static readonly string LLavaImage = "Models/extreme-ironing-taxi-610x427.jpg";
+ public static readonly string MtmdModelPath = "Models/gemma-3-4b-it-Q4_K_M.gguf";
+ public static readonly string MtmdMmpPath = "Models/gemma-mmproj-model-f16.gguf";
+ public static readonly string MtmdImage = "Models/extreme-ironing-taxi-610x427.jpg";
///
/// Calculate GpuLayer Count to use in UnitTest
diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj
index 8f9f075d8..ca3ea8854 100644
--- a/LLama.Unittest/LLama.Unittest.csproj
+++ b/LLama.Unittest/LLama.Unittest.csproj
@@ -52,16 +52,16 @@
jina-reranker-v1-tiny-en-FP16.gguf
-
- https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf
+
+ https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_M.gguf
Models
- llava-v1.6-mistral-7b.Q3_K_XS.gguf
+ gemma-3-4b-it-Q4_K_M.gguf
-
- https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf
+
+ https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF/resolve/main/mmproj-model-f16.gguf
Models
- mmproj-model-f16.gguf
+ gemma-mmproj-model-f16.gguf
@@ -142,10 +142,10 @@
PreserveNewest
-
+
PreserveNewest
-
+
PreserveNewest
diff --git a/LLama.Unittest/MtmdExecutorTests.cs b/LLama.Unittest/MtmdExecutorTests.cs
new file mode 100644
index 000000000..75a96b261
--- /dev/null
+++ b/LLama.Unittest/MtmdExecutorTests.cs
@@ -0,0 +1,81 @@
+using System;
+using System.Collections.Generic;
+using System.Threading.Tasks;
+using LLama.Common;
+using LLama.Native;
+using Microsoft.Extensions.Logging.Abstractions;
+using Xunit;
+
+namespace LLama.Unittest;
+
+[Trait("Category", "NoCI")]
+public class MtmdExecutorTests : IDisposable
+{
+ private readonly LLamaWeights _weights;
+ private readonly MtmdContextParams _mtmdParams;
+ private readonly SafeMtmdWeights _mtmd;
+ private readonly ModelParams _modelParams;
+
+ public MtmdExecutorTests()
+ {
+ _modelParams = new ModelParams(Constants.MtmdModelPath)
+ {
+ ContextSize = 1024 * 8,
+ GpuLayerCount = Constants.CIGpuLayerCount,
+ };
+
+ _weights = LLamaWeights.LoadFromFile(_modelParams);
+
+ _mtmdParams = MtmdContextParams.Default();
+ _mtmdParams.NThreads = Math.Max(1, Constants.CIGpuLayerCount);
+ _mtmdParams.UseGpu = false;
+
+ _mtmd = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _weights, _mtmdParams);
+ }
+
+ public void Dispose()
+ {
+ _mtmd.Dispose();
+ _weights.Dispose();
+ }
+
+ [Fact]
+ public async Task InteractiveExecutor_EvaluateChunks_DoesNotRetokenize()
+ {
+ using var context = _weights.CreateContext(_modelParams, NullLogger.Instance);
+ var executor = new InteractiveExecutor(context, _mtmd, NullLogger.Instance);
+ var marker = _mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? "";
+ var prompt = $"{marker}\nDescribe the image succinctly.";
+
+ executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage));
+
+ await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 }))
+ {
+ Assert.True(false, "Prefill should not emit generated text");
+ }
+
+ var diagnostics = executor.GetDiagnostics();
+ Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount);
+ Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount);
+ Assert.Equal(0, diagnostics.PendingEmbedCount);
+ }
+
+ [Fact]
+ public async Task InstructExecutor_MtmdPromptAdvancesPastTokensOnce()
+ {
+ using var context = _weights.CreateContext(_modelParams, NullLogger.Instance);
+ var executor = new InstructExecutor(context, _mtmd, logger: NullLogger.Instance);
+ executor.Embeds.Add(_mtmd.LoadMedia(Constants.MtmdImage));
+
+ var prompt = $"{_mtmdParams.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""} Provide details.";
+
+ await foreach (var _ in executor.InferAsync(prompt, new InferenceParams { MaxTokens = 0 }))
+ {
+ }
+
+ var diagnostics = executor.GetDiagnostics();
+ Assert.Equal(diagnostics.EmbedCount, diagnostics.ConsumedCount);
+ Assert.Equal(diagnostics.ConsumedCount, diagnostics.PastCount);
+ Assert.Equal(0, diagnostics.PendingEmbedCount);
+ }
+}
diff --git a/LLama.Unittest/MtmdWeightsTests.cs b/LLama.Unittest/MtmdWeightsTests.cs
new file mode 100644
index 000000000..947bbd1ea
--- /dev/null
+++ b/LLama.Unittest/MtmdWeightsTests.cs
@@ -0,0 +1,140 @@
+using System;
+using System.IO;
+using LLama.Common;
+using LLama.Native;
+using Xunit;
+
+namespace LLama.Unittest
+{
+ // Test the same things as llama model + image embedings
+ //
+ public sealed class MtmdWeightTests
+ : IDisposable
+ {
+ private readonly LLamaWeights _llamaWeights;
+ private readonly SafeMtmdWeights _safeMtmdWeights;
+ private readonly LLamaContext _context;
+ private readonly MtmdContextParams _mtmdParams;
+ private readonly string _mediaMarker;
+
+ public MtmdWeightTests()
+ {
+ var @params = new ModelParams(Constants.MtmdModelPath)
+ {
+ // Mtmd models requires big context
+ ContextSize = 1024 * 32,
+ GpuLayerCount = Constants.CIGpuLayerCount,
+ };
+ _llamaWeights = LLamaWeights.LoadFromFile(@params);
+
+ _mtmdParams = MtmdContextParams.Default();
+ _mtmdParams.NThreads = Constants.CIGpuLayerCount;
+ _mtmdParams.UseGpu = false; // keep tests portable across environments without GPU
+
+ _mediaMarker = _mtmdParams.MediaMarker ?? throw new InvalidOperationException("MTMD media marker unavailable.");
+
+ _safeMtmdWeights = SafeMtmdWeights.LoadFromFile(Constants.MtmdMmpPath, _llamaWeights, _mtmdParams);
+ _context = _llamaWeights.CreateContext(@params);
+ }
+
+ public void Dispose()
+ {
+ _context.Dispose();
+ _safeMtmdWeights.Dispose();
+ _llamaWeights.Dispose();
+ }
+
+ private SafeMtmdInputChunks TokenizeWithEmbed(Func loadEmbed)
+ {
+ _safeMtmdWeights.ClearMedia();
+
+ var embed = loadEmbed();
+ Assert.NotNull(embed);
+
+ using (embed)
+ {
+ Assert.True(embed.Nx > 0);
+ Assert.True(embed.Ny > 0);
+ Assert.False(embed.IsAudio);
+ Assert.True(embed.GetDataSpan().Length > 0);
+
+ var status = _safeMtmdWeights.Tokenize(_mediaMarker, addSpecial: true, parseSpecial: true, out var chunks);
+ Assert.Equal(0, status);
+ Assert.NotNull(chunks);
+
+ return chunks!;
+ }
+ }
+
+ private void AssertChunksEvaluate(SafeMtmdInputChunks chunks)
+ {
+ long nPast = 0;
+ var eval = _safeMtmdWeights.EvaluateChunks(chunks, _context.NativeHandle, ref nPast, seqId: 0, nBatch: checked((int)_context.BatchSize), logitsLast: true);
+ Assert.Equal(0, eval);
+ Assert.True(nPast > 0);
+ }
+
+ [Fact,Trait("Category", "NoCI")]
+ public void EmbedImageAsFileName()
+ {
+ using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage));
+ AssertChunksEvaluate(chunks);
+ }
+
+ [Fact,Trait("Category", "NoCI")]
+ public void EmbedImageAsBinary()
+ {
+ var imageBytes = File.ReadAllBytes(Constants.MtmdImage);
+ using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(imageBytes));
+ AssertChunksEvaluate(chunks);
+ }
+
+ [Fact,Trait("Category", "NoCI")]
+ public void TokenizeProvidesChunkMetadata()
+ {
+ using var chunks = TokenizeWithEmbed(() => _safeMtmdWeights.LoadMedia(Constants.MtmdImage));
+
+ Assert.True(chunks.Size > 0);
+
+ ulong totalTokens = 0;
+ long totalPositions = 0;
+ var imageChunks = 0;
+
+ foreach (var chunk in chunks.Enumerate())
+ {
+ totalTokens += chunk.NTokens;
+ totalPositions += chunk.NPos;
+
+ if (chunk.Type == SafeMtmdInputChunk.SafeMtmdInputChunkType.Image)
+ {
+ imageChunks++;
+
+ var copy = chunk.Copy();
+ try
+ {
+ Assert.NotNull(copy);
+ if (copy != null)
+ {
+ Assert.Equal(chunk.NTokens, copy.NTokens);
+ Assert.Equal(chunk.NPos, copy.NPos);
+ }
+ }
+ finally
+ {
+ copy?.Dispose();
+ }
+ }
+ }
+
+ Assert.True(imageChunks > 0);
+ Assert.True(totalTokens > 0);
+ Assert.Equal(totalTokens, _safeMtmdWeights.CountTokens(chunks));
+ Assert.Equal(totalPositions, _safeMtmdWeights.CountPositions(chunks));
+ Assert.True(_safeMtmdWeights.SupportsVision);
+ Assert.False(_safeMtmdWeights.SupportsAudio);
+
+ var audioBitrate = _safeMtmdWeights.AudioBitrate;
+ Assert.True(audioBitrate <= 0);
+ }
+ }
+}
diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
deleted file mode 100644
index f3e5798f2..000000000
--- a/LLama.Unittest/Native/SafeLlamaModelHandleTests.cs
+++ /dev/null
@@ -1,32 +0,0 @@
-using System.Runtime.InteropServices;
-using System.Text;
-using LLama.Common;
-using LLama.Extensions;
-using Xunit;
-
-namespace LLama.Unittest.Native;
-
-public class SafeLlamaModelHandleTests
-{
- private readonly LLamaWeights _model;
-
- public SafeLlamaModelHandleTests()
- {
- var @params = new ModelParams(Constants.GenerativeModelPath2)
- {
- ContextSize = 1,
- GpuLayerCount = Constants.CIGpuLayerCount
- };
- _model = LLamaWeights.LoadFromFile(@params);
- }
-
- // Note: This test is flakey, it appears to often (but not always) fail the first time it is run after downloading the model file, but then succeed every time after!
- //[SkippableFact]
- //public void MetadataValByKey_ReturnsCorrectly()
- //{
- // Skip.If(RuntimeInformation.IsOSPlatform(OSPlatform.OSX), "Skipping this test on macOS because for some reason the meta data is incorrect, but the rest of tests work well on mscOS [Check later!].");
- // const string key = "general.name";
- // var template = _model.NativeHandle.MetadataValueByKey(key);
- // var name = Encoding.UTF8.GetStringFromSpan(template!.Value.Span);
- //}
-}
diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs
deleted file mode 100644
index 1ce53f395..000000000
--- a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs
+++ /dev/null
@@ -1,42 +0,0 @@
-using System.Text;
-using System.Xml.Linq;
-using LLama.Common;
-using LLama.Extensions;
-using Microsoft.Extensions.Logging;
-
-
-namespace LLama.Unittest.Native;
-
-public class SafeLlamaModelHandleVocabularyTests: IDisposable
-{
- private readonly LLamaWeights _model;
-
- public SafeLlamaModelHandleVocabularyTests()
- {
- var @params = new ModelParams(Constants.RerankingModelPath)
- {
- ContextSize = 0,
- PoolingType = LLama.Native.LLamaPoolingType.Rank,
- GpuLayerCount = Constants.CIGpuLayerCount
- };
- _model = LLamaWeights.LoadFromFile(@params);
- }
-
- public void Dispose()
- {
- _model.Dispose();
- }
-
- [Fact]
- public void GetLLamaTokenString()
- {
- var bos = _model.Vocab.BOS;
- var eos = _model.Vocab.EOS;
-
- var bosStr = _model.Vocab.LLamaTokenToString(bos, true);
- var eosStr = _model.Vocab.LLamaTokenToString(eos, true);
-
- Assert.Equal("", bosStr);
- Assert.Equal("", eosStr);
- }
-}
diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs
index 9a2233287..92276e4a6 100644
--- a/LLama/Abstractions/ILLamaExecutor.cs
+++ b/LLama/Abstractions/ILLamaExecutor.cs
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using System.Threading;
+using LLama.Native;
namespace LLama.Abstractions
{
@@ -22,12 +23,12 @@ public interface ILLamaExecutor
///
/// Multi-Modal Projections / Clip Model weights
///
- public LLavaWeights? ClipModel { get; }
+ public SafeMtmdWeights? ClipModel { get; }
///
- /// List of images: List of images in byte array format.
+ /// List of media: List of media for Multi-Modal models.
///
- public List Images { get; }
+ public List Embeds { get; }
///
/// Asynchronously infers a response from the model.
diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs
index cdb1835e4..1d47bb0b8 100644
--- a/LLama/Batched/BatchedExecutor.cs
+++ b/LLama/Batched/BatchedExecutor.cs
@@ -17,6 +17,7 @@ public sealed class BatchedExecutor
{
private int _nextSequenceId;
private readonly List _batchQueue = [ ];
+ private string? _mtmdMarker;
///
/// Set to 1 using interlocked exchange while inference is running
@@ -60,12 +61,20 @@ public sealed class BatchedExecutor
/// The model to use
/// Parameters to create a new context
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
+ : this(model, contextParams, null)
+ {
+ }
+
+ public BatchedExecutor(LLamaWeights model, IContextParams contextParams, SafeMtmdWeights? clipModel)
{
Model = model;
Context = model.CreateContext(contextParams);
+ ClipModel = clipModel;
Epoch = 1;
}
+ public SafeMtmdWeights? ClipModel { get; }
+
///
/// Start a new
///
@@ -254,6 +263,23 @@ internal LLamaSeqId GetNextSequenceId()
return (end, Epoch + (uint)_batchQueue.Count * 2);
}
+ internal ulong QueueMtmdBatch(Conversation conversation, Conversation.MtmdChunkSequence sequence)
+ {
+ if (ClipModel is null)
+ throw new InvalidOperationException("This batched executor is not configured for multimodal inference.");
+
+ var batch = new MtmdChunkBatch(ClipModel, conversation, sequence);
+ _batchQueue.Add(batch);
+ return Epoch + (uint)_batchQueue.Count * 2;
+ }
+
+ internal string GetMtmdMarker()
+ {
+ if (ClipModel is null)
+ throw new InvalidOperationException("This batched executor is not configured for multimodal inference.");
+ return _mtmdMarker ??= NativeApi.MtmdDefaultMarker() ?? "";
+ }
+
#region batches
private interface IBatch
{
@@ -285,5 +311,44 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token)
return ctx.DecodeAsync(Batch, token);
}
}
+
+ private class MtmdChunkBatch : IBatch
+ {
+ private readonly SafeMtmdWeights _clipModel;
+ private readonly Conversation _conversation;
+ private readonly Conversation.MtmdChunkSequence _sequence;
+
+ public MtmdChunkBatch(SafeMtmdWeights clipModel, Conversation conversation, Conversation.MtmdChunkSequence sequence)
+ {
+ _clipModel = clipModel;
+ _conversation = conversation;
+ _sequence = sequence;
+ }
+
+ public int ItemCount => Math.Max(1, _sequence.TotalTokens);
+
+ public Task DecodeAsync(LLamaContext ctx, CancellationToken token)
+ {
+ try
+ {
+ var nPast = _conversation.GetMtmdPast();
+ var status = _clipModel.EvaluateChunks(_sequence.Chunks, ctx.NativeHandle, ref nPast,
+ (int)_conversation.ConversationId, checked((int)ctx.BatchSize), logitsLast: true);
+ if (status != 0)
+ {
+ _conversation.OnMtmdEvaluationFailed(status);
+ return Task.FromResult(DecodeResult.DecodeFailed);
+ }
+
+ _conversation.OnMtmdEvaluationCompleted(nPast, _sequence);
+ return Task.FromResult(DecodeResult.Ok);
+ }
+ catch
+ {
+ _conversation.OnMtmdEvaluationFailed(-1);
+ return Task.FromResult(DecodeResult.DecodeFailed);
+ }
+ }
+ }
#endregion
-}
\ No newline at end of file
+}
diff --git a/LLama/Batched/Conversation.cs b/LLama/Batched/Conversation.cs
index fcc94ae8f..807542b79 100644
--- a/LLama/Batched/Conversation.cs
+++ b/LLama/Batched/Conversation.cs
@@ -3,6 +3,7 @@
using System.Linq;
using System.Text.Json;
using CommunityToolkit.HighPerformance.Buffers;
+using LLama.Exceptions;
using LLama.Native;
namespace LLama.Batched;
@@ -21,6 +22,12 @@ public sealed class Conversation
/// Indicates if this conversation has been "forked" and may share logits with another conversation.
///
private bool _forked;
+ private readonly List _mtmdEmbeds = new();
+ private int? _mtmdLogitsIndex;
+ private MtmdChunkSequence? _pendingMtmdSequence;
+ private readonly List _embed_inps = new();
+ private readonly List _session_tokens = new();
+ private int _consumedTokensCount;
///
/// Stores the indices to sample from. Contains valid items.
@@ -65,6 +72,46 @@ internal Conversation(BatchedExecutor batch, LLamaSeqId id)
Executor = batch;
}
+ internal sealed class MtmdChunkSequence : IDisposable
+ {
+ public SafeMtmdInputChunks Chunks { get; }
+ public List TextTokens { get; }
+ public int TotalPositions { get; }
+ public int TotalTokens => TextTokens.Count;
+
+ private MtmdChunkSequence(SafeMtmdInputChunks chunks, List textTokens, int totalPositions)
+ {
+ Chunks = chunks;
+ TextTokens = textTokens;
+ TotalPositions = totalPositions;
+ }
+
+ public static MtmdChunkSequence Create(SafeMtmdInputChunks chunks, SafeMtmdWeights clipModel)
+ {
+ var textTokens = new List();
+
+ foreach (var chunk in chunks.Enumerate())
+ {
+ using (chunk)
+ {
+ if (chunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
+ continue;
+
+ foreach (var token in chunk.GetTextTokensSpan())
+ textTokens.Add((LLamaToken)unchecked((int)token));
+ }
+ }
+
+ var totalPositions = (int)clipModel.CountPositions(chunks);
+ return new MtmdChunkSequence(chunks, textTokens, totalPositions);
+ }
+
+ public void Dispose()
+ {
+ Chunks.Dispose();
+ }
+ }
+
///
/// Finalizer for Conversation
///
@@ -83,6 +130,11 @@ public void Dispose()
return;
_disposed = true;
+ _pendingMtmdSequence?.Dispose();
+ _pendingMtmdSequence = null;
+
+ DisposeQueuedMedia();
+
// Remove this conversation from the KV cache
Executor.Context.NativeHandle.MemorySequenceRemove(ConversationId, -1, -1);
@@ -206,6 +258,43 @@ private void AssertCanBePrompted()
if (RequiresInference)
throw new AlreadyPromptedConversationException();
+
+ _mtmdLogitsIndex = null;
+ }
+
+ public void QueueMedia(string path)
+ {
+ AssertCanBePrompted();
+
+ if (Executor.ClipModel is null)
+ throw new InvalidOperationException("This conversation is not configured for multimodal prompts.");
+
+ var embed = Executor.ClipModel.LoadMedia(path);
+ _mtmdEmbeds.Add(embed);
+ _mtmdLogitsIndex = null;
+ }
+
+ public void QueueMedia(SafeMtmdEmbed embed)
+ {
+ AssertCanBePrompted();
+
+ if (Executor.ClipModel is null)
+ throw new InvalidOperationException("This conversation is not configured for multimodal prompts.");
+
+ _mtmdEmbeds.Add(embed);
+ _mtmdLogitsIndex = null;
+ }
+
+ public void Prompt(string promptText, bool addBos = true, bool special = true)
+ {
+ if (Executor.ClipModel != null && _mtmdEmbeds.Count > 0)
+ {
+ PromptMultimodal(promptText, addBos);
+ return;
+ }
+
+ var tokens = Executor.Context.Tokenize(promptText, addBos, special);
+ Prompt(tokens);
}
///
@@ -246,6 +335,7 @@ public void Prompt(List tokens, bool allLogits = false)
public void Prompt(ReadOnlySpan tokens, bool allLogits = false)
{
AssertCanBePrompted();
+ _mtmdLogitsIndex = null;
// No point doing anything if there is no actual prompt!
if (tokens.Length == 0)
@@ -289,6 +379,59 @@ public void Prompt(ReadOnlySpan tokens, bool allLogits = false)
// Unset the forked flag. Since this conversation has just been prompted it's no longer
// sharing anything with any other conversations.
_forked = false;
+ _mtmdLogitsIndex = null;
+ }
+
+ private void PromptMultimodal(string text, bool addBos)
+ {
+ AssertCanBePrompted();
+
+ if (Executor.ClipModel is null)
+ throw new InvalidOperationException("This conversation is not configured for multimodal prompts.");
+ if (_mtmdEmbeds.Count == 0)
+ throw new InvalidOperationException("Queue media before prompting with multimodal input.");
+
+ var marker = Executor.GetMtmdMarker();
+ var prompt = text;
+
+ if (prompt.Contains(""))
+ prompt = prompt.Replace("", marker);
+
+ if (!prompt.Contains(marker))
+ {
+ var suffix = string.Concat(Enumerable.Repeat(marker, _mtmdEmbeds.Count));
+ prompt = string.Concat(prompt, suffix);
+ }
+
+ SafeMtmdInputChunks? chunks = null;
+ try
+ {
+ _mtmdLogitsIndex = null;
+ var status = Executor.ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
+ if (status != 0 || chunks is null)
+ {
+ Executor.ClipModel.ClearMedia();
+ throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
+ }
+
+ var sequence = MtmdChunkSequence.Create(chunks, Executor.ClipModel);
+ _pendingMtmdSequence = sequence;
+
+ var epoch = Executor.QueueMtmdBatch(this, sequence);
+ chunks = null;
+
+ if (_batchSampleIndices.Length == 0)
+ _batchSampleIndices = new int[4];
+
+ _batchSampleCount = 0;
+ _requiredEpoch = epoch;
+ _forked = false;
+ }
+ finally
+ {
+ DisposeQueuedMedia();
+ chunks?.Dispose();
+ }
}
///
@@ -305,32 +448,7 @@ public void Prompt(LLamaToken token)
Span span = [ token ];
Prompt(span);
}
-
- ///
- /// Prompt this conversation with an image embedding
- ///
- ///
- public void Prompt(SafeLlavaImageEmbedHandle embedding)
- {
- AssertCanBePrompted();
-
- if (embedding.Model.EmbeddingDimensions != Executor.Model.EmbeddingSize)
- throw new ArgumentException($"Embedding dimension mismatch between image embedding ({embedding.Model.EmbeddingDimensions}) and model ({Executor.Model.EmbeddingSize})");
-
- for (var i = 0; i < embedding.Model.PatchCount; i++)
- {
- // Get a batch with space
- (var batch, _requiredEpoch) = Executor.GetEmbeddingBatch();
-
- batch.Add(
- (i, embedding),
- static (Span dest, (int index, SafeLlavaImageEmbedHandle embedding) tup) => tup.embedding.GetEmbedding(dest, tup.index),
- _end++,
- ConversationId,
- i == embedding.Model.PatchCount - 1
- );
- }
- }
+
///
/// Prompt this conversation with embeddings
@@ -339,6 +457,7 @@ public void Prompt(SafeLlavaImageEmbedHandle embedding)
public void Prompt(ReadOnlySpan embeddings)
{
AssertCanBePrompted();
+ _mtmdLogitsIndex = null;
var dim = Executor.Model.EmbeddingSize;
var count = embeddings.Length / dim;
@@ -385,6 +504,75 @@ public void Modify(ModifyKvCache modifier)
_requiredEpoch = 0;
}
+ internal long GetMtmdPast() => _end.Value;
+
+ internal void OnMtmdEvaluationCompleted(long newPast, MtmdChunkSequence sequence)
+ {
+ _pendingMtmdSequence?.Dispose();
+ _pendingMtmdSequence = null;
+
+ _end = (LLamaPos)checked((int)newPast);
+
+ if (_batchSampleIndices.Length == 0)
+ _batchSampleIndices = new int[4];
+
+ _batchSampleCount = 1;
+ _batchSampleIndices[0] = 0;
+ _mtmdLogitsIndex = -1;
+ _requiredEpoch = Executor.Epoch + 1;
+ _forked = false;
+
+ if (sequence.TextTokens.Count > 0)
+ {
+ _embed_inps.AddRange(sequence.TextTokens);
+ _session_tokens.AddRange(sequence.TextTokens);
+ }
+
+ var fillerToken = GetFillerToken(Executor.GetMtmdMarker());
+ var fillerCount = Math.Max(0, sequence.TotalPositions - sequence.TotalTokens);
+ for (var i = 0; i < fillerCount; i++)
+ _embed_inps.Add(fillerToken);
+
+ _consumedTokensCount = _embed_inps.Count;
+ sequence.Dispose();
+ }
+
+ internal void OnMtmdEvaluationFailed(int status)
+ {
+ _pendingMtmdSequence?.Dispose();
+ _pendingMtmdSequence = null;
+ _mtmdLogitsIndex = null;
+ _requiredEpoch = Executor.Epoch;
+ DisposeQueuedMedia();
+ }
+
+ internal int? MtmdLogitsIndex => _mtmdLogitsIndex;
+
+ private LLamaToken GetFillerToken(string marker)
+ {
+ var markerTokens = Executor.Context.Tokenize(marker, addBos: false, special: true);
+ if (markerTokens.Length > 0)
+ return markerTokens[markerTokens.Length - 1];
+
+ var eos = Executor.Context.Vocab.EOS;
+ if (eos.HasValue)
+ return eos.Value;
+
+ return default;
+ }
+
+ private void DisposeQueuedMedia()
+ {
+ if (_mtmdEmbeds.Count == 0)
+ return;
+
+ foreach (var embed in _mtmdEmbeds)
+ embed.Dispose();
+
+ _mtmdEmbeds.Clear();
+ Executor.ClipModel?.ClearMedia();
+ }
+
///
/// Provides direct access to the KV cache of a .
/// See for how to use this.
@@ -629,4 +817,4 @@ internal State()
}
}
#endregion
-}
\ No newline at end of file
+}
diff --git a/LLama/Batched/ConversationExtensions.cs b/LLama/Batched/ConversationExtensions.cs
index eb0192061..3e25d3f43 100644
--- a/LLama/Batched/ConversationExtensions.cs
+++ b/LLama/Batched/ConversationExtensions.cs
@@ -18,7 +18,11 @@ public static class ConversationExtensions
///
public static LLamaToken Sample(this Conversation conversation, SafeLLamaSamplerChainHandle sampler, int offset = 0)
{
- return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset));
+ var ctx = conversation.Executor.Context.NativeHandle;
+ if (conversation.MtmdLogitsIndex == -1)
+ return sampler.Sample(ctx, -1);
+
+ return sampler.Sample(ctx, conversation.GetSampleIndex(offset));
}
///
@@ -30,7 +34,11 @@ public static LLamaToken Sample(this Conversation conversation, SafeLLamaSampler
///
public static LLamaToken Sample(this Conversation conversation, ISamplingPipeline sampler, int offset = 0)
{
- return sampler.Sample(conversation.Executor.Context.NativeHandle, conversation.GetSampleIndex(offset));
+ var ctx = conversation.Executor.Context.NativeHandle;
+ if (conversation.MtmdLogitsIndex == -1)
+ return sampler.Sample(ctx, -1);
+
+ return sampler.Sample(ctx, conversation.GetSampleIndex(offset));
}
///
@@ -82,4 +90,4 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep
return end.Value - count;
});
}
-}
\ No newline at end of file
+}
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 36989006e..5678945fc 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -32,11 +32,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
///
protected int _consumedTokensCount; // n_consume
///
- ///
+ /// Number of tokens consumed from the session cache during the current run.
///
protected int _n_session_consumed;
///
- ///
+ /// Number of prompt tokens that match the loaded session cache prefix.
///
protected int _n_matching_session_tokens;
///
@@ -52,7 +52,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
///
protected List _embed_inps = new();
///
- ///
+ /// Tokens recovered from the session file and reused to warm up the KV cache.
///
protected List _session_tokens = new();
///
@@ -76,21 +76,28 @@ public bool IsMultiModal
}
///
- public LLavaWeights? ClipModel { get; }
+ public SafeMtmdWeights? ClipModel { get; }
///
- public List Images { get; }
+ public List Embeds { get; }
+
+ ///
+ /// Pending multimodal chunks produced by the MTMD tokenizer.
+ ///
+ protected SafeMtmdInputChunks? MtmdChunks { get; set; }
+
+ private string? _mtmdMarker;
private readonly StreamingTokenDecoder _decoder;
///
- ///
+ /// Initialize a stateful executor bound to a specific context.
///
- ///
- ///
+ /// LLama context used for all native interactions.
+ /// Optional logger for diagnostic output.
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
- Images = new List();
+ Embeds = new List();
_logger = logger;
Context = context;
_pastTokensCount = 0;
@@ -101,22 +108,22 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
}
///
- ///
+ /// Initialize a multimodal executor with the supplied MTMD weights.
///
- ///
- ///
- ///
- public StatefulExecutorBase(LLamaContext context, LLavaWeights lLavaWeights, ILogger? logger = null) :
+ /// LLama context used for all native interactions.
+ /// Multimodal weights to associate with this executor.
+ /// Optional logger for diagnostic output.
+ public StatefulExecutorBase(LLamaContext context, SafeMtmdWeights safeMtmdWeights, ILogger? logger = null) :
this( context, logger )
{
- ClipModel = lLavaWeights;
+ ClipModel = safeMtmdWeights;
}
///
- /// This API is currently not verified.
+ /// Attach a session cache file so the executor can reuse previous KV state if compatible.
///
- ///
- ///
+ /// Path to the llama.cpp session file.
+ /// The current executor instance for fluent configuration.
///
///
public StatefulExecutorBase WithSessionFile(string filename)
@@ -173,9 +180,9 @@ public StatefulExecutorBase WithSessionFile(string filename)
}
///
- /// This API has not been verified currently.
+ /// Persist the current session cache to disk.
///
- ///
+ /// Destination path for the llama.cpp session file.
public void SaveSessionFile(string filename)
{
var session_token_array = _session_tokens.ToArray();
@@ -203,7 +210,7 @@ protected virtual void HandleRunOutOfContext(int tokensToKeep)
}
///
- /// Try to reuse the matching prefix from the session file.
+ /// Try to reuse the matching prompt prefix from the loaded session cache before evaluating new tokens.
///
protected virtual void TryReuseMatchingPrefix()
{
@@ -236,66 +243,254 @@ protected virtual void TryReuseMatchingPrefix()
}
///
- /// Decide whether to continue the loop.
+ /// Dispose and clear any queued multimodal chunk collection.
///
- ///
- ///
+ protected void DisposeMtmdChunks()
+ {
+ MtmdChunks?.Dispose();
+ MtmdChunks = null;
+ }
+
+ ///
+ /// Dispose and clear any pending multimodal embeddings.
+ ///
+ protected void DisposeEmbeds()
+ {
+ if (Embeds.Count == 0)
+ return;
+
+ foreach (var embed in Embeds)
+ embed.Dispose();
+
+ Embeds.Clear();
+ }
+
+ ///
+ /// Retrieve the marker token used to signal media segments to the tokenizer.
+ ///
+ protected string GetMtmdMarker()
+ {
+ if (_mtmdMarker is not null)
+ return _mtmdMarker;
+
+ _mtmdMarker = NativeApi.MtmdDefaultMarker() ?? "";
+ return _mtmdMarker;
+ }
+
+ ///
+ /// Ensure the token list fills all positional slots reported by the MTMD helper.
+ ///
+ protected static List BuildTokensWithFiller(List tokens, int totalPositions, LLamaToken fillerToken)
+ {
+ if (totalPositions <= tokens.Count)
+ return new List(tokens);
+
+ var result = new List(totalPositions);
+ result.AddRange(tokens);
+ result.AddRange(Enumerable.Repeat(fillerToken, totalPositions - tokens.Count));
+ return result;
+ }
+
+ ///
+ /// Resolve the fallback token inserted when the tokenizer emits fewer tokens than positions.
+ ///
+ protected LLamaToken GetFillerToken(string marker)
+ {
+ var markerTokens = Context.Tokenize(marker, false, true);
+ if (markerTokens.Length > 0)
+ return markerTokens[markerTokens.Length - 1];
+
+ var eos = Context.Vocab.EOS;
+ if (eos.HasValue)
+ return eos.Value;
+
+ return default;
+ }
+
+ ///
+ /// Prepare multimodal inputs by invoking the MTMD tokenizer and aligning filler tokens.
+ ///
+ protected Task PreprocessMtmd(string text, InferStateArgs args, bool addBos, bool replaceExisting)
+ {
+ if (ClipModel is null)
+ throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
+
+ DisposeMtmdChunks();
+
+ var marker = GetMtmdMarker();
+ var prompt = text;
+
+ if (Embeds.Count > 0)
+ {
+ if (prompt.Contains(""))
+ prompt = prompt.Replace("", marker);
+
+ if (!prompt.Contains(marker))
+ {
+ var suffix = string.Concat(Enumerable.Repeat(marker, Embeds.Count));
+ prompt = string.Concat(prompt, suffix);
+ }
+ }
+
+ SafeMtmdInputChunks? chunks = null;
+ try
+ {
+ var status = ClipModel.Tokenize(prompt, addBos, parseSpecial: true, out chunks);
+ if (status != 0 || chunks is null)
+ {
+ ClipModel.ClearMedia();
+ throw new RuntimeError($"Failed to tokenize multimodal prompt. Status: {status}.");
+ }
+
+ MtmdChunks = chunks;
+
+ var tokens = new List();
+ foreach (var chunk in chunks.Enumerate())
+ {
+ using var scopedChunk = chunk;
+ if (scopedChunk.Type != SafeMtmdInputChunk.SafeMtmdInputChunkType.Text)
+ continue;
+
+ foreach (var token in scopedChunk.GetTextTokensSpan())
+ tokens.Add(unchecked((int)token));
+ }
+
+ var totalPositions = (int)ClipModel.CountPositions(chunks);
+ var fillerToken = GetFillerToken(marker);
+
+ if (replaceExisting)
+ {
+ _embed_inps = BuildTokensWithFiller(tokens, totalPositions, fillerToken);
+ _consumedTokensCount = 0;
+ }
+ else
+ {
+ if (_embed_inps.Count == 0)
+ _embed_inps = new List();
+
+ _embed_inps.AddRange(tokens);
+ var fillerCount = totalPositions - tokens.Count;
+ if (fillerCount > 0)
+ _embed_inps.AddRange(Enumerable.Repeat(fillerToken, fillerCount));
+
+ args.RemainedTokens -= tokens.Count;
+ }
+ }
+ catch
+ {
+ chunks?.Dispose();
+ MtmdChunks = null;
+ throw;
+ }
+ finally
+ {
+ DisposeEmbeds();
+ }
+
+ return Task.CompletedTask;
+ }
+
+ ///
+ /// Apply bookkeeping after successfully evaluating multimodal chunks.
+ ///
+ protected void FinalizeMtmdEvaluation(long newNPast, int previousConsumed)
+ {
+ _pastTokensCount = checked((int)newNPast);
+ DisposeMtmdChunks();
+
+ if (!string.IsNullOrEmpty(_pathSession) && _embed_inps.Count > previousConsumed)
+ {
+ _session_tokens.AddRange(_embed_inps.Skip(previousConsumed));
+ _n_session_consumed = _session_tokens.Count;
+ }
+
+ _consumedTokensCount = _embed_inps.Count;
+ _embeds.Clear();
+ }
+
+ ///
+ /// Evaluate the queued MTMD chunks and update executor state.
+ ///
+ protected void EvaluateMtmdChunks(ref long nPast, int previousConsumed, string executorName)
+ {
+ if (ClipModel is null)
+ throw new InvalidOperationException("Multimodal execution requires a loaded mtmd clip model.");
+ if (MtmdChunks is null)
+ throw new InvalidOperationException("No MTMD chunks are queued for evaluation.");
+
+ var evalStatus = ClipModel.EvaluateChunks(MtmdChunks, Context.NativeHandle, ref nPast, seqId: 0,
+ nBatch: checked((int)Context.BatchSize), logitsLast: true);
+ if (evalStatus != 0)
+ {
+ _logger?.LogError("[{Executor}] Failed to evaluate multimodal chunks. Status: {Status}", executorName, evalStatus);
+ DisposeMtmdChunks();
+ throw new RuntimeError($"Failed to evaluate multimodal chunks. Status: {evalStatus}.");
+ }
+
+ FinalizeMtmdEvaluation(nPast, previousConsumed);
+ }
+
+ ///
+ /// Determine whether the inference loop should continue processing tokens.
+ ///
+ /// Mutable state associated with the current inference.
+ /// true to continue generating; otherwise false.
protected abstract Task GetLoopCondition(InferStateArgs args);
///
- /// Preprocess the inputs before the inference.
+ /// Prepare the executor for inference by tokenizing input and updating cached state.
///
- ///
- ///
+ /// Prompt text to process.
+ /// Mutable state associated with the current inference.
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
///
- /// Do some post processing after the inference.
+ /// Perform any post-processing on the generated tokens.
///
- ///
- ///
- ///
+ /// Parameters controlling sampling.
+ /// Mutable state associated with the current inference.
+ /// A tuple indicating whether generation should stop and any extra outputs to emit.
protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
///
- /// The core inference logic.
+ /// Core inference loop that advances the model by one step.
///
- ///
- ///
+ /// Parameters controlling sampling.
+ /// Mutable state associated with the current inference.
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
///
- /// Save the current state to a file.
+ /// Save the executor state to a serialized snapshot file.
///
- ///
+ /// Destination file for the serialized state.
public abstract Task SaveState(string filename);
///
- /// Get the current state data.
+ /// Capture the executor state in a serializable object.
///
- ///
+ /// State snapshot suitable for persistence.
public abstract ExecutorBaseState GetStateData();
///
- /// Load the state from data.
+ /// Restore executor state from a previously captured snapshot.
///
- ///
+ /// State snapshot created by .
public abstract Task LoadState(ExecutorBaseState data);
///
- /// Load the state from a file.
+ /// Restore executor state from a serialized snapshot file.
///
- ///
+ /// Path to the snapshot produced by .
public abstract Task LoadState(string filename);
///
- /// Execute the inference.
+ /// Execute an asynchronous inference session.
///
- /// The prompt. If null, generation will continue where it left off previously.
- ///
- ///
- ///
+ /// Optional prompt; when null generation resumes from prior state.
+ /// Sampling parameters to apply; defaults are used when null.
+ /// Cancellation token for cooperative cancellation.
+ /// Stream of decoded text segments as they become available.
public virtual async IAsyncEnumerable InferAsync(string? text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
@@ -370,12 +565,12 @@ public virtual async Task PrefillPromptAsync(string prompt)
}
///
- /// State arguments that are used in single inference
+ /// Mutable state passed between inference callbacks during a single generation pass.
///
protected class InferStateArgs
{
///
- ///
+ /// Anti-prompts that terminate generation when encountered.
///
public IList? Antiprompts { get; set; }
///
@@ -383,20 +578,23 @@ protected class InferStateArgs
///
public int RemainedTokens { get; set; }
///
- ///
+ /// Indicates whether generated tokens should be returned to the caller.
///
public bool ReturnValue { get; set; }
///
- ///
+ /// Signals that the executor should pause and wait for additional user input.
///
public bool WaitForInput { get; set; }
///
- ///
+ /// Indicates whether the session cache should be persisted after inference completes.
///
public bool NeedToSaveSession { get; set; }
}
#pragma warning disable CS1591, CS8618 // Missing XML and irrelevant nullable warnings
+ ///
+ /// Serializable snapshot of executor state used for persistence and restart.
+ ///
[JsonConverter(typeof(PolymorphicJSONConverter))]
public class ExecutorBaseState
{
@@ -434,5 +632,33 @@ public class ExecutorBaseState
public float? MirostatMu { get; set; }
}
#pragma warning restore
+
+ internal ExecutorDiagnostics GetDiagnostics()
+ {
+ return new ExecutorDiagnostics(
+ _embed_inps.Count,
+ _consumedTokensCount,
+ _pastTokensCount,
+ _embeds.Count);
+ }
+ }
+}
+
+namespace LLama
+{
+ internal readonly struct ExecutorDiagnostics
+ {
+ public ExecutorDiagnostics(int embedCount, int consumedCount, int pastCount, int pendingEmbeds)
+ {
+ EmbedCount = embedCount;
+ ConsumedCount = consumedCount;
+ PastCount = pastCount;
+ PendingEmbedCount = pendingEmbeds;
+ }
+
+ public int EmbedCount { get; }
+ public int ConsumedCount { get; }
+ public int PastCount { get; }
+ public int PendingEmbedCount { get; }
}
}
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 331591fba..35f20b776 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
@@ -24,6 +25,7 @@ public class InstructExecutor
private readonly string _instructionPrefix;
private LLamaToken[] _inp_pfx;
private LLamaToken[] _inp_sfx;
+ private readonly string _instructionSuffix;
///
///
@@ -41,6 +43,20 @@ public InstructExecutor(LLamaContext context,
_inp_pfx = Context.Tokenize(instructionPrefix, true, true);
_inp_sfx = Context.Tokenize(instructionSuffix, false, true);
_instructionPrefix = instructionPrefix;
+ _instructionSuffix = instructionSuffix;
+ }
+
+ public InstructExecutor(LLamaContext context,
+ SafeMtmdWeights clipModel,
+ string instructionPrefix = "\n\n### Instruction:\n\n",
+ string instructionSuffix = "\n\n### Response:\n\n",
+ ILogger? logger = null)
+ : base(context, clipModel, logger)
+ {
+ _inp_pfx = Context.Tokenize(instructionPrefix, true, true);
+ _inp_sfx = Context.Tokenize(instructionSuffix, false, true);
+ _instructionPrefix = instructionPrefix;
+ _instructionSuffix = instructionSuffix;
}
///
@@ -67,6 +83,7 @@ public override ExecutorBaseState GetStateData()
///
public override Task LoadState(ExecutorBaseState data)
{
+ DisposeMtmdChunks();
if(data is InstructExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
@@ -126,7 +143,14 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
// When running the first input (prompt) in inteactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
- _embed_inps = Context.Tokenize(text, true, true).ToList();
+ if (!IsMultiModal)
+ {
+ _embed_inps = Context.Tokenize(text, true, true).ToList();
+ }
+ else
+ {
+ return PreprocessMtmd(text, args, addBos: true, replaceExisting: true);
+ }
}
else
{
@@ -139,14 +163,25 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
text += "\n";
}
- _embed_inps.AddRange(_inp_pfx);
+ if (!IsMultiModal)
+ {
+ _embed_inps.AddRange(_inp_pfx);
- var line_inp = Context.Tokenize(text, false, true);
- _embed_inps.AddRange(line_inp);
+ var line_inp = Context.Tokenize(text, false, true);
+ _embed_inps.AddRange(line_inp);
- _embed_inps.AddRange(_inp_sfx);
+ _embed_inps.AddRange(_inp_sfx);
- args.RemainedTokens -= line_inp.Length;
+ args.RemainedTokens -= line_inp.Length;
+ }
+ else
+ {
+ var builder = new StringBuilder();
+ builder.Append(_instructionPrefix);
+ builder.Append(text);
+ builder.Append(_instructionSuffix);
+ return PreprocessMtmd(builder.ToString(), args, addBos: false, replaceExisting: false);
+ }
}
}
@@ -213,11 +248,25 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
_n_session_consumed = _session_tokens.Count;
}
}
+ else if (IsMultiModal && MtmdChunks is not null)
+ {
+ _is_prompt_run = false;
+ var nPast = (long)_pastTokensCount;
+ var previousConsumed = _consumedTokensCount;
+ EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InstructExecutor));
+ }
_embeds.Clear();
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
+ if (inferenceParams.MaxTokens == 0)
+ {
+ _embeds.Clear();
+ args.WaitForInput = true;
+ args.ReturnValue = false;
+ return;
+ }
// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 7c9558ee3..392be783c 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
+using LLama;
using LLama.Exceptions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@@ -20,30 +21,26 @@ namespace LLama
///
public class InteractiveExecutor : StatefulExecutorBase
{
+ // Indicates whether the executor is currently evaluating the initial prompt or a follow-up turn.
private bool _is_prompt_run = true;
-
- // LLava
- private int _EmbedImagePosition = -1;
- private List _imageEmbedHandles = new List();
- private bool _imageInPrompt = false;
///
- ///
+ /// Create an interactive executor for text-only inference.
///
- ///
- ///
+ /// LLama context to operate against.
+ /// Optional logger for diagnostic output.
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
}
///
- ///
+ /// Create an interactive multimodal executor that can process text alongside media inputs.
///
- ///
- ///
- ///
- public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null)
+ /// LLama context to operate against.
+ /// Multimodal weights (MTMD) to attach to the executor.
+ /// Optional logger for diagnostic output.
+ public InteractiveExecutor(LLamaContext context, SafeMtmdWeights clipModel, ILogger? logger = null)
: base(context, clipModel, logger)
{
}
@@ -70,6 +67,7 @@ public override ExecutorBaseState GetStateData()
///
public override Task LoadState(ExecutorBaseState data)
{
+ DisposeMtmdChunks();
if (data is InteractiveExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
@@ -108,15 +106,20 @@ public override async Task LoadState(string filename)
}
///
- /// Define whether to continue the loop to generate responses.
+ /// Decide whether generation should continue for the current iteration.
///
- ///
+ /// Mutable inference state.
+ /// true to keep generating; otherwise false.
protected override Task GetLoopCondition(InferStateArgs args)
{
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
}
- ///
+ ///
+ /// Preprocess the incoming prompt or continuation text before inference.
+ ///
+ /// Prompt text or continuation provided by the caller.
+ /// Mutable inference state.
protected override Task PreprocessInputs(string? text, InferStateArgs args)
{
if (_is_prompt_run)
@@ -129,7 +132,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
else
{
- PreprocessLlava(text, args, true);
+ return PreprocessMtmd(text, args, addBos: true, replaceExisting: true);
}
}
else
@@ -150,7 +153,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
else
{
- PreprocessLlava(text, args, false);
+ return PreprocessMtmd(text, args, addBos: false, replaceExisting: false);
}
}
}
@@ -158,51 +161,12 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
return Task.CompletedTask;
}
- ///
- private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
- {
- // If the prompt contains the tag extract this.
- _imageInPrompt = text.Contains("");
- if (_imageInPrompt && IsMultiModal)
- {
- foreach (var image in Images)
- {
- _imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel!.NativeHandle, Context, image));
- }
-
- int imageIndex = text.IndexOf("");
- // Tokenize segment 1 (before tag)
- string preImagePrompt = text.Substring(0, imageIndex);
- var segment1 = Context.Tokenize(preImagePrompt, addBos, true);
- // Remember the position to add the image embeddings
- _EmbedImagePosition = segment1.Length;
- string postImagePrompt = text.Substring(imageIndex + 7);
- var segment2 = Context.Tokenize(postImagePrompt, false, true);
- _embed_inps.AddRange(segment1);
- _embed_inps.AddRange(segment2);
- }
- else
- {
- if (addBos)
- {
- _embed_inps = Context.Tokenize(text, true, true).ToList();
- }
- else
- {
- var line_inp = Context.Tokenize(text, false, true);
- _embed_inps.AddRange(line_inp);
- args.RemainedTokens -= line_inp.Length;
- }
- }
- return Task.CompletedTask;
- }
-
///
- /// Return whether to break the generation.
+ /// Decide whether generation should stop based on antiprompts, token limits, or end-of-generation markers.
///
- ///
- ///
- ///
+ /// Sampling parameters controlling generation.
+ /// Mutable inference state.
+ /// Tuple describing whether to stop and any additional outputs to emit.
protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
@@ -253,51 +217,50 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
HandleRunOutOfContext(tokensToKeep);
}
- TryReuseMatchingPrefix();
-
- // Changes to support Multi-Modal LLMs.
- //
- (DecodeResult, int, int) header, end, result;
- if (IsMultiModal && _EmbedImagePosition > 0)
+ if (MtmdChunks is null)
{
- // Tokens previous to the images
- header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
- _pastTokensCount = header.Item3;
-
- if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);
-
- // Images
- foreach( var image in _imageEmbedHandles )
- ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount);
-
- // Post-image Tokens
- end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
- _pastTokensCount = end.Item3;
+ TryReuseMatchingPrefix();
+ }
- _EmbedImagePosition = -1;
- _imageEmbedHandles.Clear();
- Images.Clear();
+ if (IsMultiModal && MtmdChunks is not null)
+ {
+ var nPast = (long)_pastTokensCount;
+ var previousConsumed = _consumedTokensCount;
+ EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor));
}
else
{
- result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
+ var result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = result.Item3;
if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
- }
-
- if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
- {
- _session_tokens.AddRange(_embeds);
- _n_session_consumed = _session_tokens.Count;
+ if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
+ {
+ _session_tokens.AddRange(_embeds);
+ _n_session_consumed = _session_tokens.Count;
+ }
}
}
-
+ else if (IsMultiModal && MtmdChunks is not null)
+ {
+ _is_prompt_run = false;
+ var nPast = (long)_pastTokensCount;
+ var previousConsumed = _consumedTokensCount;
+ EvaluateMtmdChunks(ref nPast, previousConsumed, nameof(InteractiveExecutor));
+ }
+
_embeds.Clear();
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
+ if (inferenceParams.MaxTokens == 0)
+ {
+ _embeds.Clear();
+ args.WaitForInput = true;
+ args.ReturnValue = false;
+ return;
+ }
// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
@@ -344,10 +307,10 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
}
///
- /// The descriptor of the state of the interactive executor.
+ /// Serializable state specific to the interactive executor.
///
public class InteractiveExecutorState
- : ExecutorBaseState
+ : StatefulExecutorBase.ExecutorBaseState
{
///
/// Whether the executor is running for the first time (running the prompt).
diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj
index f53de7069..e827585b7 100644
--- a/LLama/LLamaSharp.csproj
+++ b/LLama/LLamaSharp.csproj
@@ -3,7 +3,7 @@
netstandard2.0;net8.0
LLama
enable
- 12
+ 13
AnyCPU;x64;Arm64
True
@@ -17,7 +17,7 @@
https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
LLama, LLM, GPT, ChatGPT, NLP, AI, Chat Bot, SciSharp
- LLamaSharp is a cross-platform library to run 🦙LLaMA/LLaVA model (and others) in your local device.
+ LLamaSharp is a cross-platform library to run 🦙LLaMA/Mtmd model (and others) in your local device.
Based on [llama.cpp](https://github.com/ggerganov/llama.cpp), inference with LLamaSharp is efficient on both CPU and GPU.
With the higher-level APIs and RAG support, it's convenient to deploy LLM (Large Language Model) in your application with LLamaSharp.
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 8f9b40cc3..94bc60830 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -28,10 +28,10 @@ public class StatelessExecutor
public bool IsMultiModal => false;
///
- public LLavaWeights? ClipModel => default;
+ public SafeMtmdWeights? ClipModel => default;
///
- public List Images { get; }
+ public List Embeds { get; }
///
/// The context used by the executor when running the inference.
@@ -57,7 +57,7 @@ public class StatelessExecutor
///
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
- Images = [ ];
+ Embeds = [ ];
_weights = weights;
_params = @params;
_logger = logger;
diff --git a/LLama/LLavaWeights.cs b/LLama/LLavaWeights.cs
deleted file mode 100644
index f2f9f6256..000000000
--- a/LLama/LLavaWeights.cs
+++ /dev/null
@@ -1,137 +0,0 @@
-
-using System;
-using System.Threading;
-using System.Threading.Tasks;
-using LLama.Native;
-
-namespace LLama;
-
-///
-/// A set of llava model weights (mmproj), loaded into memory.
-///
-public sealed class LLavaWeights
- : IDisposable
-{
- ///
- /// The native handle, which is used in the native APIs
- ///
- /// Be careful how you use this!
- public SafeLlavaModelHandle NativeHandle { get; }
-
- private LLavaWeights(SafeLlavaModelHandle weights)
- {
- NativeHandle = weights;
- }
-
- #region load
- ///
- /// Load weights into memory
- ///
- /// path to the "mmproj" model file
- ///
- public static LLavaWeights LoadFromFile(string mmProject)
- {
- var weights = SafeLlavaModelHandle.LoadFromFile(mmProject, 1);
- return new LLavaWeights(weights);
- }
-
- ///
- /// Load weights into memory
- ///
- /// path to the "mmproj" model file
- ///
- ///
- public static Task LoadFromFileAsync(string mmProject, CancellationToken token = default)
- {
- return Task.Run(() => LoadFromFile(mmProject), token);
- }
- #endregion
-
- #region embed
- ///
- /// Create the Image Embeddings from the bytes of an image.
- ///
- ///
- /// Image bytes. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
- {
- return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
- }
-
- ///
- /// Create the Image Embeddings.
- ///
- /// Image in binary format (it supports jpeg format only)
- /// Number of threads to use
- /// return the SafeHandle of these embeddings
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
- {
- return NativeHandle.CreateImageEmbeddings(image, threads);
- }
-
- ///
- /// Create the Image Embeddings from the bytes of an image.
- ///
- ///
- /// Path to the image file. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- ///
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image)
- {
- return NativeHandle.CreateImageEmbeddings(ctxLlama, image);
- }
-
- ///
- /// Create the Image Embeddings from the bytes of an image.
- ///
- /// Path to the image file. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- ///
- ///
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
- {
- return NativeHandle.CreateImageEmbeddings(image, threads);
- }
- #endregion
-
- ///
- /// Eval the image embeddings
- ///
- ///
- ///
- ///
- ///
- public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past)
- {
- return NativeHandle.EvalImageEmbed( ctxLlama, imageEmbed, ref n_past );
- }
-
- ///
- public void Dispose()
- {
- NativeHandle.Dispose();
- }
-
-}
\ No newline at end of file
diff --git a/LLama/Native/LLavaImageEmbed.cs b/LLama/Native/LLavaImageEmbed.cs
deleted file mode 100644
index 65eba230c..000000000
--- a/LLama/Native/LLavaImageEmbed.cs
+++ /dev/null
@@ -1,19 +0,0 @@
-namespace LLama.Native;
-
-///
-/// LLaVa Image embeddings
-///
-/// llava_image_embed
-[StructLayout(LayoutKind.Sequential)]
-public unsafe struct LLavaImageEmbed
-{
- ///
- /// The embeddings of the embedded image.
- ///
- public float* embed;
-
- ///
- /// The position of the image's tokens.
- ///
- public int n_image_pos;
-}
\ No newline at end of file
diff --git a/LLama/Native/Load/NativeLibraryConfig.cs b/LLama/Native/Load/NativeLibraryConfig.cs
index 652b1da48..7f250d1c1 100644
--- a/LLama/Native/Load/NativeLibraryConfig.cs
+++ b/LLama/Native/Load/NativeLibraryConfig.cs
@@ -299,15 +299,15 @@ public sealed partial class NativeLibraryConfig
public static NativeLibraryConfig LLama { get; }
///
- /// Configuration for LLava native library
+ /// Configuration for Mtmd native library
///
- public static NativeLibraryConfig LLava { get; }
+ public static NativeLibraryConfig Mtmd { get; }
static NativeLibraryConfig()
{
LLama = new(NativeLibraryName.LLama);
- LLava = new(NativeLibraryName.LLava);
- All = new(LLama, LLava);
+ Mtmd = new(NativeLibraryName.Mtmd);
+ All = new(LLama, Mtmd);
}
#if NETSTANDARD2_0
@@ -413,9 +413,9 @@ public void ForEach(Action action)
/// When this method is called, all the other configurations will be ignored.
///
/// The full path to the llama library to load.
- /// The full path to the llava library to load.
+ /// The full path to the mtmd library to load.
/// Thrown if `LibraryHasLoaded` is true.
- public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llavaPath)
+ public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? mtmdPath)
{
foreach(var config in _configs)
{
@@ -423,9 +423,9 @@ public NativeLibraryConfigContainer WithLibrary(string? llamaPath, string? llava
{
config.WithLibrary(llamaPath);
}
- if(config.NativeLibraryName == NativeLibraryName.LLava && llavaPath is not null)
+ if(config.NativeLibraryName == NativeLibraryName.Mtmd && mtmdPath is not null)
{
- config.WithLibrary(llavaPath);
+ config.WithLibrary(mtmdPath);
}
}
@@ -594,7 +594,7 @@ public NativeLibraryConfigContainer WithLogCallback(ILogger? logger)
/// You can still modify the configuration after this calling but only before any call from .
///
/// Whether the running is successful.
- public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedLLavaNativeLibrary)
+ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibrary? loadedMtmdNativeLibrary)
{
bool success = true;
foreach(var config in _configs)
@@ -604,16 +604,16 @@ public bool DryRun(out INativeLibrary? loadedLLamaNativeLibrary, out INativeLibr
{
loadedLLamaNativeLibrary = loadedLibrary;
}
- else if(config.NativeLibraryName == NativeLibraryName.LLava)
+ else if(config.NativeLibraryName == NativeLibraryName.Mtmd)
{
- loadedLLavaNativeLibrary = loadedLibrary;
+ loadedMtmdNativeLibrary = loadedLibrary;
}
else
{
throw new Exception("Unknown native library config during the dry run.");
}
}
- loadedLLamaNativeLibrary = loadedLLavaNativeLibrary = null;
+ loadedLLamaNativeLibrary = loadedMtmdNativeLibrary = null;
return success;
}
}
@@ -628,9 +628,9 @@ public enum NativeLibraryName
///
LLama,
///
- /// The native library compiled from the LLaVA example of llama.cpp.
+ /// The native library compiled from the MTMD library of llama.cpp.
///
- LLava
+ Mtmd
}
internal static class LibraryNameExtensions
@@ -641,8 +641,8 @@ public static string GetLibraryName(this NativeLibraryName name)
{
case NativeLibraryName.LLama:
return NativeApi.libraryName;
- case NativeLibraryName.LLava:
- return NativeApi.llavaLibraryName;
+ case NativeLibraryName.Mtmd:
+ return NativeApi.mtmdLibraryName;
default:
throw new ArgumentOutOfRangeException(nameof(name), name, null);
}
diff --git a/LLama/Native/Load/NativeLibraryUtils.cs b/LLama/Native/Load/NativeLibraryUtils.cs
index 9f6457cd1..84ababc60 100644
--- a/LLama/Native/Load/NativeLibraryUtils.cs
+++ b/LLama/Native/Load/NativeLibraryUtils.cs
@@ -9,7 +9,7 @@ namespace LLama.Native
internal static class NativeLibraryUtils
{
///
- /// Try to load libllama/llava_shared, using CPU feature detection to try and load a more specialised DLL if possible
+ /// Try to load libllama/mtmd, using CPU feature detection to try and load a more specialised DLL if possible
///
/// The library handle to unload later, or IntPtr.Zero if no library was loaded
internal static IntPtr TryLoadLibrary(NativeLibraryConfig config, out INativeLibrary? loadedLibrary)
diff --git a/LLama/Native/MtmdContextParams.cs b/LLama/Native/MtmdContextParams.cs
new file mode 100644
index 000000000..5b282d802
--- /dev/null
+++ b/LLama/Native/MtmdContextParams.cs
@@ -0,0 +1,157 @@
+using System;
+using System.Runtime.InteropServices;
+using System.Text;
+
+namespace LLama.Native;
+
+///
+/// Managed representation of the native mtmd_context_params structure used to configure multimodal helpers.
+///
+public class MtmdContextParams
+{
+ ///
+ /// Whether GPU acceleration should be requested when available.
+ ///
+ public bool UseGpu { get; set; }
+
+ ///
+ /// Whether timing information should be emitted by the native helper.
+ ///
+ public bool PrintTimings { get; set; }
+
+ ///
+ /// Number of worker threads to dedicate to preprocessing and tokenization.
+ ///
+ public int NThreads { get; set; }
+
+ ///
+ /// Verbosity level forwarded to llama.cpp logging (matches ggml_log_level).
+ ///
+ public int Verbosity { get; set; }
+
+ ///
+ /// Marker token inserted into the text stream to reference an image embedding.
+ ///
+ public string? ImageMarker { get; set; }
+
+ ///
+ /// Marker token inserted into the text stream to reference a generic media embedding.
+ ///
+ public string? MediaMarker { get; set; }
+
+ ///
+ /// Create a managed copy of the native defaults returned by .
+ ///
+ public static MtmdContextParams Default()
+ {
+ var native = NativeApi.mtmd_context_params_default();
+ return new MtmdContextParams
+ {
+ UseGpu = native.use_gpu,
+ PrintTimings = native.print_timings,
+ NThreads = native.n_threads,
+ Verbosity = native.verbosity,
+ ImageMarker = PtrToString(native.image_marker),
+ MediaMarker = PtrToString(native.media_marker)
+ };
+ }
+
+ private static string? PtrToString(IntPtr ptr)
+ {
+ if (ptr == IntPtr.Zero)
+ return null;
+
+#if NETSTANDARD2_0
+ unsafe
+ {
+ var length = 0;
+ var current = (byte*)ptr;
+ while (current[length] != 0)
+ length++;
+
+ if (length == 0)
+ return string.Empty;
+
+ var buffer = new byte[length];
+ Marshal.Copy(ptr, buffer, 0, length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+#else
+ return Marshal.PtrToStringUTF8(ptr);
+#endif
+ }
+
+ ///
+ /// Convert the managed representation to a native structure, pinning strings for the duration of the scope.
+ ///
+ internal NativeScope ToNativeScope() => new(this);
+
+ internal readonly struct NativeScope : IDisposable
+ {
+ public NativeApi.mtmd_context_params Value { get; }
+
+ private readonly PinnedUtf8String? _imageMarker;
+ private readonly PinnedUtf8String? _mediaMarker;
+
+ public NativeScope(MtmdContextParams managed)
+ {
+ _imageMarker = PinnedUtf8String.Create(managed.ImageMarker);
+ _mediaMarker = PinnedUtf8String.Create(managed.MediaMarker);
+
+ var native = NativeApi.mtmd_context_params_default();
+ native.use_gpu = managed.UseGpu;
+ native.print_timings = managed.PrintTimings;
+ native.n_threads = managed.NThreads;
+ native.verbosity = managed.Verbosity;
+
+ if (_imageMarker is not null)
+ native.image_marker = _imageMarker.Pointer;
+ if (_mediaMarker is not null)
+ native.media_marker = _mediaMarker.Pointer;
+
+ Value = native;
+ }
+
+ public void Dispose()
+ {
+ _imageMarker?.Dispose();
+ _mediaMarker?.Dispose();
+ }
+ }
+}
+
+///
+/// Helper that pins a managed string as UTF-8 for the lifetime of the instance.
+///
+internal sealed class PinnedUtf8String : IDisposable
+{
+ private readonly byte[]? _buffer;
+ private readonly GCHandle _handle;
+
+ private PinnedUtf8String(string value)
+ {
+ var bytes = Encoding.UTF8.GetBytes(value);
+ _buffer = new byte[bytes.Length + 1];
+ Buffer.BlockCopy(bytes, 0, _buffer, 0, bytes.Length);
+ _handle = GCHandle.Alloc(_buffer, GCHandleType.Pinned);
+ }
+
+ public static PinnedUtf8String? Create(string? value) => value is null ? null : new PinnedUtf8String(value);
+
+ public IntPtr Pointer
+ {
+ get
+ {
+ if (_buffer is null || !_handle.IsAllocated)
+ return IntPtr.Zero;
+
+ return _handle.AddrOfPinnedObject();
+ }
+ }
+
+ public void Dispose()
+ {
+ if (_buffer is not null && _handle.IsAllocated)
+ _handle.Free();
+ }
+}
diff --git a/LLama/Native/MtmdImageEmbed.cs b/LLama/Native/MtmdImageEmbed.cs
new file mode 100644
index 000000000..7341b8563
--- /dev/null
+++ b/LLama/Native/MtmdImageEmbed.cs
@@ -0,0 +1,20 @@
+using System.Runtime.InteropServices;
+
+namespace LLama.Native;
+
+///
+/// Representation of the native llava_image_embed structure used to return image embeddings.
+///
+[StructLayout(LayoutKind.Sequential)]
+public unsafe struct MtmdImageEmbed
+{
+ ///
+ /// Pointer to the embedding buffer for the decoded image.
+ ///
+ public float* embed;
+
+ ///
+ /// Number of sequence positions consumed by the image tokens associated with the embedding.
+ ///
+ public int n_image_pos;
+}
diff --git a/LLama/Native/NativeApi.LLava.cs b/LLama/Native/NativeApi.LLava.cs
deleted file mode 100644
index 692e3f0ad..000000000
--- a/LLama/Native/NativeApi.LLava.cs
+++ /dev/null
@@ -1,63 +0,0 @@
-using System;
-
-namespace LLama.Native;
-
-public static partial class NativeApi
-{
- ///
- /// Sanity check for clip <-> llava embed size match
- ///
- /// LLama Context
- /// Llava Model
- /// True if validate successfully
- [DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)]
- [return: MarshalAs(UnmanagedType.U1)]
- public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip);
-
- ///
- /// Build an image embed from image file bytes
- ///
- /// SafeHandle to the Clip Model
- /// Number of threads
- /// Binary image in jpeg format
- /// Bytes length of the image
- /// SafeHandle to the Embeddings
- [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_bytes",
- CallingConvention = CallingConvention.Cdecl)]
- public static extern
- SafeLlavaImageEmbedHandle llava_image_embed_make_with_bytes(SafeLlavaModelHandle ctx_clip, int n_threads,
- byte[] image_bytes, int image_bytes_length);
-
- ///
- /// Build an image embed from a path to an image filename
- ///
- /// SafeHandle to the Clip Model
- /// Number of threads
- /// Image filename (jpeg) to generate embeddings
- /// SafeHandle to the embeddings
- [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_make_with_filename", CallingConvention = CallingConvention.Cdecl)]
- public static extern
- SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHandle ctx_clip, int n_threads,
- [MarshalAs(UnmanagedType.LPStr)] string image_path);
-
- ///
- /// Free an embedding made with llava_image_embed_make_*
- ///
- /// Embeddings to release
- [DllImport(llavaLibraryName, EntryPoint = "llava_image_embed_free", CallingConvention = CallingConvention.Cdecl)]
- public static extern void llava_image_embed_free(IntPtr embed);
-
- ///
- /// Write the image represented by embed into the llama context with batch size n_batch, starting at context
- /// pos n_past. on completion, n_past points to the next position in the context after the image embed.
- ///
- /// Llama Context
- /// Embedding handle
- ///
- ///
- /// True on success
- [DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)]
- [return: MarshalAs(UnmanagedType.U1)]
- public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past);
-
-}
\ No newline at end of file
diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs
index 4555ed0d2..57bb2d146 100644
--- a/LLama/Native/NativeApi.Load.cs
+++ b/LLama/Native/NativeApi.Load.cs
@@ -16,7 +16,7 @@ static NativeApi()
// Set flag to indicate that this point has been passed. No native library config can be done after this point.
NativeLibraryConfig.LLama.LibraryHasLoaded = true;
- NativeLibraryConfig.LLava.LibraryHasLoaded = true;
+ NativeLibraryConfig.Mtmd.LibraryHasLoaded = true;
// Immediately make a call which requires loading the llama DLL. This method call
// can't fail unless the DLL hasn't been loaded.
@@ -45,7 +45,7 @@ static NativeApi()
#if NET5_0_OR_GREATER
private static IntPtr _loadedLlamaHandle;
- private static IntPtr _loadedLlavaSharedHandle;
+ private static IntPtr _loadedMtmdHandle;
#endif
private static void SetDllImportResolver()
@@ -72,15 +72,15 @@ private static void SetDllImportResolver()
return _loadedLlamaHandle;
}
- if (name == "llava_shared")
+ if (name == "mtmd")
{
- // If we've already loaded llava return the handle that was loaded last time.
- if (_loadedLlavaSharedHandle != IntPtr.Zero)
- return _loadedLlavaSharedHandle;
+ // If we've already loaded Mtmd return the handle that was loaded last time.
+ if (_loadedMtmdHandle != IntPtr.Zero)
+ return _loadedMtmdHandle;
// Try to load a preferred library, based on CPU feature detection
- _loadedLlavaSharedHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.LLava, out _loadedLLavaLibrary);
- return _loadedLlavaSharedHandle;
+ _loadedMtmdHandle = NativeLibraryUtils.TryLoadLibrary(NativeLibraryConfig.Mtmd, out _loadedMtmdLibrary);
+ return _loadedMtmdHandle;
}
// Return null pointer to indicate that nothing was loaded.
@@ -100,17 +100,17 @@ private static void SetDllImportResolver()
return name switch
{
NativeLibraryName.LLama => _loadedLLamaLibrary,
- NativeLibraryName.LLava => _loadedLLavaLibrary,
+ NativeLibraryName.Mtmd => _loadedMtmdLibrary,
_ => throw new ArgumentException($"Library name {name} is not found.")
};
}
internal const string libraryName = "llama";
- internal const string llavaLibraryName = "llava_shared";
+ internal const string mtmdLibraryName = "mtmd";
internal const string ggmlLibraryName = "ggml";
internal const string ggmlBaseLibraryName = "ggml-base";
private static INativeLibrary? _loadedLLamaLibrary = null;
- private static INativeLibrary? _loadedLLavaLibrary = null;
+ private static INativeLibrary? _loadedMtmdLibrary = null;
}
}
diff --git a/LLama/Native/NativeApi.Mtmd.cs b/LLama/Native/NativeApi.Mtmd.cs
new file mode 100644
index 000000000..bfd6193c2
--- /dev/null
+++ b/LLama/Native/NativeApi.Mtmd.cs
@@ -0,0 +1,312 @@
+using System;
+using System.Runtime.InteropServices;
+using System.Text;
+
+namespace LLama.Native;
+
+///
+/// P/Invoke surface for MTMD (multimodal) helpers exposed by llama.cpp.
+///
+public static partial class NativeApi
+{
+ ///
+ /// Convert a UTF-8 encoded native string pointer into a managed .
+ /// Returns null when the pointer is zero.
+ ///
+ public static string? PtrToStringUtf8(IntPtr ptr)
+ {
+ if (ptr == IntPtr.Zero)
+ return null;
+
+#if NETSTANDARD2_0
+ unsafe
+ {
+ var current = (byte*)ptr;
+ var length = 0;
+ while (current[length] != 0)
+ length++;
+
+ if (length == 0)
+ return string.Empty;
+
+ var buffer = new byte[length];
+ Marshal.Copy(ptr, buffer, 0, length);
+ return Encoding.UTF8.GetString(buffer);
+ }
+#else
+ return Marshal.PtrToStringUTF8(ptr);
+#endif
+ }
+
+ ///
+ /// Native context parameters returned by .
+ ///
+ [StructLayout(LayoutKind.Sequential)]
+ internal struct mtmd_context_params
+ {
+ [MarshalAs(UnmanagedType.I1)] public bool use_gpu;
+ [MarshalAs(UnmanagedType.I1)] public bool print_timings;
+ public int n_threads;
+ public int verbosity;
+ public IntPtr image_marker;
+ public IntPtr media_marker;
+ }
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_default_marker", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_default_marker();
+
+ ///
+ /// Retrieve the default multimodal marker text.
+ ///
+ public static string? MtmdDefaultMarker()
+ => PtrToStringUtf8(mtmd_default_marker());
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_context_params_default", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern mtmd_context_params mtmd_context_params_default();
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_non_causal", CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ internal static extern bool mtmd_decode_use_non_causal(IntPtr ctx);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_decode_use_mrope", CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ internal static extern bool mtmd_decode_use_mrope(IntPtr ctx);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_vision", CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ internal static extern bool mtmd_support_vision(IntPtr ctx);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_support_audio", CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ internal static extern bool mtmd_support_audio(IntPtr ctx);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_audio_bitrate", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_get_audio_bitrate(IntPtr ctx);
+
+ // bitmap ------------------------------------------------------------
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_bitmap_init(uint nx, uint ny, IntPtr data);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_init_from_audio", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_bitmap_init_from_audio(ulong n_samples, IntPtr data);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_nx", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern uint mtmd_bitmap_get_nx(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_ny", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern uint mtmd_bitmap_get_ny(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_data", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_bitmap_get_data(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_n_bytes", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_bitmap_get_n_bytes(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_is_audio", CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.I1)]
+ internal static extern bool mtmd_bitmap_is_audio(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_free", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void mtmd_bitmap_free(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_get_id", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_bitmap_get_id(IntPtr bitmap);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_bitmap_set_id", CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe void mtmd_bitmap_set_id_native(IntPtr bitmap, byte* id);
+
+ ///
+ /// Assign an identifier to a bitmap using a UTF-8 encoded string.
+ ///
+ internal static unsafe void mtmd_bitmap_set_id(IntPtr bitmap, string? id)
+ {
+ if (bitmap == IntPtr.Zero)
+ throw new ArgumentNullException(nameof(bitmap));
+
+ if (id is null)
+ {
+ mtmd_bitmap_set_id_native(bitmap, null);
+ return;
+ }
+
+ using var pinned = PinnedUtf8String.Create(id) ?? throw new ArgumentNullException(nameof(id));
+ mtmd_bitmap_set_id_native(bitmap, (byte*)pinned.Pointer);
+ }
+
+ // input_chunks ------------------------------------------------------
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_init", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_input_chunks_init();
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_size", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_input_chunks_size(IntPtr chunks);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_get", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_input_chunks_get(IntPtr chunks, UIntPtr idx);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunks_free", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void mtmd_input_chunks_free(IntPtr chunks);
+
+ // input_chunk -------------------------------------------------------
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_type", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_input_chunk_get_type(IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_text", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_input_chunk_get_tokens_text(IntPtr chunk, out UIntPtr n_tokens);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_tokens_image", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_input_chunk_get_tokens_image(IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_tokens", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_input_chunk_get_n_tokens(IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_id", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_input_chunk_get_id(IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_get_n_pos", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern long mtmd_input_chunk_get_n_pos(IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_copy", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_input_chunk_copy(IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_input_chunk_free", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void mtmd_input_chunk_free(IntPtr chunk);
+
+ // image_tokens ------------------------------------------------------
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_tokens", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_image_tokens_get_n_tokens(IntPtr image_tokens);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_nx", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_image_tokens_get_nx(IntPtr image_tokens);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_ny", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_image_tokens_get_ny(IntPtr image_tokens);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_id", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_image_tokens_get_id(IntPtr image_tokens);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_image_tokens_get_n_pos", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern long mtmd_image_tokens_get_n_pos(IntPtr image_tokens);
+
+ // tokenize ----------------------------------------------------------
+
+ ///
+ /// Native text structure consumed by .
+ ///
+ internal unsafe struct mtmd_input_text_native
+ {
+ public byte* text;
+ [MarshalAs(UnmanagedType.I1)] public bool add_special;
+ [MarshalAs(UnmanagedType.I1)] public bool parse_special;
+ }
+
+ ///
+ /// Utility scope that pins managed text while invoking the native tokenizer.
+ ///
+ internal readonly unsafe ref struct MtmdInputTextScope
+ {
+ public readonly mtmd_input_text_native Value;
+ private readonly PinnedUtf8String _text;
+
+ public MtmdInputTextScope(string text, bool addSpecial, bool parseSpecial)
+ {
+ _text = PinnedUtf8String.Create(text) ?? throw new ArgumentNullException(nameof(text));
+ Value = new mtmd_input_text_native
+ {
+ text = (byte*)_text.Pointer,
+ add_special = addSpecial,
+ parse_special = parseSpecial
+ };
+ }
+
+ public void Dispose() => _text.Dispose();
+ }
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_tokenize", CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe int mtmd_tokenize_native(
+ IntPtr ctx,
+ IntPtr output,
+ mtmd_input_text_native* text,
+ IntPtr[] bitmaps,
+ UIntPtr n_bitmaps);
+
+ internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, in mtmd_input_text_native text, IntPtr[] bitmaps, UIntPtr n_bitmaps)
+ {
+ var temp = text;
+ return mtmd_tokenize_native(ctx, output, &temp, bitmaps, n_bitmaps);
+ }
+
+ internal static unsafe int mtmd_tokenize(IntPtr ctx, IntPtr output, string text, bool addSpecial, bool parseSpecial, IntPtr[] bitmaps, UIntPtr n_bitmaps)
+ {
+ using var scope = new MtmdInputTextScope(text, addSpecial, parseSpecial);
+ return mtmd_tokenize_native(ctx, output, &scope.Value, bitmaps, n_bitmaps);
+ }
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_encode(IntPtr ctx, IntPtr image_tokens);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_encode_chunk", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_encode_chunk(IntPtr ctx, IntPtr chunk);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_get_output_embd", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_get_output_embd(IntPtr ctx);
+
+ // helper ------------------------------------------------------------
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_test_create_input_chunks", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_test_create_input_chunks();
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_file", CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe IntPtr mtmd_helper_bitmap_init_from_file_native(IntPtr ctx, byte* fname);
+
+ internal static unsafe IntPtr mtmd_helper_bitmap_init_from_file(IntPtr ctx, string fname)
+ {
+ using var pinned = PinnedUtf8String.Create(fname) ?? throw new ArgumentNullException(nameof(fname));
+ return mtmd_helper_bitmap_init_from_file_native(ctx, (byte*)pinned.Pointer);
+ }
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_bitmap_init_from_buf", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern IntPtr mtmd_helper_bitmap_init_from_buf(IntPtr ctx, IntPtr buf, UIntPtr len);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_tokens", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern UIntPtr mtmd_helper_get_n_tokens(IntPtr chunks);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_get_n_pos", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern long mtmd_helper_get_n_pos(IntPtr chunks);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunks", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_helper_eval_chunks(
+ IntPtr ctx,
+ IntPtr lctx,
+ IntPtr chunks,
+ long n_past,
+ int seq_id,
+ int n_batch,
+ [MarshalAs(UnmanagedType.I1)] bool logits_last,
+ ref long new_n_past);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_eval_chunk_single", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_helper_eval_chunk_single(
+ IntPtr ctx,
+ IntPtr lctx,
+ IntPtr chunk,
+ long n_past,
+ int seq_id,
+ int n_batch,
+ [MarshalAs(UnmanagedType.I1)] bool logits_last,
+ ref long new_n_past);
+
+ [DllImport(mtmdLibraryName, EntryPoint = "mtmd_helper_decode_image_chunk", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern int mtmd_helper_decode_image_chunk(
+ IntPtr ctx,
+ IntPtr lctx,
+ IntPtr chunk,
+ IntPtr encoded_embd,
+ long n_past,
+ int seq_id,
+ int n_batch,
+ ref long new_n_past);
+}
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index db9e928bd..0ea46a600 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -1,4 +1,5 @@
using System;
+using System.Text;
#pragma warning disable IDE1006 // Naming Styles
@@ -323,21 +324,114 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
///
///
/// Returns the split_path length.
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count);
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_path")]
+ private static extern unsafe int llama_split_path_native(byte* split_path, nuint maxlen, byte* path_prefix, int split_no, int split_count);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_split_prefix")]
+ private static extern unsafe int llama_split_prefix_native(byte* split_prefix, nuint maxlen, byte* split_path, int split_no, int split_count);
+
+ private static byte[] EncodeNullTerminatedUtf8(string value, string paramName)
+ {
+ if (value is null)
+ throw new ArgumentNullException(paramName);
+
+ var bytes = Encoding.UTF8.GetBytes(value);
+ var buffer = new byte[bytes.Length + 1];
+ Buffer.BlockCopy(bytes, 0, buffer, 0, bytes.Length);
+ return buffer;
+ }
///
- /// Extract the path prefix from the split_path if and only if the split_no and split_count match.
- /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
+ /// Build the fully-qualified path for a specific split file in a GGUF shard set.
///
- ///
- ///
- ///
- ///
- ///
- /// Returns the split_prefix length.
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count);
+ /// Writable buffer that receives the UTF-8 encoded path.
+ /// Base path (e.g. "/models/ggml-model-q4_0").
+ /// Zero-based split index.
+ /// Total number of splits.
+ /// Number of bytes written to .
+ public static int llama_split_path(Span splitPathBuffer, string pathPrefix, int splitNo, int splitCount)
+ {
+ if (splitPathBuffer.Length == 0)
+ throw new ArgumentException("Buffer must not be empty.", nameof(splitPathBuffer));
+
+ var pathPrefixBytes = EncodeNullTerminatedUtf8(pathPrefix, nameof(pathPrefix));
+
+ unsafe
+ {
+ fixed (byte* splitPtr = splitPathBuffer)
+ fixed (byte* prefixPtr = pathPrefixBytes)
+ {
+ return llama_split_path_native(splitPtr, (nuint)splitPathBuffer.Length, prefixPtr, splitNo, splitCount);
+ }
+ }
+ }
+
+ ///
+ /// Build the fully-qualified path for a specific split file in a GGUF shard set.
+ ///
+ /// Base path (e.g. "/models/ggml-model-q4_0").
+ /// Zero-based split index.
+ /// Total number of splits.
+ /// Maximum number of bytes to allocate for the resulting UTF-8 string.
+ /// UTF-8 decoded split path.
+ public static string llama_split_path(string pathPrefix, int splitNo, int splitCount, int maxLength = 1024)
+ {
+ if (maxLength <= 0)
+ throw new ArgumentOutOfRangeException(nameof(maxLength));
+
+ var buffer = new byte[maxLength];
+ var written = llama_split_path((Span)buffer, pathPrefix, splitNo, splitCount);
+ if (written <= 0)
+ throw new InvalidOperationException("Failed to build split path using llama_split_path.");
+
+ return Encoding.UTF8.GetString(buffer, 0, written);
+ }
+
+ ///
+ /// Extract the shard prefix from a GGUF split path when the split metadata matches.
+ ///
+ /// Writable buffer that receives the UTF-8 encoded prefix.
+ /// Full path to a shard file.
+ /// Zero-based split index.
+ /// Total number of splits.
+ /// Number of bytes written to .
+ public static int llama_split_prefix(Span splitPrefixBuffer, string splitPath, int splitNo, int splitCount)
+ {
+ if (splitPrefixBuffer.Length == 0)
+ throw new ArgumentException("Buffer must not be empty.", nameof(splitPrefixBuffer));
+
+ var splitPathBytes = EncodeNullTerminatedUtf8(splitPath, nameof(splitPath));
+
+ unsafe
+ {
+ fixed (byte* prefixPtr = splitPrefixBuffer)
+ fixed (byte* pathPtr = splitPathBytes)
+ {
+ return llama_split_prefix_native(prefixPtr, (nuint)splitPrefixBuffer.Length, pathPtr, splitNo, splitCount);
+ }
+ }
+ }
+
+ ///
+ /// Extract the shard prefix from a GGUF split path when the split metadata matches.
+ ///
+ /// Full path to a shard file.
+ /// Zero-based split index.
+ /// Total number of splits.
+ /// Maximum number of bytes to allocate for the resulting UTF-8 string.
+ /// UTF-8 decoded split prefix.
+ public static string llama_split_prefix(string splitPath, int splitNo, int splitCount, int maxLength = 1024)
+ {
+ if (maxLength <= 0)
+ throw new ArgumentOutOfRangeException(nameof(maxLength));
+
+ var buffer = new byte[maxLength];
+ var written = llama_split_prefix((Span)buffer, splitPath, splitNo, splitCount);
+ if (written <= 0)
+ throw new InvalidOperationException("Failed to extract split prefix using llama_split_prefix.");
+
+ return Encoding.UTF8.GetString(buffer, 0, written);
+ }
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch);
diff --git a/LLama/Native/SafeLlavaImageEmbedHandle.cs b/LLama/Native/SafeLlavaImageEmbedHandle.cs
deleted file mode 100644
index 102c4b93f..000000000
--- a/LLama/Native/SafeLlavaImageEmbedHandle.cs
+++ /dev/null
@@ -1,162 +0,0 @@
-using System;
-using System.IO;
-
-
-namespace LLama.Native
-{
- ///
- /// A Reference to a llava Image Embed handle
- ///
- public sealed class SafeLlavaImageEmbedHandle
- : SafeLLamaHandleBase
- {
- ///
- /// Get the model used to create this image embedding
- ///
- public SafeLlavaModelHandle Model { get; private set; } = null!;
-
- ///
- /// Get the number of dimensions in an embedding
- ///
- public int EmbeddingDimensions => Model.EmbeddingDimensions;
-
- ///
- /// Get the number of "patches" in an image embedding
- ///
- public int PatchCount => Model.PatchCount;
-
- #region embed
- ///
- /// Create an image embed from an image file
- ///
- ///
- ///
- /// Path to the image file. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- ///
- public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, LLamaContext ctx, string image)
- {
- if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
- throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");
-
- return CreateFromFileName(clip, image, (int)ctx.BatchThreads);
- }
-
- ///
- /// Create an image embed from an image file
- ///
- ///
- /// Path to the image file. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- ///
- ///
- public static SafeLlavaImageEmbedHandle CreateFromFileName(SafeLlavaModelHandle clip, string image, int threads = -1)
- {
- if (threads <= 0)
- threads = Environment.ProcessorCount / 2;
-
- // Try to open the image file, this will check:
- // - File exists (automatically throws FileNotFoundException)
- // - File is readable (explicit check)
- // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
- using (var fs = new FileStream(image, FileMode.Open))
- if (!fs.CanRead)
- throw new InvalidOperationException($"Llava image file '{image}' is not readable");
-
- var embed = NativeApi.llava_image_embed_make_with_filename(clip, threads, image);
- embed.Model = clip;
- return embed;
- }
-
- ///
- /// Create an image embed from the bytes of an image.
- ///
- ///
- ///
- /// Image bytes. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, LLamaContext ctx, byte[] image)
- {
- if (!NativeApi.llava_validate_embed_size(ctx.NativeHandle, clip))
- throw new InvalidOperationException($"Cannot create image embed. Embedding dim of the multimodal projector ({clip.EmbeddingDimensions}) is not equal to embedding dim of model ({ctx.EmbeddingSize})");
-
- return CreateFromMemory(clip, image, (int)ctx.BatchThreads);
- }
-
- ///
- /// Create an image embed from the bytes of an image.
- ///
- ///
- /// Image bytes. Supported formats:
- ///
- /// - JPG
- /// - PNG
- /// - BMP
- /// - TGA
- ///
- ///
- ///
- ///
- public static SafeLlavaImageEmbedHandle CreateFromMemory(SafeLlavaModelHandle clip, byte[] image, int threads = -1)
- {
- if (threads <= 0)
- threads = Environment.ProcessorCount / 2;
-
- var embed = NativeApi.llava_image_embed_make_with_bytes(clip, threads, image, image.Length);
- embed.Model = clip;
- return embed;
- }
- #endregion
-
- ///
- protected override bool ReleaseHandle()
- {
- NativeApi.llava_image_embed_free(DangerousGetHandle());
- SetHandle(IntPtr.Zero);
- return true;
- }
-
- ///
- /// Copy the embeddings data to the destination span
- ///
- ///
- ///
- public void GetEmbedding(Span dest, int index)
- {
- if (index < 0)
- throw new ArgumentOutOfRangeException(nameof(index), "index must be >= 0");
- if (index >= Model.PatchCount)
- throw new ArgumentOutOfRangeException(nameof(index), "index must be < Model.PatchCount");
-
- unsafe
- {
- var embed = (LLavaImageEmbed*)DangerousGetHandle();
- new Span(
- embed->embed + Model.EmbeddingDimensions * index,
- Model.EmbeddingDimensions
- ).CopyTo(dest);
- }
- }
- }
-}
diff --git a/LLama/Native/SafeLlavaModelHandle.cs b/LLama/Native/SafeLlavaModelHandle.cs
deleted file mode 100644
index 5b3a910e9..000000000
--- a/LLama/Native/SafeLlavaModelHandle.cs
+++ /dev/null
@@ -1,137 +0,0 @@
-using System;
-using System.IO;
-using LLama.Exceptions;
-
-
-namespace LLama.Native
-{
- ///
- /// A reference to a set of llava model weights.
- ///
- public sealed class SafeLlavaModelHandle
- : SafeLLamaHandleBase
- {
- ///
- /// Get the number of dimensions in an embedding
- ///
- public int EmbeddingDimensions => clip_n_mmproj_embd(this);
-
- ///
- /// Get the number of "patches" in an image embedding
- ///
- public int PatchCount => clip_n_patches(this);
-
- ///
- protected override bool ReleaseHandle()
- {
- clip_free(DangerousGetHandle());
- SetHandle(IntPtr.Zero);
- return true;
- }
-
- ///
- /// Load a model from the given file path into memory
- ///
- /// MMP File (Multi-Modal Projections)
- /// Verbosity level
- /// SafeHandle of the Clip Model
- ///
- ///
- public static SafeLlavaModelHandle LoadFromFile(string modelPath, int verbosity )
- {
- // Try to open the model file, this will check:
- // - File exists (automatically throws FileNotFoundException)
- // - File is readable (explicit check)
- // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
- using (var fs = new FileStream(modelPath, FileMode.Open))
- if (!fs.CanRead)
- throw new InvalidOperationException($"Llava MMP Model file '{modelPath}' is not readable");
-
- var handle = clip_model_load(modelPath, verbosity);
- if (handle.IsInvalid)
- throw new LoadWeightsFailedException(modelPath);
-
- return handle;
- }
-
- ///
- /// Create the Image Embeddings.
- ///
- /// LLama Context
- /// Image filename (it supports jpeg format only)
- /// return the SafeHandle of these embeddings
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, string image)
- {
- return SafeLlavaImageEmbedHandle.CreateFromFileName(this, ctxLlama, image);
- }
-
- ///
- /// Create the Image Embeddings.
- ///
- /// Image in binary format (it supports jpeg format only)
- /// Number of threads to use
- /// return the SafeHandle of these embeddings
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(string image, int threads = -1)
- {
- return SafeLlavaImageEmbedHandle.CreateFromFileName(this, image, threads);
- }
-
- ///
- /// Create the Image Embeddings.
- ///
- /// LLama Context
- /// Image in binary format (it supports jpeg format only)
- /// return the SafeHandle of these embeddings
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(LLamaContext ctxLlama, byte[] image)
- {
- return SafeLlavaImageEmbedHandle.CreateFromMemory(this, ctxLlama, image );
- }
-
- ///
- /// Create the Image Embeddings.
- ///
- /// Image in binary format (it supports jpeg format only)
- /// Number of threads to use
- /// return the SafeHandle of these embeddings
- public SafeLlavaImageEmbedHandle CreateImageEmbeddings(byte[] image, int threads = -1)
- {
- return SafeLlavaImageEmbedHandle.CreateFromMemory(this, image, threads);
- }
-
- ///
- /// Evaluates the image embeddings.
- ///
- /// Llama Context
- /// The current embeddings to evaluate
- ///
- /// True on success
- public bool EvalImageEmbed(LLamaContext ctxLlama, SafeLlavaImageEmbedHandle imageEmbed, ref int n_past)
- {
- return NativeApi.llava_eval_image_embed(ctxLlama.NativeHandle, imageEmbed, (int)ctxLlama.BatchSize, ref n_past );
- }
-
- #region native API
- ///
- /// Load MULTI MODAL PROJECTIONS model / Clip Model
- ///
- /// Model path/file
- /// Verbosity level
- /// SafeLlavaModelHandle
- [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_model_load", CallingConvention = CallingConvention.Cdecl)]
- private static extern SafeLlavaModelHandle clip_model_load(string mmProj, int verbosity);
-
- ///
- /// Frees MULTI MODAL PROJECTIONS model / Clip Model
- ///
- /// Internal Pointer to the model
- [DllImport(NativeApi.llavaLibraryName, EntryPoint = "clip_free", CallingConvention = CallingConvention.Cdecl)]
- private static extern void clip_free(IntPtr ctx);
-
- [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern int clip_n_mmproj_embd(SafeLlavaModelHandle ctx);
-
- [DllImport(NativeApi.llavaLibraryName, CallingConvention = CallingConvention.Cdecl)]
- private static extern int clip_n_patches(SafeLlavaModelHandle ctx);
- #endregion
- }
-}
diff --git a/LLama/Native/SafeMtmdEmbed.cs b/LLama/Native/SafeMtmdEmbed.cs
new file mode 100644
index 000000000..c651db102
--- /dev/null
+++ b/LLama/Native/SafeMtmdEmbed.cs
@@ -0,0 +1,247 @@
+using System;
+using System.IO;
+using System.Runtime.InteropServices;
+
+namespace LLama.Native
+{
+ ///
+ /// Managed wrapper around mtmd_bitmap* resources. Instances own the native pointer
+ /// and ensure proper cleanup when disposed.
+ ///
+ public sealed class SafeMtmdEmbed : IDisposable
+ {
+ ///
+ /// Raw pointer to the native bitmap structure. Internal so other wrappers can interop.
+ ///
+ internal IntPtr NativePtr { get; private set; }
+
+ private bool _disposed;
+
+ private SafeMtmdEmbed(IntPtr ptr)
+ {
+ NativePtr = ptr != IntPtr.Zero
+ ? ptr
+ : throw new InvalidOperationException("Failed to create MTMD bitmap.");
+ }
+
+ ///
+ /// Create an embedding from raw RGB bytes.
+ ///
+ /// Width of the bitmap in pixels.
+ /// Height of the bitmap in pixels.
+ /// Packed RGB data (3 bytes per pixel).
+ /// Managed wrapper when initialization succeeds; otherwise null.
+ /// The RGB buffer is null.
+ public static SafeMtmdEmbed? FromRgbBytes(uint nx, uint ny, byte[] rgbData)
+ {
+ if (rgbData == null)
+ throw new ArgumentNullException(nameof(rgbData));
+
+ var handle = GCHandle.Alloc(rgbData, GCHandleType.Pinned);
+ try
+ {
+ var native = NativeApi.mtmd_bitmap_init(nx, ny, handle.AddrOfPinnedObject());
+ return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native);
+ }
+ finally
+ {
+ if (handle.IsAllocated)
+ handle.Free();
+ }
+ }
+
+ ///
+ /// Create an embedding from PCM audio samples.
+ ///
+ /// Array of mono PCM samples in float format.
+ /// Managed wrapper when initialization succeeds; otherwise null.
+ /// The audio buffer is null.
+ public static SafeMtmdEmbed? FromAudioSamples(float[] samples)
+ {
+ if (samples == null)
+ throw new ArgumentNullException(nameof(samples));
+
+ var handle = GCHandle.Alloc(samples, GCHandleType.Pinned);
+ try
+ {
+ var native = NativeApi.mtmd_bitmap_init_from_audio((ulong)samples.Length, handle.AddrOfPinnedObject());
+ return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native);
+ }
+ finally
+ {
+ if (handle.IsAllocated)
+ handle.Free();
+ }
+ }
+
+ ///
+ /// Create an embedding by decoding a media file using libmtmd helpers.
+ ///
+ /// Model context that provides the decoder configuration.
+ /// Path to the media file on disk.
+ /// Managed wrapper when decoding succeeds; otherwise null.
+ /// The context is null.
+ /// The path is null or whitespace.
+ /// The supplied file does not exist.
+ public static SafeMtmdEmbed? FromMediaFile(SafeMtmdModelHandle mtmdContext, string path)
+ {
+ if (mtmdContext == null)
+ throw new ArgumentNullException(nameof(mtmdContext));
+ if (string.IsNullOrWhiteSpace(path))
+ throw new ArgumentException("Value cannot be null or whitespace.", nameof(path));
+
+ var fullPath = Path.GetFullPath(path);
+ if (!File.Exists(fullPath))
+ throw new FileNotFoundException("Media file not found.", fullPath);
+
+ bool added = false;
+ var ctxPtr = IntPtr.Zero;
+ try
+ {
+ // Hold a strong reference to the native context while the helper decodes the media file.
+ mtmdContext.DangerousAddRef(ref added);
+ ctxPtr = mtmdContext.DangerousGetHandle();
+ var native = NativeApi.mtmd_helper_bitmap_init_from_file(ctxPtr, fullPath);
+ return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native);
+ }
+ finally
+ {
+ if (added)
+ mtmdContext.DangerousRelease();
+ }
+ }
+
+ ///
+ /// Create an embedding from an in-memory media buffer (image/audio/video).
+ ///
+ /// Model context that provides the decoder configuration.
+ /// Binary buffer containing the encoded media.
+ /// Managed wrapper when decoding succeeds; otherwise null.
+ /// The context is null.
+ /// The buffer is empty.
+ public static unsafe SafeMtmdEmbed? FromMediaBuffer(SafeMtmdModelHandle mtmdContext, ReadOnlySpan data)
+ {
+ if (mtmdContext == null)
+ throw new ArgumentNullException(nameof(mtmdContext));
+ if (data.IsEmpty)
+ throw new ArgumentException("Buffer must not be empty.", nameof(data));
+
+ bool added = false;
+ var ctxPtr = IntPtr.Zero;
+ try
+ {
+ // Keep the context alive while the native helper processes the buffer.
+ mtmdContext.DangerousAddRef(ref added);
+ ctxPtr = mtmdContext.DangerousGetHandle();
+
+ fixed (byte* bufferPtr = data)
+ {
+ var native = NativeApi.mtmd_helper_bitmap_init_from_buf(ctxPtr, new IntPtr(bufferPtr), (UIntPtr)data.Length);
+ return native == IntPtr.Zero ? null : new SafeMtmdEmbed(native);
+ }
+ }
+ finally
+ {
+ if (added)
+ mtmdContext.DangerousRelease();
+ }
+ }
+
+ ///
+ /// Width of the bitmap in pixels (or number of samples for audio embeddings).
+ ///
+ public uint Nx
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return NativeApi.mtmd_bitmap_get_nx(NativePtr);
+ }
+ }
+
+ ///
+ /// Height of the bitmap in pixels. For audio embeddings this is typically 1.
+ ///
+ public uint Ny
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return NativeApi.mtmd_bitmap_get_ny(NativePtr);
+ }
+ }
+
+ ///
+ /// Indicates whether the embedding stores audio data instead of image pixels.
+ ///
+ public bool IsAudio
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return NativeApi.mtmd_bitmap_is_audio(NativePtr);
+ }
+ }
+
+ ///
+ /// Optional identifier assigned to this embedding.
+ ///
+ public string? Id
+ {
+ get
+ {
+ EnsureNotDisposed();
+ var ptr = NativeApi.mtmd_bitmap_get_id(NativePtr);
+ return NativeApi.PtrToStringUtf8(ptr);
+ }
+ set
+ {
+ EnsureNotDisposed();
+ NativeApi.mtmd_bitmap_set_id(NativePtr, value);
+ }
+ }
+
+ ///
+ /// Zero-copy access to the underlying bitmap bytes. The span remains valid while this wrapper is alive.
+ ///
+ /// Read-only span exposing the native data buffer.
+ /// The embedding has been disposed.
+ public unsafe ReadOnlySpan GetDataSpan()
+ {
+ EnsureNotDisposed();
+
+ var dataPtr = (byte*)NativeApi.mtmd_bitmap_get_data(NativePtr);
+ var length = checked((int)NativeApi.mtmd_bitmap_get_n_bytes(NativePtr).ToUInt64());
+ return dataPtr == null || length == 0 ? ReadOnlySpan.Empty : new ReadOnlySpan(dataPtr, length);
+ }
+
+ ///
+ /// Release the underlying native bitmap.
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ return;
+
+ if (NativePtr != IntPtr.Zero)
+ {
+ NativeApi.mtmd_bitmap_free(NativePtr);
+ NativePtr = IntPtr.Zero;
+ }
+
+ _disposed = true;
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Finalizer to ensure native resources are reclaimed when Dispose is not invoked.
+ ///
+ ~SafeMtmdEmbed() => Dispose();
+
+ private void EnsureNotDisposed()
+ {
+ if (_disposed || NativePtr == IntPtr.Zero)
+ throw new ObjectDisposedException(nameof(SafeMtmdEmbed));
+ }
+ }
+}
diff --git a/LLama/Native/SafeMtmdInputChunk.cs b/LLama/Native/SafeMtmdInputChunk.cs
new file mode 100644
index 000000000..59d1897ef
--- /dev/null
+++ b/LLama/Native/SafeMtmdInputChunk.cs
@@ -0,0 +1,150 @@
+using System;
+using System.Runtime.InteropServices;
+
+namespace LLama.Native;
+
+///
+/// Managed wrapper around a single mtmd_input_chunk. Instances can either own the
+/// underlying native pointer (when created via ) or act as non-owning views
+/// produced by the tokenizer.
+///
+public sealed class SafeMtmdInputChunk : IDisposable
+{
+ ///
+ /// Chunk modality returned by the native tokenizer.
+ ///
+ public enum SafeMtmdInputChunkType
+ {
+ Text = 0,
+ Image = 1,
+ Audio = 2
+ }
+
+ ///
+ /// Raw pointer to the native chunk structure.
+ ///
+ public IntPtr NativePtr { get; private set; }
+
+ private bool _ownsPtr;
+ private bool _disposed;
+
+ private SafeMtmdInputChunk(IntPtr ptr, bool owns)
+ {
+ NativePtr = ptr;
+ _ownsPtr = owns;
+ }
+
+ ///
+ /// Wrap an existing chunk pointer without taking ownership.
+ ///
+ /// Pointer returned by the native tokenizer.
+ /// Managed wrapper, or null when the pointer is null.
+ public static SafeMtmdInputChunk Wrap(IntPtr ptr)
+ => ptr == IntPtr.Zero ? null : new SafeMtmdInputChunk(ptr, false);
+
+ ///
+ /// Create an owning copy of the current chunk. The caller becomes responsible for disposal.
+ ///
+ /// Owning managed wrapper, or null if the native copy failed.
+ /// Thrown when the current wrapper has been disposed.
+ public SafeMtmdInputChunk Copy()
+ {
+ EnsureNotDisposed();
+
+ var p = NativeApi.mtmd_input_chunk_copy(NativePtr);
+ return p == IntPtr.Zero ? null : new SafeMtmdInputChunk(p, true);
+ }
+
+ ///
+ /// Chunk modality reported by the native helper.
+ ///
+ public SafeMtmdInputChunkType Type
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return (SafeMtmdInputChunkType)NativeApi.mtmd_input_chunk_get_type(NativePtr);
+ }
+ }
+
+ ///
+ /// Number of tokens contained in this chunk.
+ ///
+ public ulong NTokens
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return NativeApi.mtmd_input_chunk_get_n_tokens(NativePtr).ToUInt64();
+ }
+ }
+
+ ///
+ /// Identifier assigned by the tokenizer (if any).
+ ///
+ public string Id
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return Marshal.PtrToStringAnsi(NativeApi.mtmd_input_chunk_get_id(NativePtr)) ?? string.Empty;
+ }
+ }
+
+ ///
+ /// Number of positional slots consumed by this chunk.
+ ///
+ public long NPos
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return NativeApi.mtmd_input_chunk_get_n_pos(NativePtr);
+ }
+ }
+
+ ///
+ /// Zero-copy view over the chunk's token buffer. The span remains valid only while the native chunk is alive.
+ ///
+ /// Read-only span exposing the chunk's tokens.
+ /// Thrown when the wrapper has been disposed.
+ public unsafe ReadOnlySpan GetTextTokensSpan()
+ {
+ EnsureNotDisposed();
+
+ UIntPtr n;
+ var p = (uint*)NativeApi.mtmd_input_chunk_get_tokens_text(NativePtr, out n);
+ return p == null ? ReadOnlySpan.Empty : new ReadOnlySpan(p, checked((int)n.ToUInt64()));
+ }
+
+ ///
+ /// Release the underlying native resources if this instance owns them.
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ return;
+
+ if (_ownsPtr && NativePtr != IntPtr.Zero)
+ {
+ NativeApi.mtmd_input_chunk_free(NativePtr);
+ }
+
+ NativePtr = IntPtr.Zero;
+ _ownsPtr = false;
+ _disposed = true;
+
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Finalizer to ensure native memory is reclaimed when Dispose is not called by owners.
+ ///
+ ~SafeMtmdInputChunk() => Dispose();
+
+ private void EnsureNotDisposed()
+ {
+ if (_disposed || NativePtr == IntPtr.Zero)
+ throw new ObjectDisposedException(nameof(SafeMtmdInputChunk));
+ }
+}
diff --git a/LLama/Native/SafeMtmdInputChunks.cs b/LLama/Native/SafeMtmdInputChunks.cs
new file mode 100644
index 000000000..2081cd0a6
--- /dev/null
+++ b/LLama/Native/SafeMtmdInputChunks.cs
@@ -0,0 +1,103 @@
+using System;
+using System.Collections.Generic;
+
+namespace LLama.Native;
+
+///
+/// Managed lifetime wrapper around a native mtmd_input_chunks collection returned by the tokenizer.
+///
+public sealed class SafeMtmdInputChunks : IDisposable
+{
+ ///
+ /// Raw pointer to the native chunk collection. Internal to allow other wrappers to interop safely.
+ ///
+ internal IntPtr NativePtr { get; private set; }
+
+ private bool _disposed;
+
+ internal SafeMtmdInputChunks(IntPtr ptr)
+ {
+ NativePtr = ptr;
+ }
+
+ ///
+ /// Releases the native chunk collection and suppresses finalization.
+ ///
+ public void Dispose()
+ {
+ if (_disposed)
+ return;
+
+ if (NativePtr != IntPtr.Zero)
+ {
+ NativeApi.mtmd_input_chunks_free(NativePtr);
+ NativePtr = IntPtr.Zero;
+ }
+
+ _disposed = true;
+ GC.SuppressFinalize(this);
+ }
+
+ ///
+ /// Finalizer to ensure native memory is reclaimed if Dispose is not called.
+ ///
+ ~SafeMtmdInputChunks()
+ {
+ Dispose();
+ }
+
+ ///
+ /// Number of chunks currently held by the native collection.
+ ///
+ public ulong Size
+ {
+ get
+ {
+ EnsureNotDisposed();
+ return NativeApi.mtmd_input_chunks_size(NativePtr).ToUInt64();
+ }
+ }
+
+ ///
+ /// Get a raw pointer to a chunk. The returned is the mtmd_input_chunk*.
+ /// Use to create a managed wrapper if desired.
+ ///
+ /// Zero-based index of the chunk to retrieve.
+ /// Pointer to the requested chunk.
+ /// The collection has already been disposed.
+ /// The requested index is outside of the valid range.
+ public IntPtr GetChunkPtr(ulong index)
+ {
+ EnsureNotDisposed();
+
+ if (index >= Size) throw new IndexOutOfRangeException();
+ return NativeApi.mtmd_input_chunks_get(NativePtr, (UIntPtr)index);
+ }
+
+ ///
+ /// Enumerate the contained chunks as non-owning wrappers. Callers should dispose the returned chunk
+ /// if they create a copy.
+ ///
+ /// Enumeration of chunk wrappers backed by the native collection.
+ /// The collection has already been disposed.
+ public IEnumerable Enumerate()
+ {
+ EnsureNotDisposed();
+
+ for (ulong i = 0; i < Size; i++)
+ {
+ var chunk = SafeMtmdInputChunk.Wrap(GetChunkPtr(i));
+ if (chunk != null)
+ {
+ // Yield a lightweight wrapper; ownership remains with the native collection.
+ yield return chunk;
+ }
+ }
+ }
+
+ private void EnsureNotDisposed()
+ {
+ if (_disposed || NativePtr == IntPtr.Zero)
+ throw new ObjectDisposedException(nameof(SafeMtmdInputChunks));
+ }
+}
diff --git a/LLama/Native/SafeMtmdModelHandle.cs b/LLama/Native/SafeMtmdModelHandle.cs
new file mode 100644
index 000000000..86abf8c6c
--- /dev/null
+++ b/LLama/Native/SafeMtmdModelHandle.cs
@@ -0,0 +1,341 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using LLama.Exceptions;
+
+
+namespace LLama.Native
+{
+ ///
+ /// Wrapper to the Multi Modal Weights handle. This wrapper manages the low level
+ /// operations.
+ ///
+ public sealed class SafeMtmdModelHandle : SafeLLamaHandleBase
+ {
+ // Pending media embeddings queued for the next call to Tokenize.
+ private readonly List _pendingMedia = new();
+
+ ///
+ protected override bool ReleaseHandle()
+ {
+ mtmd_free(DangerousGetHandle());
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+
+ ///
+ /// Load a multimodal projection model from disk and bind it to the supplied text model.
+ ///
+ /// Path to the MMP (Multi-Modal Projections) file.
+ /// Text model that provides tokenizer weights for the multimodal helper.
+ /// Optional context parameters; defaults are used when null.
+ /// Safe handle for the MTMD model.
+ /// The file exists but is not readable by the current process.
+ /// The native loader failed to initialize the MTMD model.
+ public static SafeMtmdModelHandle LoadFromFile(string modelPath, LLamaWeights textModel, MtmdContextParams mtmdCtxParams)
+ {
+ // Try to open the model file, this will check:
+ // - File exists (automatically throws FileNotFoundException)
+ // - File is readable (explicit check)
+ // This provides better error messages that llama.cpp, which would throw an access violation exception in both cases.
+ using (var fs = new FileStream(modelPath, FileMode.Open))
+ if (!fs.CanRead)
+ throw new InvalidOperationException($"Mtmd Model file '{modelPath}' is not readable");
+
+ using var pathUtf8 = PinnedUtf8String.Create(modelPath) ?? throw new ArgumentNullException(nameof(modelPath));
+
+ unsafe
+ {
+ SafeMtmdModelHandle handle;
+ if (mtmdCtxParams is null)
+ {
+ var nativeParams = NativeApi.mtmd_context_params_default();
+ handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParams);
+ }
+ else
+ {
+ using var nativeParamsScope = mtmdCtxParams.ToNativeScope();
+ handle = mtmd_init_from_file((byte*)pathUtf8.Pointer, textModel.NativeHandle, nativeParamsScope.Value);
+ }
+
+ if (handle.IsInvalid)
+ throw new LoadWeightsFailedException(modelPath);
+
+ return handle;
+ }
+ }
+
+ ///
+ /// Load media from disk and queue it for the next tokenize call.
+ ///
+ /// Absolute or relative path to the media asset.
+ /// Safe handle to the media embedding.
+ /// The model handle has been disposed.
+ /// The native loader failed to ingest the file.
+ public SafeMtmdEmbed LoadMediaFromFile(string path)
+ {
+ EnsureNotDisposed();
+
+ var embed = SafeMtmdEmbed.FromMediaFile(this, path)
+ ?? throw new RuntimeError($"Failed to load media '{path}'.");
+ _pendingMedia.Add(embed);
+ return embed;
+ }
+
+ ///
+ /// Load media from an in-memory buffer and queue it for the next tokenize call.
+ ///
+ /// Binary buffer containing the encoded media data.
+ /// Safe handle to the media embedding.
+ /// The model handle has been disposed.
+ /// The native loader failed to ingest the buffer contents.
+ public SafeMtmdEmbed LoadMediaFromBuffer(ReadOnlySpan buffer)
+ {
+ EnsureNotDisposed();
+
+ var embed = SafeMtmdEmbed.FromMediaBuffer(this, buffer)
+ ?? throw new RuntimeError("Failed to load media from buffer.");
+ _pendingMedia.Add(embed);
+ return embed;
+ }
+
+ ///
+ /// Disposes and clears any media buffers currently queued for tokenization.
+ ///
+ public void ClearMedia()
+ {
+ foreach (var media in _pendingMedia)
+ media.Dispose();
+ _pendingMedia.Clear();
+ }
+
+ ///
+ /// Tokenize a prompt alongside the pending media buffers. Pending media is cleared on success.
+ ///
+ /// Prompt text to tokenize.
+ /// Whether to append special tokens automatically.
+ /// Whether special tokens should be treated as user-provided text.
+ /// Receives the native chunk collection when tokenization succeeds.
+ /// Zero on success; otherwise the native mtmd tokenize error code.
+ /// The model handle has been disposed.
+ public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks)
+ {
+ EnsureNotDisposed();
+
+ chunks = null;
+ // Allocate the chunk container before invoking the native tokenizer.
+ var output = NativeApi.mtmd_input_chunks_init();
+ if (output == IntPtr.Zero)
+ throw new RuntimeError("Failed to allocate mtmd_input_chunks.");
+
+ // Collect native pointers to the queued media embeddings.
+ var bitmapHandles = new IntPtr[_pendingMedia.Count];
+ for (var i = 0; i < _pendingMedia.Count; i++)
+ bitmapHandles[i] = _pendingMedia[i].NativePtr;
+
+ var result = NativeApi.mtmd_tokenize(DangerousGetHandle(), output, text, addSpecial, parseSpecial, bitmapHandles, (UIntPtr)bitmapHandles.Length);
+
+ if (result == 0)
+ {
+ chunks = new SafeMtmdInputChunks(output);
+ }
+ else
+ {
+ NativeApi.mtmd_input_chunks_free(output);
+ }
+
+ ClearMedia();
+
+ return result;
+ }
+
+ ///
+ /// Evaluate a batch of chunks using the helper (mirrors mtmd-helper eval logic).
+ ///
+ /// Chunk collection produced by .
+ /// Context handle that receives the evaluated tokens.
+ /// Number of past tokens; updated when evaluation succeeds.
+ /// Sequence identifier used for KV cache management.
+ /// Maximum number of tokens to evaluate in a single batch.
+ /// Whether to request logits for the last token only.
+ /// Zero on success; otherwise the native helper error code.
+ /// Thrown when required handles are null.
+ public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast)
+ {
+ EnsureNotDisposed();
+
+ if (chunks == null)
+ throw new ArgumentNullException(nameof(chunks));
+ if (llamaContext == null)
+ throw new ArgumentNullException(nameof(llamaContext));
+
+ var newNPast = nPast;
+ var result = NativeApi.mtmd_helper_eval_chunks(
+ DangerousGetHandle(),
+ llamaContext.DangerousGetHandle(),
+ chunks.NativePtr,
+ nPast,
+ seqId,
+ nBatch,
+ logitsLast,
+ ref newNPast);
+
+ if (result == 0)
+ nPast = newNPast;
+
+ return result;
+ }
+
+ ///
+ /// Evaluate a single chunk helper.
+ ///
+ /// Pointer to the chunk to evaluate.
+ /// Context handle that receives the evaluated tokens.
+ /// Number of past tokens; updated when evaluation succeeds.
+ /// Sequence identifier used for KV cache management.
+ /// Maximum number of tokens to evaluate in a single batch.
+ /// Whether to request logits for the last token only.
+ /// Zero on success; otherwise the native helper error code.
+ /// Thrown when required handles are null.
+ public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast)
+ {
+ EnsureNotDisposed();
+
+ if (chunkPtr == IntPtr.Zero)
+ throw new ArgumentNullException(nameof(chunkPtr));
+ if (llamaContext == null)
+ throw new ArgumentNullException(nameof(llamaContext));
+
+ var newNPast = nPast;
+ var result = NativeApi.mtmd_helper_eval_chunk_single(
+ DangerousGetHandle(),
+ llamaContext.DangerousGetHandle(),
+ chunkPtr,
+ nPast,
+ seqId,
+ nBatch,
+ logitsLast,
+ ref newNPast);
+
+ if (result == 0)
+ nPast = newNPast;
+
+ return result;
+ }
+
+ ///
+ /// Decode a prepared image chunk whose embedding is already computed.
+ ///
+ /// Pointer to the chunk whose embedding should be decoded.
+ /// Context handle used for decoding.
+ /// Pointer to the pre-computed embedding data.
+ /// Number of past tokens; updated when evaluation succeeds.
+ /// Sequence identifier used for KV cache management.
+ /// Maximum number of tokens to evaluate in a single batch.
+ /// Zero on success; otherwise the native helper error code.
+ /// Thrown when required handles are null.
+ public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch)
+ {
+ EnsureNotDisposed();
+
+ if (chunkPtr == IntPtr.Zero)
+ throw new ArgumentNullException(nameof(chunkPtr));
+
+ var newNPast = nPast;
+ var result = NativeApi.mtmd_helper_decode_image_chunk(
+ DangerousGetHandle(),
+ llamaContext?.DangerousGetHandle() ?? throw new ArgumentNullException(nameof(llamaContext)),
+ chunkPtr,
+ encodedEmbeddings,
+ nPast,
+ seqId,
+ nBatch,
+ ref newNPast);
+
+ if (result == 0)
+ nPast = newNPast;
+
+ return result;
+ }
+
+ ///
+ /// Get the number of tokens contained in the provided chunk collection.
+ ///
+ /// Chunk collection produced by .
+ /// Total token count.
+ public ulong CountTokens(SafeMtmdInputChunks chunks)
+ {
+ if (chunks == null)
+ throw new ArgumentNullException(nameof(chunks));
+ return NativeApi.mtmd_helper_get_n_tokens(chunks.NativePtr).ToUInt64();
+ }
+
+ ///
+ /// Get the number of positions contained in the provided chunk collection.
+ ///
+ /// Chunk collection produced by .
+ /// Total number of positional slots consumed.
+ public long CountPositions(SafeMtmdInputChunks chunks)
+ {
+ if (chunks == null)
+ throw new ArgumentNullException(nameof(chunks));
+ return NativeApi.mtmd_helper_get_n_pos(chunks.NativePtr);
+ }
+
+ #region native API
+
+ // mtmd_init_from_file(const char * mmproj_fname, const struct llama_model * text_model, const struct mtmd_context_params ctx_params);
+ // The llama_model layout is opaque; expose it via SafeLlamaModelHandle to match the managed wrapper.
+ [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_init_from_file", CallingConvention = CallingConvention.Cdecl)]
+ private static extern unsafe SafeMtmdModelHandle mtmd_init_from_file(
+ byte* mmproj_fname,
+ SafeLlamaModelHandle text_model,
+ NativeApi.mtmd_context_params @ctx_params);
+
+ [DllImport(NativeApi.mtmdLibraryName, EntryPoint = "mtmd_free", CallingConvention = CallingConvention.Cdecl)]
+ internal static extern void mtmd_free(IntPtr ctx);
+
+ #endregion
+
+
+
+ ///
+ /// Finalizer to ensure native resources are released if Dispose was not called.
+ ///
+ ~SafeMtmdModelHandle()
+ {
+ Dispose();
+ }
+
+ ///
+ /// Indicates whether the model decodes using the non-causal path.
+ ///
+ public bool DecodeUseNonCausal() => NativeApi.mtmd_decode_use_non_causal(handle);
+
+ ///
+ /// Indicates whether the model decodes using multi-scale RoPE.
+ ///
+ public bool DecodeUseMRope() => NativeApi.mtmd_decode_use_mrope(handle);
+
+ ///
+ /// Indicates whether the model supports vision inputs.
+ ///
+ public bool SupportVision() => NativeApi.mtmd_support_vision(handle);
+
+ ///
+ /// Indicates whether the model supports audio inputs.
+ ///
+ public bool SupportAudio() => NativeApi.mtmd_support_audio(handle);
+
+ ///
+ /// Gets the audio bitrate advertised by the model.
+ ///
+ public int GetAudioBitrate() => NativeApi.mtmd_get_audio_bitrate(handle);
+
+ private void EnsureNotDisposed()
+ {
+ if (IsInvalid || IsClosed)
+ throw new ObjectDisposedException(nameof(SafeMtmdModelHandle));
+ }
+ }
+}
diff --git a/LLama/Properties/InternalsVisibleTo.cs b/LLama/Properties/InternalsVisibleTo.cs
new file mode 100644
index 000000000..b0a1ac4be
--- /dev/null
+++ b/LLama/Properties/InternalsVisibleTo.cs
@@ -0,0 +1,3 @@
+using System.Runtime.CompilerServices;
+
+[assembly: InternalsVisibleTo("LLama.Unittest")]
diff --git a/LLama/SafeMtmdWeights.cs b/LLama/SafeMtmdWeights.cs
new file mode 100644
index 000000000..e490049b4
--- /dev/null
+++ b/LLama/SafeMtmdWeights.cs
@@ -0,0 +1,80 @@
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using LLama.Native;
+
+namespace LLama;
+
+///
+/// Lightweight wrapper around the MTMD native context and its helpers.
+///
+public sealed class SafeMtmdWeights : IDisposable
+{
+ public SafeMtmdModelHandle NativeHandle { get; }
+
+ private SafeMtmdWeights(SafeMtmdModelHandle handle)
+ {
+ NativeHandle = handle ?? throw new ArgumentNullException(nameof(handle));
+ }
+
+ public static SafeMtmdWeights LoadFromFile(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams)
+ {
+ if (mmProject == null) throw new ArgumentNullException(nameof(mmProject));
+ if (textModel == null) throw new ArgumentNullException(nameof(textModel));
+ if (mtmdCtxParams == null) throw new ArgumentNullException(nameof(mtmdCtxParams));
+
+ var handle = SafeMtmdModelHandle.LoadFromFile(mmProject, textModel, mtmdCtxParams);
+ return new SafeMtmdWeights(handle);
+ }
+
+ public static Task LoadFromFileAsync(string mmProject, LLamaWeights textModel, MtmdContextParams mtmdCtxParams, CancellationToken token = default)
+ {
+ return Task.Run(() => LoadFromFile(mmProject, textModel, mtmdCtxParams), token);
+ }
+
+ ///
+ /// Load media from disk and keep it pending for the next tokenize call.
+ ///
+ public SafeMtmdEmbed LoadMedia(string path) => NativeHandle.LoadMediaFromFile(path);
+
+ ///
+ /// Load media from an in-memory buffer and keep it pending for the next tokenize call.
+ ///
+ public SafeMtmdEmbed LoadMedia(ReadOnlySpan data) => NativeHandle.LoadMediaFromBuffer(data);
+
+ ///
+ /// Clear any pending media buffers before or after tokenization.
+ ///
+ public void ClearMedia() => NativeHandle.ClearMedia();
+
+ ///
+ /// Tokenize text (with optional special tokens) against the pending media buffers.
+ ///
+ public int Tokenize(string text, bool addSpecial, bool parseSpecial, out SafeMtmdInputChunks? chunks)
+ => NativeHandle.Tokenize(text, addSpecial, parseSpecial, out chunks);
+
+ ///
+ /// Evaluate a chunk batch using the helper that performs mtmd encode + llama decode.
+ ///
+ public int EvaluateChunks(SafeMtmdInputChunks chunks, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast)
+ => NativeHandle.EvaluateChunks(chunks, llamaContext, ref nPast, seqId, nBatch, logitsLast);
+
+ public int EvaluateChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, ref long nPast, int seqId, int nBatch, bool logitsLast)
+ => NativeHandle.EvaluateChunk(chunkPtr, llamaContext, ref nPast, seqId, nBatch, logitsLast);
+
+ public int DecodeImageChunk(IntPtr chunkPtr, SafeLLamaContextHandle llamaContext, IntPtr encodedEmbeddings, ref long nPast, int seqId, int nBatch)
+ => NativeHandle.DecodeImageChunk(chunkPtr, llamaContext, encodedEmbeddings, ref nPast, seqId, nBatch);
+
+ public ulong CountTokens(SafeMtmdInputChunks chunks) => NativeHandle.CountTokens(chunks);
+
+ public long CountPositions(SafeMtmdInputChunks chunks) => NativeHandle.CountPositions(chunks);
+
+ public bool SupportsVision => NativeHandle.SupportVision();
+ public bool SupportsAudio => NativeHandle.SupportAudio();
+ public bool UsesNonCausalAttention => NativeHandle.DecodeUseNonCausal();
+ public bool UsesMRope => NativeHandle.DecodeUseMRope();
+ public int AudioBitrate => NativeHandle.GetAudioBitrate();
+
+ public void Dispose() => NativeHandle.Dispose();
+}
diff --git a/docs/Examples/LLavaInteractiveModeExecute.md b/docs/Examples/LLavaInteractiveModeExecute.md
deleted file mode 100644
index 2bfbbea1d..000000000
--- a/docs/Examples/LLavaInteractiveModeExecute.md
+++ /dev/null
@@ -1,129 +0,0 @@
-# LLaVA - basic
-
-```cs
-using System.Text.RegularExpressions;
-using LLama.Common;
-using Spectre.Console;
-using LLama.Native;
-
-namespace LLama.Examples.Examples
-{
- // This example shows how to chat with LLaVA model with both image and text as input.
- // It uses the interactive executor to inference.
- public class LlavaInteractiveModeExecute
- {
- public static async Task Run()
- {
- string multiModalProj = UserSettings.GetMMProjPath();
- string modelPath = UserSettings.GetModelPath();
- string modelImage = UserSettings.GetImagePath();
- const int maxTokens = 1024;
-
- var prompt = $"{{{modelImage}}}\nUSER:\nProvide a full description of the image.\nASSISTANT:\n";
-
- var parameters = new ModelParams(modelPath);
-
- using var model = LLamaWeights.LoadFromFile(parameters);
- using var context = model.CreateContext(parameters);
-
- // Llava Init
- using var clipModel = LLavaWeights.LoadFromFile(multiModalProj);
-
- var ex = new InteractiveExecutor(context, clipModel );
-
- Console.ForegroundColor = ConsoleColor.Yellow;
- Console.WriteLine("The executor has been enabled. In this example, the prompt is printed, the maximum tokens is set to {0} and the context size is {1}.", maxTokens, parameters.ContextSize );
- Console.WriteLine("To send an image, enter its filename in curly braces, like this {c:/image.jpg}.");
-
- var inferenceParams = new InferenceParams() { Temperature = 0.1f, AntiPrompts = new List { "\nUSER:" }, MaxTokens = maxTokens };
-
- do
- {
-
- // Evaluate if we have images
- //
- var imageMatches = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
- var imageCount = imageMatches.Count();
- var hasImages = imageCount > 0;
-
- if (hasImages)
- {
- var imagePathsWithCurlyBraces = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Value);
- var imagePaths = Regex.Matches(prompt, "{([^}]*)}").Select(m => m.Groups[1].Value).ToList();
-
- List imageBytes;
- try
- {
- imageBytes = imagePaths.Select(File.ReadAllBytes).ToList();
- }
- catch (IOException exception)
- {
- Console.ForegroundColor = ConsoleColor.Red;
- Console.Write(
- $"Could not load your {(imageCount == 1 ? "image" : "images")}:");
- Console.Write($"{exception.Message}");
- Console.ForegroundColor = ConsoleColor.Yellow;
- Console.WriteLine("Please try again.");
- break;
- }
-
- // Each prompt with images we clear cache
- // When the prompt contains images we clear KV_CACHE to restart conversation
- // See:
- // https://github.com/ggerganov/llama.cpp/discussions/3620
- ex.Context.NativeHandle.KvCacheRemove( LLamaSeqId.Zero, -1, -1 );
-
- int index = 0;
- foreach (var path in imagePathsWithCurlyBraces)
- {
- // First image replace to tag " : "");
- }
-
-
- Console.ForegroundColor = ConsoleColor.Yellow;
- Console.WriteLine($"Here are the images, that are sent to the chat model in addition to your message.");
- Console.WriteLine();
-
- foreach (var consoleImage in imageBytes?.Select(bytes => new CanvasImage(bytes)))
- {
- consoleImage.MaxWidth = 50;
- AnsiConsole.Write(consoleImage);
- }
-
- Console.WriteLine();
- Console.ForegroundColor = ConsoleColor.Yellow;
- Console.WriteLine($"The images were scaled down for the console only, the model gets full versions.");
- Console.WriteLine($"Write /exit or press Ctrl+c to return to main menu.");
- Console.WriteLine();
-
-
- // Initialize Images in executor
- //
- foreach (var image in imagePaths)
- {
- ex.Images.Add(await File.ReadAllBytesAsync(image));
- }
- }
-
- Console.ForegroundColor = Color.White;
- await foreach (var text in ex.InferAsync(prompt, inferenceParams))
- {
- Console.Write(text);
- }
- Console.Write(" ");
- Console.ForegroundColor = ConsoleColor.Green;
- prompt = Console.ReadLine();
- Console.WriteLine();
-
- // let the user finish with exit
- //
- if (prompt != null && prompt.Equals("/exit", StringComparison.OrdinalIgnoreCase))
- break;
-
- }
- while(true);
- }
- }
-}
-```
\ No newline at end of file
diff --git a/docs/Examples/MtmdInteractiveModeExecute.md b/docs/Examples/MtmdInteractiveModeExecute.md
new file mode 100644
index 000000000..378c93a1b
--- /dev/null
+++ b/docs/Examples/MtmdInteractiveModeExecute.md
@@ -0,0 +1,41 @@
+# MTMD interactive mode
+
+`MtmdInteractiveModeExecute` shows how to pair a multimodal projection with a text model so the chat loop can reason over images supplied at runtime. The sample lives in `LLama.Examples/Examples/MtmdInteractiveModeExecute.cs` and reuses the interactive executor provided by LLamaSharp.
+
+## Workflow
+- Resolve the model, multimodal projection, and sample image paths via `UserSettings`.
+- Create `ModelParams` for the text model and capture the MTMD defaults with `MtmdContextParams.Default()`.
+- Load the base model and context, then initialize `SafeMtmdWeights` with the multimodal projection file.
+- Ask the helper for a media marker (`mtmdParameters.MediaMarker ?? NativeApi.MtmdDefaultMarker() ?? ""`) and feed it into an `InteractiveExecutor`.
+
+```cs
+var mtmdParameters = MtmdContextParams.Default();
+
+using var model = await LLamaWeights.LoadFromFileAsync(parameters);
+using var context = model.CreateContext(parameters);
+
+// Mtmd Init
+using var clipModel = await SafeMtmdWeights.LoadFromFileAsync(
+ multiModalProj,
+ model,
+ mtmdParameters);
+
+var mediaMarker = mtmdParameters.MediaMarker
+ ?? NativeApi.MtmdDefaultMarker()
+ ?? "";
+
+var ex = new InteractiveExecutor(context, clipModel);
+```
+
+## Handling user input
+- Prompts can include image paths wrapped in braces (for example `{c:/image.jpg}`); the loop searches for those markers with regular expressions.
+- Every referenced file is loaded through `SafeMtmdWeights.LoadMedia`, producing `SafeMtmdEmbed` instances that are queued for the next tokenization call.
+- When the user provides images, the executor clears its KV cache (`MemorySequenceRemove`) before replacing each brace-wrapped path in the prompt with the multimodal marker.
+- The embeds collected for the current turn are copied into `ex.Embeds`, so the executor submits both the text prompt and the pending media to the helper before generation.
+
+## Running the sample
+1. Ensure the model and projection paths returned by `UserSettings` exist locally.
+2. Start the example (for instance from the examples host application) and observe the initial description printed to the console.
+3. Type text normally, or reference new images by including their path inside braces. Type `/exit` to end the conversation.
+
+This walkthrough mirrors the logic in the sample so you can adapt it for your own multimodal workflows.
diff --git a/mkdocs.yml b/mkdocs.yml
index 09cb3b96b..fbffdbba7 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -38,7 +38,7 @@ nav:
- Interactive executor - basic: Examples/InteractiveModeExecute.md
- Kernel memory integration - basic: Examples/KernelMemory.md
- Kernel-memory - save & load: Examples/KernelMemorySaveAndLoad.md
- - LLaVA - basic: Examples/LLavaInteractiveModeExecute.md
+ - MTMD interactive: Examples/MtmdInteractiveModeExecute.md
- ChatSession - load & save: Examples/LoadAndSaveSession.md
- Executor - save/load state: Examples/LoadAndSaveState.md
- Quantization: Examples/QuantizeModel.md
@@ -254,4 +254,4 @@ markdown_extensions:
custom_checkbox: true
- pymdownx.tilde
- pymdownx.tabbed:
- alternate_style: true
\ No newline at end of file
+ alternate_style: true