Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,23 @@ public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
{
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
/// </summary>
/// <param name="tokens"></param>
/// <param name="id"></param>
/// <param name="batch"></param>
/// <param name="n_past"></param>
/// <returns>A tuple, containing the decode result, the number of tokens that have <b>not</b> been decoded yet and the total number of tokens that have been decoded.</returns>
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
{
return Task.Run(() =>
{
var past = n_past;
var res = NativeHandle.Decode(tokens, id, batch, ref past);
return (res.Item1, res.Item2, past);
});
}
#endregion

/// <inheritdoc />
Expand Down
8 changes: 5 additions & 3 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}

/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

Expand All @@ -194,7 +194,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta

TryReuseMatchingPrefix();

var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = pastTokensCount;

if (result != DecodeResult.Ok)
throw new LLamaDecodeError(result);

Expand Down Expand Up @@ -259,7 +261,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}
}

return Task.CompletedTask;
return;
}
/// <summary>
/// The descriptor of the state of the instruct executor.
Expand Down
17 changes: 11 additions & 6 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
}

/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

Expand Down Expand Up @@ -250,27 +250,32 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta

// Changes to support Multi-Modal LLMs.
//
(DecodeResult, int) header, end, result;
(DecodeResult, int, int) header, end, result;
if (IsMultiModal && _EmbedImagePosition > 0)
{
// Tokens previous to the images
header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = header.Item3;

if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);

// Images
foreach( var image in _imageEmbedHandles )
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);

// Post-image Tokens
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = end.Item3;

_EmbedImagePosition = -1;
_imageEmbedHandles.Clear();
Images.Clear();
}
else
{
result = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = result.Item3;

if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
}

Expand Down Expand Up @@ -346,7 +351,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}
}

return Task.CompletedTask;
return;
}

/// <summary>
Expand Down
4 changes: 3 additions & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams

// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past);
var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past);
n_past = past;

if (r != DecodeResult.Ok)
throw new LLamaDecodeError(r);

Expand Down