From 33358124db7f692b6f73070caffa1da03e368934 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Mon, 4 Dec 2023 01:31:42 +0000 Subject: [PATCH 1/4] Initial pass at a new sampling pipeline --- LLama/Native/LLamaTokenDataArray.cs | 10 +- LLama/Sampling/ISamplingPipeline.cs | 99 +++++++++++++++++ LLama/Sampling/Logits/ILogitProcessor.cs | 34 ++++++ LLama/Sampling/Logits/LogitBias.cs | 39 +++++++ LLama/Sampling/Logits/SaveLoad.cs | 100 ++++++++++++++++++ LLama/Sampling/Selection/GreedySelection.cs | 27 +++++ LLama/Sampling/Selection/ITokenSelector.cs | 25 +++++ .../Sampling/Selection/Mirostat2Selection.cs | 65 ++++++++++++ LLama/Sampling/Selection/MirostatSelection.cs | 76 +++++++++++++ LLama/Sampling/Selection/StandardSelection.cs | 27 +++++ LLama/Sampling/Tokens/GrammarSampling.cs | 59 +++++++++++ LLama/Sampling/Tokens/ITokenDataProcessor.cs | 34 ++++++ .../Sampling/Tokens/LocallyTypicalSampling.cs | 42 ++++++++ LLama/Sampling/Tokens/MinPSampling.cs | 42 ++++++++ LLama/Sampling/Tokens/RepetitionPenalty.cs | 77 ++++++++++++++ LLama/Sampling/Tokens/TailFreeSampling.cs | 42 ++++++++ LLama/Sampling/Tokens/TemperatureSampling.cs | 38 +++++++ LLama/Sampling/Tokens/TopKSampling.cs | 38 +++++++ LLama/Sampling/Tokens/TopPSampling.cs | 42 ++++++++ 19 files changed, 912 insertions(+), 4 deletions(-) create mode 100644 LLama/Sampling/ISamplingPipeline.cs create mode 100644 LLama/Sampling/Logits/ILogitProcessor.cs create mode 100644 LLama/Sampling/Logits/LogitBias.cs create mode 100644 LLama/Sampling/Logits/SaveLoad.cs create mode 100644 LLama/Sampling/Selection/GreedySelection.cs create mode 100644 LLama/Sampling/Selection/ITokenSelector.cs create mode 100644 LLama/Sampling/Selection/Mirostat2Selection.cs create mode 100644 LLama/Sampling/Selection/MirostatSelection.cs create mode 100644 LLama/Sampling/Selection/StandardSelection.cs create mode 100644 LLama/Sampling/Tokens/GrammarSampling.cs create mode 100644 LLama/Sampling/Tokens/ITokenDataProcessor.cs create mode 100644 LLama/Sampling/Tokens/LocallyTypicalSampling.cs create mode 100644 LLama/Sampling/Tokens/MinPSampling.cs create mode 100644 LLama/Sampling/Tokens/RepetitionPenalty.cs create mode 100644 LLama/Sampling/Tokens/TailFreeSampling.cs create mode 100644 LLama/Sampling/Tokens/TemperatureSampling.cs create mode 100644 LLama/Sampling/Tokens/TopKSampling.cs create mode 100644 LLama/Sampling/Tokens/TopPSampling.cs diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 4bc154f4c..897cf8b87 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -145,15 +145,17 @@ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_ke /// /// /// - public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + public void RepetitionPenalty(SafeLLamaContextHandle context, ReadOnlySpan last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) { unsafe { using (LLamaTokenDataArrayNative.Create(this, out var st)) - using (var last_tokens_handle = last_tokens.Pin()) { - NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); - sorted = st.sorted; + fixed (int* last_tokens_handle = last_tokens) + { + NativeApi.llama_sample_repetition_penalties(context, ref st, last_tokens_handle, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); + sorted = st.sorted; + } } } } diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs new file mode 100644 index 000000000..489f2c5ae --- /dev/null +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using LLama.Native; +using LLama.Sampling.Logits; +using LLama.Sampling.Selection; +using LLama.Sampling.Tokens; + +namespace LLama.Sampling; + +/// +/// Convert a span of logits into a single sampled token +/// +public interface ISamplingPipeline + : IDisposable +{ + /// + /// Sample a single token from the given logits + /// + /// + /// + /// + /// + int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Reset all internal state of the sampling pipeline + /// + void Reset(); +} + +/// +/// Simple implementation of `ISamplingPipeline`, applies processors in order every time +/// +public sealed class BasicSamplingPipeline + : ISamplingPipeline +{ + /// + /// Logit processors to apply in this pipeline + /// + public IList LogitProcessors { get; } = new List(); + + /// + /// Token data processors to apply in this pipeline + /// + public IList TokenDataProcessors { get; } = new List(); + + /// + /// The selector to choose the final token + /// + public ITokenSelector Selector { get; set; } = new StandardSelection(); + + /// + public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + // Modify raw logits + foreach (var logitProcessor in LogitProcessors) + logitProcessor.ProcessLogits(ctx, logits, lastTokens); + + // Convert logits into token candidates + var candidates_p = LLamaTokenDataArray.Create(logits); + + // Process token candidates + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens); + + // Select a token + var token = Selector.Select(ctx, candidates_p, lastTokens); + + // Tell processors what was selected + foreach (var logitProcessor in LogitProcessors) + logitProcessor.AcceptToken(ctx, token); + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.AcceptToken(ctx, token); + + return token; + } + + /// + public void Reset() + { + foreach (var logitProcessor in LogitProcessors) + logitProcessor.Reset(); + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.Reset(); + + Selector.Reset(); + } + + /// + public void Dispose() + { + foreach (var logitProcessor in LogitProcessors) + logitProcessor.Dispose(); + foreach (var tokenDataProcessor in TokenDataProcessors) + tokenDataProcessor.Dispose(); + + Selector.Dispose(); + } +} \ No newline at end of file diff --git a/LLama/Sampling/Logits/ILogitProcessor.cs b/LLama/Sampling/Logits/ILogitProcessor.cs new file mode 100644 index 000000000..769684992 --- /dev/null +++ b/LLama/Sampling/Logits/ILogitProcessor.cs @@ -0,0 +1,34 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Logits; + +using llama_token = Int32; + +/// +/// Processes raw logits before sampling, applying penalties to certain tokens +/// +public interface ILogitProcessor + : IDisposable +{ + /// + /// Process raw logits, indexed by llama_token + /// + /// The context this is operating in + /// The token data array to process + /// The most recent tokens output + /// LLamaTokenDataArray, created from logits + void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Inform this process when a token is accepted by the model + /// + /// + /// + void AcceptToken(SafeLLamaContextHandle ctx, int token); + + /// + /// Reset all internal sampling state + /// + void Reset(); +} \ No newline at end of file diff --git a/LLama/Sampling/Logits/LogitBias.cs b/LLama/Sampling/Logits/LogitBias.cs new file mode 100644 index 000000000..fc8215083 --- /dev/null +++ b/LLama/Sampling/Logits/LogitBias.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling.Logits; + +/// +/// Add a bias directly to logit values +/// +public sealed class LogitBias + : ILogitProcessor +{ + /// + /// Biases to apply, token -> bias + /// + public IDictionary Biases { get; } = new Dictionary(); + + /// + public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var kvp in Biases) + logits[kvp.Key] += kvp.Value; + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Logits/SaveLoad.cs b/LLama/Sampling/Logits/SaveLoad.cs new file mode 100644 index 000000000..6f80aec48 --- /dev/null +++ b/LLama/Sampling/Logits/SaveLoad.cs @@ -0,0 +1,100 @@ +using System; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling.Logits; + +/// +/// Save certain logit values +/// +public sealed class SaveLogitValues + : ILogitProcessor +{ + private readonly Dictionary _saved = new(); + + /// + /// Logits to save + /// + public ISet Logits { get; } = new HashSet(); + + /// + /// Saved logit values + /// + public IReadOnlyDictionary Values => _saved; + + /// + public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + _saved.Clear(); + foreach (var logit in Logits) + _saved[logit] = logits[logit]; + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + _saved.Clear(); + } + + /// + public void Dispose() + { + } + + /// + /// Get a logit processor that overwrite the logit values with the values saved here + /// + /// + public ILogitProcessor GetWriter() + { + return new LoadLogitValues(_saved); + } +} + +/// +/// Overwrite certain logit values +/// +public sealed class LoadLogitValues + : ILogitProcessor +{ + /// + /// Logits to overwrite, token -> logit + /// + public IDictionary Values { get; } + + /// + /// Create a new LoadLogitValues + /// + /// Source for values to overwrite + public LoadLogitValues(Dictionary? values = null) + { + Values = values ?? new Dictionary(); + } + + /// + public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var logit in Values) + logits[logit.Key] = logit.Value; + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/GreedySelection.cs b/LLama/Sampling/Selection/GreedySelection.cs new file mode 100644 index 000000000..30b724569 --- /dev/null +++ b/LLama/Sampling/Selection/GreedySelection.cs @@ -0,0 +1,27 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select the most likely token +/// +public sealed class GreedySelection + : ITokenSelector +{ + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenGreedy(ctx); + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/ITokenSelector.cs b/LLama/Sampling/Selection/ITokenSelector.cs new file mode 100644 index 000000000..c8915a92b --- /dev/null +++ b/LLama/Sampling/Selection/ITokenSelector.cs @@ -0,0 +1,25 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select a single token from a set of possibilities +/// +public interface ITokenSelector + : IDisposable +{ + /// + /// Select a single token + /// + /// + /// + /// + /// + int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + + /// + /// Reset the state + /// + void Reset(); +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/Mirostat2Selection.cs b/LLama/Sampling/Selection/Mirostat2Selection.cs new file mode 100644 index 000000000..cdc802c16 --- /dev/null +++ b/LLama/Sampling/Selection/Mirostat2Selection.cs @@ -0,0 +1,65 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select a token using Mirostat sampling. +/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. +/// +public sealed class Mirostat2Selection + : ITokenSelector +{ + private float _mu; + + /// + /// Current value of Mu, updated based on the difference between target surprise and actual surprise + /// + public float Mu + { + get => _mu; + set => _mu = value; + } + + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// + public float Tau { get; set; } + + /// + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// + public float Eta { get; set; } + + /// + /// Create a new Mirostat 2.0 sampler + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + public Mirostat2Selection(float tau, float eta) + { + Tau = tau; + Eta = eta; + } + + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); + } + + /// + public void Reset() + { + _mu = 2 * Tau; + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/MirostatSelection.cs b/LLama/Sampling/Selection/MirostatSelection.cs new file mode 100644 index 000000000..5ec34a135 --- /dev/null +++ b/LLama/Sampling/Selection/MirostatSelection.cs @@ -0,0 +1,76 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select a token using Mirostat sampling. +/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. +/// +public sealed class MirostatSelection + : ITokenSelector +{ + private float _mu; + + /// + /// Current value of Mu, updated based on the difference between target surprise and actual surprise + /// + public float Mu + { + get => _mu; + set => _mu = value; + } + + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// + public float Tau { get; set; } + + /// + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// + public float Eta { get; set; } + + /// + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn + /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects + /// the performance of the algorithm. + /// + public int M { get; set; } + + /// + /// Create a new Mirostat 2.0 sampler + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. + /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. + /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn + /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects + /// the performance of the algorithm. + public MirostatSelection(float tau, float eta, int m = 100) + { + Tau = tau; + Eta = eta; + M = m; + } + + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu); + } + + /// + public void Reset() + { + _mu = 2 * Tau; + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Selection/StandardSelection.cs b/LLama/Sampling/Selection/StandardSelection.cs new file mode 100644 index 000000000..3e3bd0865 --- /dev/null +++ b/LLama/Sampling/Selection/StandardSelection.cs @@ -0,0 +1,27 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Selection; + +/// +/// Select from all possible tokens according to their probability +/// +public sealed class StandardSelection + : ITokenSelector +{ + /// + public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + return candidates.SampleToken(ctx); + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/GrammarSampling.cs b/LLama/Sampling/Tokens/GrammarSampling.cs new file mode 100644 index 000000000..b823a7f92 --- /dev/null +++ b/LLama/Sampling/Tokens/GrammarSampling.cs @@ -0,0 +1,59 @@ +using System; +using LLama.Grammars; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Apply a grammar to prevent sampling tokens which do not match the grammar +/// +public sealed class GrammarSampling + : ITokenDataProcessor +{ + private SafeLLamaGrammarHandle? _handle; + + /// + /// Grammar to use for sampling + /// + public Grammar? Grammar { get; set; } + + /// + /// Create a new + /// + /// + public GrammarSampling(Grammar grammar) + { + Grammar = grammar; + } + + /// + public void Reset() + { + _handle?.Dispose(); + _handle = null; + } + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + // Create a new grammar instance if necessary + _handle ??= Grammar?.CreateInstance(); + + // Apply it + if (_handle != null) + tokens.ApplyGrammar(ctx, _handle); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + _handle?.AcceptToken(ctx, token); + } + + /// + public void Dispose() + { + _handle?.Dispose(); + _handle = null; + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/ITokenDataProcessor.cs b/LLama/Sampling/Tokens/ITokenDataProcessor.cs new file mode 100644 index 000000000..e6679cc29 --- /dev/null +++ b/LLama/Sampling/Tokens/ITokenDataProcessor.cs @@ -0,0 +1,34 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +using llama_token = Int32; + +/// +/// Processes token logits before sampling, applying penalties to certain tokens +/// +public interface ITokenDataProcessor + : IDisposable +{ + /// + /// Process token logits in a LLamaTokenDataArray + /// + /// The context this is operating in + /// The token data array to process + /// The most recent tokens output + /// LLamaTokenDataArray, created from logits + void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens); + + /// + /// Inform this process when a token is accepted by the model + /// + /// + /// + void AcceptToken(SafeLLamaContextHandle ctx, int token); + + /// + /// Reset all internal sampling state + /// + void Reset(); +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/LocallyTypicalSampling.cs b/LLama/Sampling/Tokens/LocallyTypicalSampling.cs new file mode 100644 index 000000000..3f602c9a7 --- /dev/null +++ b/LLama/Sampling/Tokens/LocallyTypicalSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. +/// +public sealed class LocallyTypicalSampling + : ITokenDataProcessor +{ + /// + /// P value for locally typical sampling + /// + public float P { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.LocallyTypical(ctx, P, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/MinPSampling.cs b/LLama/Sampling/Tokens/MinPSampling.cs new file mode 100644 index 000000000..c3adf0262 --- /dev/null +++ b/LLama/Sampling/Tokens/MinPSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 +/// +public sealed class MinPSampling + : ITokenDataProcessor +{ + /// + /// All tokens with probability greater than this will be kept + /// + public float P { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.MinP(ctx, P, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/RepetitionPenalty.cs b/LLama/Sampling/Tokens/RepetitionPenalty.cs new file mode 100644 index 000000000..3cfdbcd46 --- /dev/null +++ b/LLama/Sampling/Tokens/RepetitionPenalty.cs @@ -0,0 +1,77 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. +/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. +/// +public sealed class RepetitionPenalty + : ITokenDataProcessor +{ + private float _alphaFreq; + private float _alphaPresence; + + /// + /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 + /// + public float RepeatPenalty { get; set; } = 1.1f; + + /// + /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text + /// so far, decreasing the model's likelihood to repeat the same line verbatim. + ///
+ public float AlphaFrequency + { + get => _alphaFreq; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaFreq = value; + } + } + + /// + /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + /// text so far, increasing the model's likelihood to talk about new topics. + ///
+ public float AlphaPresence + { + get => _alphaPresence; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaPresence = value; + } + } + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TailFreeSampling.cs b/LLama/Sampling/Tokens/TailFreeSampling.cs new file mode 100644 index 000000000..8e9fb2b51 --- /dev/null +++ b/LLama/Sampling/Tokens/TailFreeSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. +/// +public sealed class TailFreeSampling + : ITokenDataProcessor +{ + /// + /// Z value for tail free sampling + /// + public float Z { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.TailFree(ctx, Z, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TemperatureSampling.cs b/LLama/Sampling/Tokens/TemperatureSampling.cs new file mode 100644 index 000000000..0186f275f --- /dev/null +++ b/LLama/Sampling/Tokens/TemperatureSampling.cs @@ -0,0 +1,38 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Sample with temperature. +/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual +/// +public sealed class TemperatureSampling + : ITokenDataProcessor +{ + /// + /// Temperature value to apply + /// + public float Temperature { get; set; } = 0.5f; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.Temperature(ctx, Temperature); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopKSampling.cs b/LLama/Sampling/Tokens/TopKSampling.cs new file mode 100644 index 000000000..3f797c85f --- /dev/null +++ b/LLama/Sampling/Tokens/TopKSampling.cs @@ -0,0 +1,38 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Sample with TopK, removing all by the K most likely tokens. +/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +/// +public sealed class TopKSampling + : ITokenDataProcessor +{ + /// + /// Number of tokens to keep + /// + public int Count { get; set; } + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.TopK(ctx, Count); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopPSampling.cs b/LLama/Sampling/Tokens/TopPSampling.cs new file mode 100644 index 000000000..577ce3bc3 --- /dev/null +++ b/LLama/Sampling/Tokens/TopPSampling.cs @@ -0,0 +1,42 @@ +using System; +using LLama.Native; + +namespace LLama.Sampling.Tokens; + +/// +/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +/// +public sealed class TopPSampling + : ITokenDataProcessor +{ + /// + /// P valies for TopP + /// + public float P { get; set; } + + /// + /// Minimum number of tokens to keep + /// + public ulong MinKeep { get; set; } = 1; + + /// + public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) + { + tokens.TopP(ctx, P, MinKeep); + } + + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + } + + /// + public void Reset() + { + } + + /// + public void Dispose() + { + } +} \ No newline at end of file From b34f72a883a8851cafb6fb6e3ebca9fa2c0e3a29 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 01:02:27 +0000 Subject: [PATCH 2/4] - Added `SamplingPipeline` to inference params which overrides all other options with an entirely custom pipeline. - Added a `Sample` method to `LLamaContext` which uses a custom pipeline - Modified all executors to use the custom pipeline if it exists --- LLama.Web/Common/InferenceOptions.cs | 10 ++++-- LLama/Abstractions/IInferenceParams.cs | 6 ++++ LLama/Common/InferenceParams.cs | 4 +++ LLama/LLamaContext.cs | 12 +++++++ LLama/LLamaInstructExecutor.cs | 26 ++++++++++------ LLama/LLamaInteractExecutor.cs | 28 +++++++++++------ LLama/LLamaStatelessExecutor.cs | 29 +++++++++++------ LLama/Sampling/ISamplingPipeline.cs | 43 +++++++++++++++++++++++--- 8 files changed, 123 insertions(+), 35 deletions(-) diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 89d94ade3..c604dc0d1 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -1,6 +1,9 @@ -using LLama.Common; +#nullable enable + +using LLama.Common; using LLama.Abstractions; using LLama.Native; +using LLama.Sampling; namespace LLama.Web.Common { @@ -64,6 +67,9 @@ public class InferenceOptions /// /// A grammar to constrain possible tokens /// - public SafeLLamaGrammarHandle Grammar { get; set; } = null; + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } } diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index d87faf0eb..e1e894143 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using LLama.Common; using LLama.Native; +using LLama.Sampling; namespace LLama.Abstractions { @@ -108,5 +109,10 @@ public interface IInferenceParams /// Grammar to constrain possible tokens /// SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Set a custom sampling pipeline to use. If this is set All other sampling parameters are ignored! + /// + ISamplingPipeline? SamplingPipeline { get; set; } } } \ No newline at end of file diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index d7bd19d96..c1f395505 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using LLama.Native; +using LLama.Sampling; namespace LLama.Common { @@ -76,6 +77,9 @@ public record InferenceParams /// public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + public ISamplingPipeline? SamplingPipeline { get; set; } } /// diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 3a3e51af4..2902dc8f9 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -10,6 +10,7 @@ using System.Runtime.InteropServices; using LLama.Extensions; using LLama.Abstractions; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -212,6 +213,17 @@ public void LoadState(State state) } } + /// + /// Sample a single token from this context, using the given sampling pipeline + /// + /// The pipeline to use to process the logits and to select a token + /// The tokens recently returned from the model + /// The selected token + public llama_token Sample(ISamplingPipeline pipeline, ReadOnlySpan lastTokens) + { + return pipeline.Sample(NativeHandle, NativeHandle.GetLogits(), lastTokens); + } + /// /// Perform the sampling. Please don't use it unless you fully know what it does. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d81630aa9..3ed668903 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -210,16 +210,24 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 4d28274b4..9cecf4378 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -189,16 +189,24 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In SaveSessionFile(_pathSession); } - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - var mu = MirostatMu; - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); - MirostatMu = mu; + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), _last_n_tokens.ToArray()); + } + else + { + var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + var mu = MirostatMu; + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + MirostatMu = mu; + } _last_n_tokens.Enqueue(id); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 9c41af7c0..831aceb26 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Threading; using System.Threading.Tasks; using LLama.Native; +using LLama.Sampling; using Microsoft.Extensions.Logging; namespace LLama @@ -85,16 +86,24 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, - inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); - - // Sample a single token - var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, - inferenceParams.MinP - ); + llama_token id; + if (inferenceParams.SamplingPipeline is not null) + { + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + } + else + { + // Penalize the generated tokens by various penalties + var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + + // Sample a single token + id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); + } // Decode this token into text decoder.Add(id); diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 489f2c5ae..4540e9fc8 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -1,5 +1,7 @@ using System; +using System.Buffers; using System.Collections.Generic; +using System.Runtime.InteropServices; using LLama.Native; using LLama.Sampling.Logits; using LLama.Sampling.Selection; @@ -16,9 +18,9 @@ public interface ISamplingPipeline /// /// Sample a single token from the given logits /// - /// - /// - /// + /// The context being sampled from + /// The logits produced by the model + /// A span of tokens recently returned by the model /// int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); @@ -28,10 +30,43 @@ public interface ISamplingPipeline void Reset(); } +/// +/// Extensions methods for ISamplingPipeline +/// +public static class ISamplingPipelineExtensions +{ + /// + /// Sample a single token from the given logits + /// + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + /// + public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle ctx, Span logits, List lastTokens) + { +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(lastTokens); + return pipeline.Sample(ctx, logits, span); +#else + var copy = ArrayPool.Shared.Rent(lastTokens.Count); + try + { + lastTokens.CopyTo(copy); + return pipeline.Sample(ctx, logits, copy.AsSpan(0, copy.Length)); + } + finally + { + ArrayPool.Shared.Return(copy); + } +#endif + } +} + /// /// Simple implementation of `ISamplingPipeline`, applies processors in order every time /// -public sealed class BasicSamplingPipeline +public sealed class ConfigurableSamplingPipeline : ISamplingPipeline { /// From 3afc007499866f5b47f98993d46a6fcb5b4f8fd2 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 01:17:24 +0000 Subject: [PATCH 3/4] - Added "protected" logits, instead of the awkward save/load mechanism - Added an example usage to one of the tests --- LLama.Unittest/StatelessExecutorTest.cs | 37 ++++++++- LLama/Sampling/ISamplingPipeline.cs | 33 +++++++- LLama/Sampling/Logits/SaveLoad.cs | 100 ------------------------ 3 files changed, 66 insertions(+), 104 deletions(-) delete mode 100644 LLama/Sampling/Logits/SaveLoad.cs diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 195cc4a28..d847e787d 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,5 +1,9 @@ using System.Diagnostics; using LLama.Common; +using LLama.Sampling; +using LLama.Sampling.Logits; +using LLama.Sampling.Selection; +using LLama.Sampling.Tokens; using Xunit.Abstractions; namespace LLama.Unittest @@ -30,10 +34,41 @@ public void Dispose() [Fact] public async Task Stateless() { + // Create a custom pipeline that mimics the default pipeline + var pipeline = new ConfigurableSamplingPipeline() + { + ProtectedLogits = + { + _weights.NewlineToken, + _weights.BeginningOfSentenceToken, + _weights.EndOfSentenceToken + }, + LogitProcessors = + { + new LogitBias + { + Biases = + { + { _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing! + { 42, 0f }, + } + } + }, + TokenDataProcessors = + { + new TailFreeSampling { Z = 1 }, + new LocallyTypicalSampling { P = 1 }, + new TopPSampling { P = 0.95f }, + new MinPSampling { P = 0.05f }, + new TemperatureSampling { Temperature = 0.8f }, + }, + Selector = new StandardSelection(), + }; + var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; - var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline}; var timer = new Stopwatch(); timer.Start(); diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 4540e9fc8..3b829ed43 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -74,6 +74,11 @@ public sealed class ConfigurableSamplingPipeline /// public IList LogitProcessors { get; } = new List(); + /// + /// Logits values which will not be changed by the logit processors + /// + public IList ProtectedLogits { get; } = new List(); + /// /// Token data processors to apply in this pipeline /// @@ -87,9 +92,31 @@ public sealed class ConfigurableSamplingPipeline /// public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) { - // Modify raw logits - foreach (var logitProcessor in LogitProcessors) - logitProcessor.ProcessLogits(ctx, logits, lastTokens); + var savedLogitsCount = ProtectedLogits.Count; + var savedLogitValues = ArrayPool.Shared.Rent(savedLogitsCount); + var savedLogitIndices = ArrayPool.Shared.Rent(savedLogitsCount); + try + { + // Save the values of protected logits + for (var i = 0; i < ProtectedLogits.Count; i++) + { + savedLogitValues[i] = logits[ProtectedLogits[i]]; + savedLogitIndices[i] = ProtectedLogits[i]; + } + + // Modify raw logits + foreach (var logitProcessor in LogitProcessors) + logitProcessor.ProcessLogits(ctx, logits, lastTokens); + + // Restore the values of protected logits + for (var i = 0; i < savedLogitsCount; i++) + logits[savedLogitIndices[i]] = savedLogitValues[i]; + } + finally + { + ArrayPool.Shared.Return(savedLogitValues); + ArrayPool.Shared.Return(savedLogitIndices); + } // Convert logits into token candidates var candidates_p = LLamaTokenDataArray.Create(logits); diff --git a/LLama/Sampling/Logits/SaveLoad.cs b/LLama/Sampling/Logits/SaveLoad.cs deleted file mode 100644 index 6f80aec48..000000000 --- a/LLama/Sampling/Logits/SaveLoad.cs +++ /dev/null @@ -1,100 +0,0 @@ -using System; -using System.Collections.Generic; -using LLama.Native; - -namespace LLama.Sampling.Logits; - -/// -/// Save certain logit values -/// -public sealed class SaveLogitValues - : ILogitProcessor -{ - private readonly Dictionary _saved = new(); - - /// - /// Logits to save - /// - public ISet Logits { get; } = new HashSet(); - - /// - /// Saved logit values - /// - public IReadOnlyDictionary Values => _saved; - - /// - public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - _saved.Clear(); - foreach (var logit in Logits) - _saved[logit] = logits[logit]; - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - _saved.Clear(); - } - - /// - public void Dispose() - { - } - - /// - /// Get a logit processor that overwrite the logit values with the values saved here - /// - /// - public ILogitProcessor GetWriter() - { - return new LoadLogitValues(_saved); - } -} - -/// -/// Overwrite certain logit values -/// -public sealed class LoadLogitValues - : ILogitProcessor -{ - /// - /// Logits to overwrite, token -> logit - /// - public IDictionary Values { get; } - - /// - /// Create a new LoadLogitValues - /// - /// Source for values to overwrite - public LoadLogitValues(Dictionary? values = null) - { - Values = values ?? new Dictionary(); - } - - /// - public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - foreach (var logit in Values) - logits[logit.Key] = logit.Value; - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file From 835958398cc6c5948036f269796810f20bf6657a Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 8 Dec 2023 16:25:13 +0000 Subject: [PATCH 4/4] - Removed the object wrappers and configurable pipeline, they can be better written in code. - Added BaseSamplingPipeline which provides a base impl of `ISamplingPipeline` - Added `DefaultSamplingPipeline` which mimics normal llama.cpp sampling --- LLama.Unittest/GrammarParserTest.cs | 3 +- LLama.Unittest/StatelessExecutorTest.cs | 35 +--- LLama/Native/LLamaTokenDataArray.cs | 29 +++- LLama/Sampling/BaseSamplingPipeline.cs | 128 +++++++++++++++ LLama/Sampling/DefaultSamplingPipeline.cs | 149 ++++++++++++++++++ LLama/Sampling/ISamplingPipeline.cs | 102 +----------- LLama/Sampling/Logits/ILogitProcessor.cs | 34 ---- LLama/Sampling/Logits/LogitBias.cs | 39 ----- LLama/Sampling/Selection/GreedySelection.cs | 27 ---- LLama/Sampling/Selection/ITokenSelector.cs | 25 --- .../Sampling/Selection/Mirostat2Selection.cs | 65 -------- LLama/Sampling/Selection/MirostatSelection.cs | 76 --------- LLama/Sampling/Selection/StandardSelection.cs | 27 ---- LLama/Sampling/Tokens/GrammarSampling.cs | 59 ------- LLama/Sampling/Tokens/ITokenDataProcessor.cs | 34 ---- .../Sampling/Tokens/LocallyTypicalSampling.cs | 42 ----- LLama/Sampling/Tokens/MinPSampling.cs | 42 ----- LLama/Sampling/Tokens/RepetitionPenalty.cs | 77 --------- LLama/Sampling/Tokens/TailFreeSampling.cs | 42 ----- LLama/Sampling/Tokens/TemperatureSampling.cs | 38 ----- LLama/Sampling/Tokens/TopKSampling.cs | 38 ----- LLama/Sampling/Tokens/TopPSampling.cs | 42 ----- 22 files changed, 309 insertions(+), 844 deletions(-) create mode 100644 LLama/Sampling/BaseSamplingPipeline.cs create mode 100644 LLama/Sampling/DefaultSamplingPipeline.cs delete mode 100644 LLama/Sampling/Logits/ILogitProcessor.cs delete mode 100644 LLama/Sampling/Logits/LogitBias.cs delete mode 100644 LLama/Sampling/Selection/GreedySelection.cs delete mode 100644 LLama/Sampling/Selection/ITokenSelector.cs delete mode 100644 LLama/Sampling/Selection/Mirostat2Selection.cs delete mode 100644 LLama/Sampling/Selection/MirostatSelection.cs delete mode 100644 LLama/Sampling/Selection/StandardSelection.cs delete mode 100644 LLama/Sampling/Tokens/GrammarSampling.cs delete mode 100644 LLama/Sampling/Tokens/ITokenDataProcessor.cs delete mode 100644 LLama/Sampling/Tokens/LocallyTypicalSampling.cs delete mode 100644 LLama/Sampling/Tokens/MinPSampling.cs delete mode 100644 LLama/Sampling/Tokens/RepetitionPenalty.cs delete mode 100644 LLama/Sampling/Tokens/TailFreeSampling.cs delete mode 100644 LLama/Sampling/Tokens/TemperatureSampling.cs delete mode 100644 LLama/Sampling/Tokens/TopKSampling.cs delete mode 100644 LLama/Sampling/Tokens/TopPSampling.cs diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 9ad77531c..389563aae 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -1,5 +1,4 @@ -using System.Text; -using LLama.Exceptions; +using LLama.Exceptions; using LLama.Native; using LLama.Grammars; diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index d847e787d..72e9acf87 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,9 +1,6 @@ using System.Diagnostics; using LLama.Common; using LLama.Sampling; -using LLama.Sampling.Logits; -using LLama.Sampling.Selection; -using LLama.Sampling.Tokens; using Xunit.Abstractions; namespace LLama.Unittest @@ -35,40 +32,12 @@ public void Dispose() public async Task Stateless() { // Create a custom pipeline that mimics the default pipeline - var pipeline = new ConfigurableSamplingPipeline() - { - ProtectedLogits = - { - _weights.NewlineToken, - _weights.BeginningOfSentenceToken, - _weights.EndOfSentenceToken - }, - LogitProcessors = - { - new LogitBias - { - Biases = - { - { _weights.NewlineToken, 1000 }, // This is an insane bias, but because newline is a protected logit it will do nothing! - { 42, 0f }, - } - } - }, - TokenDataProcessors = - { - new TailFreeSampling { Z = 1 }, - new LocallyTypicalSampling { P = 1 }, - new TopPSampling { P = 0.95f }, - new MinPSampling { P = 0.05f }, - new TemperatureSampling { Temperature = 0.8f }, - }, - Selector = new StandardSelection(), - }; + var pipeline = new DefaultSamplingPipeline(); var executor = new StatelessExecutor(_weights, _params); const string question = "Question. what is a cat?\nAnswer: "; - var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline}; + var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; var timer = new Stopwatch(); timer.Start(); diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 897cf8b87..5059a5f39 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -46,14 +46,41 @@ public static LLamaTokenDataArray Create(ReadOnlySpan logits) return new LLamaTokenDataArray(candidates); } + /// + /// Overwrite the logit values for all given tokens + /// + /// tuples of token and logit value to overwrite + public void OverwriteLogits(ReadOnlySpan<(llama_token token, float logit)> values) + { + if (values.Length == 0) + return; + + var dataSpan = data.Span; + foreach (var (token, value) in values) + { + for (var i = 0; i < data.Length; i++) + { + if (dataSpan[i].id == token) + { + dataSpan[i].logit = value; + break; + } + } + } + sorted = false; + } + #region sampling /// /// Apply grammar rules to candidate tokens /// /// /// - public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) + public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle? grammar) { + if (grammar == null) + return; + using (LLamaTokenDataArrayNative.Create(this, out var st)) { NativeApi.llama_sample_grammar(ctx, ref st, grammar); diff --git a/LLama/Sampling/BaseSamplingPipeline.cs b/LLama/Sampling/BaseSamplingPipeline.cs new file mode 100644 index 000000000..4c0f7689f --- /dev/null +++ b/LLama/Sampling/BaseSamplingPipeline.cs @@ -0,0 +1,128 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// Base class for implementing custom sampling pipelines. This provides a helpful framework for implementing `ISamplingPipeline`. +/// +public abstract class BaseSamplingPipeline + : ISamplingPipeline +{ + private int _savedLogitsCount; + private (int index, float logit)[]? _savedLogits; + + /// + public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + var protectedLogits = GetProtectedTokens(ctx); + _savedLogitsCount = protectedLogits.Count; + _savedLogits = ArrayPool<(int, float)>.Shared.Rent(_savedLogitsCount); + try + { + // Save the values of protected logits + for (var i = 0; i < protectedLogits.Count; i++) + { + var index = protectedLogits[i]; + var value = logits[index]; + _savedLogits[i] = (index, value); + } + + // Process raw logits + ProcessLogits(ctx, logits, lastTokens); + + // Automatically restore saved logit values after processing + RestoreProtectedTokens(logits); + + // Convert logits into token candidates + var candidates = LLamaTokenDataArray.Create(logits); + + // Process token data array + ProcessTokenDataArray(ctx, candidates, lastTokens); + + // Choose the final value + return ChooseToken(ctx, candidates); + } + finally + { + ArrayPool<(int, float)>.Shared.Return(_savedLogits); + _savedLogits = null; + _savedLogitsCount = 0; + } + } + + #region protected tokens + /// + /// Get all of the "protected" tokens that cannot be changed by ProcessLogits + /// + /// + protected abstract IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx); + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(Span logits) + { + if (_savedLogits == null) + return; + + // The array may be bigger than necessary, get a span of the valid bit + var saved = _savedLogits.AsSpan(0, _savedLogitsCount); + + // Restore the values of protected logits + for (var i = 0; i < saved.Length; i++) + logits[saved[i].index] = saved[i].logit; + } + + /// + /// Restore the value of the "protected" tokens which were saved before the call to ProcessLogits + /// + /// + protected void RestoreProtectedTokens(LLamaTokenDataArray candidates) + { + if (_savedLogits == null || _savedLogits.Length == 0) + return; + + candidates.OverwriteLogits(_savedLogits.AsSpan(0, _savedLogitsCount)); + } + #endregion + + /// + /// Process the raw logit values + /// + /// The context being sampled from + /// The logits produced by the model + /// A list of tokens recently returned by the model + protected abstract void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); + + /// + /// Process the LLamaTokenDataArray and select a single token + /// + /// The context being sampled from + /// The LLamaTokenDataArray data produced by the model + /// A list of tokens recently returned by the model + /// + protected abstract int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); + + /// + /// Choose the final token from the candidates + /// + /// + /// + /// + protected abstract int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates); + + /// + public virtual void Reset() + { + } + + /// + public virtual void Dispose() + { + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/LLama/Sampling/DefaultSamplingPipeline.cs b/LLama/Sampling/DefaultSamplingPipeline.cs new file mode 100644 index 000000000..e6db2efe3 --- /dev/null +++ b/LLama/Sampling/DefaultSamplingPipeline.cs @@ -0,0 +1,149 @@ +using System; +using System.Collections.Generic; +using LLama.Extensions; +using LLama.Native; + +namespace LLama.Sampling; + +/// +/// An implementation of ISamplePipeline which mimics the default llama.cpp sampling +/// +public sealed class DefaultSamplingPipeline + : BaseSamplingPipeline +{ + /// + /// Bias values to add to certain logits + /// + public Dictionary LogitBias { get; } = new(); + + /// + /// Grammar to constrain valid tokens + /// + public SafeLLamaGrammarHandle? Grammar { get; set; } + + /// + /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 + /// + public float RepeatPenalty { get; set; } = 1.1f; + + /// + /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text + /// so far, decreasing the model's likelihood to repeat the same line verbatim. + ///
+ public float AlphaFrequency + { + get => _alphaFreq; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaFreq = value; + } + } + private float _alphaFreq = 0.1f; + + /// + /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
+ /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + /// text so far, increasing the model's likelihood to talk about new topics. + ///
+ public float AlphaPresence + { + get => _alphaPresence; + set + { + if (value < -2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); + if (value > 2) + throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); + _alphaPresence = value; + } + } + private float _alphaPresence = 0.1f; + + /// + /// Temperature to apply (higher temperature is more "creative") + /// + public float Temperature { get; set; } = 0.75f; + + /// + /// Number of tokens to keep in TopK sampling + /// + public int TopK { get; set; } + + /// + /// Z value for tail free sampling + /// + public float TailFreeZ { get; set; } + + /// + /// P value for locally typical sampling + /// + public float TypicalP { get; set; } + + /// + /// P value for TopP sampling + /// + public float TopP { get; set; } = 1f; + + /// + /// P value for MinP sampling + /// + public float MinP { get; set; } + + /// + /// Whether the newline value should be protected from being modified by logit bias and repeat penalty + /// + public bool PenalizeNewline { get; set; } = false; + + private readonly int[] _newlineToken = new int[1]; + + /// + protected override IReadOnlyList GetProtectedTokens(SafeLLamaContextHandle ctx) + { + if (PenalizeNewline) + return Array.Empty(); + + _newlineToken[0] = NativeApi.llama_token_nl(ctx.ModelHandle); + return _newlineToken; + } + + /// + protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) + { + foreach (var (key, value) in LogitBias) + logits[key] += value; + } + + /// + protected override int ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) + { + // Apply penalties to candidates + candidates.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); + + // Restore protected tokens, so they are not affected by repetition penalties + RestoreProtectedTokens(candidates); + + // Apply the normal llama.cpp pipeline + candidates.ApplyGrammar(ctx, Grammar); + candidates.TopK(ctx, TopK); + candidates.TailFree(ctx, TailFreeZ); + candidates.LocallyTypical(ctx, TypicalP); + candidates.TopP(ctx, TopP); + candidates.MinP(ctx, MinP); + candidates.Temperature(ctx, Temperature); + var id = candidates.SampleToken(ctx); + + Grammar?.AcceptToken(ctx, id); + return id; + } + + /// + protected override int ChooseToken(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) + { + return candidates.SampleToken(ctx); + } +} \ No newline at end of file diff --git a/LLama/Sampling/ISamplingPipeline.cs b/LLama/Sampling/ISamplingPipeline.cs index 3b829ed43..f39bf9963 100644 --- a/LLama/Sampling/ISamplingPipeline.cs +++ b/LLama/Sampling/ISamplingPipeline.cs @@ -3,14 +3,11 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using LLama.Native; -using LLama.Sampling.Logits; -using LLama.Sampling.Selection; -using LLama.Sampling.Tokens; namespace LLama.Sampling; /// -/// Convert a span of logits into a single sampled token +/// Convert a span of logits into a single sampled token. This interface can be implemented to completely customise the sampling process. /// public interface ISamplingPipeline : IDisposable @@ -61,101 +58,4 @@ public static int Sample(this ISamplingPipeline pipeline, SafeLLamaContextHandle } #endif } -} - -/// -/// Simple implementation of `ISamplingPipeline`, applies processors in order every time -/// -public sealed class ConfigurableSamplingPipeline - : ISamplingPipeline -{ - /// - /// Logit processors to apply in this pipeline - /// - public IList LogitProcessors { get; } = new List(); - - /// - /// Logits values which will not be changed by the logit processors - /// - public IList ProtectedLogits { get; } = new List(); - - /// - /// Token data processors to apply in this pipeline - /// - public IList TokenDataProcessors { get; } = new List(); - - /// - /// The selector to choose the final token - /// - public ITokenSelector Selector { get; set; } = new StandardSelection(); - - /// - public int Sample(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - var savedLogitsCount = ProtectedLogits.Count; - var savedLogitValues = ArrayPool.Shared.Rent(savedLogitsCount); - var savedLogitIndices = ArrayPool.Shared.Rent(savedLogitsCount); - try - { - // Save the values of protected logits - for (var i = 0; i < ProtectedLogits.Count; i++) - { - savedLogitValues[i] = logits[ProtectedLogits[i]]; - savedLogitIndices[i] = ProtectedLogits[i]; - } - - // Modify raw logits - foreach (var logitProcessor in LogitProcessors) - logitProcessor.ProcessLogits(ctx, logits, lastTokens); - - // Restore the values of protected logits - for (var i = 0; i < savedLogitsCount; i++) - logits[savedLogitIndices[i]] = savedLogitValues[i]; - } - finally - { - ArrayPool.Shared.Return(savedLogitValues); - ArrayPool.Shared.Return(savedLogitIndices); - } - - // Convert logits into token candidates - var candidates_p = LLamaTokenDataArray.Create(logits); - - // Process token candidates - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.ProcessTokens(ctx, candidates_p, lastTokens); - - // Select a token - var token = Selector.Select(ctx, candidates_p, lastTokens); - - // Tell processors what was selected - foreach (var logitProcessor in LogitProcessors) - logitProcessor.AcceptToken(ctx, token); - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.AcceptToken(ctx, token); - - return token; - } - - /// - public void Reset() - { - foreach (var logitProcessor in LogitProcessors) - logitProcessor.Reset(); - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.Reset(); - - Selector.Reset(); - } - - /// - public void Dispose() - { - foreach (var logitProcessor in LogitProcessors) - logitProcessor.Dispose(); - foreach (var tokenDataProcessor in TokenDataProcessors) - tokenDataProcessor.Dispose(); - - Selector.Dispose(); - } } \ No newline at end of file diff --git a/LLama/Sampling/Logits/ILogitProcessor.cs b/LLama/Sampling/Logits/ILogitProcessor.cs deleted file mode 100644 index 769684992..000000000 --- a/LLama/Sampling/Logits/ILogitProcessor.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Logits; - -using llama_token = Int32; - -/// -/// Processes raw logits before sampling, applying penalties to certain tokens -/// -public interface ILogitProcessor - : IDisposable -{ - /// - /// Process raw logits, indexed by llama_token - /// - /// The context this is operating in - /// The token data array to process - /// The most recent tokens output - /// LLamaTokenDataArray, created from logits - void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens); - - /// - /// Inform this process when a token is accepted by the model - /// - /// - /// - void AcceptToken(SafeLLamaContextHandle ctx, int token); - - /// - /// Reset all internal sampling state - /// - void Reset(); -} \ No newline at end of file diff --git a/LLama/Sampling/Logits/LogitBias.cs b/LLama/Sampling/Logits/LogitBias.cs deleted file mode 100644 index fc8215083..000000000 --- a/LLama/Sampling/Logits/LogitBias.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.Collections.Generic; -using LLama.Native; - -namespace LLama.Sampling.Logits; - -/// -/// Add a bias directly to logit values -/// -public sealed class LogitBias - : ILogitProcessor -{ - /// - /// Biases to apply, token -> bias - /// - public IDictionary Biases { get; } = new Dictionary(); - - /// - public void ProcessLogits(SafeLLamaContextHandle ctx, Span logits, ReadOnlySpan lastTokens) - { - foreach (var kvp in Biases) - logits[kvp.Key] += kvp.Value; - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/GreedySelection.cs b/LLama/Sampling/Selection/GreedySelection.cs deleted file mode 100644 index 30b724569..000000000 --- a/LLama/Sampling/Selection/GreedySelection.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select the most likely token -/// -public sealed class GreedySelection - : ITokenSelector -{ - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleTokenGreedy(ctx); - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/ITokenSelector.cs b/LLama/Sampling/Selection/ITokenSelector.cs deleted file mode 100644 index c8915a92b..000000000 --- a/LLama/Sampling/Selection/ITokenSelector.cs +++ /dev/null @@ -1,25 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select a single token from a set of possibilities -/// -public interface ITokenSelector - : IDisposable -{ - /// - /// Select a single token - /// - /// - /// - /// - /// - int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens); - - /// - /// Reset the state - /// - void Reset(); -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/Mirostat2Selection.cs b/LLama/Sampling/Selection/Mirostat2Selection.cs deleted file mode 100644 index cdc802c16..000000000 --- a/LLama/Sampling/Selection/Mirostat2Selection.cs +++ /dev/null @@ -1,65 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select a token using Mirostat sampling. -/// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. -/// -public sealed class Mirostat2Selection - : ITokenSelector -{ - private float _mu; - - /// - /// Current value of Mu, updated based on the difference between target surprise and actual surprise - /// - public float Mu - { - get => _mu; - set => _mu = value; - } - - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// - public float Tau { get; set; } - - /// - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// - public float Eta { get; set; } - - /// - /// Create a new Mirostat 2.0 sampler - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - public Mirostat2Selection(float tau, float eta) - { - Tau = tau; - Eta = eta; - } - - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleTokenMirostat2(ctx, Tau, Eta, ref _mu); - } - - /// - public void Reset() - { - _mu = 2 * Tau; - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/MirostatSelection.cs b/LLama/Sampling/Selection/MirostatSelection.cs deleted file mode 100644 index 5ec34a135..000000000 --- a/LLama/Sampling/Selection/MirostatSelection.cs +++ /dev/null @@ -1,76 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select a token using Mirostat sampling. -/// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. -/// -public sealed class MirostatSelection - : ITokenSelector -{ - private float _mu; - - /// - /// Current value of Mu, updated based on the difference between target surprise and actual surprise - /// - public float Mu - { - get => _mu; - set => _mu = value; - } - - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// - public float Tau { get; set; } - - /// - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// - public float Eta { get; set; } - - /// - /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn - /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects - /// the performance of the algorithm. - /// - public int M { get; set; } - - /// - /// Create a new Mirostat 2.0 sampler - /// - /// The target cross-entropy (or surprise) value you want to achieve for the generated text. - /// A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text - /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. - /// A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn - /// helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects - /// the performance of the algorithm. - public MirostatSelection(float tau, float eta, int m = 100) - { - Tau = tau; - Eta = eta; - M = m; - } - - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleTokenMirostat(ctx, Tau, Eta, M, ref _mu); - } - - /// - public void Reset() - { - _mu = 2 * Tau; - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Selection/StandardSelection.cs b/LLama/Sampling/Selection/StandardSelection.cs deleted file mode 100644 index 3e3bd0865..000000000 --- a/LLama/Sampling/Selection/StandardSelection.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Selection; - -/// -/// Select from all possible tokens according to their probability -/// -public sealed class StandardSelection - : ITokenSelector -{ - /// - public int Select(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan lastTokens) - { - return candidates.SampleToken(ctx); - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/GrammarSampling.cs b/LLama/Sampling/Tokens/GrammarSampling.cs deleted file mode 100644 index b823a7f92..000000000 --- a/LLama/Sampling/Tokens/GrammarSampling.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using LLama.Grammars; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Apply a grammar to prevent sampling tokens which do not match the grammar -/// -public sealed class GrammarSampling - : ITokenDataProcessor -{ - private SafeLLamaGrammarHandle? _handle; - - /// - /// Grammar to use for sampling - /// - public Grammar? Grammar { get; set; } - - /// - /// Create a new - /// - /// - public GrammarSampling(Grammar grammar) - { - Grammar = grammar; - } - - /// - public void Reset() - { - _handle?.Dispose(); - _handle = null; - } - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - // Create a new grammar instance if necessary - _handle ??= Grammar?.CreateInstance(); - - // Apply it - if (_handle != null) - tokens.ApplyGrammar(ctx, _handle); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - _handle?.AcceptToken(ctx, token); - } - - /// - public void Dispose() - { - _handle?.Dispose(); - _handle = null; - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/ITokenDataProcessor.cs b/LLama/Sampling/Tokens/ITokenDataProcessor.cs deleted file mode 100644 index e6679cc29..000000000 --- a/LLama/Sampling/Tokens/ITokenDataProcessor.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -using llama_token = Int32; - -/// -/// Processes token logits before sampling, applying penalties to certain tokens -/// -public interface ITokenDataProcessor - : IDisposable -{ - /// - /// Process token logits in a LLamaTokenDataArray - /// - /// The context this is operating in - /// The token data array to process - /// The most recent tokens output - /// LLamaTokenDataArray, created from logits - void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens); - - /// - /// Inform this process when a token is accepted by the model - /// - /// - /// - void AcceptToken(SafeLLamaContextHandle ctx, int token); - - /// - /// Reset all internal sampling state - /// - void Reset(); -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/LocallyTypicalSampling.cs b/LLama/Sampling/Tokens/LocallyTypicalSampling.cs deleted file mode 100644 index 3f602c9a7..000000000 --- a/LLama/Sampling/Tokens/LocallyTypicalSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. -/// -public sealed class LocallyTypicalSampling - : ITokenDataProcessor -{ - /// - /// P value for locally typical sampling - /// - public float P { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.LocallyTypical(ctx, P, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/MinPSampling.cs b/LLama/Sampling/Tokens/MinPSampling.cs deleted file mode 100644 index c3adf0262..000000000 --- a/LLama/Sampling/Tokens/MinPSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 -/// -public sealed class MinPSampling - : ITokenDataProcessor -{ - /// - /// All tokens with probability greater than this will be kept - /// - public float P { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.MinP(ctx, P, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/RepetitionPenalty.cs b/LLama/Sampling/Tokens/RepetitionPenalty.cs deleted file mode 100644 index 3cfdbcd46..000000000 --- a/LLama/Sampling/Tokens/RepetitionPenalty.cs +++ /dev/null @@ -1,77 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. -/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. -/// -public sealed class RepetitionPenalty - : ITokenDataProcessor -{ - private float _alphaFreq; - private float _alphaPresence; - - /// - /// Repetition penalty, as described in https://arxiv.org/abs/1909.05858 - /// - public float RepeatPenalty { get; set; } = 1.1f; - - /// - /// Frequency penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
- /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text - /// so far, decreasing the model's likelihood to repeat the same line verbatim. - ///
- public float AlphaFrequency - { - get => _alphaFreq; - set - { - if (value < -2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); - if (value > 2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); - _alphaFreq = value; - } - } - - /// - /// Presence penalty as described by OpenAI: https://platform.openai.com/docs/api-reference/chat/create
- /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the - /// text so far, increasing the model's likelihood to talk about new topics. - ///
- public float AlphaPresence - { - get => _alphaPresence; - set - { - if (value < -2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be greater than -2"); - if (value > 2) - throw new ArgumentOutOfRangeException(nameof(value), "AlphaFrequency must be less than 2"); - _alphaPresence = value; - } - } - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.RepetitionPenalty(ctx, lastTokens, RepeatPenalty, AlphaFrequency, AlphaPresence); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TailFreeSampling.cs b/LLama/Sampling/Tokens/TailFreeSampling.cs deleted file mode 100644 index 8e9fb2b51..000000000 --- a/LLama/Sampling/Tokens/TailFreeSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. -/// -public sealed class TailFreeSampling - : ITokenDataProcessor -{ - /// - /// Z value for tail free sampling - /// - public float Z { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.TailFree(ctx, Z, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TemperatureSampling.cs b/LLama/Sampling/Tokens/TemperatureSampling.cs deleted file mode 100644 index 0186f275f..000000000 --- a/LLama/Sampling/Tokens/TemperatureSampling.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Sample with temperature. -/// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual -/// -public sealed class TemperatureSampling - : ITokenDataProcessor -{ - /// - /// Temperature value to apply - /// - public float Temperature { get; set; } = 0.5f; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.Temperature(ctx, Temperature); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopKSampling.cs b/LLama/Sampling/Tokens/TopKSampling.cs deleted file mode 100644 index 3f797c85f..000000000 --- a/LLama/Sampling/Tokens/TopKSampling.cs +++ /dev/null @@ -1,38 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Sample with TopK, removing all by the K most likely tokens. -/// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -/// -public sealed class TopKSampling - : ITokenDataProcessor -{ - /// - /// Number of tokens to keep - /// - public int Count { get; set; } - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.TopK(ctx, Count); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file diff --git a/LLama/Sampling/Tokens/TopPSampling.cs b/LLama/Sampling/Tokens/TopPSampling.cs deleted file mode 100644 index 577ce3bc3..000000000 --- a/LLama/Sampling/Tokens/TopPSampling.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System; -using LLama.Native; - -namespace LLama.Sampling.Tokens; - -/// -/// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -/// -public sealed class TopPSampling - : ITokenDataProcessor -{ - /// - /// P valies for TopP - /// - public float P { get; set; } - - /// - /// Minimum number of tokens to keep - /// - public ulong MinKeep { get; set; } = 1; - - /// - public void ProcessTokens(SafeLLamaContextHandle ctx, LLamaTokenDataArray tokens, ReadOnlySpan lastTokens) - { - tokens.TopP(ctx, P, MinKeep); - } - - /// - public void AcceptToken(SafeLLamaContextHandle ctx, int token) - { - } - - /// - public void Reset() - { - } - - /// - public void Dispose() - { - } -} \ No newline at end of file