Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LLama.Examples/NewVersion/CodingAssistant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public static async Task Run()
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InstructExecutor(context, null!, InstructionPrefix, InstructionSuffix);
var executor = new InstructExecutor(context, InstructionPrefix, InstructionSuffix, null);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The executor has been enabled. In this example, the LLM will follow your instructions." +
Expand Down
2 changes: 1 addition & 1 deletion LLama.Web/Models/LLamaModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public Task<LLamaContext> CreateContext(string contextName)
if (_config.MaxInstances > -1 && ContextCount >= _config.MaxInstances)
throw new Exception($"Maximum model instances reached");

context = _weights.CreateContext(_config);
context = _weights.CreateContext(_config, _llamaLogger);
if (_contexts.TryAdd(contextName, context))
return Task.FromResult(context);

Expand Down
17 changes: 10 additions & 7 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;

namespace LLama
{
Expand All @@ -22,9 +23,10 @@ public sealed class LLamaEmbedder
/// Create a new embedder (loading temporary weights)
/// </summary>
/// <param name="allParams"></param>
/// <param name="logger"></param>
[Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(ILLamaParams allParams)
: this(allParams, allParams)
public LLamaEmbedder(ILLamaParams allParams, ILogger? logger = null)
: this(allParams, allParams, logger)
{
}

Expand All @@ -33,24 +35,26 @@ public LLamaEmbedder(ILLamaParams allParams)
/// </summary>
/// <param name="modelParams"></param>
/// <param name="contextParams"></param>
/// <param name="logger"></param>
[Obsolete("Preload LLamaWeights and use the constructor which accepts them")]
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams)
public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams, ILogger? logger = null)
{
using var weights = LLamaWeights.LoadFromFile(modelParams);

contextParams.EmbeddingMode = true;
_ctx = weights.CreateContext(contextParams);
_ctx = weights.CreateContext(contextParams, logger);
}

/// <summary>
/// Create a new embedder, using the given LLamaWeights
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params)
/// <param name="logger"></param>
public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
@params.EmbeddingMode = true;
_ctx = weights.CreateContext(@params);
_ctx = weights.CreateContext(@params, logger);
}

/// <summary>
Expand Down Expand Up @@ -89,7 +93,6 @@ public float[] GetEmbeddings(string text)
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
{

var embed_inp_array = _ctx.Tokenize(text, addBos);

// TODO(Rinne): deal with log of prompt
Expand Down
3 changes: 2 additions & 1 deletion LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// </summary>
/// <param name="context"></param>
/// <param name="logger"></param>
protected StatefulExecutorBase(LLamaContext context)
protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
{
_logger = logger;
Context = context;
_pastTokensCount = 0;
_consumedTokensCount = 0;
Expand Down
12 changes: 8 additions & 4 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace LLama
/// <summary>
/// The LLama executor for instruct mode.
/// </summary>
public class InstructExecutor : StatefulExecutorBase
public class InstructExecutor
: StatefulExecutorBase
{
private bool _is_prompt_run = true;
private readonly string _instructionPrefix;
Expand All @@ -28,11 +29,14 @@ public class InstructExecutor : StatefulExecutorBase
///
/// </summary>
/// <param name="context"></param>
/// <param name="logger"></param>
/// <param name="instructionPrefix"></param>
/// <param name="instructionSuffix"></param>
public InstructExecutor(LLamaContext context, ILogger logger = null!, string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n") : base(context)
/// <param name="logger"></param>
public InstructExecutor(LLamaContext context,
string instructionPrefix = "\n\n### Instruction:\n\n",
string instructionSuffix = "\n\n### Response:\n\n",
ILogger? logger = null)
: base(context, logger)
{
_inp_pfx = Context.Tokenize(instructionPrefix, true);
_inp_sfx = Context.Tokenize(instructionSuffix, false);
Expand Down
3 changes: 2 additions & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public class InteractiveExecutor : StatefulExecutorBase
/// </summary>
/// <param name="context"></param>
/// <param name="logger"></param>
public InteractiveExecutor(LLamaContext context) : base(context)
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
_llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle);
}
Expand Down
11 changes: 6 additions & 5 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ namespace LLama
public class StatelessExecutor
: ILLamaExecutor
{
private readonly ILogger? _logger;
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;

/// <summary>
/// The context used by the executor when running the inference.
Expand All @@ -36,24 +36,25 @@ public class StatelessExecutor
/// <param name="weights"></param>
/// <param name="params"></param>
/// <param name="logger"></param>
public StatelessExecutor(LLamaWeights weights, IContextParams @params)
public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
_weights = weights;
_params = @params;
_logger = logger;

Context = _weights.CreateContext(_params);
Context = _weights.CreateContext(_params, logger);
Context.Dispose();
}

/// <inheritdoc />
public async IAsyncEnumerable<string> InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
using var context = _weights.CreateContext(_params);
using var context = _weights.CreateContext(_params, _logger);
Context = context;

if (!Context.NativeHandle.IsClosed)
Context.Dispose();
Context = _weights.CreateContext(Context.Params);
Context = _weights.CreateContext(Context.Params, _logger);

if (inferenceParams != null)
{
Expand Down
5 changes: 3 additions & 2 deletions LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ public void Dispose()
/// Create a llama_context using this model
/// </summary>
/// <param name="params"></param>
/// <param name="logger"></param>
/// <returns></returns>
public LLamaContext CreateContext(IContextParams @params)
public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null)
{
return new LLamaContext(this, @params);
return new LLamaContext(this, @params, logger);
}
}
}