Skip to content

Commit 3523c51

Browse files
authored
Merge pull request #474 from martindevans/embeddings_generator_decode
Swapped `GetEmbeddings` to `llama_decode`
2 parents 3b08874 + c9c8cd0 commit 3523c51

File tree

4 files changed

+68
-53
lines changed

4 files changed

+68
-53
lines changed

LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
using LLama;
2-
using LLama.Abstractions;
32
using LLama.Common;
43
using Microsoft.KernelMemory;
54
using Microsoft.KernelMemory.AI;
6-
using Microsoft.SemanticKernel.AI.Embeddings;
7-
using System;
8-
using System.Collections.Generic;
9-
using System.Linq;
10-
using System.Text;
11-
using System.Threading.Tasks;
125

136
namespace LLamaSharp.KernelMemory
147
{
@@ -80,24 +73,24 @@ public void Dispose()
8073
}
8174

8275
/// <inheritdoc/>
83-
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
76+
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
8477
{
8578
IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();
8679

8780
foreach (var d in data)
8881
{
89-
var embeddings = _embedder.GetEmbeddings(d);
82+
var embeddings = await _embedder.GetEmbeddings(d, cancellationToken);
9083
results.Add(new ReadOnlyMemory<float>(embeddings));
9184
}
9285

93-
return Task.FromResult(results);
86+
return results;
9487
}
9588

9689
/// <inheritdoc/>
97-
public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
90+
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
9891
{
99-
var embeddings = _embedder.GetEmbeddings(text);
100-
return Task.FromResult(new Embedding(embeddings));
92+
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
93+
return new Embedding(embeddings);
10194
}
10295

10396
/// <inheritdoc/>

LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding;
66

77
public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
88
{
9-
private LLamaEmbedder _embedder;
9+
private readonly LLamaEmbedder _embedder;
1010

1111
private readonly Dictionary<string, object?> _attributes = new();
1212

@@ -20,7 +20,11 @@ public LLamaSharpEmbeddingGeneration(LLamaEmbedder embedder)
2020
/// <inheritdoc/>
2121
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
2222
{
23-
var embeddings = data.Select(text => new ReadOnlyMemory<float>(_embedder.GetEmbeddings(text))).ToList();
24-
return await Task.FromResult(embeddings);
23+
var result = new List<ReadOnlyMemory<float>>();
24+
25+
foreach (var item in data)
26+
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));
27+
28+
return result;
2529
}
2630
}

LLama.Unittest/LLamaEmbedderTests.cs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
using LLama.Common;
2+
using Xunit.Abstractions;
23

34
namespace LLama.Unittest;
45

56
public sealed class LLamaEmbedderTests
67
: IDisposable
78
{
9+
private readonly ITestOutputHelper _testOutputHelper;
810
private readonly LLamaEmbedder _embedder;
911

10-
public LLamaEmbedderTests()
12+
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
1113
{
14+
_testOutputHelper = testOutputHelper;
1215
var @params = new ModelParams(Constants.ModelPath)
1316
{
1417
EmbeddingMode = true,
@@ -41,21 +44,23 @@ private static float Dot(float[] a, float[] b)
4144
}
4245

4346
[Fact]
44-
public void EmbedCompare()
47+
public async Task EmbedCompare()
4548
{
46-
var cat = _embedder.GetEmbeddings("cat");
47-
var kitten = _embedder.GetEmbeddings("kitten");
48-
var spoon = _embedder.GetEmbeddings("spoon");
49+
var cat = await _embedder.GetEmbeddings("cat");
50+
var kitten = await _embedder.GetEmbeddings("kitten");
51+
var spoon = await _embedder.GetEmbeddings("spoon");
4952

5053
Normalize(cat);
5154
Normalize(kitten);
5255
Normalize(spoon);
5356

54-
var close = Dot(cat, kitten);
55-
var far = Dot(cat, spoon);
57+
var close = 1 - Dot(cat, kitten);
58+
var far = 1 - Dot(cat, spoon);
5659

57-
// This comparison seems backwards, but remember that with a
58-
// dot product 1.0 means **identical** and 0.0 means **completely opposite**!
59-
Assert.True(close > far);
60+
Assert.True(close < far);
61+
62+
_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
63+
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
64+
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
6065
}
6166
}

LLama/LLamaEmbedder.cs

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using LLama.Exceptions;
44
using LLama.Abstractions;
55
using Microsoft.Extensions.Logging;
6+
using System.Threading;
7+
using System.Threading.Tasks;
68

79
namespace LLama
810
{
@@ -40,50 +42,61 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg
4042
/// Get the embeddings of the text.
4143
/// </summary>
4244
/// <param name="text"></param>
43-
/// <param name="threads">unused</param>
44-
/// <param name="addBos">Add bos to the text.</param>
45-
/// <param name="encoding">unused</param>
45+
/// <param name="cancellationToken"></param>
4646
/// <returns></returns>
4747
/// <exception cref="RuntimeError"></exception>
48-
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
49-
// ReSharper disable once MethodOverloadWithOptionalParameter
50-
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
48+
public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default)
5149
{
52-
return GetEmbeddings(text, addBos);
53-
}
54-
55-
/// <summary>
56-
/// Get the embeddings of the text.
57-
/// </summary>
58-
/// <param name="text"></param>
59-
/// <returns></returns>
60-
/// <exception cref="RuntimeError"></exception>
61-
public float[] GetEmbeddings(string text)
62-
{
63-
return GetEmbeddings(text, true);
50+
return GetEmbeddings(text, true, cancellationToken);
6451
}
6552

6653
/// <summary>
6754
/// Get the embeddings of the text.
6855
/// </summary>
6956
/// <param name="text"></param>
7057
/// <param name="addBos">Add bos to the text.</param>
58+
/// <param name="cancellationToken"></param>
7159
/// <returns></returns>
7260
/// <exception cref="RuntimeError"></exception>
73-
public float[] GetEmbeddings(string text, bool addBos)
61+
public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
7462
{
75-
var embed_inp_array = Context.Tokenize(text, addBos);
63+
var tokens = Context.Tokenize(text, addBos);
64+
if (tokens.Length > Context.ContextSize)
65+
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));
66+
67+
// Evaluate prompt in batch-size chunks
68+
var n_past = 0;
69+
var batch = new LLamaBatch();
70+
var batchSize = (int)Context.Params.BatchSize;
71+
for (var i = 0; i < tokens.Length; i += batchSize)
72+
{
73+
var n_eval = tokens.Length - i;
74+
if (n_eval > batchSize)
75+
n_eval = batchSize;
76+
77+
batch.Clear();
78+
for (var j = 0; j < n_eval; j++)
79+
batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false);
80+
81+
var returnCode = await Context.DecodeAsync(batch, cancellationToken);
82+
if (returnCode != 0)
83+
throw new LLamaDecodeError(returnCode);
84+
}
7685

77-
// TODO(Rinne): deal with log of prompt
86+
var embeddings = GetEmbeddingsArray();
7887

79-
if (embed_inp_array.Length > 0)
80-
Context.Eval(embed_inp_array.AsSpan(), 0);
88+
// Remove everything we just evaluated from the context cache
89+
Context.NativeHandle.KvCacheClear();
8190

82-
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
83-
if (embeddings == null)
84-
return Array.Empty<float>();
91+
return embeddings;
8592

86-
return embeddings.ToArray();
93+
float[] GetEmbeddingsArray()
94+
{
95+
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
96+
if (embeddings == null)
97+
return Array.Empty<float>();
98+
return embeddings.ToArray();
99+
}
87100
}
88101

89102
/// <summary>

0 commit comments

Comments
 (0)