From 42be9b136da302785c640329d8e00e9568e7556f Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 20:47:21 +0000 Subject: [PATCH 1/3] Switched form using raw integers, to a `LLamaToken` struct --- LLama.Examples/Examples/BatchedDecoding.cs | 3 +- LLama.Unittest/LLamaContextTests.cs | 7 ++-- LLama.Web/Common/InferenceOptions.cs | 2 +- LLama/Abstractions/IInferenceParams.cs | 2 +- LLama/Common/InferenceParams.cs | 4 +-- LLama/Extensions/IReadOnlyListExtensions.cs | 4 +-- LLama/LLamaContext.cs | 36 +++++++++---------- LLama/LLamaExecutorBase.cs | 21 ++++++------ LLama/LLamaInstructExecutor.cs | 13 ++++--- LLama/LLamaInteractExecutor.cs | 7 ++-- LLama/LLamaStatelessExecutor.cs | 6 ++-- LLama/LLamaWeights.cs | 6 ++-- LLama/Native/LLamaBatchSafeHandle.cs | 18 +++++----- LLama/Native/LLamaBeamView.cs | 8 ++--- LLama/Native/LLamaNativeBatch.cs | 7 ++-- LLama/Native/LLamaToken.cs | 38 +++++++++++++++++++++ LLama/Native/LLamaTokenDataArray.cs | 16 ++++----- LLama/Native/NativeApi.Grammar.cs | 4 +-- LLama/Native/NativeApi.Sampling.cs | 15 ++++---- LLama/Native/NativeApi.cs | 30 ++++++++-------- LLama/Native/SafeLLamaContextHandle.cs | 14 ++++---- LLama/Native/SafeLLamaGrammarHandle.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 20 +++++------ LLama/Native/SamplingApi.cs | 10 +++--- LLama/Sampling/BaseSamplingPipeline.cs | 14 ++++---- LLama/Sampling/DefaultSamplingPipeline.cs | 12 +++---- LLama/Sampling/ISamplingPipeline.cs | 8 ++--- LLama/StreamingTokenDecoder.cs | 19 +++++++++++ 28 files changed, 187 insertions(+), 159 deletions(-) create mode 100644 LLama/Native/LLamaToken.cs diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 9042a60b5..306e74c13 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -1,6 +1,5 @@ using System.Diagnostics; using System.Text; -using LLama.Abstractions; using LLama.Common; using LLama.Native; @@ -94,7 +93,7 @@ public static async Task Run() var n_cur = batch.NativeBatch.n_tokens; var n_decode = 0; - var streams = new List[n_parallel]; + var streams = new List[n_parallel]; for (var i = 0; i < n_parallel; i++) streams[i] = new(); diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 2edf3a623..7bc2f7ff4 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -1,4 +1,5 @@ using LLama.Common; +using LLama.Native; namespace LLama.Unittest { @@ -37,7 +38,7 @@ public void Tokenize() { var tokens = _context.Tokenize("The quick brown fox", true); - Assert.Equal(new[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens); + Assert.Equal([ 1, 450, 4996, 17354, 1701, 29916 ], tokens); } [Fact] @@ -45,7 +46,7 @@ public void TokenizeWithoutBOS() { var tokens = _context.Tokenize("The quick brown fox", false); - Assert.Equal(new[] { 450, 4996, 17354, 1701, 29916 }, tokens); + Assert.Equal([450, 4996, 17354, 1701, 29916], tokens); } [Fact] @@ -53,7 +54,7 @@ public void TokenizeEmpty() { var tokens = _context.Tokenize("", false); - Assert.Equal(Array.Empty(), tokens); + Assert.Equal(Array.Empty(), tokens); } } } diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index c604dc0d1..30c5ccebc 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -17,7 +17,7 @@ public class InferenceOptions public int MaxTokens { get; set; } = -1; /// - public Dictionary? LogitBias { get; set; } = null; + public Dictionary? LogitBias { get; set; } = null; /// public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index e1e894143..fd8d4189a 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -24,7 +24,7 @@ public interface IInferenceParams /// /// logit bias for specific tokens /// - public Dictionary? LogitBias { get; set; } + public Dictionary? LogitBias { get; set; } /// /// Sequences where the model will stop generating further tokens. diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index 0e6020ad4..c0a8357e7 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -6,8 +6,6 @@ namespace LLama.Common { - using llama_token = Int32; - /// /// The paramters used for inference. /// @@ -28,7 +26,7 @@ public record InferenceParams /// /// logit bias for specific tokens /// - public Dictionary? LogitBias { get; set; } = null; + public Dictionary? LogitBias { get; set; } = null; /// /// Sequences where the model will stop generating further tokens. diff --git a/LLama/Extensions/IReadOnlyListExtensions.cs b/LLama/Extensions/IReadOnlyListExtensions.cs index 7a3473b71..f23c54d65 100644 --- a/LLama/Extensions/IReadOnlyListExtensions.cs +++ b/LLama/Extensions/IReadOnlyListExtensions.cs @@ -38,7 +38,7 @@ internal static class IReadOnlyListExtensions /// [Obsolete("Use an Antiprompt processor instead")] internal static bool TokensEndsWithAnyString(this TTokens tokens, TQueries? queries, SafeLlamaModelHandle model, Encoding encoding) - where TTokens : IReadOnlyList + where TTokens : IReadOnlyList where TQueries : IReadOnlyList { if (queries == null || queries.Count == 0 || tokens.Count == 0) @@ -79,7 +79,7 @@ internal static bool TokensEndsWithAnyString(this TTokens tok /// [Obsolete("Use an Antiprompt processor instead")] internal static bool TokensEndsWithAnyString(this TTokens tokens, IList? queries, SafeLlamaModelHandle model, Encoding encoding) - where TTokens : IReadOnlyList + where TTokens : IReadOnlyList { if (queries == null || queries.Count == 0 || tokens.Count == 0) return false; diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index abd8f879f..cd755501c 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -15,8 +15,6 @@ namespace LLama { - using llama_token = Int32; - /// /// A llama_context, which holds all the context required to interact with a model /// @@ -93,7 +91,7 @@ public void SetSeed(uint seed) /// Whether to add a bos to the text. /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// - public llama_token[] Tokenize(string text, bool addBos = true, bool special = false) + public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = false) { return NativeHandle.Tokenize(text, addBos, special, Encoding); } @@ -104,7 +102,7 @@ public llama_token[] Tokenize(string text, bool addBos = true, bool special = fa /// /// [Obsolete("Use a `StreamingTokenDecoder` instead")] - public string DeTokenize(IReadOnlyList tokens) + public string DeTokenize(IReadOnlyList tokens) { // Do **not** use this method as an example of how to correctly use the StreamingTokenDecoder! // It should be kept around for the entire time you are decoding one stream of tokens. @@ -219,7 +217,7 @@ public void LoadState(State state) /// The pipeline to use to process the logits and to select a token /// The tokens recently returned from the model /// The selected token - public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) + public LLamaToken Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) { return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); } @@ -240,11 +238,11 @@ public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan /// /// /// - public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, - float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, - SafeLLamaGrammarHandle? grammar, float minP) + public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, + float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, + SafeLLamaGrammarHandle? grammar, float minP) { - llama_token id; + LLamaToken id; if (grammar != null) { @@ -301,7 +299,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu /// /// /// - public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, + public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { @@ -311,12 +309,12 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic if (logitBias is not null) { foreach (var (key, value) in logitBias) - logits[key] += value; + logits[key.Value] += value; } // Save the newline logit value - var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); - var nl_logit = logits[nl_token]; + var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); + var nl_logit = logits[nl_token.Value]; // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); @@ -353,7 +351,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic /// The updated `pastTokensCount`. /// [Obsolete("use llama_decode() instead")] - public int Eval(llama_token[] tokens, int pastTokensCount) + public int Eval(LLamaToken[] tokens, int pastTokensCount) { return Eval(tokens.AsSpan(), pastTokensCount); } @@ -366,7 +364,7 @@ public int Eval(llama_token[] tokens, int pastTokensCount) /// The updated `pastTokensCount`. /// [Obsolete("use llama_decode() instead")] - public int Eval(List tokens, int pastTokensCount) + public int Eval(List tokens, int pastTokensCount) { #if NET5_0_OR_GREATER var span = CollectionsMarshal.AsSpan(tokens); @@ -376,7 +374,7 @@ public int Eval(List tokens, int pastTokensCount) // the list. Instead rent an array and copy the data into it. This avoids an allocation, but can't // avoid the copying. - var rented = System.Buffers.ArrayPool.Shared.Rent(tokens.Count); + var rented = System.Buffers.ArrayPool.Shared.Rent(tokens.Count); try { tokens.CopyTo(rented, 0); @@ -384,7 +382,7 @@ public int Eval(List tokens, int pastTokensCount) } finally { - System.Buffers.ArrayPool.Shared.Return(rented); + System.Buffers.ArrayPool.Shared.Return(rented); } #endif } @@ -397,7 +395,7 @@ public int Eval(List tokens, int pastTokensCount) /// The updated `pastTokensCount`. /// [Obsolete("use llama_decode() instead")] - public int Eval(ReadOnlyMemory tokens, int pastTokensCount) + public int Eval(ReadOnlyMemory tokens, int pastTokensCount) { return Eval(tokens.Span, pastTokensCount); } @@ -410,7 +408,7 @@ public int Eval(ReadOnlyMemory tokens, int pastTokensCount) /// The updated `pastTokensCount`. /// [Obsolete("use llama_decode() instead")] - public int Eval(ReadOnlySpan tokens, int pastTokensCount) + public int Eval(ReadOnlySpan tokens, int pastTokensCount) { var total = tokens.Length; for(var i = 0; i < total; i += (int)Params.BatchSize) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e7b768be1..4713166ea 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -14,7 +14,6 @@ namespace LLama { - using llama_token = Int32; /// /// The base class for stateful LLama executors. /// @@ -47,19 +46,19 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// /// A container of the tokens to be processed and after processed. /// - protected List _embeds = new(); // embd + protected List _embeds = new(); // embd /// /// A container for the tokens of input. /// - protected List _embed_inps = new(); + protected List _embed_inps = new(); /// /// /// - protected List _session_tokens = new(); + protected List _session_tokens = new(); /// /// The last tokens generated by the model. /// - protected FixedSizeQueue _last_n_tokens; + protected FixedSizeQueue _last_n_tokens; /// /// The context used by the executor. /// @@ -84,7 +83,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) _pastTokensCount = 0; _consumedTokensCount = 0; _n_session_consumed = 0; - _last_n_tokens = new FixedSizeQueue(Context.ContextSize); + _last_n_tokens = new FixedSizeQueue(Context.ContextSize); _decoder = new StreamingTokenDecoder(context); } @@ -105,7 +104,7 @@ public StatefulExecutorBase WithSessionFile(string filename) if (File.Exists(filename)) { _logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}"); - var session_tokens = new llama_token[Context.ContextSize]; + var session_tokens = new LLamaToken[Context.ContextSize]; if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out)) { _logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}"); @@ -361,16 +360,16 @@ public class ExecutorBaseState public string? SessionFilePath { get; set; } [JsonPropertyName("embd")] - public List Embeds { get; set; } + public List Embeds { get; set; } [JsonPropertyName("embd_inps")] - public List EmbedInps { get; set; } + public List EmbedInps { get; set; } [JsonPropertyName("session_tokens")] - public List SessionTokens { get; set; } + public List SessionTokens { get; set; } [JsonPropertyName("last_n_tokens")] - public llama_token[] LastTokens { get; set; } + public LLamaToken[] LastTokens { get; set; } [JsonPropertyName("last_tokens_maximum_count")] public int LastTokensCapacity { get; set; } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 3ed668903..b763145eb 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -13,7 +13,6 @@ namespace LLama { - using llama_token = Int32; /// /// The LLama executor for instruct mode. /// @@ -22,8 +21,8 @@ public class InstructExecutor { private bool _is_prompt_run = true; private readonly string _instructionPrefix; - private llama_token[] _inp_pfx; - private llama_token[] _inp_sfx; + private LLamaToken[] _inp_pfx; + private LLamaToken[] _inp_sfx; /// /// @@ -75,7 +74,7 @@ public override Task LoadState(ExecutorBaseState data) _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; _embeds = state.Embeds; - _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); + _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _inp_pfx = state.InputPrefixTokens; _inp_sfx = state.InputSuffixTokens; _n_matching_session_tokens = state.MatchingSessionTokensCount; @@ -210,7 +209,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta SaveSessionFile(_pathSession); } - llama_token id; + LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); @@ -266,12 +265,12 @@ public class InstructExecutorState : ExecutorBaseState /// Instruction prefix tokens. /// [JsonPropertyName("inp_pfx")] - public llama_token[] InputPrefixTokens { get; set; } + public LLamaToken[] InputPrefixTokens { get; set; } /// /// Instruction suffix tokens. /// [JsonPropertyName("inp_sfx")] - public llama_token[] InputSuffixTokens { get; set; } + public LLamaToken[] InputSuffixTokens { get; set; } } } } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 9cecf4378..11973a273 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -13,14 +13,13 @@ namespace LLama { - using llama_token = Int32; /// /// The LLama executor for interactive mode. /// public class InteractiveExecutor : StatefulExecutorBase { private bool _is_prompt_run = true; - private readonly llama_token _llama_token_newline; + private readonly LLamaToken _llama_token_newline; /// /// @@ -63,7 +62,7 @@ public override Task LoadState(ExecutorBaseState data) _is_prompt_run = state.IsPromptRun; _consumedTokensCount = state.ConsumedTokensCount; _embeds = state.Embeds; - _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); + _last_n_tokens = new FixedSizeQueue(state.LastTokensCapacity, state.LastTokens); _n_matching_session_tokens = state.MatchingSessionTokensCount; _pastTokensCount = state.PastTokensCount; _pathSession = state.SessionFilePath; @@ -189,7 +188,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In SaveSessionFile(_pathSession); } - llama_token id; + LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 7922db2b2..217e4f305 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -12,8 +12,6 @@ namespace LLama { - using llama_token = Int32; - /// /// This executor infer the input as one-time job. Previous inputs won't impact on the /// response to current input. @@ -71,7 +69,7 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Keep track of the last N tokens emitted var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); - var lastTokens = new List(repeat_last_n); + var lastTokens = new List(repeat_last_n); for (var i = 0; i < repeat_last_n; i++) lastTokens.Add(0); @@ -89,7 +87,7 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - llama_token id; + LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 5cb482add..df0eadcad 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -42,17 +42,17 @@ public sealed class LLamaWeights /// /// Get the newline token for this model /// - public int NewlineToken => NativeApi.llama_token_nl(NativeHandle); + public LLamaToken NewlineToken => NativeApi.llama_token_nl(NativeHandle); /// /// Get the "end of sentence" token for this model /// - public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); + public LLamaToken EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); /// /// Get the "beginning of sentence" token for this model /// - public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); + public LLamaToken BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); /// /// Dimension of embedding vectors diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs index 30e703946..e67a61b44 100644 --- a/LLama/Native/LLamaBatchSafeHandle.cs +++ b/LLama/Native/LLamaBatchSafeHandle.cs @@ -2,8 +2,6 @@ namespace LLama.Native; -using llama_token = Int32; - /// /// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. /// @@ -20,16 +18,16 @@ public sealed class LLamaBatchSafeHandle /// /// the token ids of the input (used when embd is NULL) /// - public Span Token + public Span Token { get { unsafe { if (_embd != 0) - return new Span(null, 0); + return new Span(null, 0); else - return new Span(NativeBatch.token, NativeBatch.n_tokens); + return new Span(NativeBatch.token, NativeBatch.n_tokens); } } } @@ -37,7 +35,7 @@ public Span Token /// /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) /// - public Span Embed + public Span Embed { get { @@ -47,9 +45,9 @@ public Span Embed // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token if (_embd != 0) - return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd); + return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd); else - return new Span(null, 0); + return new Span(null, 0); } } } @@ -133,11 +131,11 @@ protected override bool ReleaseHandle() /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// - public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan sequences, bool logits) + public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { unsafe { - NativeBatch.token[NativeBatch.n_tokens] = token; + NativeBatch.token[NativeBatch.n_tokens] = token.Value; NativeBatch.pos[NativeBatch.n_tokens] = pos; NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs index e6bc504ee..e832eb620 100644 --- a/LLama/Native/LLamaBeamView.cs +++ b/LLama/Native/LLamaBeamView.cs @@ -3,15 +3,13 @@ namespace LLama.Native; -using llama_token = Int32; - /// /// Information about a single beam in a beam search /// [StructLayout(LayoutKind.Sequential)] public struct LLamaBeamView { - private unsafe llama_token* tokens; + private unsafe LLamaToken* tokens; private nint n_tokens; /// @@ -27,7 +25,7 @@ public struct LLamaBeamView /// /// Tokens in this beam /// - public readonly Span Tokens + public readonly Span Tokens { get { @@ -35,7 +33,7 @@ public readonly Span Tokens { if (n_tokens > int.MaxValue) throw new InvalidOperationException("More than 2147483647 tokens is not supported"); - return new Span(tokens, (int)n_tokens); + return new Span(tokens, (int)n_tokens); } } } diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs index d46f8b99e..978e955c3 100644 --- a/LLama/Native/LLamaNativeBatch.cs +++ b/LLama/Native/LLamaNativeBatch.cs @@ -1,10 +1,7 @@ -using System; -using System.Runtime.InteropServices; +using System.Runtime.InteropServices; namespace LLama.Native; -using llama_token = Int32; - /// /// Input data for llama_decode /// A llama_batch object can contain input about one or many sequences @@ -21,7 +18,7 @@ public unsafe struct LLamaNativeBatch /// /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created /// - public llama_token* token; + public LLamaToken* token; /// /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs new file mode 100644 index 000000000..e0c18fa0b --- /dev/null +++ b/LLama/Native/LLamaToken.cs @@ -0,0 +1,38 @@ +using System.Runtime.InteropServices; + +namespace LLama.Native; + +/// +/// A single token +/// +[StructLayout(LayoutKind.Sequential)] +public readonly record struct LLamaToken +{ + /// + /// The raw value + /// + public readonly int Value; + + /// + /// Create a new LLamaToken + /// + /// + private LLamaToken(int value) + { + Value = value; + } + + /// + /// Convert a LLamaToken into an integer (extract the raw value) + /// + /// + /// + public static explicit operator int(LLamaToken pos) => pos.Value; + + /// + /// Convert an integer into a LLamaToken + /// + /// + /// + public static implicit operator LLamaToken(int value) => new(value); +} \ No newline at end of file diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 12f588a6c..1c02625f8 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -2,8 +2,6 @@ using System.Buffers; using System.Runtime.InteropServices; -using llama_token = System.Int32; - namespace LLama.Native { /// @@ -50,7 +48,7 @@ public static LLamaTokenDataArray Create(ReadOnlySpan logits) /// Overwrite the logit values for all given tokens /// /// tuples of token and logit value to overwrite - public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) + public void OverwriteLogits(ReadOnlySpan<(LLamaToken token, float logit)> values) { if (values.Length == 0) return; @@ -172,13 +170,13 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke /// /// /// - public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) { - fixed (int* last_tokens_handle = last_tokens) + fixed (LLamaToken* last_tokens_handle = last_tokens) { NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); sorted = st.sorted; @@ -220,7 +218,7 @@ public void Softmax(SafeLLamaContextHandle context) /// /// /// - public int SampleToken(SafeLLamaContextHandle context) + public LLamaToken SampleToken(SafeLLamaContextHandle context) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { @@ -235,7 +233,7 @@ public int SampleToken(SafeLLamaContextHandle context) /// /// /// - public int SampleTokenGreedy(SafeLLamaContextHandle context) + public LLamaToken SampleTokenGreedy(SafeLLamaContextHandle context) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { @@ -254,7 +252,7 @@ public int SampleTokenGreedy(SafeLLamaContextHandle context) /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// - public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) + public LLamaToken SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { @@ -272,7 +270,7 @@ public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// - public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) + public LLamaToken SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) { using (LLamaTokenDataArrayNative.Create(this, out var st)) { diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index 4d47872b5..48e56b726 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -3,8 +3,6 @@ namespace LLama.Native { - using llama_token = Int32; - public static partial class NativeApi { /// @@ -48,6 +46,6 @@ public static partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, llama_token token); + public static extern void llama_grammar_accept_token(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar, LLamaToken token); } } diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index 53a6dd233..7128441ea 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -1,10 +1,7 @@ -using System; -using System.Runtime.InteropServices; +using System.Runtime.InteropServices; namespace LLama.Native { - using llama_token = Int32; - public static partial class NativeApi { /// @@ -21,7 +18,7 @@ public static partial class NativeApi [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, - llama_token* last_tokens, ulong last_tokens_size, + LLamaToken* last_tokens, ulong last_tokens_size, float penalty_repeat, float penalty_freq, float penalty_present); @@ -115,7 +112,7 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); + public static extern LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, int m, ref float mu); /// /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. @@ -127,7 +124,7 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); + public static extern LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float tau, float eta, ref float mu); /// /// Selects the token with the highest probability. @@ -136,7 +133,7 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont /// Pointer to LLamaTokenDataArray /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); + public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); /// /// Randomly selects a token from the candidates based on their probabilities. @@ -145,6 +142,6 @@ public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaCont /// Pointer to LLamaTokenDataArray /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); + public static extern LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates); } } diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 2a34820d6..191cef3ca 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -7,8 +7,6 @@ namespace LLama.Native { - using llama_token = Int32; - /// /// Callback from llama.cpp with log messages /// @@ -172,7 +170,7 @@ public static partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); + public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -183,7 +181,7 @@ public static partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens, ulong n_token_count); + public static extern bool llama_save_session_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count); /// /// Run the llama inference to obtain the logits and probabilities for the next token. @@ -197,7 +195,7 @@ public static partial class NativeApi /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [Obsolete("use llama_decode() instead")] - public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); + public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, LLamaToken* tokens, int n_tokens, int n_past); /// /// Convert the provided text into tokens. @@ -212,7 +210,7 @@ public static partial class NativeApi /// Returns the number of tokens on success, no more than n_max_tokens. /// Returns a negative number on failure - the number of tokens that would have been returned /// - public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) + public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, LLamaToken[] tokens, int n_max_tokens, bool add_bos, bool special) { unsafe { @@ -233,7 +231,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi // Do the actual tokenization fixed (byte* arrayPtr = array) - fixed (llama_token* tokensPtr = tokens) + fixed (LLamaToken* tokensPtr = tokens) return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); } finally @@ -244,13 +242,13 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi } [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); + public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, LLamaToken token); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token); + public static extern float llama_token_get_score(SafeLlamaModelHandle model, LLamaToken token); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, llama_token token); + public static extern LLamaTokenType llama_token_get_type(SafeLlamaModelHandle model, LLamaToken token); /// /// Get the size of the context window for the model for this context @@ -303,21 +301,21 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_bos(SafeLlamaModelHandle model); + public static extern LLamaToken llama_token_bos(SafeLlamaModelHandle model); /// /// Get the "End of sentence" token /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_eos(SafeLlamaModelHandle model); + public static extern LLamaToken llama_token_eos(SafeLlamaModelHandle model); /// /// Get the "new line" token /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); + public static extern LLamaToken llama_token_nl(SafeLlamaModelHandle model); /// /// Returns -1 if unknown, 1 for true or 0 for false. @@ -508,7 +506,7 @@ public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, /// /// buffer to write string into /// The length written, or if the buffer is too small a negative that indicates the length required - public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span buffer) + public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span buffer) { unsafe { @@ -519,7 +517,7 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token l } [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")] - static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length); + static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length); } /// @@ -536,7 +534,7 @@ public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token l /// Returns a negative number on failure - the number of tokens that would have been returned /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); + public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, bool add_bos, bool special); /// /// Register a callback to receive llama log messages diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 98b510783..47027ae0d 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -137,19 +137,19 @@ public Span GetLogitsIth(int i) /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// /// - public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) + public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { ThrowIfDisposed(); if (string.IsNullOrEmpty(text) && !add_bos) - return Array.Empty(); + return Array.Empty(); // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't // possibly be more than this. var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); // "Rent" an array to write results into (avoiding an allocation of a large array) - var temporaryArray = ArrayPool.Shared.Rent(count); + var temporaryArray = ArrayPool.Shared.Rent(count); try { // Do the actual conversion @@ -161,14 +161,14 @@ public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding } // Copy the results from the rented into an array which is exactly the right size - var result = new int[n]; + var result = new LLamaToken[n]; Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); return result; } finally { - ArrayPool.Shared.Return(temporaryArray); + ArrayPool.Shared.Return(temporaryArray); } } @@ -191,11 +191,11 @@ public int TokenToSpan(int token, Span dest) /// the number of tokens to use from previous eval calls /// Returns true on success [Obsolete("use llama_decode() instead")] - public bool Eval(ReadOnlySpan tokens, int n_past) + public bool Eval(ReadOnlySpan tokens, int n_past) { unsafe { - fixed (int* pinned = tokens) + fixed (LLamaToken* pinned = tokens) { // the entire `eval` system needs replacing with the new batch system! var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index 49096d44f..6277e5941 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -119,7 +119,7 @@ public SafeLLamaGrammarHandle Clone() /// /// /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) + public void AcceptToken(SafeLLamaContextHandle ctx, LLamaToken token) { NativeApi.llama_grammar_accept_token(ctx, this, token); } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 291cfbc20..30ab5cd09 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -98,12 +98,12 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null /// /// Convert a single llama token into bytes /// - /// Token to decode + /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` - public int TokenToSpan(int llama_token, Span dest) + public int TokenToSpan(LLamaToken token, Span dest) { - var length = NativeApi.llama_token_to_piece(this, llama_token, dest); + var length = NativeApi.llama_token_to_piece(this, token, dest); return Math.Abs(length); } @@ -118,12 +118,10 @@ public int TokenToSpan(int llama_token, Span dest) /// filled with as many characters as possible, starting from the _last_ token. /// [Obsolete("Use a StreamingTokenDecoder instead")] - internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Encoding encoding) + internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Encoding encoding) { var decoder = new StreamingTokenDecoder(encoding, this); - - foreach (var token in tokens) - decoder.Add(token); + decoder.AddRange(tokens); var str = decoder.Read(); @@ -147,7 +145,7 @@ internal Span TokensToSpan(IReadOnlyList tokens, Span dest, Enc /// /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// - public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) + public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); @@ -166,11 +164,11 @@ public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding fixed (byte* bytesPtr = &bytes[0]) { // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) - var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special); + var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (LLamaToken*)IntPtr.Zero, 0, add_bos, special); // Tokenize again, this time outputting into an array of exactly the right size - var tokens = new int[count]; - fixed (int* tokensPtr = &tokens[0]) + var tokens = new LLamaToken[count]; + fixed (LLamaToken* tokensPtr = &tokens[0]) { NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); return tokens; diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index 41709def3..c211d0258 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -4,8 +4,6 @@ namespace LLama.Native { - using llama_token = Int32; - /// /// Direct translation of the llama.cpp sampling API /// @@ -110,7 +108,7 @@ public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTok /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] - public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) + public static LLamaToken llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) { return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); } @@ -125,7 +123,7 @@ public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] - public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) + public static LLamaToken llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) { return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); } @@ -137,7 +135,7 @@ public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle /// Pointer to LLamaTokenDataArray /// [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] - public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + public static LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { return candidates.SampleTokenGreedy(ctx); } @@ -149,7 +147,7 @@ public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, /// Pointer to LLamaTokenDataArray /// [Obsolete("use LLamaTokenDataArray SampleToken() method")] - public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + public static LLamaToken llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { return candidates.SampleToken(ctx); } diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index 4c0f7689f..436c35d35 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -15,7 +15,7 @@ public abstract class BaseSamplingPipeline private (int index, float logit)[]? _savedLogits; /// - public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { var protectedLogits = GetProtectedTokens(ctx); _savedLogitsCount = protectedLogits.Count; @@ -26,8 +26,8 @@ public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan logits, ReadOnlySpan /// - protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); + protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); /// /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits @@ -96,7 +96,7 @@ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) /// The context being sampled from /// The logits produced by the model /// A list of tokens recently returned by the model - protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); /// /// Process the LLamaTokenDataArray and select a single token @@ -105,7 +105,7 @@ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) /// The LLamaTokenDataArray data produced by the model /// A list of tokens recently returned by the model /// - protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + protected abstract LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); /// /// Choose the final token from the candidates @@ -113,7 +113,7 @@ protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) /// /// /// - protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); + protected abstract LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); /// public virtual void Reset() diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs index e6db2efe3..b0fb5c596 100644 --- a/LLama/Sampling/DefaultSamplingPipeline.cs +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -99,27 +99,27 @@ public float AlphaPresence /// public bool PenalizeNewline { get; set; } = false; - private readonly int[] _newlineToken = new int[1]; + private readonly LLamaToken[] _newlineToken = new LLamaToken[1]; /// - protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) { if (PenalizeNewline) - return Array.Empty(); + return Array.Empty(); _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); return _newlineToken; } /// - protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { foreach (var (key, value) in LogitBias) logits[key] += value; } /// - protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) { // Apply penalties to candidates candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); @@ -142,7 +142,7 @@ protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTo } /// - protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + protected override LLamaToken ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { return candidates.SampleToken(ctx); } diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index f39bf9963..be1398790 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -19,7 +19,7 @@ public interface ISamplingPipeline /// The logits produced by the model /// A span of tokens recently returned by the model /// - int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); /// /// Reset all internal state of the sampling pipeline @@ -40,13 +40,13 @@ public static class ISamplingPipelineExtensions /// The logits produced by the model /// A list of tokens recently returned by the model /// - public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + public static LLamaToken Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) { #if NET5_0_OR_GREATER var span = CollectionsMarshal.AsSpan(lastTokens); return pipeline.Sample(ctx, logits, span); #else - var copy = ArrayPool.Shared.Rent(lastTokens.Count); + var copy = ArrayPool.Shared.Rent(lastTokens.Count); try { lastTokens.CopyTo(copy); @@ -54,7 +54,7 @@ public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle } finally { - ArrayPool.Shared.Return(copy); + ArrayPool.Shared.Return(copy); } #endif } diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs index f82f8c37f..22ead9af0 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/StreamingTokenDecoder.cs @@ -129,6 +129,15 @@ static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle } } + /// + /// Add a single token to the decoder + /// + /// + public void Add(LLamaToken token) + { + Add((int)token); + } + /// /// Add all tokens in the given enumerable /// @@ -139,6 +148,16 @@ public void AddRange(IEnumerable tokens) Add(item); } + /// + /// Add all tokens in the given enumerable + /// + /// + public void AddRange(IEnumerable tokens) + { + foreach (var item in tokens) + Add((int)item); + } + /// /// Read all decoded characters and clear the buffer /// From 2eb52b1630a624ae52acb2a10613f01e982f2112 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 20:57:37 +0000 Subject: [PATCH 2/3] made casts to/from int explicit, fixed places affected --- LLama/LLamaContext.cs | 4 ++-- LLama/LLamaStatelessExecutor.cs | 2 +- LLama/Native/LLamaBatchSafeHandle.cs | 2 +- LLama/Native/LLamaToken.cs | 4 ++-- LLama/Native/LLamaTokenData.cs | 4 ++-- LLama/Native/LLamaTokenDataArray.cs | 2 +- LLama/Native/SafeLLamaContextHandle.cs | 2 +- LLama/Sampling/BaseSamplingPipeline.cs | 12 ++++++------ LLama/StreamingTokenDecoder.cs | 8 ++++---- 9 files changed, 20 insertions(+), 20 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index cd755501c..dd3d081a4 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -309,12 +309,12 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dict if (logitBias is not null) { foreach (var (key, value) in logitBias) - logits[key.Value] += value; + logits[(int)key] += value; } // Save the newline logit value var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); - var nl_logit = logits[nl_token.Value]; + var nl_logit = logits[(int)nl_token]; // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 217e4f305..77c9dbe4f 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -71,7 +71,7 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); var lastTokens = new List(repeat_last_n); for (var i = 0; i < repeat_last_n; i++) - lastTokens.Add(0); + lastTokens.Add((LLamaToken)0); // Tokenize the prompt var tokens = Context.Tokenize(prompt).ToList(); diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs index e67a61b44..4198ad02a 100644 --- a/LLama/Native/LLamaBatchSafeHandle.cs +++ b/LLama/Native/LLamaBatchSafeHandle.cs @@ -135,7 +135,7 @@ public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan /// The raw value /// - public readonly int Value; + private readonly int Value; /// /// Create a new LLamaToken @@ -34,5 +34,5 @@ private LLamaToken(int value) /// /// /// - public static implicit operator LLamaToken(int value) => new(value); + public static explicit operator LLamaToken(int value) => new(value); } \ No newline at end of file diff --git a/LLama/Native/LLamaTokenData.cs b/LLama/Native/LLamaTokenData.cs index 45edd4542..d6dd500bc 100644 --- a/LLama/Native/LLamaTokenData.cs +++ b/LLama/Native/LLamaTokenData.cs @@ -11,7 +11,7 @@ public struct LLamaTokenData /// /// token id /// - public int id; + public LLamaToken id; /// /// log-odds of the token @@ -29,7 +29,7 @@ public struct LLamaTokenData /// /// /// - public LLamaTokenData(int id, float logit, float p) + public LLamaTokenData(LLamaToken id, float logit, float p) { this.id = id; this.logit = logit; diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 1c02625f8..98dd91b6e 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -39,7 +39,7 @@ public static LLamaTokenDataArray Create(ReadOnlySpan logits) { var candidates = new LLamaTokenData[logits.Length]; for (var token_id = 0; token_id < logits.Length; token_id++) - candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); + candidates[token_id] = new LLamaTokenData((LLamaToken)token_id, logits[token_id], 0.0f); return new LLamaTokenDataArray(candidates); } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 47027ae0d..0ceaf66c1 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -178,7 +178,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e /// Token to decode /// A span to attempt to write into. If this is too small nothing will be written /// The size of this token. **nothing will be written** if this is larger than `dest` - public int TokenToSpan(int token, Span dest) + public int TokenToSpan(LLamaToken token, Span dest) { return ThrowIfDisposed().TokenToSpan(token, dest); } diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs index 436c35d35..a41aa67e1 100644 --- a/LLama/Sampling/BaseSamplingPipeline.cs +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -12,22 +12,22 @@ public abstract class BaseSamplingPipeline : ISamplingPipeline { private int _savedLogitsCount; - private (int index, float logit)[]? _savedLogits; + private (LLamaToken index, float logit)[]? _savedLogits; /// public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { var protectedLogits = GetProtectedTokens(ctx); _savedLogitsCount = protectedLogits.Count; - _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); + _savedLogits = ArrayPool<(LLamaToken, float)>.Shared.Rent(_savedLogitsCount); try { // Save the values of protected logits for (var i = 0; i < protectedLogits.Count; i++) { var index = protectedLogits[i]; - var value = logits[index.Value]; - _savedLogits[i] = (index.Value, value); + var value = logits[(int)index]; + _savedLogits[i] = (index, value); } // Process raw logits @@ -47,7 +47,7 @@ public LLamaToken Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnl } finally { - ArrayPool<(int, float)>.Shared.Return(_savedLogits); + ArrayPool<(LLamaToken, float)>.Shared.Return(_savedLogits); _savedLogits = null; _savedLogitsCount = 0; } @@ -74,7 +74,7 @@ protected void RestoreProtectedTokens(Span logits) // Restore the values of protected logits for (var i = 0; i < saved.Length; i++) - logits[saved[i].index] = saved[i].logit; + logits[(int)saved[i].index] = saved[i].logit; } /// diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs index 22ead9af0..a66d5c9e9 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/StreamingTokenDecoder.cs @@ -69,7 +69,7 @@ public StreamingTokenDecoder(Encoding encoding, SafeLlamaModelHandle weights) /// Add a single token to the decoder /// /// - public void Add(int token) + public void Add(LLamaToken token) { var charsArr = ArrayPool.Shared.Rent(16); var bytesArr = ArrayPool.Shared.Rent(16); @@ -108,7 +108,7 @@ public void Add(int token) // Converts a single token into bytes, using the `bytes` array as temporary storage. // If the `bytes` array is too small it will get a larger one from the ArrayPool. - static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle model) + static Span TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaModelHandle model) { // Try to get bytes var l = model.TokenToSpan(token, bytes); @@ -133,9 +133,9 @@ static Span TokenToBytes(ref byte[] bytes, int token, SafeLlamaModelHandle /// Add a single token to the decoder /// /// - public void Add(LLamaToken token) + public void Add(int token) { - Add((int)token); + Add((LLamaToken)token); } /// From 82727c441405941f3c45aa7d00113828b6cbb161 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 21:12:38 +0000 Subject: [PATCH 3/3] Removed collection expressions from test --- LLama.Unittest/LLamaContextTests.cs | 4 ++-- LLama/Native/LLamaToken.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 7bc2f7ff4..7f1c94960 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -38,7 +38,7 @@ public void Tokenize() { var tokens = _context.Tokenize("The quick brown fox", true); - Assert.Equal([ 1, 450, 4996, 17354, 1701, 29916 ], tokens); + Assert.Equal(new LLamaToken[] { 1, 450, 4996, 17354, 1701, 29916 }, tokens); } [Fact] @@ -46,7 +46,7 @@ public void TokenizeWithoutBOS() { var tokens = _context.Tokenize("The quick brown fox", false); - Assert.Equal([450, 4996, 17354, 1701, 29916], tokens); + Assert.Equal(new LLamaToken[] { 450, 4996, 17354, 1701, 29916 }, tokens); } [Fact] diff --git a/LLama/Native/LLamaToken.cs b/LLama/Native/LLamaToken.cs index ce448f08c..0bc485856 100644 --- a/LLama/Native/LLamaToken.cs +++ b/LLama/Native/LLamaToken.cs @@ -34,5 +34,5 @@ private LLamaToken(int value) /// /// /// - public static explicit operator LLamaToken(int value) => new(value); + public static implicit operator LLamaToken(int value) => new(value); } \ No newline at end of file