Skip to content

Commit 6025979

Browse files
authored
Merge pull request #902 from martindevans/llama_embedder_2
LLamaEmbedder 2.0
2 parents df8cc71 + a3028de commit 6025979

File tree

15 files changed

+427
-166
lines changed

15 files changed

+427
-166
lines changed

LLama.Examples/ExampleRunner.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ public class ExampleRunner
1919
{ "Executor: Stateless mode chat", StatelessModeExecute.Run },
2020
{ "Save and Load: chat session", SaveAndLoadSession.Run },
2121
{ "Save and Load: state of model and executor", LoadAndSaveState.Run },
22-
{ "LLama Model: Get embeddings", () => Task.Run(GetEmbeddings.Run) },
23-
{ "LLama Model: Quantize", () => Task.Run(QuantizeModel.Run) },
22+
{ "LLama Model: Get embeddings", GetEmbeddings.Run },
23+
{ "LLama Model: Quantize", QuantizeModel.Run },
2424
{ "Grammar: Constrain response to json format", GrammarJsonResponse.Run },
2525
{ "Kernel Memory: Document Q&A", KernelMemory.Run },
2626
{ "Kernel Memory: Save and Load", KernelMemorySaveAndLoad.Run },

LLama.Examples/Examples/GetEmbeddings.cs

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
using LLama.Common;
2+
using LLama.Native;
23

