Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedExecutorFork.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public static async Task Run()
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedExecutorGuidance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public static async Task Run()
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedExecutorRewind.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static async Task Run()
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static async Task Run()
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/ChatChineseGB2312.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public static async Task Run()
GpuLayerCount = 5,
Encoding = Encoding.UTF8
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatSessionStripRoleName.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/ChatSessionWithHistory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatSessionWithRestart.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
ChatSession prototypeSession =
await ChatSession.InitializeSessionFromHistoryAsync(executor, chatHistory);
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/ChatSessionWithRoleName.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
var chatHistoryJson = await File.ReadAllTextAsync("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/CodingAssistant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public static async Task Run()
{
ContextSize = 4096
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null);

Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/GrammarJsonResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var gbnf = File.ReadAllText("Assets/json.gbnf").Trim();
var gbnf = (await File.ReadAllTextAsync("Assets/json.gbnf")).Trim();
var grammar = Grammar.Parse(gbnf, "root");

var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);

Console.ForegroundColor = ConsoleColor.Yellow;
Expand Down
4 changes: 2 additions & 2 deletions LLama.Examples/Examples/InstructModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var prompt = File.ReadAllText("Assets/dan.txt").Trim();
var prompt = (await File.ReadAllTextAsync("Assets/dan.txt")).Trim();

var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context);

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/InteractiveModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/LlavaInteractiveModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static async Task Run()

var parameters = new ModelParams(modelPath);

using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);

// Llava Init
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/LoadAndSaveSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/LoadAndSaveState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);

Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/SemanticKernelChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static async Task Run()

// Load weights into memory
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);

var chatGPT = new LLamaSharpChatCompletion(ex);
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/SemanticKernelMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static async Task Run()
Embeddings = true
};

using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var embedding = new LLamaEmbedder(model, parameters);

Console.WriteLine("====================================================");
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/SemanticKernelPrompt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public static async Task Run()

// Load weights into memory
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);

var builder = Kernel.CreateBuilder();
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/StatelessModeExecute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public static async Task Run()
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
var ex = new StatelessExecutor(model, parameters);

Console.ForegroundColor = ConsoleColor.Yellow;
Expand Down
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/TalkToYourself.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public static async Task Run()

// Load weights into memory
var @params = new ModelParams(modelPath);
using var weights = LLamaWeights.LoadFromFile(@params);
using var weights = await LLamaWeights.LoadFromFileAsync(@params);

// Create 2 contexts sharing the same weights
using var aliceCtx = weights.CreateContext(@params);
Expand Down
101 changes: 101 additions & 0 deletions LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -84,6 +87,104 @@ public static LLamaWeights LoadFromFile(IModelParams @params)
return new LLamaWeights(weights);
}

/// <summary>
/// Load weights into memory
/// </summary>
/// <param name="params">Parameters to use to load the model</param>
/// <param name="token">A cancellation token that can interrupt model loading</param>
/// <param name="progressReporter">Receives progress updates as the model loads (0 to 1)</param>
/// <returns></returns>
/// <exception cref="LoadWeightsFailedException">Thrown if weights failed to load for any reason. e.g. Invalid file format or loading cancelled.</exception>
/// <exception cref="OperationCanceledException">Thrown if the cancellation token is cancelled.</exception>
public static async Task<LLamaWeights> LoadFromFileAsync(IModelParams @params, CancellationToken token = default, IProgress<float>? progressReporter = null)
{
// don't touch the @params object inside the task, it might be changed
// externally! Save a copy of everything that we need later.
var modelPath = @params.ModelPath;
var loraBase = @params.LoraBase;
var loraAdapters = @params.LoraAdapters.ToArray();

// Determine the range to report for model loading. llama.cpp reports 0-1, but we'll remap that into a
// slightly smaller range to allow some space for reporting LoRA loading too.
var modelLoadProgressRange = 1f;
if (loraAdapters.Length > 0)
modelLoadProgressRange = 0.9f;

using (@params.ToLlamaModelParams(out var lparams))
{
#if !NETSTANDARD2_0
// Overwrite the progress callback with one which polls the cancellation token and updates the progress object
if (token.CanBeCanceled || progressReporter != null)
{
var internalCallback = lparams.progress_callback;
lparams.progress_callback = (progress, ctx) =>
{
// Update the progress reporter (remapping the value into the smaller range).
progressReporter?.Report(Math.Clamp(progress, 0, 1) * modelLoadProgressRange);

// If the user set a callback in the model params, call that and see if we should cancel
if (internalCallback != null && !internalCallback(progress, ctx))
return false;

// Check the cancellation token
if (token.IsCancellationRequested)
return false;

return true;
};
}
#endif

var model = await Task.Run(() =>
{
try
{
// Load the model
var weights = SafeLlamaModelHandle.LoadFromFile(modelPath, lparams);

// Apply the LoRA adapters
for (var i = 0; i < loraAdapters.Length; i++)
{
// Interrupt applying LoRAs if the token is cancelled
if (token.IsCancellationRequested)
{
weights.Dispose();
token.ThrowIfCancellationRequested();
}

// Don't apply invalid adapters
var adapter = loraAdapters[i];
if (string.IsNullOrEmpty(adapter.Path))
continue;
if (adapter.Scale <= 0)
continue;

weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, loraBase);

// Report progress. Model loading reported progress from 0 -> 0.9, use
// the last 0.1 to represent all of the LoRA adapters being applied.
progressReporter?.Report(0.9f + (0.1f / loraAdapters.Length) * (i + 1));
}

// Update progress reporter to indicate completion
progressReporter?.Report(1);

return new LLamaWeights(weights);
}
catch (LoadWeightsFailedException)
{
// Convert a LoadWeightsFailedException into a cancellation exception if possible.
token.ThrowIfCancellationRequested();

// Ok the weights failed to load for some reason other than cancellation.
throw;
}
}, token);

return model;
}
}

/// <inheritdoc />
public void Dispose()
{
Expand Down
2 changes: 2 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace LLama.Native
/// </summary>
/// <param name="progress"></param>
/// <param name="ctx"></param>
/// <returns>If the provided progress_callback returns true, model loading continues.
/// If it returns false, model loading is immediately aborted.</returns>
/// <remarks>llama_progress_callback</remarks>
public delegate bool LlamaProgressCallback(float progress, IntPtr ctx);

Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public unsafe struct LLamaModelParams
// as NET Framework 4.8 does not play nice with the LlamaProgressCallback type
public IntPtr progress_callback;
#else
public LlamaProgressCallback progress_callback;
public LlamaProgressCallback? progress_callback;
#endif

/// <summary>
Expand Down
7 changes: 5 additions & 2 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelPara
if (!fs.CanRead)
throw new InvalidOperationException($"Model file '{modelPath}' is not readable");

return llama_load_model_from_file(modelPath, lparams)
?? throw new LoadWeightsFailedException(modelPath);
var handle = llama_load_model_from_file(modelPath, lparams);
if (handle.IsInvalid)
throw new LoadWeightsFailedException(modelPath);

return handle;
}

#region native API
Expand Down