Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions LLama.Examples/Examples/CustomGenerationControl.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Control;
using LLama.Examples.Extensions;

namespace LLama.Examples.Examples
{
public class CustomGenerationControl
{
public class NumberGenerationControl: IGenerationControl
{
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText)
{
return lastOutputText.Any(x => char.IsDigit(x) && (x == '4' || x == '5'));
}

public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId)
{
return false;
}
}
public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath)
{
ContextSize = 1024,
Seed = 1337,
GpuLayerCount = 5
};
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("This is an example to show how to customize the generation control of the executors. Here we implement a control mode in which" +
" the generation will stop once there's a number 4 or 5 is generated. Please try different questions to lead the model to generate answers with and without numbers." +
" No anti-prompt is used in this example.");
Console.ForegroundColor = ConsoleColor.White;

var inferenceParams = new InferenceParams() { Temperature = 0.6f, MaxTokens = 60, GenerationControl = new NumberGenerationControl() };

while (true)
{
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner())
{
Console.Write(text);
}
}
}
}
}
1 change: 1 addition & 0 deletions LLama.Examples/Examples/Runner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class Runner
{ "Interactive mode chat by using executor.", InteractiveModeExecute.Run },
{ "Instruct mode chat by using executor.", InstructModeExecute.Run },
{ "Stateless mode chat by using executor.", StatelessModeExecute.Run },
{ "Customize the generation control of executor.", CustomGenerationControl.Run },
{ "Load and save chat session.", SaveAndLoadSession.Run },
{ "Load and save state of model and executor.", LoadAndSaveState.Run },
{ "Get embeddings from LLama model.", () => Task.Run(GetEmbeddings.Run) },
Expand Down
1 change: 1 addition & 0 deletions LLama.Unittest/TokenTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Text;
using LLama.Common;
using LLama.Control;
using LLama.Extensions;

namespace LLama.Unittest;
Expand Down
4 changes: 4 additions & 0 deletions LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using LLama.Abstractions;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;

namespace LLama.Web.Common
{
Expand Down Expand Up @@ -71,5 +72,8 @@ public class InferenceOptions

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }

/// <inheritdoc />
public IGenerationControl? GenerationControl { get; set; }
}
}
8 changes: 7 additions & 1 deletion LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Collections.Generic;
using LLama.Common;
using LLama.Control;
using LLama.Native;
using LLama.Sampling;

Expand Down Expand Up @@ -114,5 +115,10 @@ public interface IInferenceParams
/// Set a custom sampling pipeline to use. <b>If this is set All other sampling parameters are ignored!</b>
/// </summary>
ISamplingPipeline? SamplingPipeline { get; set; }
}

/// <summary>
/// Set a custom generation control to use. <b>If this is set antiprompt will be ignored!</b>
/// </summary>
IGenerationControl? GenerationControl { get; set; }
}
}
4 changes: 4 additions & 0 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;

namespace LLama.Common
{
Expand Down Expand Up @@ -80,6 +81,9 @@ public record InferenceParams

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }

/// <inheritdoc />
public IGenerationControl? GenerationControl { get; set; }
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;

