|
3 | 3 | using LLama.Exceptions; |
4 | 4 | using LLama.Abstractions; |
5 | 5 | using Microsoft.Extensions.Logging; |
| 6 | +using System.Threading; |
| 7 | +using System.Threading.Tasks; |
6 | 8 |
|
7 | 9 | namespace LLama |
8 | 10 | { |
@@ -40,50 +42,61 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg |
40 | 42 | /// Get the embeddings of the text. |
41 | 43 | /// </summary> |
42 | 44 | /// <param name="text"></param> |
43 | | - /// <param name="threads">unused</param> |
44 | | - /// <param name="addBos">Add bos to the text.</param> |
45 | | - /// <param name="encoding">unused</param> |
| 45 | + /// <param name="cancellationToken"></param> |
46 | 46 | /// <returns></returns> |
47 | 47 | /// <exception cref="RuntimeError"></exception> |
48 | | - [Obsolete("'threads' and 'encoding' parameters are no longer used")] |
49 | | - // ReSharper disable once MethodOverloadWithOptionalParameter |
50 | | - public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") |
| 48 | + public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default) |
51 | 49 | { |
52 | | - return GetEmbeddings(text, addBos); |
53 | | - } |
54 | | - |
55 | | - /// <summary> |
56 | | - /// Get the embeddings of the text. |
57 | | - /// </summary> |
58 | | - /// <param name="text"></param> |
59 | | - /// <returns></returns> |
60 | | - /// <exception cref="RuntimeError"></exception> |
61 | | - public float[] GetEmbeddings(string text) |
62 | | - { |
63 | | - return GetEmbeddings(text, true); |
| 50 | + return GetEmbeddings(text, true, cancellationToken); |
64 | 51 | } |
65 | 52 |
|
66 | 53 | /// <summary> |
67 | 54 | /// Get the embeddings of the text. |
68 | 55 | /// </summary> |
69 | 56 | /// <param name="text"></param> |
70 | 57 | /// <param name="addBos">Add bos to the text.</param> |
| 58 | + /// <param name="cancellationToken"></param> |
71 | 59 | /// <returns></returns> |
72 | 60 | /// <exception cref="RuntimeError"></exception> |
73 | | - public float[] GetEmbeddings(string text, bool addBos) |
| 61 | + public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default) |
74 | 62 | { |
75 | | - var embed_inp_array = Context.Tokenize(text, addBos); |
| 63 | + var tokens = Context.Tokenize(text, addBos); |
| 64 | + if (tokens.Length > Context.ContextSize) |
| 65 | + throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text)); |
| 66 | + |
| 67 | + // Evaluate prompt in batch-size chunks |
| 68 | + var n_past = 0; |
| 69 | + var batch = new LLamaBatch(); |
| 70 | + var batchSize = (int)Context.Params.BatchSize; |
| 71 | + for (var i = 0; i < tokens.Length; i += batchSize) |
| 72 | + { |
| 73 | + var n_eval = tokens.Length - i; |
| 74 | + if (n_eval > batchSize) |
| 75 | + n_eval = batchSize; |
| 76 | + |
| 77 | + batch.Clear(); |
| 78 | + for (var j = 0; j < n_eval; j++) |
| 79 | + batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false); |
| 80 | + |
| 81 | + var returnCode = await Context.DecodeAsync(batch, cancellationToken); |
| 82 | + if (returnCode != 0) |
| 83 | + throw new LLamaDecodeError(returnCode); |
| 84 | + } |
76 | 85 |
|
77 | | - // TODO(Rinne): deal with log of prompt |
| 86 | + var embeddings = GetEmbeddingsArray(); |
78 | 87 |
|
79 | | - if (embed_inp_array.Length > 0) |
80 | | - Context.Eval(embed_inp_array.AsSpan(), 0); |
| 88 | + // Remove everything we just evaluated from the context cache |
| 89 | + Context.NativeHandle.KvCacheClear(); |
81 | 90 |
|
82 | | - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); |
83 | | - if (embeddings == null) |
84 | | - return Array.Empty<float>(); |
| 91 | + return embeddings; |
85 | 92 |
|
86 | | - return embeddings.ToArray(); |
| 93 | + float[] GetEmbeddingsArray() |
| 94 | + { |
| 95 | + var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); |
| 96 | + if (embeddings == null) |
| 97 | + return Array.Empty<float>(); |
| 98 | + return embeddings.ToArray(); |
| 99 | + } |
87 | 100 | } |
88 | 101 |
|
89 | 102 | /// <summary> |
|
0 commit comments