Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions LLama.Examples/Examples/QuantizeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public static async Task Run()
{
Console.WriteLine("Quantization failed!");
}

await Task.CompletedTask;
}
}
}
2 changes: 1 addition & 1 deletion LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ public void Remove(LLamaPos start, LLamaPos end)
}

/// <summary>
/// Removes <see cref="count"/> tokens starting from <see cref="start"/>
/// Removes <paramref name="count"/> tokens starting from <paramref name="start"/>
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="count">Number of tokens</param>
Expand Down
1 change: 0 additions & 1 deletion LLama/Common/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ public class FixedSizeQueue<T>
private readonly T[] _buffer;
private int _start;
private int _count;
private T[]? _window;

// Minimum capacity for the temporary buffer used to expose a contiguous view.
private const int MinimumWindowSize = 4;
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};

AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? Array.Empty<string>());
AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? []);

await PreprocessInputs(text, args);

Expand Down
12 changes: 8 additions & 4 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public override async Task SaveState(string filename)
await JsonSerializer.SerializeAsync(fs, state);
}
}

/// <inheritdoc />
public override async Task LoadState(string filename)
{
Expand Down Expand Up @@ -154,19 +155,19 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput))
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, []));
}

if (_pastTokensCount > 0 && args.WaitForInput)
{
return (true, new[] { "\n> " });
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, [ "\n> " ]));
}
}

Expand All @@ -180,7 +181,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
return (false, Array.Empty<string>());
return Task.FromResult<(bool, IReadOnlyList<string>)>((false, []));
}

/// <inheritdoc />
Expand All @@ -205,7 +206,9 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
_pastTokensCount = pastTokensCount;

if (result != DecodeResult.Ok)
{
throw new LLamaDecodeError(result);
}

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand Down Expand Up @@ -250,6 +253,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In

return;
}

/// <summary>
/// The descriptor of the state of the instruct executor.
/// </summary>
Expand Down
53 changes: 31 additions & 22 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace LLama
public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;

// LLava
private int _EmbedImagePosition = -1;
private List<SafeLlavaImageEmbedHandle> _imageEmbedHandles = new List<SafeLlavaImageEmbedHandle>();
Expand All @@ -36,7 +36,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
}

/// <summary>
///
/// </summary>
Expand All @@ -46,7 +46,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null)
: base(context, clipModel, logger)
{
}
}

/// <inheritdoc />
public override ExecutorBaseState GetStateData()
Expand All @@ -67,6 +67,7 @@ public override ExecutorBaseState GetStateData()
};
return state;
}

/// <inheritdoc />
public override Task LoadState(ExecutorBaseState data)
{
Expand All @@ -88,23 +89,23 @@ public override Task LoadState(ExecutorBaseState data)

return Task.CompletedTask;
}

/// <inheritdoc />
public override async Task SaveState(string filename)
{
var state = (InteractiveExecutorState)GetStateData();
using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
await JsonSerializer.SerializeAsync(fs, state);
}
}

/// <inheritdoc />
public override async Task LoadState(string filename)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state!);
}
using var fs = new FileStream(filename, FileMode.Open, FileAccess.Read);
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state!);
}

/// <summary>
Expand All @@ -122,7 +123,11 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
if (_is_prompt_run)
{
// When running the first input (prompt) in interactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
if (text == null)
{
throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
}

if (!IsMultiModal)
{
_embed_inps = Context.Tokenize(text, true, true).ToList();
Expand Down Expand Up @@ -159,8 +164,8 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}

/// <inheritdoc />
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
{
private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true)
{
// If the prompt contains the tag <image> extract this.
_imageInPrompt = text.Contains("<image>");
if (_imageInPrompt && IsMultiModal)
Expand Down Expand Up @@ -191,7 +196,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
args.RemainedTokens -= line_inp.Length;
args.RemainedTokens -= line_inp.Length;
}
}
return Task.CompletedTask;
Expand All @@ -203,20 +208,24 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput))
{
args.WaitForInput = true;
}

if (_pastTokensCount > 0 && args.WaitForInput)
return (true, Array.Empty<string>());
{
return Task.FromResult((true, (IReadOnlyList<string>)[]));
}
}

if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab))
{
return (true, Array.Empty<string>());
return Task.FromResult((true, (IReadOnlyList<string>)[]));
}

if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
Expand All @@ -225,7 +234,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
args.WaitForInput = true;
}

return (false, Array.Empty<string>());
return Task.FromResult((true, (IReadOnlyList<string>)[]));
}

/// <inheritdoc />
Expand Down Expand Up @@ -258,18 +267,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
// Changes to support Multi-Modal LLMs.
//
(DecodeResult, int, int) header, end, result;
if (IsMultiModal && _EmbedImagePosition > 0)
if (IsMultiModal && _EmbedImagePosition > 0)
{
// Tokens previous to the images
header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = header.Item3;

if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);

// Images
foreach( var image in _imageEmbedHandles )
foreach (var image in _imageEmbedHandles)
ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount);

// Post-image Tokens
end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = end.Item3;
Expand All @@ -285,7 +294,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In

if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
}


if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand Down
2 changes: 2 additions & 0 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern uint llama_model_n_cls_out(SafeLlamaModelHandle model);

/// <summary>
Expand All @@ -444,6 +445,7 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
/// <param name="model"></param>
/// <param name="i"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern string? llama_model_cls_label(SafeLlamaModelHandle model, uint i);
#endregion

Expand Down
Loading