From c4c4b0c5f642eea71fb970712582a768cc3abe74 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Mon, 8 Jul 2024 16:43:28 +0300 Subject: [PATCH 1/3] Use DecodeAsync method in LLamaExecutors --- LLama/LLamaContext.cs | 29 +++++++++++++++++++++++++++++ LLama/LLamaInstructExecutor.cs | 8 +++++--- LLama/LLamaInteractExecutor.cs | 17 +++++++++++------ LLama/LLamaStatelessExecutor.cs | 4 +++- 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index ca2e45b89..6feeeaf36 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -564,6 +564,35 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo { return Task.Run(() => Decode(batch), cancellationToken); } + + /// + /// + /// + /// + /// + /// + /// + public (DecodeResult, int) Decode(List tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past) + { + return NativeHandle.Decode(tokens, id, batch, ref n_past); + } + + /// + /// + /// + /// + /// + /// + /// + public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past) + { + return Task.Run(() => + { + var past = n_past; + var res = NativeHandle.Decode(tokens, id, batch, ref past); + return (res.Item1, res.Item2, past); + }); + } #endregion /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d6e24530f..d8a5c530d 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -177,7 +177,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } /// - protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { var batch = new LLamaBatch(); @@ -194,7 +194,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta TryReuseMatchingPrefix(); - var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = pastTokensCount; + if (result != DecodeResult.Ok) throw new LLamaDecodeError(result); @@ -259,7 +261,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } } - return Task.CompletedTask; + return; } /// /// The descriptor of the state of the instruct executor. diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index c4893e5b9..f4a4ca965 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -222,7 +222,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru } /// - protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { var batch = new LLamaBatch(); @@ -250,11 +250,13 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta // Changes to support Multi-Modal LLMs. // - (DecodeResult, int) header, end, result; + (DecodeResult, int, int) header, end, result; if (IsMultiModal && _EmbedImagePosition > 0) { // Tokens previous to the images - header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); + header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = header.Item3; + if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); // Images @@ -262,7 +264,8 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount); // Post-image Tokens - end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); + end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = end.Item3; _EmbedImagePosition = -1; _imageEmbedHandles.Clear(); @@ -270,7 +273,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } else { - result = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = result.Item3; + if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); } @@ -346,7 +351,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } } - return Task.CompletedTask; + return; } /// diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 433d9cd16..ca868b77d 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -96,7 +96,9 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Evaluate the prompt, in chunks smaller than the max batch size var n_past = 0; - var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past); + var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past); + n_past = past; + if (r != DecodeResult.Ok) throw new LLamaDecodeError(r); From 6e45ca0af8462fd179fa2f8089fb1ef88e56eae5 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Mon, 8 Jul 2024 17:23:57 +0300 Subject: [PATCH 2/3] Add doc comments for what the returned tuple elements mean. --- LLama/LLamaContext.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 6feeeaf36..35a023ec1 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -571,7 +571,7 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo /// /// /// - /// + /// A tuple, containing the decode result and the number of tokens that have not been decoded yet. public (DecodeResult, int) Decode(List tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past) { return NativeHandle.Decode(tokens, id, batch, ref n_past); @@ -583,7 +583,7 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo /// /// /// - /// + /// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded. public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past) { return Task.Run(() => From 7ffd5b19a3adf71f5ee27ab00a4df5e31a07c27f Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Mon, 8 Jul 2024 17:25:58 +0300 Subject: [PATCH 3/3] Remove unused sync API method, as the n_past style of calling Decode is legacy --- LLama/LLamaContext.cs | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 35a023ec1..93c3e74ab 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -565,18 +565,6 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo return Task.Run(() => Decode(batch), cancellationToken); } - /// - /// - /// - /// - /// - /// - /// A tuple, containing the decode result and the number of tokens that have not been decoded yet. - public (DecodeResult, int) Decode(List tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past) - { - return NativeHandle.Decode(tokens, id, batch, ref n_past); - } - /// /// ///