Skip to content

Commit 937ed23

Browse files
authored
Merge pull request #1102 from stephentoub/updatemeai
Update Microsoft.Extensions.AI to 9.3.0-preview.1.25114.11
2 parents 6f1862b + f4e2a7d commit 937ed23

File tree

4 files changed

+39
-27
lines changed

4 files changed

+39
-27
lines changed

LLama.Unittest/LLamaEmbedderTests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ private async Task CompareEmbeddings(string modelPath)
4343
Assert.DoesNotContain(float.NaN, spoon);
4444

4545
var generator = (IEmbeddingGenerator<string, Embedding<float>>)embedder;
46-
Assert.NotNull(generator.Metadata);
47-
Assert.Equal(nameof(LLamaEmbedder), generator.Metadata.ProviderName);
48-
Assert.NotNull(generator.Metadata.ModelId);
49-
Assert.NotEmpty(generator.Metadata.ModelId);
46+
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>());
47+
Assert.Equal(nameof(LLamaEmbedder), generator.GetService<EmbeddingGeneratorMetadata>()?.ProviderName);
48+
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>()?.ModelId);
49+
Assert.NotEmpty(generator.GetService<EmbeddingGeneratorMetadata>()?.ModelId!);
5050
Assert.Same(embedder, generator.GetService<LLamaEmbedder>());
5151
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
5252
Assert.Null(generator.GetService<string>());

LLama/Extensions/LLamaExecutorExtensions.cs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ private sealed class LLamaExecutorChatClient(
3636
IHistoryTransform? historyTransform = null,
3737
ITextStreamTransform? outputTransform = null) : IChatClient
3838
{
39+
private static readonly ChatClientMetadata s_metadata = new(nameof(LLamaExecutorChatClient));
3940
private static readonly InferenceParams s_defaultParams = new();
4041
private static readonly DefaultSamplingPipeline s_defaultPipeline = new();
4142
private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"];
@@ -47,21 +48,19 @@ private sealed class LLamaExecutorChatClient(
4748
private readonly ITextStreamTransform _outputTransform = outputTransform ??
4849
new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts);
4950

50-
/// <inheritdoc/>
51-
public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient));
52-
5351
/// <inheritdoc/>
5452
public void Dispose() { }
5553

5654
/// <inheritdoc/>
57-
public object? GetService(Type serviceType, object? key = null) =>
58-
key is not null ? null :
55+
public object? GetService(Type serviceType, object? serviceKey = null) =>
56+
serviceKey is not null ? null :
57+
serviceType == typeof(ChatClientMetadata) ? s_metadata :
5958
serviceType?.IsInstanceOfType(_executor) is true ? _executor :
6059
serviceType?.IsInstanceOfType(this) is true ? this :
6160
null;
6261

6362
/// <inheritdoc/>
64-
public async Task<ChatCompletion> CompleteAsync(
63+
public async Task<ChatResponse> GetResponseAsync(
6564
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
6665
{
6766
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);
@@ -79,7 +78,7 @@ public async Task<ChatCompletion> CompleteAsync(
7978
}
8079

8180
/// <inheritdoc/>
82-
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
81+
public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
8382
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
8483
{
8584
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);
@@ -142,8 +141,8 @@ private string CreatePrompt(IList<ChatMessage> messages)
142141
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
143142
SamplingPipeline = new DefaultSamplingPipeline()
144143
{
145-
FrequencyPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.FrequencyPenalty), out float af) is true ? af : s_defaultPipeline.FrequencyPenalty,
146-
PresencePenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PresencePenalty), out float ap) is true ? ap : s_defaultPipeline.PresencePenalty,
144+
FrequencyPenalty = options?.FrequencyPenalty ?? s_defaultPipeline.FrequencyPenalty,
145+
PresencePenalty = options?.PresencePenalty ?? s_defaultPipeline.PresencePenalty,
147146
PreventEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PreventEOS), out bool eos) is true ? eos : s_defaultPipeline.PreventEOS,
148147
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
149148
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
@@ -152,8 +151,8 @@ private string CreatePrompt(IList<ChatMessage> messages)
152151
MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep,
153152
MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP,
154153
Seed = options?.Seed is long seed ? (uint)seed : (uint)(t_random ??= new()).Next(),
155-
Temperature = options?.Temperature ?? 0,
156-
TopP = options?.TopP ?? 0,
154+
Temperature = options?.Temperature ?? s_defaultPipeline.Temperature,
155+
TopP = options?.TopP ?? s_defaultPipeline.TopP,
157156
TopK = options?.TopK ?? s_defaultPipeline.TopK,
158157
TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP,
159158
},

LLama/LLamaEmbedder.EmbeddingGenerator.cs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,31 @@ public partial class LLamaEmbedder
1414
private EmbeddingGeneratorMetadata? _metadata;
1515

1616
/// <inheritdoc />
17-
EmbeddingGeneratorMetadata IEmbeddingGenerator<string, Embedding<float>>.Metadata =>
18-
_metadata ??= new(
19-
nameof(LLamaEmbedder),
20-
modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
21-
dimensions: EmbeddingSize);
17+
object? IEmbeddingGenerator<string, Embedding<float>>.GetService(Type serviceType, object? serviceKey)
18+
{
19+
if (serviceKey is null)
20+
{
21+
if (serviceType == typeof(EmbeddingGeneratorMetadata))
22+
{
23+
return _metadata ??= new(
24+
nameof(LLamaEmbedder),
25+
modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
26+
dimensions: EmbeddingSize);
27+
}
2228

23-
/// <inheritdoc />
24-
object? IEmbeddingGenerator<string, Embedding<float>>.GetService(Type serviceType, object? key) =>
25-
key is not null ? null :
26-
serviceType?.IsInstanceOfType(Context) is true ? Context :
27-
serviceType?.IsInstanceOfType(this) is true ? this :
28-
null;
29+
if (serviceType?.IsInstanceOfType(Context) is true)
30+
{
31+
return Context;
32+
}
33+
34+
if (serviceType?.IsInstanceOfType(this) is true)
35+
{
36+
return this;
37+
}
38+
}
39+
40+
return null;
41+
}
2942

3043
/// <inheritdoc />
3144
async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)

LLama/LLamaSharp.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050

5151
<ItemGroup>
5252
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="9.0.0" />
53-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.1.0-preview.1.25064.3" />
53+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.3.0-preview.1.25114.11" />
5454
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="9.0.0" />
5555
<PackageReference Include="System.Numerics.Tensors" Version="9.0.0" />
5656
</ItemGroup>

0 commit comments

Comments
 (0)