From e89a936434c72c46390619bdbfd47f126a25a5cb Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 16 May 2025 09:35:32 -0400 Subject: [PATCH] Update to stable Microsoft.Extensions.AI.Abstractions Also take advantage of some more recently added capabilities. --- LLama/Extensions/LLamaExecutorExtensions.cs | 59 +++++++++++---------- LLama/LLamaSharp.csproj | 2 +- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/LLama/Extensions/LLamaExecutorExtensions.cs b/LLama/Extensions/LLamaExecutorExtensions.cs index 26ed9710c..61729ce16 100644 --- a/LLama/Extensions/LLamaExecutorExtensions.cs +++ b/LLama/Extensions/LLamaExecutorExtensions.cs @@ -71,7 +71,12 @@ public async Task GetResponseAsync( text.Append(token); } - return new(new ChatMessage(ChatRole.Assistant, text.ToString())) + var message = new ChatMessage(ChatRole.Assistant, text.ToString()) + { + MessageId = Guid.NewGuid().ToString("N"), + }; + + return new(message) { CreatedAt = DateTime.UtcNow, }; @@ -83,11 +88,13 @@ public async IAsyncEnumerable GetStreamingResponseAsync( { var result = _executor.InferAsync(CreatePrompt(messages), CreateInferenceParams(options), cancellationToken); + string messageId = Guid.NewGuid().ToString("N"); await foreach (var token in _outputTransform.TransformAsync(result)) { yield return new(ChatRole.Assistant, token) { CreatedAt = DateTime.UtcNow, + MessageId = messageId, }; } } @@ -124,37 +131,33 @@ private string CreatePrompt(IEnumerable messages) } /// Convert the chat options to inference parameters. - private static InferenceParams? CreateInferenceParams(ChatOptions? options) + private InferenceParams CreateInferenceParams(ChatOptions? options) { - List antiPrompts = new(s_antiPrompts); - if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList? anti) is true) - { - antiPrompts.AddRange(anti); - } + InferenceParams ip = options?.RawRepresentationFactory?.Invoke(this) as InferenceParams ?? new(); - return new() + ip.AntiPrompts = [.. s_antiPrompts, .. ip.AntiPrompts]; + ip.MaxTokens = options?.MaxOutputTokens ?? 256; // arbitrary upper limit + ip.SamplingPipeline = new DefaultSamplingPipeline() { - AntiPrompts = antiPrompts, - TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep, - MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit - SamplingPipeline = new DefaultSamplingPipeline() - { - FrequencyPenalty = options?.FrequencyPenalty ?? s_defaultPipeline.FrequencyPenalty, - PresencePenalty = options?.PresencePenalty ?? s_defaultPipeline.PresencePenalty, - PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS, - PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline, - RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty, - PenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.PenaltyCount, - Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar, - MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep, - MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP, - Seed = options?.Seed is long seed ? (uint)seed : (uint)(t_random ??= new()).Next(), - Temperature = options?.Temperature ?? s_defaultPipeline.Temperature, - TopP = options?.TopP ?? s_defaultPipeline.TopP, - TopK = options?.TopK ?? s_defaultPipeline.TopK, - TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP, - }, + FrequencyPenalty = options?.FrequencyPenalty ?? (ip.SamplingPipeline as DefaultSamplingPipeline)?.FrequencyPenalty ?? s_defaultPipeline.FrequencyPenalty, + PresencePenalty = options?.PresencePenalty ?? (ip.SamplingPipeline as DefaultSamplingPipeline)?.PresencePenalty ?? s_defaultPipeline.PresencePenalty, + PreventEOS = (ip.SamplingPipeline as DefaultSamplingPipeline)?.PreventEOS ?? s_defaultPipeline.PreventEOS, + PenalizeNewline = (ip.SamplingPipeline as DefaultSamplingPipeline)?.PenalizeNewline ?? s_defaultPipeline.PenalizeNewline, + RepeatPenalty = (ip.SamplingPipeline as DefaultSamplingPipeline)?.RepeatPenalty ?? s_defaultPipeline.RepeatPenalty, + PenaltyCount = (ip.SamplingPipeline as DefaultSamplingPipeline)?.PenaltyCount ?? s_defaultPipeline.PenaltyCount, + Grammar = (ip.SamplingPipeline as DefaultSamplingPipeline)?.Grammar ?? s_defaultPipeline.Grammar, + GrammarOptimization = (ip.SamplingPipeline as DefaultSamplingPipeline)?.GrammarOptimization ?? s_defaultPipeline.GrammarOptimization, + LogitBias = (ip.SamplingPipeline as DefaultSamplingPipeline)?.LogitBias ?? s_defaultPipeline.LogitBias, + MinKeep = (ip.SamplingPipeline as DefaultSamplingPipeline)?.MinKeep ?? s_defaultPipeline.MinKeep, + MinP = (ip.SamplingPipeline as DefaultSamplingPipeline)?.MinP ?? s_defaultPipeline.MinP, + Seed = options?.Seed is long seed ? (uint)seed : (uint)(t_random ??= new()).Next(), + Temperature = options?.Temperature ?? s_defaultPipeline.Temperature, + TopP = options?.TopP ?? s_defaultPipeline.TopP, + TopK = options?.TopK ?? s_defaultPipeline.TopK, + TypicalP = (ip.SamplingPipeline as DefaultSamplingPipeline)?.TypicalP ?? s_defaultPipeline.TypicalP, }; + + return ip; } /// A default transform that appends "Assistant: " to the end. diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index f400640d0..10476a121 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -51,7 +51,7 @@ - +