34
namespace LLama.Examples.Examples
45
{
56
public class GetEmbeddings
67
{
7-
public static void Run()
8+
public static async Task Run()
89
{
910
string modelPath = UserSettings.GetModelPath();
1011

1112
Console.ForegroundColor = ConsoleColor.DarkGray;
12-
var @params = new ModelParams(modelPath) { Embeddings = true };
13+
var @params = new ModelParams(modelPath)
14+
{
15+
// Embedding models can return one embedding per token, or all of them can be combined ("pooled") into
16+
// one single embedding. Setting PoolingType to "Mean" will combine all of the embeddings using mean average.
17+
PoolingType = LLamaPoolingType.Mean,
18+
};
1319
using var weights = LLamaWeights.LoadFromFile(@params);
1420
var embedder = new LLamaEmbedder(weights, @params);
1521

1622
Console.ForegroundColor = ConsoleColor.Yellow;
1723
Console.WriteLine(
1824
"""
1925
This example displays embeddings from a text prompt.
20-
Embeddings are numerical codes that represent information like words, images, or concepts.
21-
These codes capture important relationships between those objects,
26+
Embeddings are vectors that represent information like words, images, or concepts.
27+
These vector capture important relationships between those objects,
2228
like how similar words are in meaning or how close images are visually.
2329
This allows machine learning models to efficiently understand and process complex data.
2430
Embeddings of a text in LLM is sometimes useful, for example, to train other MLP models.
25-
"""); // NOTE: this description was AI generated
31+
""");
2632

2733
while (true)
2834
{
@@ -32,8 +38,13 @@ This allows machine learning models to efficiently understand and process comple
3238
var text = Console.ReadLine();
3339
Console.ForegroundColor = ConsoleColor.White;
3440

35-
float[] embeddings = embedder.GetEmbeddings(text).Result;
36-
Console.WriteLine($"Embeddings contain {embeddings.Length:N0} floating point values:");
41+
// Get embeddings for the text
42+
var embeddings = await embedder.GetEmbeddings(text);
43+
44+
// This should have returned one single embedding vector, because PoolingType was set to Mean above.
45+
var embedding = embeddings.Single();
46+
47+
Console.WriteLine($"Embeddings contain {embedding.Length:N0} floating point values:");
3748
Console.ForegroundColor = ConsoleColor.DarkGray;
3849
Console.WriteLine(string.Join(", ", embeddings.Take(20)) + ", ...");
3950
Console.WriteLine();

LLama.Examples/Examples/QuantizeModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
namespace LLama.Examples.Examples
1+
namespace LLama.Examples.Examples
22
{
33
public class QuantizeModel
44
{
5-
public static void Run()
5+
public static async Task Run()
66
{
77
string inputPath = UserSettings.GetModelPath();
88

LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LLama;
22
using LLama.Common;
3+
using LLama.Native;
34
using Microsoft.KernelMemory;
45
using Microsoft.KernelMemory.AI;
56

@@ -35,7 +36,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
3536
GpuLayerCount = config.GpuLayerCount ?? 20,
3637
Embeddings = true,
3738
MainGpu = config.MainGpu,
38-
SplitMode = config.SplitMode
39+
SplitMode = config.SplitMode,
40+
PoolingType = LLamaPoolingType.Mean,
3941
};
4042
_weights = LLamaWeights.LoadFromFile(@params);
4143
_embedder = new LLamaEmbedder(_weights, @params);
@@ -59,7 +61,8 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
5961
GpuLayerCount = config.GpuLayerCount ?? 20,
6062
Embeddings = true,
6163
MainGpu = config.MainGpu,
62-
SplitMode = config.SplitMode
64+
SplitMode = config.SplitMode,
65+
PoolingType = LLamaPoolingType.Mean,
6366
};
6467
_weights = weights;
6568
_embedder = new LLamaEmbedder(_weights, @params);
@@ -92,7 +95,7 @@ public void Dispose()
9295
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
9396
{
9497
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
95-
return new Embedding(embeddings);
98+
return new Embedding(embeddings.First());
9699
}
97100

98101
/// <inheritdoc/>

LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
namespace LLamaSharp.SemanticKernel.TextEmbedding;
66

7-
public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
7+
public sealed class LLamaSharpEmbeddingGeneration
8+
: ITextEmbeddingGenerationService
89
{
910
private readonly LLamaEmbedder _embedder;
1011

@@ -23,7 +24,7 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<st
2324
var result = new List<ReadOnlyMemory<float>>();
2425

2526
foreach (var item in data)
26-
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));
27+
result.Add((await _embedder.GetEmbeddings(item, cancellationToken)).First());
2728

2829
return result;
2930
}

LLama.Unittest/LLamaEmbedderTests.cs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using LLama.Common;
2+
using LLama.Extensions;
3+
using LLama.Native;
24
using Xunit.Abstractions;
35

46
namespace LLama.Unittest;
@@ -24,19 +26,19 @@ private async Task CompareEmbeddings(string modelPath)
2426
{
2527
ContextSize = 8,
2628
Threads = 4,
27-
Embeddings = true,
2829
GpuLayerCount = Constants.CIGpuLayerCount,
30+
PoolingType = LLamaPoolingType.Mean,
2931
};
3032
using var weights = LLamaWeights.LoadFromFile(@params);
3133
using var embedder = new LLamaEmbedder(weights, @params);
3234

33-
var cat = await embedder.GetEmbeddings("The cat is cute");
35+
var cat = (await embedder.GetEmbeddings("The cat is cute")).Single().EuclideanNormalization();
3436
Assert.DoesNotContain(float.NaN, cat);
3537

36-
var kitten = await embedder.GetEmbeddings("The kitten is kawaii");
38+
var kitten = (await embedder.GetEmbeddings("The kitten is cute")).Single().EuclideanNormalization();
3739
Assert.DoesNotContain(float.NaN, kitten);
3840

39-
var spoon = await embedder.GetEmbeddings("The spoon is not real");
41+
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
4042
Assert.DoesNotContain(float.NaN, spoon);
4143

4244
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
@@ -64,4 +66,33 @@ public async Task EmbedCompareGenerateModel()
6466
{
6567
await CompareEmbeddings(Constants.GenerativeModelPath);
6668
}
69+
70+
private async Task NonPooledEmbeddings(string modelPath)
71+
{
72+
var @params = new ModelParams(modelPath)
73+
{
74+
ContextSize = 8,
75+
Threads = 4,
76+
GpuLayerCount = Constants.CIGpuLayerCount,
77+
PoolingType = LLamaPoolingType.None,
78+
};
79+
using var weights = LLamaWeights.LoadFromFile(@params);
80+
using var embedder = new LLamaEmbedder(weights, @params);
81+
82+
var kitten = await embedder.GetEmbeddings("the kitten is kawaii");
83+
foreach (var embd in kitten)
84+
Assert.DoesNotContain(float.NaN, embd);
85+
}
86+
87+
[Fact]
88+
public async Task EmbeddingModelNonPooledEmbeddings()
89+
{
90+
await NonPooledEmbeddings(Constants.EmbeddingModelPath);
91+
}
92+
93+
[Fact]
94+
public async Task GenerativeModelNonPooledEmbeddings()
95+
{
96+
await NonPooledEmbeddings(Constants.GenerativeModelPath);
97+
}
6798
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
using System;
2+
using System.Numerics.Tensors;
3+
4+
namespace LLama.Extensions;
5+
6+
/// <summary>
7+
/// Extensions to span which apply <b>in-place</b> normalization
8+
/// </summary>
9+
public static class SpanNormalizationExtensions
10+
{
11+
/// <summary>
12+
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
13+
/// </summary>
14+
/// <param name="vector"></param>
15+
/// <returns>The same array</returns>
16+
public static float[] MaxAbsoluteNormalization(this float[] vector)
17+
{
18+
vector.AsSpan().MaxAbsoluteNormalization();
19+
return vector;
20+
}
21+
22+
/// <summary>
23+
/// <b>In-place</b> multiple every element by 32760 and divide every element in the span by the max absolute value in the span
24+
/// </summary>
25+
/// <param name="vector"></param>
26+
/// <returns>The same span</returns>
27+
public static Span<float> MaxAbsoluteNormalization(this Span<float> vector)
28+
{
29+
var factor = 32760 / TensorPrimitives.MaxMagnitude(vector);
30+
TensorPrimitives.Multiply(vector, factor, vector);
31+
return vector;
32+
}
33+
34+
/// <summary>
35+
/// <b>In-place</b> divide every element in the array by the sum of absolute values in the array
36+
/// </summary>
37+
/// <remarks>Also known as "Manhattan normalization".</remarks>
38+
/// <param name="vector"></param>
39+
/// <returns>The same array</returns>
40+
public static float[] TaxicabNormalization(this float[] vector)
41+
{
42+
vector.AsSpan().TaxicabNormalization();
43+
return vector;
44+
}
45+
46+
/// <summary>
47+
/// <b>In-place</b> divide every element in the span by the sum of absolute values in the span
48+
/// </summary>
49+
/// <remarks>Also known as "Manhattan normalization".</remarks>
50+
/// <param name="vector"></param>
51+
/// <returns>The same span</returns>
52+
public static Span<float> TaxicabNormalization(this Span<float> vector)
53+
{
54+
var sumAbs = TensorPrimitives.SumOfMagnitudes(vector);
55+
TensorPrimitives.Divide(vector, sumAbs, vector);
56+
return vector;
57+
}
58+
59+
/// <summary>
60+
/// <b>In-place</b> divide every element by the euclidean length of the vector
61+
/// </summary>
62+
/// <remarks>Also known as "L2 normalization".</remarks>
63+
/// <param name="vector"></param>
64+
/// <returns>The same array</returns>
65+
public static float[] EuclideanNormalization(this float[] vector)
66+
{
67+
vector.AsSpan().EuclideanNormalization();
68+
return vector;
69+
}
70+
71+
/// <summary>
72+
/// <b>In-place</b> divide every element by the euclidean length of the vector
73+
/// </summary>
74+
/// <remarks>Also known as "L2 normalization".</remarks>
75+
/// <param name="vector"></param>
76+
/// <returns>The same span</returns>
77+
public static Span<float> EuclideanNormalization(this Span<float> vector)
78+
{
79+
var norm = TensorPrimitives.Norm(vector);
80+
TensorPrimitives.Divide(vector, norm, vector);
81+
return vector;
82+
}
83+
84+
/// <summary>
85+
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
86+
/// <list type="bullet">
87+
/// <item>For p = 1, this is taxicab normalization</item>
88+
/// <item>For p = 2, this is euclidean normalization</item>
89+
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
90+
/// </list>
91+
/// </summary>
92+
/// <param name="vector"></param>
93+
/// <param name="p"></param>
94+
/// <returns>The same array</returns>
95+
public static float[] PNormalization(this float[] vector, int p)
96+
{
97+
vector.AsSpan().PNormalization(p);
98+
return vector;
99+
}
100+
101+
/// <summary>
102+
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
103+
/// <list type="bullet">
104+
/// <item>For p = 1, this is taxicab normalization</item>
105+
/// <item>For p = 2, this is euclidean normalization</item>
106+
/// <item>As p => infinity, this approaches infinity norm or maximum norm</item>
107+
/// </list>
108+
/// </summary>
109+
/// <param name="vector"></param>
110+
/// <param name="p"></param>
111+
/// <returns>The same span</returns>
112+
public static Span<float> PNormalization(this Span<float> vector, int p)
113+
{
114+
if (p == 2)
115+
return vector.EuclideanNormalization();
116+
117+
var sum = 0.0;
118+
for (var i = 0; i < vector.Length; i++)
119+
sum += MathF.Pow(vector[i], p);
120+
var divisor = (float)Math.Pow(sum, 1.0 / p);
121+
122+
TensorPrimitives.Divide(vector, divisor, vector);
123+
124+
return vector;
125+
}
126+
}

LLama/LLamaContext.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,28 @@ public bool ShouldAddBosToken()
379379
}
380380

381381
#region eval overloads
382+
/// <summary>
383+
/// </summary>
384+
/// <param name="batch"></param>
385+
public EncodeResult Encode(LLamaBatch batch)
386+
{
387+
if (batch.TokenCount == 0)
388+
return 0;
389+
if (batch.TokenCount > BatchSize)
390+
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
391+
392+
return (EncodeResult)NativeHandle.Encode(batch);
393+
}
394+
395+
/// <summary>
396+
/// </summary>
397+
/// <param name="batch"></param>
398+
/// <param name="cancellationToken"></param>
399+
public Task<EncodeResult> EncodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
400+
{
401+
return Task.Run(() => Encode(batch), cancellationToken);
402+
}
403+
382404
/// <summary>
383405
/// </summary>
384406
/// <param name="batch"></param>

0 commit comments

Comments
 (0)