Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Diagnostics;
using System.Text;
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;

Expand Down Expand Up @@ -94,7 +93,7 @@ public static async Task Run()
var n_cur = batch.NativeBatch.n_tokens;
var n_decode = 0;

var streams = new List<int>[n_parallel];
var streams = new List<LLamaToken>[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new();

Expand Down
7 changes: 4 additions & 3 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Common;
using LLama.Native;

namespace LLama.Unittest
{
Expand Down Expand Up @@ -37,23 +38,23 @@ public void Tokenize()
{
var tokens = _context.Tokenize("The quick brown fox", true);

Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
public void TokenizeWithoutBOS()
{
var tokens = _context.Tokenize("The quick brown fox", false);

Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens);
Assert.Equal(new LLamaToken[] { 450, 4996, 17354, 1701, 29916 }, tokens);
}

[Fact]
public void TokenizeEmpty()
{
var tokens = _context.Tokenize("", false);

Assert.Equal(Array.Empty<int>(), tokens);
Assert.Equal(Array.Empty<LLamaToken>(), tokens);
}
}
}
2 changes: 1 addition & 1 deletion LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class InferenceOptions
public int MaxTokens { get; set; } = -1;

/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
Expand Down
2 changes: 1 addition & 1 deletion LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public interface IInferenceParams
/// <summary>
/// logit bias for specific tokens
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }
public Dictionary<LLamaToken, float>? LogitBias { get; set; }

/// <summary>
/// Sequences where the model will stop generating further tokens.
Expand Down
4 changes: 1 addition & 3 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

namespace LLama.Common
{
using llama_token = Int32;

/// <summary>
/// The paramters used for inference.
/// </summary>
Expand All @@ -28,7 +26,7 @@ public record InferenceParams
/// <summary>
/// logit bias for specific tokens
/// </summary>
public Dictionary<llama_token, float>? LogitBias { get; set; } = null;
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
Expand Down
4 changes: 2 additions & 2 deletions LLama/Extensions/IReadOnlyListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal static class IReadOnlyListExtensions
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TTokens : IReadOnlyList<LLamaToken>
where TQueries : IReadOnlyList<string>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
Expand Down Expand Up @@ -79,7 +79,7 @@ internal static bool TokensEndsWithAnyString<TTokens, TQueries>(this TTokens tok
/// <returns></returns>
[Obsolete("Use an Antiprompt processor instead")]
internal static bool TokensEndsWithAnyString<TTokens>(this TTokens tokens, IList<string>? queries, SafeLlamaModelHandle model, Encoding encoding)
where TTokens : IReadOnlyList<int>
where TTokens : IReadOnlyList<LLamaToken>
{
if (queries == null || queries.Count == 0 || tokens.Count == 0)
return false;
Expand Down
36 changes: 17 additions & 19 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

namespace LLama
{
using llama_token = Int32;

/// <summary>
/// A llama_context, which holds all the context required to interact with a model
/// </summary>
Expand Down Expand Up @@ -93,7 +91,7 @@ public void SetSeed(uint seed)
/// <param name="addBos">Whether to add a bos to the text.</param>
/// <param name="special">Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.</param>
/// <returns></returns>
public llama_token[] Tokenize(string text, bool addBos = true, bool special = false)
public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false)
{
return NativeHandle.Tokenize(text, addBos, special, Encoding);
}
Expand All @@ -104,7 +102,7 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa
/// <param name="tokens"></param>
/// <returns></returns>
[Obsolete("Use a `StreamingTokenDecoder` instead")]
public string DeTokenize(IReadOnlyList<llama_token> tokens)
public string DeTokenize(IReadOnlyList<LLamaToken> tokens)
{
// Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder!
// It should be kept around for the entire time you are decoding one stream of tokens.
Expand Down Expand Up @@ -219,7 +217,7 @@ public void LoadState(State state)
/// <param name="pipeline">The pipeline to use to process the logits and to select a token</param>
/// <param name="lastTokens">The tokens recently returned from the model</param>
/// <returns>The selected token</returns>
public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token> lastTokens)
public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan<LLamaToken> lastTokens)
{
return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens);
}
Expand All @@ -240,11 +238,11 @@ public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan<llama_token>
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;
LLamaToken id;

if (grammar != null)
{
Expand Down Expand Up @@ -301,7 +299,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dictionary<llama_token, float>? logitBias = null,
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
Expand All @@ -311,12 +309,12 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
if (logitBias is not null)
{
foreach (var (key, value) in logitBias)
logits[key] += value;
logits[(int)key] += value;
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];
var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[(int)nl_token];

// Convert logits into token candidates
var candidates_p = LLamaTokenDataArray.Create(logits);
Expand Down Expand Up @@ -353,7 +351,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<llama_token> lastTokens, Dic
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(llama_token[] tokens, int pastTokensCount)
public int Eval(LLamaToken[] tokens, int pastTokensCount)
{
return Eval(tokens.AsSpan(), pastTokensCount);
}
Expand All @@ -366,7 +364,7 @@ public int Eval(llama_token[] tokens, int pastTokensCount)
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(List<llama_token> tokens, int pastTokensCount)
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
Expand All @@ -376,15 +374,15 @@ public int Eval(List<llama_token> tokens, int pastTokensCount)
// the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't
// avoid the copying.

var rented = System.Buffers.ArrayPool<llama_token>.Shared.Rent(tokens.Count);
var rented = System.Buffers.ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
tokens.CopyTo(rented, 0);
return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
}
finally
{
System.Buffers.ArrayPool<llama_token>.Shared.Return(rented);
System.Buffers.ArrayPool<LLamaToken>.Shared.Return(rented);
}
#endif
}
Expand All @@ -397,7 +395,7 @@ public int Eval(List<llama_token> tokens, int pastTokensCount)
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}
Expand All @@ -410,7 +408,7 @@ public int Eval(ReadOnlyMemory<llama_token> tokens, int pastTokensCount)
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlySpan<llama_token> tokens, int pastTokensCount)
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;
for(var i = 0; i < total; i += (int)Params.BatchSize)
Expand Down
21 changes: 10 additions & 11 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The base class for stateful LLama executors.
/// </summary>
Expand Down Expand Up @@ -47,19 +46,19 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// <summary>
/// A container of the tokens to be processed and after processed.
/// </summary>
protected List<llama_token> _embeds = new(); // embd
protected List<LLamaToken> _embeds = new(); // embd
/// <summary>
/// A container for the tokens of input.
/// </summary>
protected List<llama_token> _embed_inps = new();
protected List<LLamaToken> _embed_inps = new();
/// <summary>
///
/// </summary>
protected List<llama_token> _session_tokens = new();
protected List<LLamaToken> _session_tokens = new();
/// <summary>
/// The last tokens generated by the model.
/// </summary>
protected FixedSizeQueue<llama_token> _last_n_tokens;
protected FixedSizeQueue<LLamaToken> _last_n_tokens;
/// <summary>
/// The context used by the executor.
/// </summary>
Expand All @@ -84,7 +83,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}

