diff --git a/LLama.Examples/NewVersion/CodingAssistant.cs b/LLama.Examples/NewVersion/CodingAssistant.cs index 9108e01db..727dd0103 100644 --- a/LLama.Examples/NewVersion/CodingAssistant.cs +++ b/LLama.Examples/NewVersion/CodingAssistant.cs @@ -31,7 +31,7 @@ public static async Task Run() }; using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); - var executor = new InstructExecutor(context, null!, InstructionPrefix, InstructionSuffix); + var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null); Console.ForegroundColor = ConsoleColor.Yellow; Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions." + diff --git a/LLama.Web/Models/LLamaModel.cs b/LLama.Web/Models/LLamaModel.cs index 61341d422..5aedc5f5e 100644 --- a/LLama.Web/Models/LLamaModel.cs +++ b/LLama.Web/Models/LLamaModel.cs @@ -58,7 +58,7 @@ public Task CreateContext(string contextName) if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances) throw new Exception($"Maximum model instances reached"); - context = _weights.CreateContext(_config); + context = _weights.CreateContext(_config, _llamaLogger); if (_contexts.TryAdd(contextName, context)) return Task.FromResult(context); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index fde901b1b..ee23cd393 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -2,6 +2,7 @@ using System; using LLama.Exceptions; using LLama.Abstractions; +using Microsoft.Extensions.Logging; namespace LLama { @@ -22,9 +23,10 @@ public sealed class LLamaEmbedder /// Create a new embedder (loading temporary weights) /// /// + /// [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] - public LLamaEmbedder(ILLamaParams allParams) - : this(allParams, allParams) + public LLamaEmbedder(ILLamaParams allParams, ILogger? logger = null) + : this(allParams, allParams, logger) { } @@ -33,13 +35,14 @@ public LLamaEmbedder(ILLamaParams allParams) /// /// /// + /// [Obsolete("Preload LLamaWeights and use the constructor which accepts them")] - public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) + public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams, ILogger? logger = null) { using var weights = LLamaWeights.LoadFromFile(modelParams); contextParams.EmbeddingMode = true; - _ctx = weights.CreateContext(contextParams); + _ctx = weights.CreateContext(contextParams, logger); } /// @@ -47,10 +50,11 @@ public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) /// /// /// - public LLamaEmbedder(LLamaWeights weights, IContextParams @params) + /// + public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { @params.EmbeddingMode = true; - _ctx = weights.CreateContext(@params); + _ctx = weights.CreateContext(@params, logger); } /// @@ -89,7 +93,6 @@ public float[] GetEmbeddings(string text) /// public float[] GetEmbeddings(string text, bool addBos) { - var embed_inp_array = _ctx.Tokenize(text, addBos); // TODO(Rinne): deal with log of prompt diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 0c8e46793..1a12c6b2d 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -75,8 +75,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// /// /// - protected StatefulExecutorBase(LLamaContext context) + protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) { + _logger = logger; Context = context; _pastTokensCount = 0; _consumedTokensCount = 0; diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index c7cb55fe2..9e4292ea6 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -17,7 +17,8 @@ namespace LLama /// /// The LLama executor for instruct mode. /// - public class InstructExecutor : StatefulExecutorBase + public class InstructExecutor + : StatefulExecutorBase { private bool _is_prompt_run = true; private readonly string _instructionPrefix; @@ -28,11 +29,14 @@ public class InstructExecutor : StatefulExecutorBase /// /// /// - /// /// /// - public InstructExecutor(LLamaContext context, ILogger logger = null!, string instructionPrefix = "\n\n### Instruction:\n\n", - string instructionSuffix = "\n\n### Response:\n\n") : base(context) + /// + public InstructExecutor(LLamaContext context, + string instructionPrefix = "\n\n### Instruction:\n\n", + string instructionSuffix = "\n\n### Response:\n\n", + ILogger? logger = null) + : base(context, logger) { _inp_pfx = Context.Tokenize(instructionPrefix, true); _inp_sfx = Context.Tokenize(instructionSuffix, false); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 8247ca108..d3d4a9e39 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -27,7 +27,8 @@ public class InteractiveExecutor : StatefulExecutorBase /// /// /// - public InteractiveExecutor(LLamaContext context) : base(context) + public InteractiveExecutor(LLamaContext context, ILogger? logger = null) + : base(context, logger) { _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index d1b73c2fb..80488b712 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -21,9 +21,9 @@ namespace LLama public class StatelessExecutor : ILLamaExecutor { - private readonly ILogger? _logger; private readonly LLamaWeights _weights; private readonly IContextParams _params; + private readonly ILogger? _logger; /// /// The context used by the executor when running the inference. @@ -36,24 +36,25 @@ public class StatelessExecutor /// /// /// - public StatelessExecutor(LLamaWeights weights, IContextParams @params) + public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { _weights = weights; _params = @params; + _logger = logger; - Context = _weights.CreateContext(_params); + Context = _weights.CreateContext(_params, logger); Context.Dispose(); } /// public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - using var context = _weights.CreateContext(_params); + using var context = _weights.CreateContext(_params, _logger); Context = context; if (!Context.NativeHandle.IsClosed) Context.Dispose(); - Context = _weights.CreateContext(Context.Params); + Context = _weights.CreateContext(Context.Params, _logger); if (inferenceParams != null) { diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 5dc2024da..64878e2ab 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -81,10 +81,11 @@ public void Dispose() /// Create a llama_context using this model /// /// + /// /// - public LLamaContext CreateContext(IContextParams @params) + public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null) { - return new LLamaContext(this, @params); + return new LLamaContext(this, @params, logger); } } }