Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion LLama.Unittest/BeamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public void BasicBeam()

var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens.AsSpan(), 0);
//context.Eval(initial_tokens.AsSpan(), 0);
throw new NotImplementedException("Replace Eval");

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{
Expand Down
61 changes: 0 additions & 61 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -370,67 +370,6 @@ public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancel
{
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use Decode() instead")]
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
return Eval(span, pastTokensCount);
#else
// on netstandard2.0 we can't use CollectionsMarshal to get directly at the internal memory of
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
}
finally
{
System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented);
}
#endif
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use Decode() instead")]
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += (int)Params.BatchSize)
{
var n_eval = total - i;
if (n_eval > Params.BatchSize)
{
n_eval = (int)Params.BatchSize;
}

if (!NativeHandle.Eval(tokens.Slice(i, n_eval), pastTokensCount))
{
_logger?.LogError("[LLamaContext] Failed to eval.");
throw new RuntimeError("Failed to eval.");
}

pastTokensCount += n_eval;
}
return pastTokensCount;
}
#endregion

/// <inheritdoc />
Expand Down
12 changes: 9 additions & 3 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Extensions;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -178,6 +179,8 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

if (_embeds.Count > 0)
{
_is_prompt_run = false;
Expand All @@ -187,7 +190,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}

TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);

var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
if (result != DecodeResult.Ok)
throw new LLamaDecodeError(result);

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand All @@ -212,12 +218,12 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray());
inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
12 changes: 9 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Extensions;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -157,6 +158,8 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

if (_embeds.Count > 0)
{
_is_prompt_run = false;
Expand All @@ -166,7 +169,10 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}

TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);

var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
if (result != DecodeResult.Ok)
throw new LLamaDecodeError(result);

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand All @@ -191,12 +197,12 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(batch.TokenCount - 1), _last_n_tokens.ToArray());
inferenceParams.SamplingPipeline.Accept(Context.NativeHandle, id);
}
else
{
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(batch.TokenCount - 1, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
18 changes: 3 additions & 15 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,21 +81,9 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams

// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

_batch.Clear();
for (var j = 0; j < n_eval; j++)
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);

var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past);
if (r != DecodeResult.Ok)
throw new LLamaDecodeError(r);

// Begin loop, evaluating one token at a time
var mu = (float?)null;
Expand Down
16 changes: 1 addition & 15 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,6 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);

/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// tokens + n_tokens is the provided batch of new tokens to process
/// n_past is the number of tokens to use from previous eval calls
/// </summary>
/// <param name="ctx"></param>
/// <param name="tokens"></param>
/// <param name="n_tokens"></param>
/// <param name="n_past"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[Obsolete("use llama_decode() instead")]
public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token);

Expand All @@ -181,7 +167,7 @@ public static void llama_empty_call()
public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);

/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
Expand Down
67 changes: 44 additions & 23 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using LLama.Exceptions;

namespace LLama.Native
Expand Down Expand Up @@ -28,6 +30,11 @@ public sealed class SafeLLamaContextHandle
/// </summary>
public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize;

/// <summary>
/// Get the maximum batch size for this context
/// </summary>
public uint BatchSize => NativeApi.llama_n_batch(this);

/// <summary>
/// Get the model which this context is using
/// </summary>
Expand Down Expand Up @@ -108,7 +115,7 @@ static SafeLLamaContextHandle()
#endregion

/// <summary>
/// Token logits obtained from the last call to llama_eval()
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
Expand Down Expand Up @@ -170,26 +177,6 @@ public uint TokenToSpan(LLamaToken token, Span<byte> dest)
#endregion

#region infer
/// <summary>
/// Run the llama inference to obtain the logits and probabilities for the next token.
/// </summary>
/// <param name="tokens">The provided batch of new tokens to process</param>
/// <param name="n_past">the number of tokens to use from previous eval calls</param>
/// <returns>Returns true on success</returns>
[Obsolete("use llama_decode() instead")]
public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past)
{
unsafe
{
fixed (LLamaToken* pinned = tokens)
{
// the entire `eval` system needs replacing with the new batch system!
var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past);
return ret == 0;
}
}
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
Expand All @@ -198,10 +185,44 @@ public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past)
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
public DecodeResult Decode(LLamaBatch batch)
{
using (batch.ToNativeBatch(out var nb))
return NativeApi.llama_decode(this, nb);
return (DecodeResult)NativeApi.llama_decode(this, nb);
}

/// <summary>
/// Decode a set of tokens in batch-size chunks.
/// </summary>
/// <param name="tokens"></param>
/// <param name="id"></param>
/// <param name="batch"></param>
/// <param name="n_past"></param>
/// <returns>A tuple, containing the decode result and the number of tokens that have <b>not</b> been decoded yet.</returns>
internal (DecodeResult, int) Decode(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past)
{
var batchSize = checked((int)BatchSize);

// Evaluate the prompt, in chunks smaller than the max batch size
var n_left = tokens.Count;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

batch.Clear();
for (var j = 0; j < n_eval; j++)
batch.Add(tokens[i + j], n_past++, id, (i + j) == tokens.Count - 1);

var returnCode = Decode(batch);
if (returnCode != DecodeResult.Ok)
return (returnCode, n_left);

n_left -= n_eval;
}

return (DecodeResult.Ok, 0);
}
#endregion

Expand Down