namespace LLama
namespace LLama.Control
{
/// <summary>
/// AntipromptProcessor keeps track of past tokens looking for any set Anti-Prompts
Expand Down
42 changes: 42 additions & 0 deletions LLama/Control/DefaultGenerationControl.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// The default generation control in LLamaSharp, using antiprompts. This class should not be inherited.
/// <b>Note that this class has state. The previous outputs feeded to it will affect its control.</b>
/// If you use it in a session, please don't reuse it for another session unless you intend to do so.
/// </summary>
public sealed class DefaultGenerationControl : IGenerationControl
{
private AntipromptProcessor _antipromptProcessor;

/// <summary>
/// <inheritdoc/>
/// </summary>
public DefaultGenerationControl()
{
_antipromptProcessor = new AntipromptProcessor();
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText)
{
_antipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts);
return _antipromptProcessor.Add(lastOutputText);
}

/// <summary>
/// <inheritdoc/>
/// </summary>
public bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId)
{
return context.IsEOS(lastOutputId);
}
}
}
35 changes: 35 additions & 0 deletions LLama/Control/IGenerationControl.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Control
{
/// <summary>
/// Control the text generation of LLama Executors.
/// </summary>
public interface IGenerationControl
{
/// <summary>
/// Use the last output text to determine if the generation should stop.
/// This method will be called after the overload with output id.
/// The text will be returned even if this returns true but the generation will be stopped.
/// </summary>
/// <param name="context">The LLamaContext used in the current generation.</param>
/// <param name="inferenceParams">The inference params used in the current generation.</param>
/// <param name="lastOutputText">The last output text generated by the model.</param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, string lastOutputText);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought - is it possible to decouple this interface from the IInferenceParams somehow? It feels a bit odd that GenerationControl is overriding parts of IInferenceParams and is also configured by it.

The obvious way to do that would be to remove the inferenceParams parameter here, that would introduce problems in DefaultGenerationControl though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I also thought of that and I could come up with the following alternatives:

  1. Change the API of ILLamaExecutor to InferAsync(string text, IInferenceParams? inferenceParams = null, IGenerationControl? control = null, CancellationToken token = default); or similar formats. It will introduce some break changes though.
  2. Use generic type. Change the ILLamaExecutor to ILLamaExecutor<TControl>, or change the method InferAsync.
  3. Use chained-call. For example, in the sight of users, they could do var executor = new StatelessExecutor().WithControl().

Each of them has advantages and disadvantages. The most important one I think is that whether we should make it class level or method level. Actually I can't think up a case that the user must use different control strategies in different calls of an executor, though it does have such flexibility. If to get a compromise of it, I'd like to suggest the following proposal:

  1. Use the option 3 above, the chained-call for ILLamaExecutor, taking WithGenerationControl as an extension method.
  2. Define a static method to execute the generation, just like what I did in [WIP] refactor: init some experimental refactoring. #362. In this static method, we could introduce any break change, like option 1 above. Since dotnet 7, we could define a static method in the interface but in older version it's not supported. Therefore if taking this approach, some parts of the design may look a little awkward.

BTW there's one thing related with it, the sampling pipeline. Currently InferenceParams contains many params of sampling, while specifying SamplingPipeline will totally override them. It's completely okay in the master branch because we don't want to introduce substantial break changes now. However I'm wondering if we should refactor it before v1.0.0.
For example, we could add many kinds of sampling pipelines as we discussed here, along with a class named CustomSamplingPipeline, which accepts all the parameters related with samping in InferenceParams now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more along the lines of removing the IInferenceParams inferenceParams from the ShouldStopGeneration methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's completely okay in the master branch because we don't want to introduce substantial break changes now. However I'm wondering if we should refactor it before v1.0.0.

Yeah this is my thinking. It feels like we're moving in a direction where we can completely get rid of the inference params in the future (which is part of why I want to avoid depending on it in this new interface, if possible).


/// <summary>
/// Use the last output token to determine if the generation should stop.
/// This method will be called before the overload with output text.
/// The token will be returned even if this returns true but the generation will be stopped.
/// </summary>
/// <param name="context">The LLamaContext used in the current generation.</param>
/// <param name="inferenceParams">The inference params used in the current generation.</param>
/// <param name="lastOutputId">The last output token generated by the model.</param>
/// <returns></returns>
bool ShouldStopGeneration(LLamaContext context, IInferenceParams inferenceParams, int lastOutputId);
}
}
27 changes: 27 additions & 0 deletions LLama/Extensions/GenerationControlExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using LLama.Abstractions;
using LLama.Control;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace LLama.Extensions
{
/// <summary>
/// Extension methods for generation control
/// </summary>
public static class GenerationControlExtensions
{
public static bool ShouldStopGeneration(this IGenerationControl control, LLamaContext context, IInferenceParams inferenceParams, IEnumerable<int> lastOutputIds)
{
foreach (var id in lastOutputIds)
{
if(control.ShouldStopGeneration(context, inferenceParams, id))
{
return true;
}
}
return false;
}
}
}
9 changes: 9 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger
}

/// <summary>
/// Return if a token marks the end of a sentence.
/// </summary>
/// <param name="token"></param>
/// <returns></returns>
public bool IsEOS(int token)
{
return NativeApi.llama_token_eos(this.NativeHandle.ModelHandle) == token;
}

/// Set the seed for the RNG
/// </summary>
/// <param name="seed"></param>
Expand Down
9 changes: 8 additions & 1 deletion LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// The last tokens generated by the model.
/// </summary>
protected FixedSizeQueue<llama_token> _last_n_tokens;

/// <summary>
/// The last output text generated by the model.
/// </summary>
protected string _lastOutputText = string.Empty;
/// <summary>
/// The context used by the executor.
/// </summary>
Expand Down Expand Up @@ -299,9 +304,11 @@ public virtual async IAsyncEnumerable<string> InferAsync(string text, IInference
if (args.ReturnValue)
{
_decoder.AddRange(_embeds);
yield return _decoder.Read();
_lastOutputText = _decoder.Read();
yield return _lastOutputText;
}

// TODO(Rinne): Refactor the logic here.
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
{
Expand Down
14 changes: 13 additions & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Threading.Tasks;
using LLama.Extensions;
using Microsoft.Extensions.Logging;
using LLama.Control;

namespace LLama
{
Expand All @@ -21,6 +22,7 @@ public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;
private readonly llama_token _llama_token_newline;
private IGenerationControl _control;

/// <summary>
///
Expand All @@ -31,6 +33,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
_llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
_control = new DefaultGenerationControl();
}

/// <inheritdoc />
Expand Down Expand Up @@ -134,8 +137,17 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
var control = inferenceParams.GenerationControl ?? _control;
// Get stop signal by ids
if(control.ShouldStopGeneration(Context, inferenceParams, _embeds))
{
args.WaitForInput = true;
}
// Get stop signal by text
else if (control.ShouldStopGeneration(Context, inferenceParams, _lastOutputText))
{
args.WaitForInput = true;
}

if (_pastTokensCount > 0 && args.WaitForInput)
return (true, Array.Empty<string>());
Expand Down
16 changes: 11 additions & 5 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Threading.Tasks;
using LLama.Native;
using LLama.Sampling;
using LLama.Control;
using Microsoft.Extensions.Logging;

namespace LLama
Expand Down Expand Up @@ -66,8 +67,8 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");

// Create decoders for the token stream
IGenerationControl control = inferenceParams.GenerationControl ?? new DefaultGenerationControl();
var decoder = new StreamingTokenDecoder(Context);
var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);

// Keep track of the last N tokens emitted
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
Expand Down Expand Up @@ -113,12 +114,17 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
var decoded = decoder.Read();
yield return decoded;

// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
break;

lastTokens.Add(id);
tokens.Clear();

// Check if we should steop generation by ids
if (control.ShouldStopGeneration(context, inferenceParams, id))
break;
// Check if we should steop generation by text
if (control.ShouldStopGeneration(Context, inferenceParams, decoded))
break;

// prepare for the next loop
tokens.Add(id);

// when run out of context
Expand Down