Expand All @@ -105,7 +104,7 @@ public StatefulExecutorBase WithSessionFile(string filename)
if (File.Exists(filename))
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
var session_tokens = new llama_token[Context.ContextSize];
var session_tokens = new LLamaToken[Context.ContextSize];
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
Expand Down Expand Up @@ -361,16 +360,16 @@ public class ExecutorBaseState
public string? SessionFilePath { get; set; }

[JsonPropertyName("embd")]
public List<llama_token> Embeds { get; set; }
public List<LLamaToken> Embeds { get; set; }

[JsonPropertyName("embd_inps")]
public List<llama_token> EmbedInps { get; set; }
public List<LLamaToken> EmbedInps { get; set; }

[JsonPropertyName("session_tokens")]
public List<llama_token> SessionTokens { get; set; }
public List<LLamaToken> SessionTokens { get; set; }

[JsonPropertyName("last_n_tokens")]
public llama_token[] LastTokens { get; set; }
public LLamaToken[] LastTokens { get; set; }

[JsonPropertyName("last_tokens_maximum_count")]
public int LastTokensCapacity { get; set; }
Expand Down
13 changes: 6 additions & 7 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The LLama executor for instruct mode.
/// </summary>
Expand All @@ -22,8 +21,8 @@ public class InstructExecutor
{
private bool _is_prompt_run = true;
private readonly string _instructionPrefix;
private llama_token[] _inp_pfx;
private llama_token[] _inp_sfx;
private LLamaToken[] _inp_pfx;
private LLamaToken[] _inp_sfx;

/// <summary>
///
Expand Down Expand Up @@ -75,7 +74,7 @@ public override Task LoadState(ExecutorBaseState data)
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_inp_pfx = state.InputPrefixTokens;
_inp_sfx = state.InputSuffixTokens;
_n_matching_session_tokens = state.MatchingSessionTokensCount;
Expand Down Expand Up @@ -210,7 +209,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
SaveSessionFile(_pathSession);
}

llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
Expand Down Expand Up @@ -266,12 +265,12 @@ public class InstructExecutorState : ExecutorBaseState
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public llama_token[] InputPrefixTokens { get; set; }
public LLamaToken[] InputPrefixTokens { get; set; }
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public llama_token[] InputSuffixTokens { get; set; }
public LLamaToken[] InputSuffixTokens { get; set; }
}
}
}
7 changes: 3 additions & 4 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// The LLama executor for interactive mode.
/// </summary>
public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;
private readonly llama_token _llama_token_newline;
private readonly LLamaToken _llama_token_newline;

/// <summary>
///
Expand Down Expand Up @@ -63,7 +62,7 @@ public override Task LoadState(ExecutorBaseState data)
_is_prompt_run = state.IsPromptRun;
_consumedTokensCount = state.ConsumedTokensCount;
_embeds = state.Embeds;
_last_n_tokens = new FixedSizeQueue<llama_token>(state.LastTokensCapacity, state.LastTokens);
_last_n_tokens = new FixedSizeQueue<LLamaToken>(state.LastTokensCapacity, state.LastTokens);
_n_matching_session_tokens = state.MatchingSessionTokensCount;
_pastTokensCount = state.PastTokensCount;
_pathSession = state.SessionFilePath;
Expand Down Expand Up @@ -189,7 +188,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
SaveSessionFile(_pathSession);
}

llama_token id;
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray());
Expand Down
Loading